diff options
author | Matthias Springer <me@m-sp.org> | 2024-09-13 20:16:05 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-09-13 20:16:05 +0200 |
commit | d588e49a324b3d6039c19f3108d722a8b9fcd96e (patch) | |
tree | 4a99003b3ba7fa1c7f9dbd5d96fe0cb44e0449cd | |
parent | d0e7714de73b8b657dca1706e676027d42bbb775 (diff) | |
download | llvm-d588e49a324b3d6039c19f3108d722a8b9fcd96e.zip llvm-d588e49a324b3d6039c19f3108d722a8b9fcd96e.tar.gz llvm-d588e49a324b3d6039c19f3108d722a8b9fcd96e.tar.bz2 |
[mlir][Transforms][NFC] Dialect conversion: Cache `UnresolvedMaterializationRewrite` (#108359)
The dialect conversion maintains a set of unresolved materializations
(`UnrealizedConversionCastOp`). Turn that set into a `DenseMap` that
maps from ops to `UnresolvedMaterializationRewrite *`. This improves
efficiency a bit, because an iteration over
`ConversionPatternRewriterImpl::rewrites` can be avoided.
Also delete some dead code.
-rw-r--r-- | mlir/lib/Transforms/Utils/DialectConversion.cpp | 72 |
1 files changed, 27 insertions, 45 deletions
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index b58a95c..caea9e1 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -688,9 +688,7 @@ public: UnresolvedMaterializationRewrite( ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op, const TypeConverter *converter = nullptr, - MaterializationKind kind = MaterializationKind::Target) - : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op), - converterAndKind(converter, kind) {} + MaterializationKind kind = MaterializationKind::Target); static bool classof(const IRRewrite *rewrite) { return rewrite->getKind() == Kind::UnresolvedMaterialization; @@ -730,26 +728,6 @@ static bool hasRewrite(R &&rewrites, Operation *op) { }); } -/// Find the single rewrite object of the specified type and block among the -/// given rewrites. In debug mode, asserts that there is mo more than one such -/// object. Return "nullptr" if no object was found. -template <typename RewriteTy, typename R> -static RewriteTy *findSingleRewrite(R &&rewrites, Block *block) { - RewriteTy *result = nullptr; - for (auto &rewrite : rewrites) { - auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get()); - if (rewriteTy && rewriteTy->getBlock() == block) { -#ifndef NDEBUG - assert(!result && "expected single matching rewrite"); - result = rewriteTy; -#else - return rewriteTy; -#endif // NDEBUG - } - } - return result; -} - //===----------------------------------------------------------------------===// // ConversionPatternRewriterImpl //===----------------------------------------------------------------------===// @@ -892,10 +870,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { bool wasErased(void *ptr) const { return erased.contains(ptr); } - bool wasErased(OperationRewrite *rewrite) const { - return wasErased(rewrite->getOperation()); - } - void notifyOperationErased(Operation *op) override { erased.insert(op); } void notifyBlockErased(Block *block) override { erased.insert(block); } @@ -935,8 +909,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// to modify/access them is invalid rewriter API usage. SetVector<Operation *> replacedOps; - /// A set of all unresolved materializations. - DenseSet<Operation *> unresolvedMaterializations; + /// A mapping of all unresolved materializations (UnrealizedConversionCastOp) + /// to the corresponding rewrite objects. + DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *> + unresolvedMaterializations; /// The current type converter, or nullptr if no type converter is currently /// active. @@ -1058,12 +1034,20 @@ void CreateOperationRewrite::rollback() { op->erase(); } +UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite( + ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op, + const TypeConverter *converter, MaterializationKind kind) + : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op), + converterAndKind(converter, kind) { + rewriterImpl.unresolvedMaterializations[op] = this; +} + void UnresolvedMaterializationRewrite::rollback() { if (getMaterializationKind() == MaterializationKind::Target) { for (Value input : op->getOperands()) rewriterImpl.mapping.erase(input); } - rewriterImpl.unresolvedMaterializations.erase(op); + rewriterImpl.unresolvedMaterializations.erase(getOperation()); op->erase(); } @@ -1345,7 +1329,6 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization( builder.setInsertionPoint(ip.getBlock(), ip.getPoint()); auto convertOp = builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs); - unresolvedMaterializations.insert(convertOp); appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind); return convertOp.getResult(0); } @@ -1382,10 +1365,12 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op, for (auto [newValue, result] : llvm::zip(newValues, op->getResults())) { if (!newValue) { // This result was dropped and no replacement value was provided. - if (unresolvedMaterializations.contains(op)) { - // Do not create another materializations if we are erasing a - // materialization. - continue; + if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) { + if (unresolvedMaterializations.contains(castOp)) { + // Do not create another materializations if we are erasing a + // materialization. + continue; + } } // Materialize a replacement value "out of thin air". @@ -2499,15 +2484,12 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) { // Gather all unresolved materializations. SmallVector<UnrealizedConversionCastOp> allCastOps; - DenseMap<Operation *, UnresolvedMaterializationRewrite *> rewriteMap; - for (std::unique_ptr<IRRewrite> &rewrite : rewriterImpl.rewrites) { - auto *mat = dyn_cast<UnresolvedMaterializationRewrite>(rewrite.get()); - if (!mat) - continue; - if (rewriterImpl.eraseRewriter.wasErased(mat)) + const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *> + &materializations = rewriterImpl.unresolvedMaterializations; + for (auto it : materializations) { + if (rewriterImpl.eraseRewriter.wasErased(it.first)) continue; - allCastOps.push_back(mat->getOperation()); - rewriteMap[mat->getOperation()] = mat; + allCastOps.push_back(it.first); } // Reconcile all UnrealizedConversionCastOps that were inserted by the @@ -2520,8 +2502,8 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) { if (config.buildMaterializations) { IRRewriter rewriter(rewriterImpl.context, config.listener); for (UnrealizedConversionCastOp castOp : remainingCastOps) { - auto it = rewriteMap.find(castOp.getOperation()); - assert(it != rewriteMap.end() && "inconsistent state"); + auto it = materializations.find(castOp); + assert(it != materializations.end() && "inconsistent state"); if (failed(legalizeUnresolvedMaterialization(rewriter, it->second))) return failure(); } |