diff options
Diffstat (limited to 'mlir/lib/Transforms')
-rw-r--r-- | mlir/lib/Transforms/CSE.cpp | 12 | ||||
-rw-r--r-- | mlir/lib/Transforms/InlinerPass.cpp | 7 | ||||
-rw-r--r-- | mlir/lib/Transforms/Mem2Reg.cpp | 4 | ||||
-rw-r--r-- | mlir/lib/Transforms/RemoveDeadValues.cpp | 49 | ||||
-rw-r--r-- | mlir/lib/Transforms/SROA.cpp | 4 | ||||
-rw-r--r-- | mlir/lib/Transforms/SymbolDCE.cpp | 30 | ||||
-rw-r--r-- | mlir/lib/Transforms/Utils/ControlFlowSinkUtils.cpp | 9 | ||||
-rw-r--r-- | mlir/lib/Transforms/Utils/DialectConversion.cpp | 926 | ||||
-rw-r--r-- | mlir/lib/Transforms/Utils/FoldUtils.cpp | 7 | ||||
-rw-r--r-- | mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp | 42 | ||||
-rw-r--r-- | mlir/lib/Transforms/Utils/InliningUtils.cpp | 13 | ||||
-rw-r--r-- | mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp | 12 | ||||
-rw-r--r-- | mlir/lib/Transforms/Utils/RegionUtils.cpp | 33 | ||||
-rw-r--r-- | mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp | 126 |
14 files changed, 938 insertions, 336 deletions
diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp index 09e5a02..8eaac30 100644 --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -177,11 +177,10 @@ void CSEDriver::replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op, bool CSEDriver::hasOtherSideEffectingOpInBetween(Operation *fromOp, Operation *toOp) { assert(fromOp->getBlock() == toOp->getBlock()); - assert( - isa<MemoryEffectOpInterface>(fromOp) && - cast<MemoryEffectOpInterface>(fromOp).hasEffect<MemoryEffects::Read>() && - isa<MemoryEffectOpInterface>(toOp) && - cast<MemoryEffectOpInterface>(toOp).hasEffect<MemoryEffects::Read>()); + assert(hasEffect<MemoryEffects::Read>(fromOp) && + "expected read effect on fromOp"); + assert(hasEffect<MemoryEffects::Read>(toOp) && + "expected read effect on toOp"); Operation *nextOp = fromOp->getNextNode(); auto result = memEffectsCache.try_emplace(fromOp, std::make_pair(fromOp, nullptr)); @@ -245,11 +244,10 @@ LogicalResult CSEDriver::simplifyOperation(ScopedMapTy &knownValues, // Some simple use case of operation with memory side-effect are dealt with // here. Operations with no side-effect are done after. if (!isMemoryEffectFree(op)) { - auto memEffects = dyn_cast<MemoryEffectOpInterface>(op); // TODO: Only basic use case for operations with MemoryEffects::Read can be // eleminated now. More work needs to be done for more complicated patterns // and other side-effects. - if (!memEffects || !memEffects.onlyHasEffect<MemoryEffects::Read>()) + if (!hasSingleEffect<MemoryEffects::Read>(op)) return failure(); // Look for an existing definition for the operation. diff --git a/mlir/lib/Transforms/InlinerPass.cpp b/mlir/lib/Transforms/InlinerPass.cpp index 703e517..77a9e6c 100644 --- a/mlir/lib/Transforms/InlinerPass.cpp +++ b/mlir/lib/Transforms/InlinerPass.cpp @@ -18,6 +18,7 @@ #include "mlir/Analysis/CallGraph.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Inliner.h" +#include "llvm/Support/DebugLog.h" namespace mlir { #define GEN_PASS_DEF_INLINER @@ -120,8 +121,8 @@ static bool isProfitableToInline(const Inliner::ResolvedCall &resolvedCall, return true; unsigned ratio = countOps(calleeRegion) * 100 / callerOps; - LLVM_DEBUG(llvm::dbgs() << "Callee / caller operation ratio (max: " - << inliningThreshold << "%): " << ratio << "%\n"); + LDBG() << "Callee / caller operation ratio (max: " << inliningThreshold + << "%): " << ratio << "%"; return ratio <= inliningThreshold; } @@ -138,7 +139,7 @@ void InlinerPass::runOnOperation() { } // By default, assume that any inlining is profitable. - auto profitabilityCb = [=](const Inliner::ResolvedCall &call) { + auto profitabilityCb = [this](const Inliner::ResolvedCall &call) { return isProfitableToInline(call, inliningThreshold); }; diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp index cf039c3..d36a3c1 100644 --- a/mlir/lib/Transforms/Mem2Reg.cpp +++ b/mlir/lib/Transforms/Mem2Reg.cpp @@ -19,6 +19,7 @@ #include "mlir/Interfaces/MemorySlotInterfaces.h" #include "mlir/Transforms/Passes.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/GenericIteratedDominanceFrontier.h" namespace mlir { @@ -632,8 +633,7 @@ MemorySlotPromoter::promoteSlot() { } } - LLVM_DEBUG(llvm::dbgs() << "[mem2reg] Promoted memory slot: " << slot.ptr - << "\n"); + LDBG() << "Promoted memory slot: " << slot.ptr; if (statistics.promotedAmount) (*statistics.promotedAmount)++; diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp index 4ccb83f..0e84b6d 100644 --- a/mlir/lib/Transforms/RemoveDeadValues.cpp +++ b/mlir/lib/Transforms/RemoveDeadValues.cpp @@ -258,18 +258,17 @@ static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) { static void processSimpleOp(Operation *op, RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet, RDVFinalCleanupList &cl) { - LDBG() << "Processing simple op: " << *op; if (!isMemoryEffectFree(op) || hasLive(op->getResults(), nonLiveSet, la)) { - LDBG() - << "Simple op is not memory effect free or has live results, skipping: " - << *op; + LDBG() << "Simple op is not memory effect free or has live results, " + "preserving it: " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); return; } LDBG() << "Simple op has all dead results and is memory effect free, scheduling " "for removal: " - << *op; + << OpWithFlags(op, OpPrintingFlags().skipRegions()); cl.operations.push_back(op); collectNonLiveValues(nonLiveSet, op->getResults(), BitVector(op->getNumResults(), true)); @@ -345,8 +344,7 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module, // being returned, in order to optimize our IR. So, this demonstrates how we // can make our optimization strong by even removing a live return value (%0), // since it forwards only to non-live value(s) (%1#1). - Operation *lastReturnOp = funcOp.back().getTerminator(); - size_t numReturns = lastReturnOp->getNumOperands(); + size_t numReturns = funcOp.getNumResults(); BitVector nonLiveRets(numReturns, true); for (SymbolTable::SymbolUse use : uses) { Operation *callOp = use.getUser(); @@ -728,19 +726,31 @@ static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la, /// Removes dead values collected in RDVFinalCleanupList. /// To be run once when all dead values have been collected. static void cleanUpDeadVals(RDVFinalCleanupList &list) { + LDBG() << "Starting cleanup of dead values..."; + // 1. Operations + LDBG() << "Cleaning up " << list.operations.size() << " operations"; for (auto &op : list.operations) { + LDBG() << "Erasing operation: " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); op->dropAllUses(); op->erase(); } // 2. Values + LDBG() << "Cleaning up " << list.values.size() << " values"; for (auto &v : list.values) { + LDBG() << "Dropping all uses of value: " << v; v.dropAllUses(); } // 3. Functions + LDBG() << "Cleaning up " << list.functions.size() << " functions"; for (auto &f : list.functions) { + LDBG() << "Cleaning up function: " << f.funcOp.getOperation()->getName(); + LDBG() << " Erasing " << f.nonLiveArgs.count() << " non-live arguments"; + LDBG() << " Erasing " << f.nonLiveRets.count() + << " non-live return values"; // Some functions may not allow erasing arguments or results. These calls // return failure in such cases without modifying the function, so it's okay // to proceed. @@ -749,44 +759,67 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) { } // 4. Operands + LDBG() << "Cleaning up " << list.operands.size() << " operand lists"; for (OperationToCleanup &o : list.operands) { - if (o.op->getNumOperands() > 0) + if (o.op->getNumOperands() > 0) { + LDBG() << "Erasing " << o.nonLive.count() + << " non-live operands from operation: " + << OpWithFlags(o.op, OpPrintingFlags().skipRegions()); o.op->eraseOperands(o.nonLive); + } } // 5. Results + LDBG() << "Cleaning up " << list.results.size() << " result lists"; for (auto &r : list.results) { + LDBG() << "Erasing " << r.nonLive.count() + << " non-live results from operation: " + << OpWithFlags(r.op, OpPrintingFlags().skipRegions()); dropUsesAndEraseResults(r.op, r.nonLive); } // 6. Blocks + LDBG() << "Cleaning up " << list.blocks.size() << " block argument lists"; for (auto &b : list.blocks) { // blocks that are accessed via multiple codepaths processed once if (b.b->getNumArguments() != b.nonLiveArgs.size()) continue; + LDBG() << "Erasing " << b.nonLiveArgs.count() + << " non-live arguments from block: " << b.b; // it iterates backwards because erase invalidates all successor indexes for (int i = b.nonLiveArgs.size() - 1; i >= 0; --i) { if (!b.nonLiveArgs[i]) continue; + LDBG() << " Erasing block argument " << i << ": " << b.b->getArgument(i); b.b->getArgument(i).dropAllUses(); b.b->eraseArgument(i); } } // 7. Successor Operands + LDBG() << "Cleaning up " << list.successorOperands.size() + << " successor operand lists"; for (auto &op : list.successorOperands) { SuccessorOperands successorOperands = op.branch.getSuccessorOperands(op.successorIndex); // blocks that are accessed via multiple codepaths processed once if (successorOperands.size() != op.nonLiveOperands.size()) continue; + LDBG() << "Erasing " << op.nonLiveOperands.count() + << " non-live successor operands from successor " + << op.successorIndex << " of branch: " + << OpWithFlags(op.branch, OpPrintingFlags().skipRegions()); // it iterates backwards because erase invalidates all successor indexes for (int i = successorOperands.size() - 1; i >= 0; --i) { if (!op.nonLiveOperands[i]) continue; + LDBG() << " Erasing successor operand " << i << ": " + << successorOperands[i]; successorOperands.erase(i); } } + + LDBG() << "Finished cleanup of dead values"; } struct RemoveDeadValues : public impl::RemoveDeadValuesBase<RemoveDeadValues> { diff --git a/mlir/lib/Transforms/SROA.cpp b/mlir/lib/Transforms/SROA.cpp index 67f536a..859c030 100644 --- a/mlir/lib/Transforms/SROA.cpp +++ b/mlir/lib/Transforms/SROA.cpp @@ -12,6 +12,7 @@ #include "mlir/Analysis/TopologicalSortUtils.h" #include "mlir/Interfaces/MemorySlotInterfaces.h" #include "mlir/Transforms/Passes.h" +#include "llvm/Support/DebugLog.h" namespace mlir { #define GEN_PASS_DEF_SROA @@ -180,8 +181,7 @@ static void destructureSlot( assert(slot.ptr.use_empty() && "after destructuring, the original slot " "pointer should no longer be used"); - LLVM_DEBUG(llvm::dbgs() << "[sroa] Destructured memory slot: " << slot.ptr - << "\n"); + LDBG() << "Destructured memory slot: " << slot.ptr; if (statistics.destructuredAmount) (*statistics.destructuredAmount)++; diff --git a/mlir/lib/Transforms/SymbolDCE.cpp b/mlir/lib/Transforms/SymbolDCE.cpp index 0a925c4..87885be 100644 --- a/mlir/lib/Transforms/SymbolDCE.cpp +++ b/mlir/lib/Transforms/SymbolDCE.cpp @@ -13,8 +13,11 @@ #include "mlir/Transforms/Passes.h" +#include "mlir/IR/Operation.h" #include "mlir/IR/SymbolTable.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" +#include "llvm/Support/InterleavedRange.h" namespace mlir { #define GEN_PASS_DEF_SYMBOLDCE @@ -87,8 +90,8 @@ LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp, SymbolTableCollection &symbolTable, bool symbolTableIsHidden, DenseSet<Operation *> &liveSymbols) { - LLVM_DEBUG(llvm::dbgs() << "computeLiveness: " << symbolTableOp->getName() - << "\n"); + LDBG() << "computeLiveness: " + << OpWithFlags(symbolTableOp, OpPrintingFlags().skipRegions()); // A worklist of live operations to propagate uses from. SmallVector<Operation *, 16> worklist; @@ -116,7 +119,8 @@ LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp, // consideration. while (!worklist.empty()) { Operation *op = worklist.pop_back_val(); - LLVM_DEBUG(llvm::dbgs() << "processing: " << op->getName() << "\n"); + LDBG() << "processing: " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); // If this is a symbol table, recursively compute its liveness. if (op->hasTrait<OpTrait::SymbolTable>()) { @@ -124,13 +128,14 @@ LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp, // symbol, or if it is a private symbol. SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op); bool symIsHidden = symbolTableIsHidden || !symbol || symbol.isPrivate(); - LLVM_DEBUG(llvm::dbgs() << "\tsymbol table: " << op->getName() - << " is hidden: " << symIsHidden << "\n"); + LDBG() << "\tsymbol table: " + << OpWithFlags(op, OpPrintingFlags().skipRegions()) + << " is hidden: " << symIsHidden; if (failed(computeLiveness(op, symbolTable, symIsHidden, liveSymbols))) return failure(); } else { - LLVM_DEBUG(llvm::dbgs() - << "\tnon-symbol table: " << op->getName() << "\n"); + LDBG() << "\tnon-symbol table: " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); // If the op is not a symbol table, then, unless op itself is dead which // would be handled by DCE, we need to check all the regions and blocks // within the op to find the uses (e.g., consider visibility within op as @@ -160,20 +165,17 @@ LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp, } SmallVector<Operation *, 4> resolvedSymbols; - LLVM_DEBUG(llvm::dbgs() << "uses of " << op->getName() << "\n"); + LDBG() << "uses of " << OpWithFlags(op, OpPrintingFlags().skipRegions()); for (const SymbolTable::SymbolUse &use : *uses) { - LLVM_DEBUG(llvm::dbgs() << "\tuse: " << use.getUser() << "\n"); + LDBG() << "\tuse: " << use.getUser(); // Lookup the symbols referenced by this use. resolvedSymbols.clear(); if (failed(symbolTable.lookupSymbolIn(parentOp, use.getSymbolRef(), resolvedSymbols))) // Ignore references to unknown symbols. continue; - LLVM_DEBUG({ - llvm::dbgs() << "\t\tresolved symbols: "; - llvm::interleaveComma(resolvedSymbols, llvm::dbgs()); - llvm::dbgs() << "\n"; - }); + LDBG() << "\t\tresolved symbols: " + << llvm::interleaved(resolvedSymbols, ", "); // Mark each of the resolved symbols as live. for (Operation *resolvedSymbol : resolvedSymbols) diff --git a/mlir/lib/Transforms/Utils/ControlFlowSinkUtils.cpp b/mlir/lib/Transforms/Utils/ControlFlowSinkUtils.cpp index cfd568f..19cf464 100644 --- a/mlir/lib/Transforms/Utils/ControlFlowSinkUtils.cpp +++ b/mlir/lib/Transforms/Utils/ControlFlowSinkUtils.cpp @@ -21,7 +21,10 @@ #include "mlir/Transforms/ControlFlowSinkUtils.h" #include "mlir/IR/Dominance.h" #include "mlir/IR/Matchers.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/OperationSupport.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "llvm/Support/DebugLog.h" #include <vector> #define DEBUG_TYPE "cf-sink" @@ -84,13 +87,15 @@ bool Sinker::allUsersDominatedBy(Operation *op, Region *region) { void Sinker::tryToSinkPredecessors(Operation *user, Region *region, std::vector<Operation *> &stack) { - LLVM_DEBUG(user->print(llvm::dbgs() << "\nContained op:\n")); + LDBG() << "Contained op: " + << OpWithFlags(user, OpPrintingFlags().skipRegions()); for (Value value : user->getOperands()) { Operation *op = value.getDefiningOp(); // Ignore block arguments and ops that are already inside the region. if (!op || op->getParentRegion() == region) continue; - LLVM_DEBUG(op->print(llvm::dbgs() << "\nTry to sink:\n")); + LDBG() << "Try to sink:\n" + << OpWithFlags(op, OpPrintingFlags().skipRegions()); // If the op's users are all in the region and it can be moved, then do so. if (allUsersDominatedBy(op, region) && shouldMoveIntoRegion(op, region)) { diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index f23c619..5ba109d 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -121,17 +121,8 @@ struct ConversionValueMapping { /// false positives. bool isMappedTo(Value value) const { return mappedTo.contains(value); } - /// Lookup the most recently mapped values with the desired types in the - /// mapping. - /// - /// Special cases: - /// - If the desired type range is empty, simply return the most recently - /// mapped values. - /// - If there is no mapping to the desired types, also return the most - /// recently mapped values. - /// - If there is no mapping for the given values at all, return the given - /// value. - ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {}) const; + /// Lookup a value in the mapping. + ValueVector lookup(const ValueVector &from) const; template <typename T> struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {}; @@ -185,54 +176,40 @@ private: }; } // namespace -ValueVector -ConversionValueMapping::lookupOrDefault(Value from, - TypeRange desiredTypes) const { - // Try to find the deepest values that have the desired types. If there is no - // such mapping, simply return the deepest values. - ValueVector desiredValue; - ValueVector current{from}; - do { - // Store the current value if the types match. - if (TypeRange(ValueRange(current)) == desiredTypes) - desiredValue = current; - - // If possible, Replace each value with (one or multiple) mapped values. - ValueVector next; - for (Value v : current) { - auto it = mapping.find({v}); - if (it != mapping.end()) { - llvm::append_range(next, it->second); - } else { - next.push_back(v); - } - } - if (next != current) { - // If at least one value was replaced, continue the lookup from there. - current = std::move(next); - continue; - } +/// Marker attribute for pure type conversions. I.e., mappings whose only +/// purpose is to resolve a type mismatch. (In contrast, mappings that point to +/// the replacement values of a "replaceOp" call, etc., are not pure type +/// conversions.) +static const StringRef kPureTypeConversionMarker = "__pure_type_conversion__"; + +/// Return the operation that defines all values in the vector. Return nullptr +/// if the values are not defined by the same operation. +static Operation *getCommonDefiningOp(const ValueVector &values) { + assert(!values.empty() && "expected non-empty value vector"); + Operation *op = values.front().getDefiningOp(); + for (Value v : llvm::drop_begin(values)) { + if (v.getDefiningOp() != op) + return nullptr; + } + return op; +} - // Otherwise: Check if there is a mapping for the entire vector. Such - // mappings are materializations. (N:M mapping are not supported for value - // replacements.) - // - // Note: From a correctness point of view, materializations do not have to - // be stored (and looked up) in the mapping. But for performance reasons, - // we choose to reuse existing IR (when possible) instead of creating it - // multiple times. - auto it = mapping.find(current); - if (it == mapping.end()) { - // No mapping found: The lookup stops here. - break; - } - current = it->second; - } while (true); +/// A vector of values is a pure type conversion if all values are defined by +/// the same operation and the operation has the `kPureTypeConversionMarker` +/// attribute. +static bool isPureTypeConversion(const ValueVector &values) { + assert(!values.empty() && "expected non-empty value vector"); + Operation *op = getCommonDefiningOp(values); + return op && op->hasAttr(kPureTypeConversionMarker); +} - // If the desired values were found use them, otherwise default to the leaf - // values. - // Note: If `desiredTypes` is empty, this function always returns `current`. - return !desiredValue.empty() ? std::move(desiredValue) : std::move(current); +ValueVector ConversionValueMapping::lookup(const ValueVector &from) const { + auto it = mapping.find(from); + if (it == mapping.end()) { + // No mapping found: The lookup stops here. + return {}; + } + return it->second; } //===----------------------------------------------------------------------===// @@ -871,9 +848,10 @@ static bool hasRewrite(R &&rewrites, Block *block) { namespace mlir { namespace detail { struct ConversionPatternRewriterImpl : public RewriterBase::Listener { - explicit ConversionPatternRewriterImpl(MLIRContext *ctx, + explicit ConversionPatternRewriterImpl(ConversionPatternRewriter &rewriter, const ConversionConfig &config) - : context(ctx), config(config) {} + : rewriter(rewriter), config(config), + notifyingRewriter(rewriter.getContext(), config.listener) {} //===--------------------------------------------------------------------===// // State Management @@ -895,6 +873,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// failure. template <typename RewriteTy, typename... Args> void appendRewrite(Args &&...args) { + assert(config.allowPatternRollback && "appending rewrites is not allowed"); rewrites.push_back( std::make_unique<RewriteTy>(*this, std::forward<Args>(args)...)); } @@ -909,8 +888,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// is the tag used when describing a value within a diagnostic, e.g. /// "operand". LogicalResult remapValues(StringRef valueDiagTag, - std::optional<Location> inputLoc, - PatternRewriter &rewriter, ValueRange values, + std::optional<Location> inputLoc, ValueRange values, SmallVector<ValueVector> &remapped); /// Return "true" if the given operation is ignored, and does not need to be @@ -921,16 +899,13 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { bool wasOpReplaced(Operation *op) const; /// Lookup the most recently mapped values with the desired types in the - /// mapping. + /// mapping, taking into account only replacements. Perform a best-effort + /// search for existing materializations with the desired types. /// - /// Special cases: - /// - If the desired type range is empty, simply return the most recently - /// mapped values. - /// - If there is no mapping to the desired types, also return the most - /// recently mapped values. - /// - If there is no mapping for the given values at all, return the given - /// value. - ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {}) const; + /// If `skipPureTypeConversions` is "true", materializations that are pure + /// type conversions are not considered. + ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {}, + bool skipPureTypeConversions = false) const; /// Lookup the given value within the map, or return an empty vector if the /// value is not mapped. If it is mapped, this follows the same behavior @@ -943,8 +918,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// Convert the types of block arguments within the given region. FailureOr<Block *> - convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region, - const TypeConverter &converter, + convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion); /// Apply the given signature conversion on the given block. The new block @@ -954,8 +928,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// translate between the origin argument types and those specified in the /// signature conversion. Block *applySignatureConversion( - ConversionPatternRewriter &rewriter, Block *block, - const TypeConverter *converter, + Block *block, const TypeConverter *converter, TypeConverter::SignatureConversion &signatureConversion); /// Replace the results of the given operation with the given values and @@ -993,11 +966,18 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// If `valuesToMap` is set to a non-null Value, then that value is mapped to /// the results of the unresolved materialization in the conversion value /// mapping. + /// + /// If `isPureTypeConversion` is "true", the materialization is created only + /// to resolve a type mismatch. That means it is not a regular value + /// replacement issued by the user. (Replacement values that are created + /// "out of thin air" appear like unresolved materializations because they are + /// unrealized_conversion_cast ops. However, they must be treated like + /// regular value replacements.) ValueRange buildUnresolvedMaterialization( MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes, Type originalType, const TypeConverter *converter, - UnrealizedConversionCastOp *castOp = nullptr); + bool isPureTypeConversion = true); /// Find a replacement value for the given SSA value in the conversion value /// mapping. The replacement value must have the same type as the given SSA @@ -1078,14 +1058,17 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { // State //===--------------------------------------------------------------------===// - /// MLIR context. - MLIRContext *context; + /// The rewriter that is used to perform the conversion. + ConversionPatternRewriter &rewriter; // Mapping between replaced values that differ in type. This happens when // replacing a value with one of a different type. ConversionValueMapping mapping; /// Ordered list of block operations (creations, splits, motions). + /// This vector is maintained only if `allowPatternRollback` is set to + /// "true". Otherwise, all IR rewrites are materialized immediately and no + /// bookkeeping is needed. SmallVector<std::unique_ptr<IRRewrite>> rewrites; /// A set of operations that should no longer be considered for legalization. @@ -1109,6 +1092,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// by the current pattern. SetVector<Block *> patternInsertedBlocks; + /// A list of unresolved materializations that were created by the current + /// pattern. + DenseSet<UnrealizedConversionCastOp> patternMaterializations; + /// A mapping for looking up metadata of unresolved materializations. DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo> unresolvedMaterializations; @@ -1124,15 +1111,37 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// Dialect conversion configuration. const ConversionConfig &config; + /// A set of erased operations. This set is utilized only if + /// `allowPatternRollback` is set to "false". Conceptually, this set is + /// similar to `replacedOps` (which is maintained when the flag is set to + /// "true"). However, erasing from a DenseSet is more efficient than erasing + /// from a SetVector. + DenseSet<Operation *> erasedOps; + + /// A set of erased blocks. This set is utilized only if + /// `allowPatternRollback` is set to "false". + DenseSet<Block *> erasedBlocks; + + /// A rewriter that notifies the listener (if any) about all IR + /// modifications. This rewriter is utilized only if `allowPatternRollback` + /// is set to "false". If the flag is set to "true", the listener is notified + /// with a separate mechanism (e.g., in `IRRewrite::commit`). + IRRewriter notifyingRewriter; + #ifndef NDEBUG + /// A set of replaced block arguments. This set is for debugging purposes + /// only and it is maintained only if `allowPatternRollback` is set to + /// "true". + DenseSet<BlockArgument> replacedArgs; + /// A set of operations that have pending updates. This tracking isn't /// strictly necessary, and is thus only active during debug builds for extra /// verification. SmallPtrSet<Operation *, 1> pendingRootUpdates; /// A raw output stream used to prefix the debug log. - llvm::impl::raw_ldbg_ostream os{(Twine("[") + DEBUG_TYPE + "] ").str(), - llvm::dbgs(), /*HasPendingNewline=*/false}; + llvm::impl::raw_ldbg_ostream os{(Twine("[") + DEBUG_TYPE + ":1] ").str(), + llvm::dbgs()}; /// A logger used to emit diagnostics during the conversion process. llvm::ScopedPrinter logger{os}; @@ -1160,11 +1169,8 @@ void BlockTypeConversionRewrite::rollback() { getNewBlock()->replaceAllUsesWith(getOrigBlock()); } -void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) { - Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter); - if (!repl) - return; - +static void performReplaceBlockArg(RewriterBase &rewriter, BlockArgument arg, + Value repl) { if (isa<BlockArgument>(repl)) { rewriter.replaceAllUsesWith(arg, repl); return; @@ -1181,6 +1187,13 @@ void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) { }); } +void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) { + Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter); + if (!repl) + return; + performReplaceBlockArg(rewriter, arg, repl); +} + void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase({arg}); } void ReplaceOperationRewrite::commit(RewriterBase &rewriter) { @@ -1243,16 +1256,17 @@ void UnresolvedMaterializationRewrite::rollback() { } void ConversionPatternRewriterImpl::applyRewrites() { - // Commit all rewrites. - IRRewriter rewriter(context, config.listener); + // Commit all rewrites. Use a new rewriter, so the modifications are not + // tracked for rollback purposes etc. + IRRewriter irRewriter(rewriter.getContext(), config.listener); // Note: New rewrites may be added during the "commit" phase and the // `rewrites` vector may reallocate. for (size_t i = 0; i < rewrites.size(); ++i) - rewrites[i]->commit(rewriter); + rewrites[i]->commit(irRewriter); // Clean up all rewrites. SingleEraseRewriter eraseRewriter( - context, /*opErasedCallback=*/[&](Operation *op) { + rewriter.getContext(), /*opErasedCallback=*/[&](Operation *op) { if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) unresolvedMaterializations.erase(castOp); }); @@ -1264,10 +1278,101 @@ void ConversionPatternRewriterImpl::applyRewrites() { // State Management //===----------------------------------------------------------------------===// -ValueVector -ConversionPatternRewriterImpl::lookupOrDefault(Value from, - TypeRange desiredTypes) const { - return mapping.lookupOrDefault(from, desiredTypes); +ValueVector ConversionPatternRewriterImpl::lookupOrDefault( + Value from, TypeRange desiredTypes, bool skipPureTypeConversions) const { + // Helper function that looks up a single value. + auto lookup = [&](const ValueVector &values) -> ValueVector { + assert(!values.empty() && "expected non-empty value vector"); + + // If the pattern rollback is enabled, use the mapping to look up the + // values. + if (config.allowPatternRollback) + return mapping.lookup(values); + + // Otherwise, look up values by examining the IR. All replacements have + // already been materialized in IR. + Operation *op = getCommonDefiningOp(values); + if (!op) + return {}; + auto castOp = dyn_cast<UnrealizedConversionCastOp>(op); + if (!castOp) + return {}; + if (!this->unresolvedMaterializations.contains(castOp)) + return {}; + if (castOp.getOutputs() != values) + return {}; + return castOp.getInputs(); + }; + + // Helper function that looks up each value in `values` individually and then + // composes the results. If that fails, it tries to look up the entire vector + // at once. + auto composedLookup = [&](const ValueVector &values) -> ValueVector { + // If possible, replace each value with (one or multiple) mapped values. + ValueVector next; + for (Value v : values) { + ValueVector r = lookup({v}); + if (!r.empty()) { + llvm::append_range(next, r); + } else { + next.push_back(v); + } + } + if (next != values) { + // At least one value was replaced. + return next; + } + + // Otherwise: Check if there is a mapping for the entire vector. Such + // mappings are materializations. (N:M mapping are not supported for value + // replacements.) + // + // Note: From a correctness point of view, materializations do not have to + // be stored (and looked up) in the mapping. But for performance reasons, + // we choose to reuse existing IR (when possible) instead of creating it + // multiple times. + ValueVector r = lookup(values); + if (r.empty()) { + // No mapping found: The lookup stops here. + return {}; + } + return r; + }; + + // Try to find the deepest values that have the desired types. If there is no + // such mapping, simply return the deepest values. + ValueVector desiredValue; + ValueVector current{from}; + ValueVector lastNonMaterialization{from}; + do { + // Store the current value if the types match. + bool match = TypeRange(ValueRange(current)) == desiredTypes; + if (skipPureTypeConversions) { + // Skip pure type conversions, if requested. + bool pureConversion = isPureTypeConversion(current); + match &= !pureConversion; + // Keep track of the last mapped value that was not a pure type + // conversion. + if (!pureConversion) + lastNonMaterialization = current; + } + if (match) + desiredValue = current; + + // Lookup next value in the mapping. + ValueVector next = composedLookup(current); + if (next.empty()) + break; + current = std::move(next); + } while (true); + + // If the desired values were found use them, otherwise default to the leaf + // values. (Skip pure type conversions, if requested.) + if (!desiredTypes.empty()) + return desiredValue; + if (skipPureTypeConversions) + return lastNonMaterialization; + return current; } ValueVector @@ -1300,21 +1405,13 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state, void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep, StringRef patternName) { for (auto &rewrite : - llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep))) { - if (!config.allowPatternRollback && - !isa<UnresolvedMaterializationRewrite>(rewrite)) { - // Unresolved materializations can always be rolled back (erased). - llvm::report_fatal_error("pattern '" + patternName + - "' rollback of IR modifications requested"); - } + llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep))) rewrite->rollback(); - } rewrites.resize(numRewritesToKeep); } LogicalResult ConversionPatternRewriterImpl::remapValues( - StringRef valueDiagTag, std::optional<Location> inputLoc, - PatternRewriter &rewriter, ValueRange values, + StringRef valueDiagTag, std::optional<Location> inputLoc, ValueRange values, SmallVector<ValueVector> &remapped) { remapped.reserve(llvm::size(values)); @@ -1324,16 +1421,19 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( Location operandLoc = inputLoc ? *inputLoc : operand.getLoc(); if (!currentTypeConverter) { - // The current pattern does not have a type converter. I.e., it does not - // distinguish between legal and illegal types. For each operand, simply - // pass through the most recently mapped values. - remapped.push_back(lookupOrDefault(operand)); + // The current pattern does not have a type converter. Pass the most + // recently mapped values, excluding materializations. Materializations + // are intentionally excluded because their presence may depend on other + // patterns. Including materializations would make the lookup fragile + // and unpredictable. + remapped.push_back(lookupOrDefault(operand, /*desiredTypes=*/{}, + /*skipPureTypeConversions=*/true)); continue; } // If there is no legal conversion, fail to match this pattern. SmallVector<Type, 1> legalTypes; - if (failed(currentTypeConverter->convertType(origType, legalTypes))) { + if (failed(currentTypeConverter->convertType(operand, legalTypes))) { notifyMatchFailure(operandLoc, [=](Diagnostic &diag) { diag << "unable to convert type for " << valueDiagTag << " #" << it.index() << ", type was " << origType; @@ -1356,7 +1456,8 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( } // Create a materialization for the most recently mapped values. - repl = lookupOrDefault(operand); + repl = lookupOrDefault(operand, /*desiredTypes=*/{}, + /*skipPureTypeConversions=*/true); ValueRange castValues = buildUnresolvedMaterialization( MaterializationKind::Target, computeInsertPoint(repl), operandLoc, /*valuesToMap=*/repl, /*inputs=*/repl, /*outputTypes=*/legalTypes, @@ -1368,12 +1469,12 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const { // Check to see if this operation is ignored or was replaced. - return replacedOps.count(op) || ignoredOps.count(op); + return wasOpReplaced(op) || ignoredOps.count(op); } bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const { // Check to see if this operation was replaced. - return replacedOps.count(op); + return replacedOps.count(op) || erasedOps.count(op); } //===----------------------------------------------------------------------===// @@ -1381,8 +1482,7 @@ bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const { //===----------------------------------------------------------------------===// FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes( - ConversionPatternRewriter &rewriter, Region *region, - const TypeConverter &converter, + Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion) { regionToConverter[region] = &converter; if (region->empty()) @@ -1397,25 +1497,23 @@ FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes( if (!conversion) return failure(); // Convert the block with the computed signature. - applySignatureConversion(rewriter, &block, &converter, *conversion); + applySignatureConversion(&block, &converter, *conversion); } // Convert the entry block. If an entry signature conversion was provided, // use that one. Otherwise, compute the signature with the type converter. if (entryConversion) - return applySignatureConversion(rewriter, ®ion->front(), &converter, + return applySignatureConversion(®ion->front(), &converter, *entryConversion); std::optional<TypeConverter::SignatureConversion> conversion = converter.convertBlockSignature(®ion->front()); if (!conversion) return failure(); - return applySignatureConversion(rewriter, ®ion->front(), &converter, - *conversion); + return applySignatureConversion(®ion->front(), &converter, *conversion); } Block *ConversionPatternRewriterImpl::applySignatureConversion( - ConversionPatternRewriter &rewriter, Block *block, - const TypeConverter *converter, + Block *block, const TypeConverter *converter, TypeConverter::SignatureConversion &signatureConversion) { #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS // A block cannot be converted multiple times. @@ -1457,7 +1555,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( // a bit more efficient, so we try to do that when possible. bool fastPath = !config.listener; if (fastPath) { - appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->end()); + if (config.allowPatternRollback) + appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->end()); newBlock->getOperations().splice(newBlock->end(), block->getOperations()); } else { while (!block->empty()) @@ -1482,7 +1581,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(), /*valuesToMap=*/{}, /*inputs=*/ValueRange(), - /*outputTypes=*/origArgType, /*originalType=*/Type(), converter) + /*outputTypes=*/origArgType, /*originalType=*/Type(), converter, + /*isPureTypeConversion=*/false) .front(); replaceUsesOfBlockArgument(origArg, mat, converter); continue; @@ -1504,7 +1604,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( replaceUsesOfBlockArgument(origArg, replArgs, converter); } - appendRewrite<BlockTypeConversionRewrite>(/*origBlock=*/block, newBlock); + if (config.allowPatternRollback) + appendRewrite<BlockTypeConversionRewrite>(/*origBlock=*/block, newBlock); // Erase the old block. (It is just unlinked for now and will be erased during // cleanup.) @@ -1523,7 +1624,7 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization( MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes, Type originalType, const TypeConverter *converter, - UnrealizedConversionCastOp *castOp) { + bool isPureTypeConversion) { assert((!originalType || kind == MaterializationKind::Target) && "original type is valid only for target materializations"); assert(TypeRange(inputs) != outputTypes && @@ -1533,21 +1634,35 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization( // tracking the materialization like we do for other operations. OpBuilder builder(outputTypes.front().getContext()); builder.setInsertionPoint(ip.getBlock(), ip.getPoint()); - auto convertOp = + UnrealizedConversionCastOp convertOp = UnrealizedConversionCastOp::create(builder, loc, outputTypes, inputs); - if (!valuesToMap.empty()) - mapping.map(valuesToMap, convertOp.getResults()); - if (castOp) - *castOp = convertOp; + if (config.attachDebugMaterializationKind) { + StringRef kindStr = + kind == MaterializationKind::Source ? "source" : "target"; + convertOp->setAttr("__kind__", builder.getStringAttr(kindStr)); + } + if (isPureTypeConversion) + convertOp->setAttr(kPureTypeConversionMarker, builder.getUnitAttr()); + + // Register the materialization. unresolvedMaterializations[convertOp] = UnresolvedMaterializationInfo(converter, kind, originalType); - appendRewrite<UnresolvedMaterializationRewrite>(convertOp, - std::move(valuesToMap)); + if (config.allowPatternRollback) { + if (!valuesToMap.empty()) + mapping.map(valuesToMap, convertOp.getResults()); + appendRewrite<UnresolvedMaterializationRewrite>(convertOp, + std::move(valuesToMap)); + } else { + patternMaterializations.insert(convertOp); + } return convertOp.getResults(); } Value ConversionPatternRewriterImpl::findOrBuildReplacementValue( Value value, const TypeConverter *converter) { + assert(config.allowPatternRollback && + "this code path is valid only in rollback mode"); + // Try to find a replacement value with the same type in the conversion value // mapping. This includes cached materializations. We try to reuse those // instead of generating duplicate IR. @@ -1609,26 +1724,119 @@ void ConversionPatternRewriterImpl::notifyOperationInserted( logger.getOStream() << " (was detached)"; logger.getOStream() << "\n"; }); - assert(!wasOpReplaced(op->getParentOp()) && + + // In rollback mode, it is easier to misuse the API, so perform extra error + // checking. + assert(!(config.allowPatternRollback && wasOpReplaced(op->getParentOp())) && "attempting to insert into a block within a replaced/erased op"); + // In "no rollback" mode, the listener is always notified immediately. + if (!config.allowPatternRollback && config.listener) + config.listener->notifyOperationInserted(op, previous); + if (wasDetached) { - // If the op was detached, it is most likely a newly created op. - // TODO: If the same op is inserted multiple times from a detached state, - // the rollback mechanism may erase the same op multiple times. This is a - // bug in the rollback-based dialect conversion driver. - appendRewrite<CreateOperationRewrite>(op); + // If the op was detached, it is most likely a newly created op. Add it the + // set of newly created ops, so that it will be legalized. If this op is + // not a newly created op, it will be legalized a second time, which is + // inefficient but harmless. patternNewOps.insert(op); + + if (config.allowPatternRollback) { + // TODO: If the same op is inserted multiple times from a detached + // state, the rollback mechanism may erase the same op multiple times. + // This is a bug in the rollback-based dialect conversion driver. + appendRewrite<CreateOperationRewrite>(op); + } else { + // In "no rollback" mode, there is an extra data structure for tracking + // erased operations that must be kept up to date. + erasedOps.erase(op); + } return; } // The op was moved from one place to another. - appendRewrite<MoveOperationRewrite>(op, previous); + if (config.allowPatternRollback) + appendRewrite<MoveOperationRewrite>(op, previous); +} + +/// Given that `fromRange` is about to be replaced with `toRange`, compute +/// replacement values with the types of `fromRange`. +static SmallVector<Value> +getReplacementValues(ConversionPatternRewriterImpl &impl, ValueRange fromRange, + const SmallVector<SmallVector<Value>> &toRange, + const TypeConverter *converter) { + assert(!impl.config.allowPatternRollback && + "this code path is valid only in 'no rollback' mode"); + SmallVector<Value> repls; + for (auto [from, to] : llvm::zip_equal(fromRange, toRange)) { + if (from.use_empty()) { + // The replaced value is dead. No replacement value is needed. + repls.push_back(Value()); + continue; + } + + if (to.empty()) { + // The replaced value is dropped. Materialize a replacement value "out of + // thin air". + Value srcMat = impl.buildUnresolvedMaterialization( + MaterializationKind::Source, computeInsertPoint(from), from.getLoc(), + /*valuesToMap=*/{}, /*inputs=*/ValueRange(), + /*outputTypes=*/from.getType(), /*originalType=*/Type(), + converter)[0]; + repls.push_back(srcMat); + continue; + } + + if (TypeRange(ValueRange(to)) == TypeRange(from.getType())) { + // The replacement value already has the correct type. Use it directly. + repls.push_back(to[0]); + continue; + } + + // The replacement value has the wrong type. Build a source materialization + // to the original type. + // TODO: This is a bit inefficient. We should try to reuse existing + // materializations if possible. This would require an extension of the + // `lookupOrDefault` API. + Value srcMat = impl.buildUnresolvedMaterialization( + MaterializationKind::Source, computeInsertPoint(to), from.getLoc(), + /*valuesToMap=*/{}, /*inputs=*/to, /*outputTypes=*/from.getType(), + /*originalType=*/Type(), converter)[0]; + repls.push_back(srcMat); + } + + return repls; } void ConversionPatternRewriterImpl::replaceOp( Operation *op, SmallVector<SmallVector<Value>> &&newValues) { - assert(newValues.size() == op->getNumResults()); + assert(newValues.size() == op->getNumResults() && + "incorrect number of replacement values"); + + if (!config.allowPatternRollback) { + // Pattern rollback is not allowed: materialize all IR changes immediately. + SmallVector<Value> repls = getReplacementValues( + *this, op->getResults(), newValues, currentTypeConverter); + // Update internal data structures, so that there are no dangling pointers + // to erased IR. + op->walk([&](Operation *op) { + erasedOps.insert(op); + ignoredOps.remove(op); + if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) { + unresolvedMaterializations.erase(castOp); + patternMaterializations.erase(castOp); + } + // The original op will be erased, so remove it from the set of + // unlegalized ops. + if (config.unlegalizedOps) + config.unlegalizedOps->erase(op); + }); + op->walk([&](Block *block) { erasedBlocks.insert(block); }); + // Replace the op with the replacement values and notify the listener. + notifyingRewriter.replaceOp(op, repls); + return; + } + assert(!ignoredOps.contains(op) && "operation was already replaced"); // Check if replaced op is an unresolved materialization, i.e., an @@ -1650,7 +1858,7 @@ void ConversionPatternRewriterImpl::replaceOp( MaterializationKind::Source, computeInsertPoint(result), result.getLoc(), /*valuesToMap=*/{result}, /*inputs=*/ValueRange(), /*outputTypes=*/result.getType(), /*originalType=*/Type(), - currentTypeConverter); + currentTypeConverter, /*isPureTypeConversion=*/false); continue; } @@ -1667,11 +1875,59 @@ void ConversionPatternRewriterImpl::replaceOp( void ConversionPatternRewriterImpl::replaceUsesOfBlockArgument( BlockArgument from, ValueRange to, const TypeConverter *converter) { + if (!config.allowPatternRollback) { + SmallVector<Value> toConv = llvm::to_vector(to); + SmallVector<Value> repls = + getReplacementValues(*this, from, {toConv}, converter); + IRRewriter r(from.getContext()); + Value repl = repls.front(); + if (!repl) + return; + + performReplaceBlockArg(r, from, repl); + return; + } + +#ifndef NDEBUG + // Make sure that a block argument is not replaced multiple times. In + // rollback mode, `replaceUsesOfBlockArgument` replaces not only all current + // uses of the given block argument, but also all future uses that may be + // introduced by future pattern applications. Therefore, it does not make + // sense to call `replaceUsesOfBlockArgument` multiple times with the same + // block argument. Doing so would overwrite the mapping and mess with the + // internal state of the dialect conversion driver. + assert(!replacedArgs.contains(from) && + "attempting to replace a block argument that was already replaced"); + replacedArgs.insert(from); +#endif // NDEBUG + appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from, converter); mapping.map(from, to); } void ConversionPatternRewriterImpl::eraseBlock(Block *block) { + if (!config.allowPatternRollback) { + // Pattern rollback is not allowed: materialize all IR changes immediately. + // Update internal data structures, so that there are no dangling pointers + // to erased IR. + block->walk([&](Operation *op) { + erasedOps.insert(op); + ignoredOps.remove(op); + if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) { + unresolvedMaterializations.erase(castOp); + patternMaterializations.erase(castOp); + } + // The original op will be erased, so remove it from the set of + // unlegalized ops. + if (config.unlegalizedOps) + config.unlegalizedOps->erase(op); + }); + block->walk([&](Block *block) { erasedBlocks.insert(block); }); + // Erase the block and notify the listener. + notifyingRewriter.eraseBlock(block); + return; + } + assert(!wasOpReplaced(block->getParentOp()) && "attempting to erase a block within a replaced/erased op"); appendRewrite<EraseBlockRewrite>(block); @@ -1705,23 +1961,37 @@ void ConversionPatternRewriterImpl::notifyBlockInserted( logger.getOStream() << " (was detached)"; logger.getOStream() << "\n"; }); - assert(!wasOpReplaced(newParentOp) && + + // In rollback mode, it is easier to misuse the API, so perform extra error + // checking. + assert(!(config.allowPatternRollback && wasOpReplaced(newParentOp)) && "attempting to insert into a region within a replaced/erased op"); (void)newParentOp; + // In "no rollback" mode, the listener is always notified immediately. + if (!config.allowPatternRollback && config.listener) + config.listener->notifyBlockInserted(block, previous, previousIt); + patternInsertedBlocks.insert(block); if (wasDetached) { // If the block was detached, it is most likely a newly created block. - // TODO: If the same block is inserted multiple times from a detached state, - // the rollback mechanism may erase the same block multiple times. This is a - // bug in the rollback-based dialect conversion driver. - appendRewrite<CreateBlockRewrite>(block); + if (config.allowPatternRollback) { + // TODO: If the same block is inserted multiple times from a detached + // state, the rollback mechanism may erase the same block multiple times. + // This is a bug in the rollback-based dialect conversion driver. + appendRewrite<CreateBlockRewrite>(block); + } else { + // In "no rollback" mode, there is an extra data structure for tracking + // erased blocks that must be kept up to date. + erasedBlocks.erase(block); + } return; } // The block was moved from one place to another. - appendRewrite<MoveBlockRewrite>(block, previous, previousIt); + if (config.allowPatternRollback) + appendRewrite<MoveBlockRewrite>(block, previous, previousIt); } void ConversionPatternRewriterImpl::inlineBlockBefore(Block *source, @@ -1748,12 +2018,16 @@ void ConversionPatternRewriterImpl::notifyMatchFailure( ConversionPatternRewriter::ConversionPatternRewriter( MLIRContext *ctx, const ConversionConfig &config) : PatternRewriter(ctx), - impl(new detail::ConversionPatternRewriterImpl(ctx, config)) { + impl(new detail::ConversionPatternRewriterImpl(*this, config)) { setListener(impl.get()); } ConversionPatternRewriter::~ConversionPatternRewriter() = default; +const ConversionConfig &ConversionPatternRewriter::getConfig() const { + return impl->config; +} + void ConversionPatternRewriter::replaceOp(Operation *op, Operation *newOp) { assert(op && newOp && "expected non-null op"); replaceOp(op, newOp->getResults()); @@ -1821,7 +2095,7 @@ Block *ConversionPatternRewriter::applySignatureConversion( assert(!impl->wasOpReplaced(block->getParentOp()) && "attempting to apply a signature conversion to a block within a " "replaced/erased op"); - return impl->applySignatureConversion(*this, block, converter, conversion); + return impl->applySignatureConversion(block, converter, conversion); } FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes( @@ -1830,7 +2104,7 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes( assert(!impl->wasOpReplaced(region->getParentOp()) && "attempting to apply a signature conversion to a block within a " "replaced/erased op"); - return impl->convertRegionTypes(*this, region, converter, entryConversion); + return impl->convertRegionTypes(region, converter, entryConversion); } void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from, @@ -1849,7 +2123,7 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from, Value ConversionPatternRewriter::getRemappedValue(Value key) { SmallVector<ValueVector> remappedValues; - if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, key, + if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, key, remappedValues))) return nullptr; assert(remappedValues.front().size() == 1 && "1:N conversion not supported"); @@ -1862,7 +2136,7 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys, if (keys.empty()) return success(); SmallVector<ValueVector> remapped; - if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys, + if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, keys, remapped))) return failure(); for (const auto &values : remapped) { @@ -1895,9 +2169,9 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest, // ops should be moved one-by-one ("slow path"), so that a separate // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is // a bit more efficient, so we try to do that when possible. - bool fastPath = !impl->config.listener; + bool fastPath = !getConfig().listener; - if (fastPath) + if (fastPath && impl->config.allowPatternRollback) impl->inlineBlockBefore(source, dest, before); // Replace all uses of block arguments. @@ -1923,6 +2197,11 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest, } void ConversionPatternRewriter::startOpModification(Operation *op) { + if (!impl->config.allowPatternRollback) { + // Pattern rollback is not allowed: no extra bookkeeping is needed. + PatternRewriter::startOpModification(op); + return; + } assert(!impl->wasOpReplaced(op) && "attempting to modify a replaced/erased op"); #ifndef NDEBUG @@ -1932,20 +2211,29 @@ void ConversionPatternRewriter::startOpModification(Operation *op) { } void ConversionPatternRewriter::finalizeOpModification(Operation *op) { - assert(!impl->wasOpReplaced(op) && - "attempting to modify a replaced/erased op"); - PatternRewriter::finalizeOpModification(op); impl->patternModifiedOps.insert(op); + if (!impl->config.allowPatternRollback) { + PatternRewriter::finalizeOpModification(op); + if (getConfig().listener) + getConfig().listener->notifyOperationModified(op); + return; + } // There is nothing to do here, we only need to track the operation at the // start of the update. #ifndef NDEBUG + assert(!impl->wasOpReplaced(op) && + "attempting to modify a replaced/erased op"); assert(impl->pendingRootUpdates.erase(op) && "operation did not have a pending in-place update"); #endif } void ConversionPatternRewriter::cancelOpModification(Operation *op) { + if (!impl->config.allowPatternRollback) { + PatternRewriter::cancelOpModification(op); + return; + } #ifndef NDEBUG assert(impl->pendingRootUpdates.erase(op) && "operation did not have a pending in-place update"); @@ -1970,17 +2258,17 @@ detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() { // ConversionPattern //===----------------------------------------------------------------------===// -SmallVector<Value> ConversionPattern::getOneToOneAdaptorOperands( +FailureOr<SmallVector<Value>> ConversionPattern::getOneToOneAdaptorOperands( ArrayRef<ValueRange> operands) const { SmallVector<Value> oneToOneOperands; oneToOneOperands.reserve(operands.size()); for (ValueRange operand : operands) { if (operand.size() != 1) - llvm::report_fatal_error("pattern '" + getDebugName() + - "' does not support 1:N conversion"); + return failure(); + oneToOneOperands.push_back(operand.front()); } - return oneToOneOperands; + return std::move(oneToOneOperands); } LogicalResult @@ -1995,7 +2283,7 @@ ConversionPattern::matchAndRewrite(Operation *op, // Remap the operands of the operation. SmallVector<ValueVector> remapped; - if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter, + if (failed(rewriterImpl.remapValues("operand", op->getLoc(), op->getOperands(), remapped))) { return failure(); } @@ -2017,38 +2305,34 @@ class OperationLegalizer { public: using LegalizationAction = ConversionTarget::LegalizationAction; - OperationLegalizer(const ConversionTarget &targetInfo, - const FrozenRewritePatternSet &patterns, - const ConversionConfig &config); + OperationLegalizer(ConversionPatternRewriter &rewriter, + const ConversionTarget &targetInfo, + const FrozenRewritePatternSet &patterns); /// Returns true if the given operation is known to be illegal on the target. bool isIllegal(Operation *op) const; /// Attempt to legalize the given operation. Returns success if the operation /// was legalized, failure otherwise. - LogicalResult legalize(Operation *op, ConversionPatternRewriter &rewriter); + LogicalResult legalize(Operation *op); /// Returns the conversion target in use by the legalizer. const ConversionTarget &getTarget() { return target; } private: /// Attempt to legalize the given operation by folding it. - LogicalResult legalizeWithFold(Operation *op, - ConversionPatternRewriter &rewriter); + LogicalResult legalizeWithFold(Operation *op); /// Attempt to legalize the given operation by applying a pattern. Returns /// success if the operation was legalized, failure otherwise. - LogicalResult legalizeWithPattern(Operation *op, - ConversionPatternRewriter &rewriter); + LogicalResult legalizeWithPattern(Operation *op); /// Return true if the given pattern may be applied to the given operation, /// false otherwise. - bool canApplyPattern(Operation *op, const Pattern &pattern, - ConversionPatternRewriter &rewriter); + bool canApplyPattern(Operation *op, const Pattern &pattern); /// Legalize the resultant IR after successfully applying the given pattern. LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern, - ConversionPatternRewriter &rewriter, const RewriterState &curState, const SetVector<Operation *> &newOps, const SetVector<Operation *> &modifiedOps, @@ -2057,18 +2341,12 @@ private: /// Legalizes the actions registered during the execution of a pattern. LogicalResult legalizePatternBlockRewrites(Operation *op, - ConversionPatternRewriter &rewriter, - ConversionPatternRewriterImpl &impl, const SetVector<Block *> &insertedBlocks, const SetVector<Operation *> &newOps); LogicalResult - legalizePatternCreatedOperations(ConversionPatternRewriter &rewriter, - ConversionPatternRewriterImpl &impl, - const SetVector<Operation *> &newOps); + legalizePatternCreatedOperations(const SetVector<Operation *> &newOps); LogicalResult - legalizePatternRootUpdates(ConversionPatternRewriter &rewriter, - ConversionPatternRewriterImpl &impl, - const SetVector<Operation *> &modifiedOps); + legalizePatternRootUpdates(const SetVector<Operation *> &modifiedOps); //===--------------------------------------------------------------------===// // Cost Model @@ -2111,21 +2389,21 @@ private: /// The current set of patterns that have been applied. SmallPtrSet<const Pattern *, 8> appliedPatterns; + /// The rewriter to use when converting operations. + ConversionPatternRewriter &rewriter; + /// The legalization information provided by the target. const ConversionTarget ⌖ /// The pattern applicator to use for conversions. PatternApplicator applicator; - - /// Dialect conversion configuration. - const ConversionConfig &config; }; } // namespace -OperationLegalizer::OperationLegalizer(const ConversionTarget &targetInfo, - const FrozenRewritePatternSet &patterns, - const ConversionConfig &config) - : target(targetInfo), applicator(patterns), config(config) { +OperationLegalizer::OperationLegalizer(ConversionPatternRewriter &rewriter, + const ConversionTarget &targetInfo, + const FrozenRewritePatternSet &patterns) + : rewriter(rewriter), target(targetInfo), applicator(patterns) { // The set of patterns that can be applied to illegal operations to transform // them into legal ones. DenseMap<OperationName, LegalizationPatterns> legalizerPatterns; @@ -2139,9 +2417,7 @@ bool OperationLegalizer::isIllegal(Operation *op) const { return target.isIllegal(op); } -LogicalResult -OperationLegalizer::legalize(Operation *op, - ConversionPatternRewriter &rewriter) { +LogicalResult OperationLegalizer::legalize(Operation *op) { #ifndef NDEBUG const char *logLineComment = "//===-------------------------------------------===//\n"; @@ -2203,19 +2479,21 @@ OperationLegalizer::legalize(Operation *op, return success(); } - // If the operation isn't legal, try to fold it in-place. - // TODO: Should we always try to do this, even if the op is - // already legal? - if (succeeded(legalizeWithFold(op, rewriter))) { - LLVM_DEBUG({ - logSuccess(logger, "operation was folded"); - logger.startLine() << logLineComment; - }); - return success(); + // If the operation is not legal, try to fold it in-place if the folding mode + // is 'BeforePatterns'. 'Never' will skip this. + const ConversionConfig &config = rewriter.getConfig(); + if (config.foldingMode == DialectConversionFoldingMode::BeforePatterns) { + if (succeeded(legalizeWithFold(op))) { + LLVM_DEBUG({ + logSuccess(logger, "operation was folded"); + logger.startLine() << logLineComment; + }); + return success(); + } } // Otherwise, we need to apply a legalization pattern to this operation. - if (succeeded(legalizeWithPattern(op, rewriter))) { + if (succeeded(legalizeWithPattern(op))) { LLVM_DEBUG({ logSuccess(logger, ""); logger.startLine() << logLineComment; @@ -2223,6 +2501,18 @@ OperationLegalizer::legalize(Operation *op, return success(); } + // If the operation can't be legalized via patterns, try to fold it in-place + // if the folding mode is 'AfterPatterns'. + if (config.foldingMode == DialectConversionFoldingMode::AfterPatterns) { + if (succeeded(legalizeWithFold(op))) { + LLVM_DEBUG({ + logSuccess(logger, "operation was folded"); + logger.startLine() << logLineComment; + }); + return success(); + } + } + LLVM_DEBUG({ logFailure(logger, "no matched legalization pattern"); logger.startLine() << logLineComment; @@ -2239,9 +2529,7 @@ static T moveAndReset(T &obj) { return result; } -LogicalResult -OperationLegalizer::legalizeWithFold(Operation *op, - ConversionPatternRewriter &rewriter) { +LogicalResult OperationLegalizer::legalizeWithFold(Operation *op) { auto &rewriterImpl = rewriter.getImpl(); LLVM_DEBUG({ rewriterImpl.logger.startLine() << "* Fold {\n"; @@ -2275,18 +2563,18 @@ OperationLegalizer::legalizeWithFold(Operation *op, // An empty list of replacement values indicates that the fold was in-place. // As the operation changed, a new legalization needs to be attempted. if (replacementValues.empty()) - return legalize(op, rewriter); + return legalize(op); // Insert a replacement for 'op' with the folded replacement values. rewriter.replaceOp(op, replacementValues); // Recursively legalize any new constant operations. for (Operation *newOp : newOps) { - if (failed(legalize(newOp, rewriter))) { + if (failed(legalize(newOp))) { LLVM_DEBUG(logFailure(rewriterImpl.logger, "failed to legalize generated constant '{0}'", newOp->getName())); - if (!config.allowPatternRollback) { + if (!rewriter.getConfig().allowPatternRollback) { // Rolling back a folder is like rolling back a pattern. llvm::report_fatal_error( "op '" + opName + @@ -2302,10 +2590,34 @@ OperationLegalizer::legalizeWithFold(Operation *op, return success(); } -LogicalResult -OperationLegalizer::legalizeWithPattern(Operation *op, - ConversionPatternRewriter &rewriter) { +/// Report a fatal error indicating that newly produced or modified IR could +/// not be legalized. +static void +reportNewIrLegalizationFatalError(const Pattern &pattern, + const SetVector<Operation *> &newOps, + const SetVector<Operation *> &modifiedOps, + const SetVector<Block *> &insertedBlocks) { + auto newOpNames = llvm::map_range( + newOps, [](Operation *op) { return op->getName().getStringRef(); }); + auto modifiedOpNames = llvm::map_range( + modifiedOps, [](Operation *op) { return op->getName().getStringRef(); }); + StringRef detachedBlockStr = "(detached block)"; + auto insertedBlockNames = llvm::map_range(insertedBlocks, [&](Block *block) { + if (block->getParentOp()) + return block->getParentOp()->getName().getStringRef(); + return detachedBlockStr; + }); + llvm::report_fatal_error( + "pattern '" + pattern.getDebugName() + + "' produced IR that could not be legalized. " + "new ops: {" + + llvm::join(newOpNames, ", ") + "}, " + "modified ops: {" + + llvm::join(modifiedOpNames, ", ") + "}, " + "inserted block into ops: {" + + llvm::join(insertedBlockNames, ", ") + "}"); +} + +LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) { auto &rewriterImpl = rewriter.getImpl(); + const ConversionConfig &config = rewriter.getConfig(); #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS Operation *checkOp; @@ -2335,7 +2647,7 @@ OperationLegalizer::legalizeWithPattern(Operation *op, // Functor that returns if the given pattern may be applied. auto canApply = [&](const Pattern &pattern) { - bool canApply = canApplyPattern(op, pattern, rewriter); + bool canApply = canApplyPattern(op, pattern); if (canApply && config.listener) config.listener->notifyPatternBegin(pattern, op); return canApply; @@ -2345,17 +2657,23 @@ OperationLegalizer::legalizeWithPattern(Operation *op, RewriterState curState = rewriterImpl.getCurrentState(); auto onFailure = [&](const Pattern &pattern) { assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates"); -#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS if (!rewriterImpl.config.allowPatternRollback) { - // Returning "failure" after modifying IR is not allowed. + // Erase all unresolved materializations. + for (auto op : rewriterImpl.patternMaterializations) { + rewriterImpl.unresolvedMaterializations.erase(op); + op.erase(); + } + rewriterImpl.patternMaterializations.clear(); +#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + // Expensive pattern check that can detect API violations. if (checkOp) { OperationFingerPrint fingerPrintAfterPattern(checkOp); if (fingerPrintAfterPattern != *topLevelFingerPrint) llvm::report_fatal_error("pattern '" + pattern.getDebugName() + "' returned failure but IR did change"); } - } #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + } rewriterImpl.patternNewOps.clear(); rewriterImpl.patternModifiedOps.clear(); rewriterImpl.patternInsertedBlocks.clear(); @@ -2379,18 +2697,28 @@ OperationLegalizer::legalizeWithPattern(Operation *op, // successfully applied. auto onSuccess = [&](const Pattern &pattern) { assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates"); + if (!rewriterImpl.config.allowPatternRollback) { + // Eagerly erase unused materializations. + for (auto op : rewriterImpl.patternMaterializations) { + if (op->use_empty()) { + rewriterImpl.unresolvedMaterializations.erase(op); + op.erase(); + } + } + rewriterImpl.patternMaterializations.clear(); + } SetVector<Operation *> newOps = moveAndReset(rewriterImpl.patternNewOps); SetVector<Operation *> modifiedOps = moveAndReset(rewriterImpl.patternModifiedOps); SetVector<Block *> insertedBlocks = moveAndReset(rewriterImpl.patternInsertedBlocks); - auto result = legalizePatternResult(op, pattern, rewriter, curState, newOps, + auto result = legalizePatternResult(op, pattern, curState, newOps, modifiedOps, insertedBlocks); appliedPatterns.erase(&pattern); if (failed(result)) { if (!rewriterImpl.config.allowPatternRollback) - llvm::report_fatal_error("pattern '" + pattern.getDebugName() + - "' produced IR that could not be legalized"); + reportNewIrLegalizationFatalError(pattern, newOps, modifiedOps, + insertedBlocks); rewriterImpl.resetState(curState, pattern.getDebugName()); } if (config.listener) @@ -2403,8 +2731,8 @@ OperationLegalizer::legalizeWithPattern(Operation *op, onSuccess); } -bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern, - ConversionPatternRewriter &rewriter) { +bool OperationLegalizer::canApplyPattern(Operation *op, + const Pattern &pattern) { LLVM_DEBUG({ auto &os = rewriter.getImpl().logger; os.getOStream() << "\n"; @@ -2426,11 +2754,11 @@ bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern, } LogicalResult OperationLegalizer::legalizePatternResult( - Operation *op, const Pattern &pattern, ConversionPatternRewriter &rewriter, - const RewriterState &curState, const SetVector<Operation *> &newOps, + Operation *op, const Pattern &pattern, const RewriterState &curState, + const SetVector<Operation *> &newOps, const SetVector<Operation *> &modifiedOps, const SetVector<Block *> &insertedBlocks) { - auto &impl = rewriter.getImpl(); + [[maybe_unused]] auto &impl = rewriter.getImpl(); assert(impl.pendingRootUpdates.empty() && "dangling root updates"); #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS @@ -2448,10 +2776,9 @@ LogicalResult OperationLegalizer::legalizePatternResult( #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS // Legalize each of the actions registered during application. - if (failed(legalizePatternBlockRewrites(op, rewriter, impl, insertedBlocks, - newOps)) || - failed(legalizePatternRootUpdates(rewriter, impl, modifiedOps)) || - failed(legalizePatternCreatedOperations(rewriter, impl, newOps))) { + if (failed(legalizePatternBlockRewrites(op, insertedBlocks, newOps)) || + failed(legalizePatternRootUpdates(modifiedOps)) || + failed(legalizePatternCreatedOperations(newOps))) { return failure(); } @@ -2460,15 +2787,17 @@ LogicalResult OperationLegalizer::legalizePatternResult( } LogicalResult OperationLegalizer::legalizePatternBlockRewrites( - Operation *op, ConversionPatternRewriter &rewriter, - ConversionPatternRewriterImpl &impl, - const SetVector<Block *> &insertedBlocks, + Operation *op, const SetVector<Block *> &insertedBlocks, const SetVector<Operation *> &newOps) { + ConversionPatternRewriterImpl &impl = rewriter.getImpl(); SmallPtrSet<Operation *, 16> alreadyLegalized; // If the pattern moved or created any blocks, make sure the types of block // arguments get legalized. for (Block *block : insertedBlocks) { + if (impl.erasedBlocks.contains(block)) + continue; + // Only check blocks outside of the current operation. Operation *parentOp = block->getParentOp(); if (!parentOp || parentOp == op || block->getNumArguments() == 0) @@ -2484,7 +2813,7 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites( "block")); return failure(); } - impl.applySignatureConversion(rewriter, block, converter, *conversion); + impl.applySignatureConversion(block, converter, *conversion); continue; } @@ -2493,7 +2822,7 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites( // operation, and blocks in regions created by this pattern will already be // legalized later on. if (!newOps.count(parentOp) && alreadyLegalized.insert(parentOp).second) { - if (failed(legalize(parentOp, rewriter))) { + if (failed(legalize(parentOp))) { LLVM_DEBUG(logFailure( impl.logger, "operation '{0}'({1}) became illegal after rewrite", parentOp->getName(), parentOp)); @@ -2505,11 +2834,10 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites( } LogicalResult OperationLegalizer::legalizePatternCreatedOperations( - ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl, const SetVector<Operation *> &newOps) { for (Operation *op : newOps) { - if (failed(legalize(op, rewriter))) { - LLVM_DEBUG(logFailure(impl.logger, + if (failed(legalize(op))) { + LLVM_DEBUG(logFailure(rewriter.getImpl().logger, "failed to legalize generated operation '{0}'({1})", op->getName(), op)); return failure(); @@ -2519,13 +2847,13 @@ LogicalResult OperationLegalizer::legalizePatternCreatedOperations( } LogicalResult OperationLegalizer::legalizePatternRootUpdates( - ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl, const SetVector<Operation *> &modifiedOps) { for (Operation *op : modifiedOps) { - if (failed(legalize(op, rewriter))) { - LLVM_DEBUG(logFailure( - impl.logger, "failed to legalize operation updated in-place '{0}'", - op->getName())); + if (failed(legalize(op))) { + LLVM_DEBUG( + logFailure(rewriter.getImpl().logger, + "failed to legalize operation updated in-place '{0}'", + op->getName())); return failure(); } } @@ -2745,11 +3073,11 @@ namespace mlir { // rewrite patterns. The conversion behaves differently depending on the // conversion mode. struct OperationConverter { - explicit OperationConverter(const ConversionTarget &target, + explicit OperationConverter(MLIRContext *ctx, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, const ConversionConfig &config, OpConversionMode mode) - : config(config), opLegalizer(target, patterns, this->config), + : rewriter(ctx, config), opLegalizer(rewriter, target, patterns), mode(mode) {} /// Converts the given operations to the conversion target. @@ -2757,10 +3085,10 @@ struct OperationConverter { private: /// Converts an operation with the given rewriter. - LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op); + LogicalResult convert(Operation *op); - /// Dialect conversion configuration. - ConversionConfig config; + /// The rewriter to use when converting operations. + ConversionPatternRewriter rewriter; /// The legalizer to use when converting operations. OperationLegalizer opLegalizer; @@ -2770,10 +3098,11 @@ private: }; } // namespace mlir -LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter, - Operation *op) { +LogicalResult OperationConverter::convert(Operation *op) { + const ConversionConfig &config = rewriter.getConfig(); + // Legalize the given operation. - if (failed(opLegalizer.legalize(op, rewriter))) { + if (failed(opLegalizer.legalize(op))) { // Handle the case of a failed conversion for each of the different modes. // Full conversions expect all operations to be converted. if (mode == OpConversionMode::Full) @@ -2849,7 +3178,6 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter, } LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) { - assert(!ops.empty() && "expected at least one operation"); const ConversionTarget &target = opLegalizer.getTarget(); // Compute the set of operations and blocks to convert. @@ -2868,11 +3196,10 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) { } // Convert each operation and discard rewrites on failure. - ConversionPatternRewriter rewriter(ops.front()->getContext(), config); ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl(); for (auto *op : toConvert) { - if (failed(convert(rewriter, op))) { + if (failed(convert(op))) { // Dialect conversion failed. if (rewriterImpl.config.allowPatternRollback) { // Rollback is allowed: restore the original IR. @@ -2902,14 +3229,21 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) { SmallVector<UnrealizedConversionCastOp> remainingCastOps; reconcileUnrealizedCasts(allCastOps, &remainingCastOps); + // Drop markers. + for (UnrealizedConversionCastOp castOp : remainingCastOps) + castOp->removeAttr(kPureTypeConversionMarker); + // Try to legalize all unresolved materializations. - if (config.buildMaterializations) { - IRRewriter rewriter(rewriterImpl.context, config.listener); + if (rewriter.getConfig().buildMaterializations) { + // Use a new rewriter, so the modifications are not tracked for rollback + // purposes etc. + IRRewriter irRewriter(rewriterImpl.rewriter.getContext(), + rewriter.getConfig().listener); for (UnrealizedConversionCastOp castOp : remainingCastOps) { auto it = materializations.find(castOp); assert(it != materializations.end() && "inconsistent state"); - if (failed( - legalizeUnresolvedMaterialization(rewriter, castOp, it->second))) + if (failed(legalizeUnresolvedMaterialization(irRewriter, castOp, + it->second))) return failure(); } } @@ -3076,6 +3410,27 @@ LogicalResult TypeConverter::convertType(Type t, return failure(); } +LogicalResult TypeConverter::convertType(Value v, + SmallVectorImpl<Type> &results) const { + assert(v && "expected non-null value"); + + // If this type converter does not have context-aware type conversions, call + // the type-based overload, which has caching. + if (!hasContextAwareTypeConversions) + return convertType(v.getType(), results); + + // Walk the added converters in reverse order to apply the most recently + // registered first. + for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) { + if (std::optional<LogicalResult> result = converter(v, results)) { + if (!succeeded(*result)) + return failure(); + return success(); + } + } + return failure(); +} + Type TypeConverter::convertType(Type t) const { // Use the multi-type result version to convert the type. SmallVector<Type, 1> results; @@ -3086,6 +3441,16 @@ Type TypeConverter::convertType(Type t) const { return results.size() == 1 ? results.front() : nullptr; } +Type TypeConverter::convertType(Value v) const { + // Use the multi-type result version to convert the type. + SmallVector<Type, 1> results; + if (failed(convertType(v, results))) + return nullptr; + + // Check to ensure that only one type was produced. + return results.size() == 1 ? results.front() : nullptr; +} + LogicalResult TypeConverter::convertTypes(TypeRange types, SmallVectorImpl<Type> &results) const { @@ -3095,21 +3460,38 @@ TypeConverter::convertTypes(TypeRange types, return success(); } +LogicalResult +TypeConverter::convertTypes(ValueRange values, + SmallVectorImpl<Type> &results) const { + for (Value value : values) + if (failed(convertType(value, results))) + return failure(); + return success(); +} + bool TypeConverter::isLegal(Type type) const { return convertType(type) == type; } + +bool TypeConverter::isLegal(Value value) const { + return convertType(value) == value.getType(); +} + bool TypeConverter::isLegal(Operation *op) const { - return isLegal(op->getOperandTypes()) && isLegal(op->getResultTypes()); + return isLegal(op->getOperands()) && isLegal(op->getResults()); } bool TypeConverter::isLegal(Region *region) const { - return llvm::all_of(*region, [this](Block &block) { - return isLegal(block.getArgumentTypes()); - }); + return llvm::all_of( + *region, [this](Block &block) { return isLegal(block.getArguments()); }); } bool TypeConverter::isSignatureLegal(FunctionType ty) const { - return isLegal(llvm::concat<const Type>(ty.getInputs(), ty.getResults())); + if (!isLegal(ty.getInputs())) + return false; + if (!isLegal(ty.getResults())) + return false; + return true; } LogicalResult @@ -3137,6 +3519,31 @@ TypeConverter::convertSignatureArgs(TypeRange types, return failure(); return success(); } +LogicalResult +TypeConverter::convertSignatureArg(unsigned inputNo, Value value, + SignatureConversion &result) const { + // Try to convert the given input type. + SmallVector<Type, 1> convertedTypes; + if (failed(convertType(value, convertedTypes))) + return failure(); + + // If this argument is being dropped, there is nothing left to do. + if (convertedTypes.empty()) + return success(); + + // Otherwise, add the new inputs. + result.addInputs(inputNo, convertedTypes); + return success(); +} +LogicalResult +TypeConverter::convertSignatureArgs(ValueRange values, + SignatureConversion &result, + unsigned origInputOffset) const { + for (unsigned i = 0, e = values.size(); i != e; ++i) + if (failed(convertSignatureArg(origInputOffset + i, values[i], result))) + return failure(); + return success(); +} Value TypeConverter::materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, @@ -3180,7 +3587,7 @@ SmallVector<Value> TypeConverter::materializeTargetConversion( std::optional<TypeConverter::SignatureConversion> TypeConverter::convertBlockSignature(Block *block) const { SignatureConversion conversion(block->getNumArguments()); - if (failed(convertSignatureArgs(block->getArgumentTypes(), conversion))) + if (failed(convertSignatureArgs(block->getArguments(), conversion))) return std::nullopt; return conversion; } @@ -3305,7 +3712,7 @@ mlir::convertOpResultTypes(Operation *op, ValueRange operands, newOp.addOperands(operands); SmallVector<Type> newResultTypes; - if (failed(converter.convertTypes(op->getResultTypes(), newResultTypes))) + if (failed(converter.convertTypes(op->getResults(), newResultTypes))) return rewriter.notifyMatchFailure(loc, "couldn't convert return types"); newOp.addTypes(newResultTypes); @@ -3578,7 +3985,8 @@ static LogicalResult applyConversion(ArrayRef<Operation *> ops, SmallVector<IRUnit> irUnits(ops.begin(), ops.end()); ctx->executeAction<ApplyConversionAction>( [&] { - OperationConverter opConverter(target, patterns, config, mode); + OperationConverter opConverter(ops.front()->getContext(), target, + patterns, config, mode); status = opConverter.convertOperations(ops); }, irUnits); diff --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp index e9adda0..5e07509 100644 --- a/mlir/lib/Transforms/Utils/FoldUtils.cpp +++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp @@ -154,11 +154,14 @@ bool OperationFolder::insertKnownConstant(Operation *op, Attribute constValue) { // already at the front of the block, or the previous operation is already a // constant we unique'd (i.e. one we inserted), then we don't need to do // anything. Otherwise, we move the constant to the insertion block. + // The location info is erased if the constant is moved to a different block. Block *insertBlock = &insertRegion->front(); - if (opBlock != insertBlock || (&insertBlock->front() != op && - !isFolderOwnedConstant(op->getPrevNode()))) { + if (opBlock != insertBlock) { op->moveBefore(&insertBlock->front()); op->setLoc(erasedFoldedLocation); + } else if (&insertBlock->front() != op && + !isFolderOwnedConstant(op->getPrevNode())) { + op->moveBefore(&insertBlock->front()); } folderConstOp = op; diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index 607b86c..0324588 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -15,6 +15,8 @@ #include "mlir/Config/mlir-config.h" #include "mlir/IR/Action.h" #include "mlir/IR/Matchers.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/OperationSupport.h" #include "mlir/IR/Verifier.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Rewrite/PatternApplicator.h" @@ -23,7 +25,7 @@ #include "llvm/ADT/BitVector.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/ScopeExit.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/ScopedPrinter.h" #include "llvm/Support/raw_ostream.h" @@ -178,9 +180,8 @@ static Operation *getDumpRootOp(Operation *op) { return op; } static void logSuccessfulFolding(Operation *op) { - llvm::dbgs() << "// *** IR Dump After Successful Folding ***\n"; - op->dump(); - llvm::dbgs() << "\n\n"; + LDBG() << "// *** IR Dump After Successful Folding ***\n" + << OpWithFlags(op, OpPrintingFlags().elideLargeElementsAttrs()); } #endif // NDEBUG @@ -394,8 +395,12 @@ private: function_ref<void(Diagnostic &)> reasonCallback) override; #ifndef NDEBUG + /// A raw output stream used to prefix the debug log. + + llvm::impl::raw_ldbg_ostream os{(Twine("[") + DEBUG_TYPE + ":1] ").str(), + llvm::dbgs()}; /// A logger used to emit information during the application process. - llvm::ScopedPrinter logger{llvm::dbgs()}; + llvm::ScopedPrinter logger{os}; #endif /// The low-level pattern applicator. @@ -871,7 +876,18 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && { ctx->executeAction<GreedyPatternRewriteIteration>( [&] { - continueRewrites = processWorklist(); + continueRewrites = false; + + // Erase unreachable blocks + // Operations like: + // %add = arith.addi %add, %add : i64 + // are legal in unreachable code. Unfortunately many patterns would be + // unsafe to apply on such IR and can lead to crashes or infinite + // loops. + continueRewrites |= + succeeded(eraseUnreachableBlocks(rewriter, region)); + + continueRewrites |= processWorklist(); // After applying patterns, make sure that the CFG of each of the // regions is kept up to date. @@ -917,10 +933,9 @@ mlir::applyPatternsGreedily(Region ®ion, RegionPatternRewriteDriver driver(region.getContext(), patterns, config, region); LogicalResult converged = std::move(driver).simplify(changed); - LLVM_DEBUG(if (failed(converged)) { - llvm::dbgs() << "The pattern rewrite did not converge after scanning " - << config.getMaxIterations() << " times\n"; - }); + if (failed(converged)) + LDBG() << "The pattern rewrite did not converge after scanning " + << config.getMaxIterations() << " times"; return converged; } @@ -1052,9 +1067,8 @@ LogicalResult mlir::applyOpPatternsGreedily( LogicalResult converged = std::move(driver).simplify(ops, changed); if (allErased) *allErased = surviving.empty(); - LLVM_DEBUG(if (failed(converged)) { - llvm::dbgs() << "The pattern rewrite did not converge after " - << config.getMaxNumRewrites() << " rewrites"; - }); + if (failed(converged)) + LDBG() << "The pattern rewrite did not converge after " + << config.getMaxNumRewrites() << " rewrites"; return converged; } diff --git a/mlir/lib/Transforms/Utils/InliningUtils.cpp b/mlir/lib/Transforms/Utils/InliningUtils.cpp index eeb4052..73107cf 100644 --- a/mlir/lib/Transforms/Utils/InliningUtils.cpp +++ b/mlir/lib/Transforms/Utils/InliningUtils.cpp @@ -13,10 +13,12 @@ #include "mlir/Transforms/InliningUtils.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/Operation.h" #include "mlir/Interfaces/CallInterfaces.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/raw_ostream.h" #include <optional> @@ -182,13 +184,16 @@ static bool isLegalToInline(InlinerInterface &interface, Region *src, IRMapping &valueMapping) { for (auto &block : *src) { for (auto &op : block) { + // UnrealizedConversionCastOp is inlineable but cannot implement the + // inliner interface due to layering constraints. + if (isa<UnrealizedConversionCastOp>(op)) + continue; + // Check this operation. if (!interface.isLegalToInline(&op, insertRegion, shouldCloneInlinedRegion, valueMapping)) { - LLVM_DEBUG({ - llvm::dbgs() << "* Illegal to inline because of op: "; - op.dump(); - }); + LDBG() << "* Illegal to inline because of op: " + << OpWithFlags(&op, OpPrintingFlags().skipRegions()); return false; } // Check any nested regions. diff --git a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp index cb3f2c5..111f58e 100644 --- a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp @@ -13,11 +13,13 @@ #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" #include "mlir/IR/Operation.h" +#include "mlir/IR/OperationSupport.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/SubsetOpInterface.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include <queue> #define DEBUG_TYPE "licm" @@ -64,8 +66,7 @@ size_t mlir::moveLoopInvariantCode( size_t numMoved = 0; for (Region *region : regions) { - LLVM_DEBUG(llvm::dbgs() << "Original loop:\n" - << *region->getParentOp() << "\n"); + LDBG() << "Original loop:\n" << *region->getParentOp(); std::queue<Operation *> worklist; // Add top-level operations in the loop body to the worklist. @@ -83,12 +84,13 @@ size_t mlir::moveLoopInvariantCode( if (op->getParentRegion() != region) continue; - LLVM_DEBUG(llvm::dbgs() << "Checking op: " << *op << "\n"); + LDBG() << "Checking op: " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); if (!shouldMoveOutOfRegion(op, region) || !canBeHoisted(op, definedOutside)) continue; - LLVM_DEBUG(llvm::dbgs() << "Moving loop-invariant op: " << *op << "\n"); + LDBG() << "Moving loop-invariant op: " << *op; moveOutOfRegion(op, region); ++numMoved; @@ -322,7 +324,7 @@ static LoopLikeOpInterface hoistSubsetAtIterArg(RewriterBase &rewriter, LoopLikeOpInterface loopLike, BlockArgument iterArg) { assert(iterArg.getOwner()->getParentOp() == loopLike && "invalid iter_arg"); - auto it = llvm::find(loopLike.getRegionIterArgs(), iterArg); + BlockArgument *it = llvm::find(loopLike.getRegionIterArgs(), iterArg); int64_t iterArgIdx = std::distance(loopLike.getRegionIterArgs().begin(), it); MatchingSubsets subsets; if (failed(subsets.populateSubsetOpsAtIterArg(loopLike, iterArg))) diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp index a1d975d..31ae1d1 100644 --- a/mlir/lib/Transforms/Utils/RegionUtils.cpp +++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp @@ -23,12 +23,15 @@ #include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/Support/DebugLog.h" #include <deque> #include <iterator> using namespace mlir; +#define DEBUG_TYPE "region-utils" + void mlir::replaceAllUsesInRegionWith(Value orig, Value replacement, Region ®ion) { for (auto &use : llvm::make_early_inc_range(orig.getUses())) { @@ -182,19 +185,34 @@ SmallVector<Value> mlir::makeRegionIsolatedFromAbove( // TODO: We could likely merge this with the DCE algorithm below. LogicalResult mlir::eraseUnreachableBlocks(RewriterBase &rewriter, MutableArrayRef<Region> regions) { + LDBG() << "Starting eraseUnreachableBlocks with " << regions.size() + << " regions"; + // Set of blocks found to be reachable within a given region. llvm::df_iterator_default_set<Block *, 16> reachable; // If any blocks were found to be dead. - bool erasedDeadBlocks = false; + int erasedDeadBlocks = 0; SmallVector<Region *, 1> worklist; worklist.reserve(regions.size()); for (Region ®ion : regions) worklist.push_back(®ion); + + LDBG(2) << "Initial worklist size: " << worklist.size(); + while (!worklist.empty()) { Region *region = worklist.pop_back_val(); - if (region->empty()) + if (region->empty()) { + LDBG(2) << "Skipping empty region"; continue; + } + + LDBG(2) << "Processing region with " << region->getBlocks().size() + << " blocks"; + if (region->getParentOp()) + LDBG(2) << " -> for operation: " + << OpWithFlags(region->getParentOp(), + OpPrintingFlags().skipRegions()); // If this is a single block region, just collect the nested regions. if (region->hasOneBlock()) { @@ -209,13 +227,17 @@ LogicalResult mlir::eraseUnreachableBlocks(RewriterBase &rewriter, for (Block *block : depth_first_ext(®ion->front(), reachable)) (void)block /* Mark all reachable blocks */; + LDBG(2) << "Found " << reachable.size() << " reachable blocks out of " + << region->getBlocks().size() << " total blocks"; + // Collect all of the dead blocks and push the live regions onto the // worklist. for (Block &block : llvm::make_early_inc_range(*region)) { if (!reachable.count(&block)) { + LDBG() << "Erasing unreachable block: " << █ block.dropAllDefinedValueUses(); rewriter.eraseBlock(&block); - erasedDeadBlocks = true; + ++erasedDeadBlocks; continue; } @@ -226,7 +248,10 @@ LogicalResult mlir::eraseUnreachableBlocks(RewriterBase &rewriter, } } - return success(erasedDeadBlocks); + LDBG() << "Finished eraseUnreachableBlocks, erased " << erasedDeadBlocks + << " dead blocks"; + + return success(erasedDeadBlocks > 0); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp index ee5c642..1382550 100644 --- a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp @@ -13,18 +13,40 @@ #include "mlir/Transforms/WalkPatternRewriteDriver.h" #include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Verifier.h" #include "mlir/IR/Visitors.h" #include "mlir/Rewrite/PatternApplicator.h" -#include "llvm/Support/Debug.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/ErrorHandling.h" #define DEBUG_TYPE "walk-rewriter" namespace mlir { +// Find all reachable blocks in the region and add them to the visitedBlocks +// set. +static void findReachableBlocks(Region ®ion, + DenseSet<Block *> &reachableBlocks) { + Block *entryBlock = ®ion.front(); + reachableBlocks.insert(entryBlock); + // Traverse the CFG and add all reachable blocks to the blockList. + SmallVector<Block *> worklist({entryBlock}); + while (!worklist.empty()) { + Block *block = worklist.pop_back_val(); + Operation *terminator = &block->back(); + for (Block *successor : terminator->getSuccessors()) { + if (reachableBlocks.contains(successor)) + continue; + worklist.push_back(successor); + reachableBlocks.insert(successor); + } + } +} + namespace { struct WalkAndApplyPatternsAction final : tracing::ActionImpl<WalkAndApplyPatternsAction> { @@ -88,20 +110,104 @@ void walkAndApplyPatterns(Operation *op, PatternApplicator applicator(patterns); applicator.applyDefaultCostModel(); + // Iterator on all reachable operations in the region. + // Also keep track if we visited the nested regions of the current op + // already to drive the post-order traversal. + struct RegionReachableOpIterator { + RegionReachableOpIterator(Region *region) : region(region) { + regionIt = region->begin(); + if (regionIt != region->end()) + blockIt = regionIt->begin(); + if (!llvm::hasSingleElement(*region)) + findReachableBlocks(*region, reachableBlocks); + } + // Advance the iterator to the next reachable operation. + void advance() { + assert(regionIt != region->end()); + hasVisitedRegions = false; + if (blockIt == regionIt->end()) { + ++regionIt; + while (regionIt != region->end() && + !reachableBlocks.contains(&*regionIt)) + ++regionIt; + if (regionIt != region->end()) + blockIt = regionIt->begin(); + return; + } + ++blockIt; + if (blockIt != regionIt->end()) { + LDBG() << "Incrementing block iterator, next op: " + << OpWithFlags(&*blockIt, OpPrintingFlags().skipRegions()); + } + } + // The region we're iterating over. + Region *region; + // The Block currently being iterated over. + Region::iterator regionIt; + // The Operation currently being iterated over. + Block::iterator blockIt; + // The set of blocks that are reachable in the current region. + DenseSet<Block *> reachableBlocks; + // Whether we've visited the nested regions of the current op already. + bool hasVisitedRegions = false; + }; + + // Worklist of regions to visit to drive the post-order traversal. + SmallVector<RegionReachableOpIterator> worklist; + + LDBG() << "Starting walk-based pattern rewrite driver"; ctx->executeAction<WalkAndApplyPatternsAction>( [&] { + // Perform a post-order traversal of the regions, visiting each + // reachable operation. for (Region ®ion : op->getRegions()) { - region.walk([&](Operation *visitedOp) { - LLVM_DEBUG(llvm::dbgs() << "Visiting op: "; visitedOp->print( - llvm::dbgs(), OpPrintingFlags().skipRegions()); - llvm::dbgs() << "\n";); + assert(worklist.empty()); + if (region.empty()) + continue; + + // Prime the worklist with the entry block of this region. + worklist.push_back({®ion}); + while (!worklist.empty()) { + RegionReachableOpIterator &it = worklist.back(); + if (it.regionIt == it.region->end()) { + // We're done with this region. + worklist.pop_back(); + continue; + } + if (it.blockIt == it.regionIt->end()) { + // We're done with this block. + it.advance(); + continue; + } + Operation *op = &*it.blockIt; + // If we haven't visited the nested regions of this op yet, + // enqueue them. + if (!it.hasVisitedRegions) { + it.hasVisitedRegions = true; + for (Region &nestedRegion : llvm::reverse(op->getRegions())) { + if (nestedRegion.empty()) + continue; + worklist.push_back({&nestedRegion}); + } + } + // If we're not at the back of the worklist, we've enqueued some + // nested region for processing. We'll come back to this op later + // (post-order) + if (&it != &worklist.back()) + continue; + + // Preemptively increment the iterator, in case the current op + // would be erased. + it.advance(); + + LDBG() << "Visiting op: " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS - erasedListener.visitedOp = visitedOp; + erasedListener.visitedOp = op; #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS - if (succeeded(applicator.matchAndRewrite(visitedOp, rewriter))) { - LLVM_DEBUG(llvm::dbgs() << "\tOp matched and rewritten\n";); - } - }); + if (succeeded(applicator.matchAndRewrite(op, rewriter))) + LDBG() << "\tOp matched and rewritten"; + } } }, {op}); |