aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthias Springer <me@m-sp.org>2025-08-03 13:13:24 +0000
committerMatthias Springer <me@m-sp.org>2025-08-03 13:13:24 +0000
commit9f4a45d332b6b839c7d660222156131c957188ec (patch)
treeb1a03be830832f2e22ad4ea263c173d9e78fa025
parent3b7dd9f48fe6127b1ec41d02f1deafac6d0b5efc (diff)
downloadllvm-users/matthias-springer/dialect_conversion_cse.zip
llvm-users/matthias-springer/dialect_conversion_cse.tar.gz
llvm-users/matthias-springer/dialect_conversion_cse.tar.bz2
-rw-r--r--mlir/lib/Transforms/Utils/DialectConversion.cpp44
1 files changed, 44 insertions, 0 deletions
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index c48043b..8008958 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -3180,6 +3180,49 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter,
return failure();
}
+static SmallVector<UnrealizedConversionCastOp>
+cseUnrealizedCasts(SmallVectorImpl<UnrealizedConversionCastOp> &castOps) {
+ SmallVector<UnrealizedConversionCastOp> result;
+ DominanceInfo domInfo;
+ DenseMap<unsigned, SmallVector<UnrealizedConversionCastOp>> hashedOps;
+ for (UnrealizedConversionCastOp castOp : castOps) {
+ unsigned hash = 0;
+ for (Type type : castOp.getResultTypes())
+ hash ^= hash_value(type);
+ for (Value value : castOp.getInputs())
+ hash ^= hash_value(value);
+ hashedOps[hash].push_back(castOp);
+ }
+ // TODO: This should run to a fixed point.
+ DenseSet<UnrealizedConversionCastOp> erasedOps;
+ for (auto &it : hashedOps) {
+ SmallVector<UnrealizedConversionCastOp> &ops = it.second;
+ if (ops.size() == 1)
+ continue;
+ UnrealizedConversionCastOp top = ops.front();
+ for (UnrealizedConversionCastOp castOp : llvm::drop_begin(ops)) {
+ if (castOp.getInputs() != top.getInputs())
+ continue;
+ if (castOp.getResultTypes() != top.getResultTypes())
+ continue;
+ if (domInfo.dominates(castOp, top)) {
+ std::swap(top, castOp);
+ }
+ if (domInfo.properlyDominates(top, castOp)) {
+ castOp.replaceAllUsesWith(top);
+ castOp.erase();
+ erasedOps.insert(castOp);
+ continue;
+ }
+ }
+ }
+
+ for (UnrealizedConversionCastOp castOp : castOps)
+ if (!erasedOps.contains(castOp))
+ result.push_back(castOp);
+ return result;
+}
+
LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
assert(!ops.empty() && "expected at least one operation");
const ConversionTarget &target = opLegalizer.getTarget();
@@ -3233,6 +3276,7 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
// patterns.)
SmallVector<UnrealizedConversionCastOp> remainingCastOps;
reconcileUnrealizedCasts(allCastOps, &remainingCastOps);
+ remainingCastOps = cseUnrealizedCasts(remainingCastOps);
// Drop markers.
for (UnrealizedConversionCastOp castOp : remainingCastOps)