diff options
Diffstat (limited to 'mlir/lib/Transforms')
-rw-r--r-- | mlir/lib/Transforms/RemoveDeadValues.cpp | 1 | ||||
-rw-r--r-- | mlir/lib/Transforms/Utils/DialectConversion.cpp | 221 | ||||
-rw-r--r-- | mlir/lib/Transforms/Utils/Inliner.cpp | 33 |
3 files changed, 183 insertions, 72 deletions
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp index 5650de2..4ccb83f 100644 --- a/mlir/lib/Transforms/RemoveDeadValues.cpp +++ b/mlir/lib/Transforms/RemoveDeadValues.cpp @@ -60,7 +60,6 @@ #include <vector> #define DEBUG_TYPE "remove-dead-values" -#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") namespace mlir { #define GEN_PASS_DEF_REMOVEDEADVALUES diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 466718b..08803e0 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -17,6 +17,7 @@ #include "mlir/IR/Operation.h" #include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Rewrite/PatternApplicator.h" +#include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/Debug.h" #include "llvm/Support/FormatVariadic.h" @@ -131,11 +132,6 @@ struct ConversionValueMapping { /// value. ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {}) const; - /// Lookup the given value within the map, or return an empty vector if the - /// value is not mapped. If it is mapped, this follows the same behavior - /// as `lookupOrDefault`. - ValueVector lookupOrNull(Value from, TypeRange desiredTypes = {}) const; - template <typename T> struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {}; @@ -238,15 +234,6 @@ ConversionValueMapping::lookupOrDefault(Value from, return !desiredValue.empty() ? std::move(desiredValue) : std::move(current); } -ValueVector ConversionValueMapping::lookupOrNull(Value from, - TypeRange desiredTypes) const { - ValueVector result = lookupOrDefault(from, desiredTypes); - if (result == ValueVector{from} || - (!desiredTypes.empty() && TypeRange(ValueRange(result)) != desiredTypes)) - return {}; - return result; -} - //===----------------------------------------------------------------------===// // Rewriter and Translation State //===----------------------------------------------------------------------===// @@ -522,9 +509,11 @@ private: class MoveBlockRewrite : public BlockRewrite { public: MoveBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block, - Region *region, Block *insertBeforeBlock) - : BlockRewrite(Kind::MoveBlock, rewriterImpl, block), region(region), - insertBeforeBlock(insertBeforeBlock) {} + Region *previousRegion, Region::iterator previousIt) + : BlockRewrite(Kind::MoveBlock, rewriterImpl, block), + region(previousRegion), + insertBeforeBlock(previousIt == previousRegion->end() ? nullptr + : &*previousIt) {} static bool classof(const IRRewrite *rewrite) { return rewrite->getKind() == Kind::MoveBlock; @@ -631,9 +620,12 @@ protected: class MoveOperationRewrite : public OperationRewrite { public: MoveOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl, - Operation *op, Block *block, Operation *insertBeforeOp) - : OperationRewrite(Kind::MoveOperation, rewriterImpl, op), block(block), - insertBeforeOp(insertBeforeOp) {} + Operation *op, OpBuilder::InsertPoint previous) + : OperationRewrite(Kind::MoveOperation, rewriterImpl, op), + block(previous.getBlock()), + insertBeforeOp(previous.getPoint() == previous.getBlock()->end() + ? nullptr + : &*previous.getPoint()) {} static bool classof(const IRRewrite *rewrite) { return rewrite->getKind() == Kind::MoveOperation; @@ -927,6 +919,23 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// Return "true" if the given operation was replaced or erased. bool wasOpReplaced(Operation *op) const; + /// Lookup the most recently mapped values with the desired types in the + /// mapping. + /// + /// Special cases: + /// - If the desired type range is empty, simply return the most recently + /// mapped values. + /// - If there is no mapping to the desired types, also return the most + /// recently mapped values. + /// - If there is no mapping for the given values at all, return the given + /// value. + ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {}) const; + + /// Lookup the given value within the map, or return an empty vector if the + /// value is not mapped. If it is mapped, this follows the same behavior + /// as `lookupOrDefault`. + ValueVector lookupOrNull(Value from, TypeRange desiredTypes = {}) const; + //===--------------------------------------------------------------------===// // IR Rewrites / Type Conversion //===--------------------------------------------------------------------===// @@ -1249,6 +1258,22 @@ void ConversionPatternRewriterImpl::applyRewrites() { // State Management //===----------------------------------------------------------------------===// +ValueVector +ConversionPatternRewriterImpl::lookupOrDefault(Value from, + TypeRange desiredTypes) const { + return mapping.lookupOrDefault(from, desiredTypes); +} + +ValueVector +ConversionPatternRewriterImpl::lookupOrNull(Value from, + TypeRange desiredTypes) const { + ValueVector result = lookupOrDefault(from, desiredTypes); + if (result == ValueVector{from} || + (!desiredTypes.empty() && TypeRange(ValueRange(result)) != desiredTypes)) + return {}; + return result; +} + RewriterState ConversionPatternRewriterImpl::getCurrentState() { return RewriterState(rewrites.size(), ignoredOps.size(), replacedOps.size()); } @@ -1296,7 +1321,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( // The current pattern does not have a type converter. I.e., it does not // distinguish between legal and illegal types. For each operand, simply // pass through the most recently mapped values. - remapped.push_back(mapping.lookupOrDefault(operand)); + remapped.push_back(lookupOrDefault(operand)); continue; } @@ -1315,7 +1340,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( continue; } - ValueVector repl = mapping.lookupOrDefault(operand, legalTypes); + ValueVector repl = lookupOrDefault(operand, legalTypes); if (!repl.empty() && TypeRange(ValueRange(repl)) == legalTypes) { // Mapped values have the correct type or there is an existing // materialization. Or the operand is not mapped at all and has the @@ -1325,7 +1350,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( } // Create a materialization for the most recently mapped values. - repl = mapping.lookupOrDefault(operand); + repl = lookupOrDefault(operand); ValueRange castValues = buildUnresolvedMaterialization( MaterializationKind::Target, computeInsertPoint(repl), operandLoc, /*valuesToMap=*/repl, /*inputs=*/repl, /*outputTypes=*/legalTypes, @@ -1520,7 +1545,7 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue( // Try to find a replacement value with the same type in the conversion value // mapping. This includes cached materializations. We try to reuse those // instead of generating duplicate IR. - ValueVector repl = mapping.lookupOrNull(value, value.getType()); + ValueVector repl = lookupOrNull(value, value.getType()); if (!repl.empty()) return repl.front(); @@ -1536,7 +1561,7 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue( // No replacement value was found. Get the latest replacement value // (regardless of the type) and build a source materialization to the // original type. - repl = mapping.lookupOrNull(value); + repl = lookupOrNull(value); if (repl.empty()) { // No replacement value is registered in the mapping. This means that the // value is dropped and no longer needed. (If the value were still needed, @@ -1569,23 +1594,30 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue( void ConversionPatternRewriterImpl::notifyOperationInserted( Operation *op, OpBuilder::InsertPoint previous) { + // If no previous insertion point is provided, the op used to be detached. + bool wasDetached = !previous.isSet(); LLVM_DEBUG({ - logger.startLine() << "** Insert : '" << op->getName() << "'(" << op - << ")\n"; + logger.startLine() << "** Insert : '" << op->getName() << "' (" << op + << ")"; + if (wasDetached) + logger.getOStream() << " (was detached)"; + logger.getOStream() << "\n"; }); assert(!wasOpReplaced(op->getParentOp()) && "attempting to insert into a block within a replaced/erased op"); - if (!previous.isSet()) { - // This is a newly created op. + if (wasDetached) { + // If the op was detached, it is most likely a newly created op. + // TODO: If the same op is inserted multiple times from a detached state, + // the rollback mechanism may erase the same op multiple times. This is a + // bug in the rollback-based dialect conversion driver. appendRewrite<CreateOperationRewrite>(op); patternNewOps.insert(op); return; } - Operation *prevOp = previous.getPoint() == previous.getBlock()->end() - ? nullptr - : &*previous.getPoint(); - appendRewrite<MoveOperationRewrite>(op, previous.getBlock(), prevOp); + + // The op was moved from one place to another. + appendRewrite<MoveOperationRewrite>(op, previous); } void ConversionPatternRewriterImpl::replaceOp( @@ -1650,29 +1682,40 @@ void ConversionPatternRewriterImpl::eraseBlock(Block *block) { void ConversionPatternRewriterImpl::notifyBlockInserted( Block *block, Region *previous, Region::iterator previousIt) { - assert(!wasOpReplaced(block->getParentOp()) && - "attempting to insert into a region within a replaced/erased op"); + // If no previous insertion point is provided, the block used to be detached. + bool wasDetached = !previous; + Operation *newParentOp = block->getParentOp(); LLVM_DEBUG( { - Operation *parent = block->getParentOp(); + Operation *parent = newParentOp; if (parent) { logger.startLine() << "** Insert Block into : '" << parent->getName() - << "'(" << parent << ")\n"; + << "' (" << parent << ")"; } else { logger.startLine() - << "** Insert Block into detached Region (nullptr parent op)'\n"; + << "** Insert Block into detached Region (nullptr parent op)"; } + if (wasDetached) + logger.getOStream() << " (was detached)"; + logger.getOStream() << "\n"; }); + assert(!wasOpReplaced(newParentOp) && + "attempting to insert into a region within a replaced/erased op"); + (void)newParentOp; patternInsertedBlocks.insert(block); - if (!previous) { - // This is a newly created block. + if (wasDetached) { + // If the block was detached, it is most likely a newly created block. + // TODO: If the same block is inserted multiple times from a detached state, + // the rollback mechanism may erase the same block multiple times. This is a + // bug in the rollback-based dialect conversion driver. appendRewrite<CreateBlockRewrite>(block); return; } - Block *prevBlock = previousIt == previous->end() ? nullptr : &*previousIt; - appendRewrite<MoveBlockRewrite>(block, previous, prevBlock); + + // The block was moved from one place to another. + appendRewrite<MoveBlockRewrite>(block, previous, previousIt); } void ConversionPatternRewriterImpl::inlineBlockBefore(Block *source, @@ -1717,6 +1760,12 @@ void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) { impl->logger.startLine() << "** Replace : '" << op->getName() << "'(" << op << ")\n"; }); + + // If the current insertion point is before the erased operation, we adjust + // the insertion point to be after the operation. + if (getInsertionPoint() == op->getIterator()) + setInsertionPointAfter(op); + SmallVector<SmallVector<Value>> newVals = llvm::map_to_vector(newValues, [](Value v) -> SmallVector<Value> { return v ? SmallVector<Value>{v} : SmallVector<Value>(); @@ -1732,6 +1781,12 @@ void ConversionPatternRewriter::replaceOpWithMultiple( impl->logger.startLine() << "** Replace : '" << op->getName() << "'(" << op << ")\n"; }); + + // If the current insertion point is before the erased operation, we adjust + // the insertion point to be after the operation. + if (getInsertionPoint() == op->getIterator()) + setInsertionPointAfter(op); + impl->replaceOp(op, std::move(newValues)); } @@ -1740,6 +1795,12 @@ void ConversionPatternRewriter::eraseOp(Operation *op) { impl->logger.startLine() << "** Erase : '" << op->getName() << "'(" << op << ")\n"; }); + + // If the current insertion point is before the erased operation, we adjust + // the insertion point to be after the operation. + if (getInsertionPoint() == op->getIterator()) + setInsertionPointAfter(op); + SmallVector<SmallVector<Value>> nullRepls(op->getNumResults(), {}); impl->replaceOp(op, std::move(nullRepls)); } @@ -1846,6 +1907,11 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest, moveOpBefore(&source->front(), dest, before); } + // If the current insertion point is within the source block, adjust the + // insertion point to the destination block. + if (getInsertionBlock() == source) + setInsertionPoint(dest, getInsertionPoint()); + // Erase the source block. eraseBlock(source); } @@ -1977,6 +2043,7 @@ private: /// Legalize the resultant IR after successfully applying the given pattern. LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern, ConversionPatternRewriter &rewriter, + const RewriterState &curState, const SetVector<Operation *> &newOps, const SetVector<Operation *> &modifiedOps, const SetVector<Block *> &insertedBlocks); @@ -2174,23 +2241,39 @@ OperationLegalizer::legalizeWithFold(Operation *op, rewriterImpl.logger.startLine() << "* Fold {\n"; rewriterImpl.logger.indent(); }); - (void)rewriterImpl; + + // Clear pattern state, so that the next pattern application starts with a + // clean slate. (The op/block sets are populated by listener notifications.) + auto cleanup = llvm::make_scope_exit([&]() { + rewriterImpl.patternNewOps.clear(); + rewriterImpl.patternModifiedOps.clear(); + rewriterImpl.patternInsertedBlocks.clear(); + }); + + // Upon failure, undo all changes made by the folder. + RewriterState curState = rewriterImpl.getCurrentState(); // Try to fold the operation. StringRef opName = op->getName().getStringRef(); SmallVector<Value, 2> replacementValues; SmallVector<Operation *, 2> newOps; rewriter.setInsertionPoint(op); + rewriter.startOpModification(op); if (failed(rewriter.tryFold(op, replacementValues, &newOps))) { LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold")); + rewriter.cancelOpModification(op); return failure(); } + rewriter.finalizeOpModification(op); // An empty list of replacement values indicates that the fold was in-place. // As the operation changed, a new legalization needs to be attempted. if (replacementValues.empty()) return legalize(op, rewriter); + // Insert a replacement for 'op' with the folded replacement values. + rewriter.replaceOp(op, replacementValues); + // Recursively legalize any new constant operations. for (Operation *newOp : newOps) { if (failed(legalize(newOp, rewriter))) { @@ -2203,16 +2286,12 @@ OperationLegalizer::legalizeWithFold(Operation *op, "op '" + opName + "' folder rollback of IR modifications requested"); } - // Legalization failed: erase all materialized constants. - for (Operation *op : newOps) - rewriter.eraseOp(op); + rewriterImpl.resetState( + curState, std::string(op->getName().getStringRef()) + " folder"); return failure(); } } - // Insert a replacement for 'op' with the folded replacement values. - rewriter.replaceOp(op, replacementValues); - LLVM_DEBUG(logSuccess(rewriterImpl.logger, "")); return success(); } @@ -2222,6 +2301,32 @@ OperationLegalizer::legalizeWithPattern(Operation *op, ConversionPatternRewriter &rewriter) { auto &rewriterImpl = rewriter.getImpl(); +#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + Operation *checkOp; + std::optional<OperationFingerPrint> topLevelFingerPrint; + if (!rewriterImpl.config.allowPatternRollback) { + // The op may be getting erased, so we have to check the parent op. + // (In rare cases, a pattern may even erase the parent op, which will cause + // a crash here. Expensive checks are "best effort".) Skip the check if the + // op does not have a parent op. + if ((checkOp = op->getParentOp())) { + if (!op->getContext()->isMultithreadingEnabled()) { + topLevelFingerPrint = OperationFingerPrint(checkOp); + } else { + // Another thread may be modifying a sibling operation. Therefore, the + // fingerprinting mechanism of the parent op works only in + // single-threaded mode. + LLVM_DEBUG({ + rewriterImpl.logger.startLine() + << "WARNING: Multi-threadeding is enabled. Some dialect " + "conversion expensive checks are skipped in multithreading " + "mode!\n"; + }); + } + } + } +#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + // Functor that returns if the given pattern may be applied. auto canApply = [&](const Pattern &pattern) { bool canApply = canApplyPattern(op, pattern, rewriter); @@ -2234,6 +2339,17 @@ OperationLegalizer::legalizeWithPattern(Operation *op, RewriterState curState = rewriterImpl.getCurrentState(); auto onFailure = [&](const Pattern &pattern) { assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates"); +#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + if (!rewriterImpl.config.allowPatternRollback) { + // Returning "failure" after modifying IR is not allowed. + if (checkOp) { + OperationFingerPrint fingerPrintAfterPattern(checkOp); + if (fingerPrintAfterPattern != *topLevelFingerPrint) + llvm::report_fatal_error("pattern '" + pattern.getDebugName() + + "' returned failure but IR did change"); + } + } +#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS rewriterImpl.patternNewOps.clear(); rewriterImpl.patternModifiedOps.clear(); rewriterImpl.patternInsertedBlocks.clear(); @@ -2262,7 +2378,7 @@ OperationLegalizer::legalizeWithPattern(Operation *op, moveAndReset(rewriterImpl.patternModifiedOps); SetVector<Block *> insertedBlocks = moveAndReset(rewriterImpl.patternInsertedBlocks); - auto result = legalizePatternResult(op, pattern, rewriter, newOps, + auto result = legalizePatternResult(op, pattern, rewriter, curState, newOps, modifiedOps, insertedBlocks); appliedPatterns.erase(&pattern); if (failed(result)) { @@ -2305,7 +2421,7 @@ bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern, LogicalResult OperationLegalizer::legalizePatternResult( Operation *op, const Pattern &pattern, ConversionPatternRewriter &rewriter, - const SetVector<Operation *> &newOps, + const RewriterState &curState, const SetVector<Operation *> &newOps, const SetVector<Operation *> &modifiedOps, const SetVector<Block *> &insertedBlocks) { auto &impl = rewriter.getImpl(); @@ -2321,7 +2437,8 @@ LogicalResult OperationLegalizer::legalizePatternResult( return hasRewrite<ModifyOperationRewrite>(newRewrites, op); }; if (!replacedRoot() && !updatedRootInPlace()) - llvm::report_fatal_error("expected pattern to replace the root operation"); + llvm::report_fatal_error( + "expected pattern to replace the root operation or modify it in place"); #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS // Legalize each of the actions registered during application. diff --git a/mlir/lib/Transforms/Utils/Inliner.cpp b/mlir/lib/Transforms/Utils/Inliner.cpp index b639e87f..26c965c 100644 --- a/mlir/lib/Transforms/Utils/Inliner.cpp +++ b/mlir/lib/Transforms/Utils/Inliner.cpp @@ -21,7 +21,7 @@ #include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/SCCIterator.h" #include "llvm/ADT/STLExtras.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #define DEBUG_TYPE "inlining" @@ -348,13 +348,11 @@ static void collectCallOps(iterator_range<Region::iterator> blocks, // InlinerInterfaceImpl //===----------------------------------------------------------------------===// -#ifndef NDEBUG static std::string getNodeName(CallOpInterface op) { if (llvm::dyn_cast_if_present<SymbolRefAttr>(op.getCallableForCallee())) return debugString(op); return "_unnamed_callee_"; } -#endif /// Return true if the specified `inlineHistoryID` indicates an inline history /// that already includes `node`. @@ -614,10 +612,10 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface, std::vector<InlineHistoryT> callHistory(calls.size(), InlineHistoryT{}); LLVM_DEBUG({ - llvm::dbgs() << "* Inliner: Initial calls in SCC are: {\n"; + LDBG() << "* Inliner: Initial calls in SCC are: {"; for (unsigned i = 0, e = calls.size(); i < e; ++i) - llvm::dbgs() << " " << i << ". " << calls[i].call << ",\n"; - llvm::dbgs() << "}\n"; + LDBG() << " " << i << ". " << calls[i].call << ","; + LDBG() << "}"; }); // Try to inline each of the call operations. Don't cache the end iterator @@ -635,9 +633,9 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface, CallOpInterface call = it.call; LLVM_DEBUG({ if (doInline) - llvm::dbgs() << "* Inlining call: " << i << ". " << call << "\n"; + LDBG() << "* Inlining call: " << i << ". " << call; else - llvm::dbgs() << "* Not inlining call: " << i << ". " << call << "\n"; + LDBG() << "* Not inlining call: " << i << ". " << call; }); if (!doInline) continue; @@ -654,7 +652,7 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface, cast<CallableOpInterface>(targetRegion->getParentOp()), targetRegion, /*shouldCloneInlinedRegion=*/!inlineInPlace); if (failed(inlineResult)) { - LLVM_DEBUG(llvm::dbgs() << "** Failed to inline\n"); + LDBG() << "** Failed to inline"; continue; } inlinedAnyCalls = true; @@ -667,19 +665,16 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface, auto historyToString = [](InlineHistoryT h) { return h.has_value() ? std::to_string(*h) : "root"; }; - (void)historyToString; - LLVM_DEBUG(llvm::dbgs() - << "* new inlineHistory entry: " << newInlineHistoryID << ". [" - << getNodeName(call) << ", " << historyToString(inlineHistoryID) - << "]\n"); + LDBG() << "* new inlineHistory entry: " << newInlineHistoryID << ". [" + << getNodeName(call) << ", " << historyToString(inlineHistoryID) + << "]"; for (unsigned k = prevSize; k != calls.size(); ++k) { callHistory.push_back(newInlineHistoryID); - LLVM_DEBUG(llvm::dbgs() << "* new call " << k << " {" << calls[i].call - << "}\n with historyID = " << newInlineHistoryID - << ", added due to inlining of\n call {" << call - << "}\n with historyID = " - << historyToString(inlineHistoryID) << "\n"); + LDBG() << "* new call " << k << " {" << calls[k].call + << "}\n with historyID = " << newInlineHistoryID + << ", added due to inlining of\n call {" << call + << "}\n with historyID = " << historyToString(inlineHistoryID); } // If the inlining was successful, Merge the new uses into the source node. |