diff options
author | Matthias Springer <me@m-sp.org> | 2025-08-03 13:13:24 +0000 |
---|---|---|
committer | Matthias Springer <me@m-sp.org> | 2025-08-03 13:13:24 +0000 |
commit | 9f4a45d332b6b839c7d660222156131c957188ec (patch) | |
tree | b1a03be830832f2e22ad4ea263c173d9e78fa025 | |
parent | 3b7dd9f48fe6127b1ec41d02f1deafac6d0b5efc (diff) | |
download | llvm-users/matthias-springer/dialect_conversion_cse.zip llvm-users/matthias-springer/dialect_conversion_cse.tar.gz llvm-users/matthias-springer/dialect_conversion_cse.tar.bz2 |
-rw-r--r-- | mlir/lib/Transforms/Utils/DialectConversion.cpp | 44 |
1 files changed, 44 insertions, 0 deletions
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index c48043b..8008958 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -3180,6 +3180,49 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter, return failure(); } +static SmallVector<UnrealizedConversionCastOp> +cseUnrealizedCasts(SmallVectorImpl<UnrealizedConversionCastOp> &castOps) { + SmallVector<UnrealizedConversionCastOp> result; + DominanceInfo domInfo; + DenseMap<unsigned, SmallVector<UnrealizedConversionCastOp>> hashedOps; + for (UnrealizedConversionCastOp castOp : castOps) { + unsigned hash = 0; + for (Type type : castOp.getResultTypes()) + hash ^= hash_value(type); + for (Value value : castOp.getInputs()) + hash ^= hash_value(value); + hashedOps[hash].push_back(castOp); + } + // TODO: This should run to a fixed point. + DenseSet<UnrealizedConversionCastOp> erasedOps; + for (auto &it : hashedOps) { + SmallVector<UnrealizedConversionCastOp> &ops = it.second; + if (ops.size() == 1) + continue; + UnrealizedConversionCastOp top = ops.front(); + for (UnrealizedConversionCastOp castOp : llvm::drop_begin(ops)) { + if (castOp.getInputs() != top.getInputs()) + continue; + if (castOp.getResultTypes() != top.getResultTypes()) + continue; + if (domInfo.dominates(castOp, top)) { + std::swap(top, castOp); + } + if (domInfo.properlyDominates(top, castOp)) { + castOp.replaceAllUsesWith(top); + castOp.erase(); + erasedOps.insert(castOp); + continue; + } + } + } + + for (UnrealizedConversionCastOp castOp : castOps) + if (!erasedOps.contains(castOp)) + result.push_back(castOp); + return result; +} + LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) { assert(!ops.empty() && "expected at least one operation"); const ConversionTarget &target = opLegalizer.getTarget(); @@ -3233,6 +3276,7 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) { // patterns.) SmallVector<UnrealizedConversionCastOp> remainingCastOps; reconcileUnrealizedCasts(allCastOps, &remainingCastOps); + remainingCastOps = cseUnrealizedCasts(remainingCastOps); // Drop markers. for (UnrealizedConversionCastOp castOp : remainingCastOps) |