diff options
Diffstat (limited to 'mlir/lib/Transforms')
-rw-r--r-- | mlir/lib/Transforms/CMakeLists.txt | 1 | ||||
-rw-r--r-- | mlir/lib/Transforms/CSE.cpp | 6 | ||||
-rw-r--r-- | mlir/lib/Transforms/Canonicalizer.cpp | 1 | ||||
-rw-r--r-- | mlir/lib/Transforms/OpStats.cpp | 2 | ||||
-rw-r--r-- | mlir/lib/Transforms/RemoveDeadValues.cpp | 65 | ||||
-rw-r--r-- | mlir/lib/Transforms/Utils/DialectConversion.cpp | 229 | ||||
-rw-r--r-- | mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp | 1 | ||||
-rw-r--r-- | mlir/lib/Transforms/Utils/Inliner.cpp | 33 | ||||
-rw-r--r-- | mlir/lib/Transforms/Utils/RegionUtils.cpp | 4 |
9 files changed, 226 insertions, 116 deletions
diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt index 3a8088b..058039e 100644 --- a/mlir/lib/Transforms/CMakeLists.txt +++ b/mlir/lib/Transforms/CMakeLists.txt @@ -37,5 +37,4 @@ add_mlir_library(MLIRTransforms MLIRSideEffectInterfaces MLIRSupport MLIRTransformUtils - MLIRUBDialect ) diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp index 8e03f62..09e5a02 100644 --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -19,7 +19,6 @@ #include "mlir/Pass/Pass.h" #include "mlir/Transforms/Passes.h" #include "llvm/ADT/DenseMapInfo.h" -#include "llvm/ADT/Hashing.h" #include "llvm/ADT/ScopedHashTable.h" #include "llvm/Support/Allocator.h" #include "llvm/Support/RecyclingAllocator.h" @@ -239,9 +238,8 @@ LogicalResult CSEDriver::simplifyOperation(ScopedMapTy &knownValues, // Don't simplify operations with regions that have multiple blocks. // TODO: We need additional tests to verify that we handle such IR correctly. - if (!llvm::all_of(op->getRegions(), [](Region &r) { - return r.getBlocks().empty() || llvm::hasSingleElement(r.getBlocks()); - })) + if (!llvm::all_of(op->getRegions(), + [](Region &r) { return r.empty() || r.hasOneBlock(); })) return failure(); // Some simple use case of operation with memory side-effect are dealt with diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp index 4b0ac28..7a99fe8 100644 --- a/mlir/lib/Transforms/Canonicalizer.cpp +++ b/mlir/lib/Transforms/Canonicalizer.cpp @@ -13,7 +13,6 @@ #include "mlir/Transforms/Passes.h" -#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" diff --git a/mlir/lib/Transforms/OpStats.cpp b/mlir/lib/Transforms/OpStats.cpp index 0dc3fe9..9ebf310 100644 --- a/mlir/lib/Transforms/OpStats.cpp +++ b/mlir/lib/Transforms/OpStats.cpp @@ -8,10 +8,8 @@ #include "mlir/Transforms/Passes.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Operation.h" #include "mlir/IR/OperationSupport.h" -#include "llvm/ADT/DenseMap.h" #include "llvm/Support/Format.h" #include "llvm/Support/raw_ostream.h" diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp index 608bdcb..4ccb83f 100644 --- a/mlir/lib/Transforms/RemoveDeadValues.cpp +++ b/mlir/lib/Transforms/RemoveDeadValues.cpp @@ -36,6 +36,7 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Dialect.h" +#include "mlir/IR/Operation.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/SymbolTable.h" #include "mlir/IR/Value.h" @@ -51,6 +52,7 @@ #include "mlir/Transforms/Passes.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include <cassert> #include <cstddef> #include <memory> @@ -58,8 +60,6 @@ #include <vector> #define DEBUG_TYPE "remove-dead-values" -#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") -#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") namespace mlir { #define GEN_PASS_DEF_REMOVEDEADVALUES @@ -119,21 +119,21 @@ static bool hasLive(ValueRange values, const DenseSet<Value> &nonLiveSet, RunLivenessAnalysis &la) { for (Value value : values) { if (nonLiveSet.contains(value)) { - LDBG("Value " << value << " is already marked non-live (dead)"); + LDBG() << "Value " << value << " is already marked non-live (dead)"; continue; } const Liveness *liveness = la.getLiveness(value); if (!liveness) { - LDBG("Value " << value - << " has no liveness info, conservatively considered live"); + LDBG() << "Value " << value + << " has no liveness info, conservatively considered live"; return true; } if (liveness->isLive) { - LDBG("Value " << value << " is live according to liveness analysis"); + LDBG() << "Value " << value << " is live according to liveness analysis"; return true; } else { - LDBG("Value " << value << " is dead according to liveness analysis"); + LDBG() << "Value " << value << " is dead according to liveness analysis"; } } return false; @@ -148,8 +148,8 @@ static BitVector markLives(ValueRange values, const DenseSet<Value> &nonLiveSet, for (auto [index, value] : llvm::enumerate(values)) { if (nonLiveSet.contains(value)) { lives.reset(index); - LDBG("Value " << value << " is already marked non-live (dead) at index " - << index); + LDBG() << "Value " << value + << " is already marked non-live (dead) at index " << index; continue; } @@ -161,17 +161,17 @@ static BitVector markLives(ValueRange values, const DenseSet<Value> &nonLiveSet, // (because they weren't erased) and also their liveness is null because // liveness analysis ran before their creation. if (!liveness) { - LDBG("Value " << value << " at index " << index - << " has no liveness info, conservatively considered live"); + LDBG() << "Value " << value << " at index " << index + << " has no liveness info, conservatively considered live"; continue; } if (!liveness->isLive) { lives.reset(index); - LDBG("Value " << value << " at index " << index - << " is dead according to liveness analysis"); + LDBG() << "Value " << value << " at index " << index + << " is dead according to liveness analysis"; } else { - LDBG("Value " << value << " at index " << index - << " is live according to liveness analysis"); + LDBG() << "Value " << value << " at index " << index + << " is live according to liveness analysis"; } } @@ -187,8 +187,8 @@ static void collectNonLiveValues(DenseSet<Value> &nonLiveSet, ValueRange range, if (!nonLive[index]) continue; nonLiveSet.insert(result); - LDBG("Marking value " << result << " as non-live (dead) at index " - << index); + LDBG() << "Marking value " << result << " as non-live (dead) at index " + << index; } } @@ -258,16 +258,18 @@ static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) { static void processSimpleOp(Operation *op, RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet, RDVFinalCleanupList &cl) { - LDBG("Processing simple op: " << *op); + LDBG() << "Processing simple op: " << *op; if (!isMemoryEffectFree(op) || hasLive(op->getResults(), nonLiveSet, la)) { - LDBG("Simple op is not memory effect free or has live results, skipping: " - << *op); + LDBG() + << "Simple op is not memory effect free or has live results, skipping: " + << *op; return; } - LDBG("Simple op has all dead results and is memory effect free, scheduling " - "for removal: " - << *op); + LDBG() + << "Simple op has all dead results and is memory effect free, scheduling " + "for removal: " + << *op; cl.operations.push_back(op); collectNonLiveValues(nonLiveSet, op->getResults(), BitVector(op->getNumResults(), true)); @@ -286,10 +288,10 @@ static void processSimpleOp(Operation *op, RunLivenessAnalysis &la, static void processFuncOp(FunctionOpInterface funcOp, Operation *module, RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet, RDVFinalCleanupList &cl) { - LDBG("Processing function op: " << funcOp.getOperation()->getName()); + LDBG() << "Processing function op: " << funcOp.getOperation()->getName(); if (funcOp.isPublic() || funcOp.isExternal()) { - LDBG("Function is public or external, skipping: " - << funcOp.getOperation()->getName()); + LDBG() << "Function is public or external, skipping: " + << funcOp.getOperation()->getName(); return; } @@ -345,8 +347,6 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module, // since it forwards only to non-live value(s) (%1#1). Operation *lastReturnOp = funcOp.back().getTerminator(); size_t numReturns = lastReturnOp->getNumOperands(); - if (numReturns == 0) - return; BitVector nonLiveRets(numReturns, true); for (SymbolTable::SymbolUse use : uses) { Operation *callOp = use.getUser(); @@ -368,6 +368,8 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module, cl.functions.push_back({funcOp, nonLiveArgs, nonLiveRets}); // Do (5) and (6). + if (numReturns == 0) + return; for (SymbolTable::SymbolUse use : uses) { Operation *callOp = use.getUser(); assert(isa<CallOpInterface>(callOp) && "expected a call-like user"); @@ -409,9 +411,8 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp, RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet, RDVFinalCleanupList &cl) { - LLVM_DEBUG(DBGS() << "Processing region branch op: "; regionBranchOp->print( - llvm::dbgs(), OpPrintingFlags().skipRegions()); - llvm::dbgs() << "\n"); + LDBG() << "Processing region branch op: " + << OpWithFlags(regionBranchOp, OpPrintingFlags().skipRegions()); // Mark live results of `regionBranchOp` in `liveResults`. auto markLiveResults = [&](BitVector &liveResults) { liveResults = markLives(regionBranchOp->getResults(), nonLiveSet, la); @@ -697,7 +698,7 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp, static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet, RDVFinalCleanupList &cl) { - LDBG("Processing branch op: " << *branchOp); + LDBG() << "Processing branch op: " << *branchOp; unsigned numSuccessors = branchOp->getNumSuccessors(); for (unsigned succIdx = 0; succIdx < numSuccessors; ++succIdx) { diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 4c4ce3c..08803e0 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -14,8 +14,10 @@ #include "mlir/IR/Dominance.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/Iterators.h" +#include "mlir/IR/Operation.h" #include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Rewrite/PatternApplicator.h" +#include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/Debug.h" #include "llvm/Support/FormatVariadic.h" @@ -130,11 +132,6 @@ struct ConversionValueMapping { /// value. ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {}) const; - /// Lookup the given value within the map, or return an empty vector if the - /// value is not mapped. If it is mapped, this follows the same behavior - /// as `lookupOrDefault`. - ValueVector lookupOrNull(Value from, TypeRange desiredTypes = {}) const; - template <typename T> struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {}; @@ -237,15 +234,6 @@ ConversionValueMapping::lookupOrDefault(Value from, return !desiredValue.empty() ? std::move(desiredValue) : std::move(current); } -ValueVector ConversionValueMapping::lookupOrNull(Value from, - TypeRange desiredTypes) const { - ValueVector result = lookupOrDefault(from, desiredTypes); - if (result == ValueVector{from} || - (!desiredTypes.empty() && TypeRange(ValueRange(result)) != desiredTypes)) - return {}; - return result; -} - //===----------------------------------------------------------------------===// // Rewriter and Translation State //===----------------------------------------------------------------------===// @@ -521,9 +509,11 @@ private: class MoveBlockRewrite : public BlockRewrite { public: MoveBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block, - Region *region, Block *insertBeforeBlock) - : BlockRewrite(Kind::MoveBlock, rewriterImpl, block), region(region), - insertBeforeBlock(insertBeforeBlock) {} + Region *previousRegion, Region::iterator previousIt) + : BlockRewrite(Kind::MoveBlock, rewriterImpl, block), + region(previousRegion), + insertBeforeBlock(previousIt == previousRegion->end() ? nullptr + : &*previousIt) {} static bool classof(const IRRewrite *rewrite) { return rewrite->getKind() == Kind::MoveBlock; @@ -630,9 +620,12 @@ protected: class MoveOperationRewrite : public OperationRewrite { public: MoveOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl, - Operation *op, Block *block, Operation *insertBeforeOp) - : OperationRewrite(Kind::MoveOperation, rewriterImpl, op), block(block), - insertBeforeOp(insertBeforeOp) {} + Operation *op, OpBuilder::InsertPoint previous) + : OperationRewrite(Kind::MoveOperation, rewriterImpl, op), + block(previous.getBlock()), + insertBeforeOp(previous.getPoint() == previous.getBlock()->end() + ? nullptr + : &*previous.getPoint()) {} static bool classof(const IRRewrite *rewrite) { return rewrite->getKind() == Kind::MoveOperation; @@ -926,6 +919,23 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// Return "true" if the given operation was replaced or erased. bool wasOpReplaced(Operation *op) const; + /// Lookup the most recently mapped values with the desired types in the + /// mapping. + /// + /// Special cases: + /// - If the desired type range is empty, simply return the most recently + /// mapped values. + /// - If there is no mapping to the desired types, also return the most + /// recently mapped values. + /// - If there is no mapping for the given values at all, return the given + /// value. + ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {}) const; + + /// Lookup the given value within the map, or return an empty vector if the + /// value is not mapped. If it is mapped, this follows the same behavior + /// as `lookupOrDefault`. + ValueVector lookupOrNull(Value from, TypeRange desiredTypes = {}) const; + //===--------------------------------------------------------------------===// // IR Rewrites / Type Conversion //===--------------------------------------------------------------------===// @@ -1248,6 +1258,22 @@ void ConversionPatternRewriterImpl::applyRewrites() { // State Management //===----------------------------------------------------------------------===// +ValueVector +ConversionPatternRewriterImpl::lookupOrDefault(Value from, + TypeRange desiredTypes) const { + return mapping.lookupOrDefault(from, desiredTypes); +} + +ValueVector +ConversionPatternRewriterImpl::lookupOrNull(Value from, + TypeRange desiredTypes) const { + ValueVector result = lookupOrDefault(from, desiredTypes); + if (result == ValueVector{from} || + (!desiredTypes.empty() && TypeRange(ValueRange(result)) != desiredTypes)) + return {}; + return result; +} + RewriterState ConversionPatternRewriterImpl::getCurrentState() { return RewriterState(rewrites.size(), ignoredOps.size(), replacedOps.size()); } @@ -1295,7 +1321,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( // The current pattern does not have a type converter. I.e., it does not // distinguish between legal and illegal types. For each operand, simply // pass through the most recently mapped values. - remapped.push_back(mapping.lookupOrDefault(operand)); + remapped.push_back(lookupOrDefault(operand)); continue; } @@ -1314,7 +1340,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( continue; } - ValueVector repl = mapping.lookupOrDefault(operand, legalTypes); + ValueVector repl = lookupOrDefault(operand, legalTypes); if (!repl.empty() && TypeRange(ValueRange(repl)) == legalTypes) { // Mapped values have the correct type or there is an existing // materialization. Or the operand is not mapped at all and has the @@ -1324,7 +1350,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( } // Create a materialization for the most recently mapped values. - repl = mapping.lookupOrDefault(operand); + repl = lookupOrDefault(operand); ValueRange castValues = buildUnresolvedMaterialization( MaterializationKind::Target, computeInsertPoint(repl), operandLoc, /*valuesToMap=*/repl, /*inputs=*/repl, /*outputTypes=*/legalTypes, @@ -1502,7 +1528,7 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization( OpBuilder builder(outputTypes.front().getContext()); builder.setInsertionPoint(ip.getBlock(), ip.getPoint()); auto convertOp = - builder.create<UnrealizedConversionCastOp>(loc, outputTypes, inputs); + UnrealizedConversionCastOp::create(builder, loc, outputTypes, inputs); if (!valuesToMap.empty()) mapping.map(valuesToMap, convertOp.getResults()); if (castOp) @@ -1519,7 +1545,7 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue( // Try to find a replacement value with the same type in the conversion value // mapping. This includes cached materializations. We try to reuse those // instead of generating duplicate IR. - ValueVector repl = mapping.lookupOrNull(value, value.getType()); + ValueVector repl = lookupOrNull(value, value.getType()); if (!repl.empty()) return repl.front(); @@ -1535,7 +1561,7 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue( // No replacement value was found. Get the latest replacement value // (regardless of the type) and build a source materialization to the // original type. - repl = mapping.lookupOrNull(value); + repl = lookupOrNull(value); if (repl.empty()) { // No replacement value is registered in the mapping. This means that the // value is dropped and no longer needed. (If the value were still needed, @@ -1568,23 +1594,30 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue( void ConversionPatternRewriterImpl::notifyOperationInserted( Operation *op, OpBuilder::InsertPoint previous) { + // If no previous insertion point is provided, the op used to be detached. + bool wasDetached = !previous.isSet(); LLVM_DEBUG({ - logger.startLine() << "** Insert : '" << op->getName() << "'(" << op - << ")\n"; + logger.startLine() << "** Insert : '" << op->getName() << "' (" << op + << ")"; + if (wasDetached) + logger.getOStream() << " (was detached)"; + logger.getOStream() << "\n"; }); assert(!wasOpReplaced(op->getParentOp()) && "attempting to insert into a block within a replaced/erased op"); - if (!previous.isSet()) { - // This is a newly created op. + if (wasDetached) { + // If the op was detached, it is most likely a newly created op. + // TODO: If the same op is inserted multiple times from a detached state, + // the rollback mechanism may erase the same op multiple times. This is a + // bug in the rollback-based dialect conversion driver. appendRewrite<CreateOperationRewrite>(op); patternNewOps.insert(op); return; } - Operation *prevOp = previous.getPoint() == previous.getBlock()->end() - ? nullptr - : &*previous.getPoint(); - appendRewrite<MoveOperationRewrite>(op, previous.getBlock(), prevOp); + + // The op was moved from one place to another. + appendRewrite<MoveOperationRewrite>(op, previous); } void ConversionPatternRewriterImpl::replaceOp( @@ -1649,29 +1682,40 @@ void ConversionPatternRewriterImpl::eraseBlock(Block *block) { void ConversionPatternRewriterImpl::notifyBlockInserted( Block *block, Region *previous, Region::iterator previousIt) { - assert(!wasOpReplaced(block->getParentOp()) && - "attempting to insert into a region within a replaced/erased op"); + // If no previous insertion point is provided, the block used to be detached. + bool wasDetached = !previous; + Operation *newParentOp = block->getParentOp(); LLVM_DEBUG( { - Operation *parent = block->getParentOp(); + Operation *parent = newParentOp; if (parent) { logger.startLine() << "** Insert Block into : '" << parent->getName() - << "'(" << parent << ")\n"; + << "' (" << parent << ")"; } else { logger.startLine() - << "** Insert Block into detached Region (nullptr parent op)'\n"; + << "** Insert Block into detached Region (nullptr parent op)"; } + if (wasDetached) + logger.getOStream() << " (was detached)"; + logger.getOStream() << "\n"; }); + assert(!wasOpReplaced(newParentOp) && + "attempting to insert into a region within a replaced/erased op"); + (void)newParentOp; patternInsertedBlocks.insert(block); - if (!previous) { - // This is a newly created block. + if (wasDetached) { + // If the block was detached, it is most likely a newly created block. + // TODO: If the same block is inserted multiple times from a detached state, + // the rollback mechanism may erase the same block multiple times. This is a + // bug in the rollback-based dialect conversion driver. appendRewrite<CreateBlockRewrite>(block); return; } - Block *prevBlock = previousIt == previous->end() ? nullptr : &*previousIt; - appendRewrite<MoveBlockRewrite>(block, previous, prevBlock); + + // The block was moved from one place to another. + appendRewrite<MoveBlockRewrite>(block, previous, previousIt); } void ConversionPatternRewriterImpl::inlineBlockBefore(Block *source, @@ -1716,6 +1760,12 @@ void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) { impl->logger.startLine() << "** Replace : '" << op->getName() << "'(" << op << ")\n"; }); + + // If the current insertion point is before the erased operation, we adjust + // the insertion point to be after the operation. + if (getInsertionPoint() == op->getIterator()) + setInsertionPointAfter(op); + SmallVector<SmallVector<Value>> newVals = llvm::map_to_vector(newValues, [](Value v) -> SmallVector<Value> { return v ? SmallVector<Value>{v} : SmallVector<Value>(); @@ -1731,6 +1781,12 @@ void ConversionPatternRewriter::replaceOpWithMultiple( impl->logger.startLine() << "** Replace : '" << op->getName() << "'(" << op << ")\n"; }); + + // If the current insertion point is before the erased operation, we adjust + // the insertion point to be after the operation. + if (getInsertionPoint() == op->getIterator()) + setInsertionPointAfter(op); + impl->replaceOp(op, std::move(newValues)); } @@ -1739,6 +1795,12 @@ void ConversionPatternRewriter::eraseOp(Operation *op) { impl->logger.startLine() << "** Erase : '" << op->getName() << "'(" << op << ")\n"; }); + + // If the current insertion point is before the erased operation, we adjust + // the insertion point to be after the operation. + if (getInsertionPoint() == op->getIterator()) + setInsertionPointAfter(op); + SmallVector<SmallVector<Value>> nullRepls(op->getNumResults(), {}); impl->replaceOp(op, std::move(nullRepls)); } @@ -1845,6 +1907,11 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest, moveOpBefore(&source->front(), dest, before); } + // If the current insertion point is within the source block, adjust the + // insertion point to the destination block. + if (getInsertionBlock() == source) + setInsertionPoint(dest, getInsertionPoint()); + // Erase the source block. eraseBlock(source); } @@ -1976,6 +2043,7 @@ private: /// Legalize the resultant IR after successfully applying the given pattern. LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern, ConversionPatternRewriter &rewriter, + const RewriterState &curState, const SetVector<Operation *> &newOps, const SetVector<Operation *> &modifiedOps, const SetVector<Block *> &insertedBlocks); @@ -2092,8 +2160,9 @@ OperationLegalizer::legalize(Operation *op, // If the operation has no regions, just print it here. if (!isIgnored && op->getNumRegions() == 0) { - op->print(logger.startLine(), OpPrintingFlags().printGenericOpForm()); - logger.getOStream() << "\n\n"; + logger.startLine() << OpWithFlags(op, + OpPrintingFlags().printGenericOpForm()) + << "\n"; } }); @@ -2172,23 +2241,39 @@ OperationLegalizer::legalizeWithFold(Operation *op, rewriterImpl.logger.startLine() << "* Fold {\n"; rewriterImpl.logger.indent(); }); - (void)rewriterImpl; + + // Clear pattern state, so that the next pattern application starts with a + // clean slate. (The op/block sets are populated by listener notifications.) + auto cleanup = llvm::make_scope_exit([&]() { + rewriterImpl.patternNewOps.clear(); + rewriterImpl.patternModifiedOps.clear(); + rewriterImpl.patternInsertedBlocks.clear(); + }); + + // Upon failure, undo all changes made by the folder. + RewriterState curState = rewriterImpl.getCurrentState(); // Try to fold the operation. StringRef opName = op->getName().getStringRef(); SmallVector<Value, 2> replacementValues; SmallVector<Operation *, 2> newOps; rewriter.setInsertionPoint(op); + rewriter.startOpModification(op); if (failed(rewriter.tryFold(op, replacementValues, &newOps))) { LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold")); + rewriter.cancelOpModification(op); return failure(); } + rewriter.finalizeOpModification(op); // An empty list of replacement values indicates that the fold was in-place. // As the operation changed, a new legalization needs to be attempted. if (replacementValues.empty()) return legalize(op, rewriter); + // Insert a replacement for 'op' with the folded replacement values. + rewriter.replaceOp(op, replacementValues); + // Recursively legalize any new constant operations. for (Operation *newOp : newOps) { if (failed(legalize(newOp, rewriter))) { @@ -2201,16 +2286,12 @@ OperationLegalizer::legalizeWithFold(Operation *op, "op '" + opName + "' folder rollback of IR modifications requested"); } - // Legalization failed: erase all materialized constants. - for (Operation *op : newOps) - rewriter.eraseOp(op); + rewriterImpl.resetState( + curState, std::string(op->getName().getStringRef()) + " folder"); return failure(); } } - // Insert a replacement for 'op' with the folded replacement values. - rewriter.replaceOp(op, replacementValues); - LLVM_DEBUG(logSuccess(rewriterImpl.logger, "")); return success(); } @@ -2220,6 +2301,32 @@ OperationLegalizer::legalizeWithPattern(Operation *op, ConversionPatternRewriter &rewriter) { auto &rewriterImpl = rewriter.getImpl(); +#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + Operation *checkOp; + std::optional<OperationFingerPrint> topLevelFingerPrint; + if (!rewriterImpl.config.allowPatternRollback) { + // The op may be getting erased, so we have to check the parent op. + // (In rare cases, a pattern may even erase the parent op, which will cause + // a crash here. Expensive checks are "best effort".) Skip the check if the + // op does not have a parent op. + if ((checkOp = op->getParentOp())) { + if (!op->getContext()->isMultithreadingEnabled()) { + topLevelFingerPrint = OperationFingerPrint(checkOp); + } else { + // Another thread may be modifying a sibling operation. Therefore, the + // fingerprinting mechanism of the parent op works only in + // single-threaded mode. + LLVM_DEBUG({ + rewriterImpl.logger.startLine() + << "WARNING: Multi-threadeding is enabled. Some dialect " + "conversion expensive checks are skipped in multithreading " + "mode!\n"; + }); + } + } + } +#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + // Functor that returns if the given pattern may be applied. auto canApply = [&](const Pattern &pattern) { bool canApply = canApplyPattern(op, pattern, rewriter); @@ -2232,6 +2339,17 @@ OperationLegalizer::legalizeWithPattern(Operation *op, RewriterState curState = rewriterImpl.getCurrentState(); auto onFailure = [&](const Pattern &pattern) { assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates"); +#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + if (!rewriterImpl.config.allowPatternRollback) { + // Returning "failure" after modifying IR is not allowed. + if (checkOp) { + OperationFingerPrint fingerPrintAfterPattern(checkOp); + if (fingerPrintAfterPattern != *topLevelFingerPrint) + llvm::report_fatal_error("pattern '" + pattern.getDebugName() + + "' returned failure but IR did change"); + } + } +#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS rewriterImpl.patternNewOps.clear(); rewriterImpl.patternModifiedOps.clear(); rewriterImpl.patternInsertedBlocks.clear(); @@ -2260,7 +2378,7 @@ OperationLegalizer::legalizeWithPattern(Operation *op, moveAndReset(rewriterImpl.patternModifiedOps); SetVector<Block *> insertedBlocks = moveAndReset(rewriterImpl.patternInsertedBlocks); - auto result = legalizePatternResult(op, pattern, rewriter, newOps, + auto result = legalizePatternResult(op, pattern, rewriter, curState, newOps, modifiedOps, insertedBlocks); appliedPatterns.erase(&pattern); if (failed(result)) { @@ -2303,7 +2421,7 @@ bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern, LogicalResult OperationLegalizer::legalizePatternResult( Operation *op, const Pattern &pattern, ConversionPatternRewriter &rewriter, - const SetVector<Operation *> &newOps, + const RewriterState &curState, const SetVector<Operation *> &newOps, const SetVector<Operation *> &modifiedOps, const SetVector<Block *> &insertedBlocks) { auto &impl = rewriter.getImpl(); @@ -2319,7 +2437,8 @@ LogicalResult OperationLegalizer::legalizePatternResult( return hasRewrite<ModifyOperationRewrite>(newRewrites, op); }; if (!replacedRoot() && !updatedRootInPlace()) - llvm::report_fatal_error("expected pattern to replace the root operation"); + llvm::report_fatal_error( + "expected pattern to replace the root operation or modify it in place"); #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS // Legalize each of the actions registered during application. diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index b82d850..607b86c 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -22,6 +22,7 @@ #include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/BitVector.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/ScopeExit.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ScopedPrinter.h" #include "llvm/Support/raw_ostream.h" diff --git a/mlir/lib/Transforms/Utils/Inliner.cpp b/mlir/lib/Transforms/Utils/Inliner.cpp index b639e87f..26c965c 100644 --- a/mlir/lib/Transforms/Utils/Inliner.cpp +++ b/mlir/lib/Transforms/Utils/Inliner.cpp @@ -21,7 +21,7 @@ #include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/SCCIterator.h" #include "llvm/ADT/STLExtras.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #define DEBUG_TYPE "inlining" @@ -348,13 +348,11 @@ static void collectCallOps(iterator_range<Region::iterator> blocks, // InlinerInterfaceImpl //===----------------------------------------------------------------------===// -#ifndef NDEBUG static std::string getNodeName(CallOpInterface op) { if (llvm::dyn_cast_if_present<SymbolRefAttr>(op.getCallableForCallee())) return debugString(op); return "_unnamed_callee_"; } -#endif /// Return true if the specified `inlineHistoryID` indicates an inline history /// that already includes `node`. @@ -614,10 +612,10 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface, std::vector<InlineHistoryT> callHistory(calls.size(), InlineHistoryT{}); LLVM_DEBUG({ - llvm::dbgs() << "* Inliner: Initial calls in SCC are: {\n"; + LDBG() << "* Inliner: Initial calls in SCC are: {"; for (unsigned i = 0, e = calls.size(); i < e; ++i) - llvm::dbgs() << " " << i << ". " << calls[i].call << ",\n"; - llvm::dbgs() << "}\n"; + LDBG() << " " << i << ". " << calls[i].call << ","; + LDBG() << "}"; }); // Try to inline each of the call operations. Don't cache the end iterator @@ -635,9 +633,9 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface, CallOpInterface call = it.call; LLVM_DEBUG({ if (doInline) - llvm::dbgs() << "* Inlining call: " << i << ". " << call << "\n"; + LDBG() << "* Inlining call: " << i << ". " << call; else - llvm::dbgs() << "* Not inlining call: " << i << ". " << call << "\n"; + LDBG() << "* Not inlining call: " << i << ". " << call; }); if (!doInline) continue; @@ -654,7 +652,7 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface, cast<CallableOpInterface>(targetRegion->getParentOp()), targetRegion, /*shouldCloneInlinedRegion=*/!inlineInPlace); if (failed(inlineResult)) { - LLVM_DEBUG(llvm::dbgs() << "** Failed to inline\n"); + LDBG() << "** Failed to inline"; continue; } inlinedAnyCalls = true; @@ -667,19 +665,16 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface, auto historyToString = [](InlineHistoryT h) { return h.has_value() ? std::to_string(*h) : "root"; }; - (void)historyToString; - LLVM_DEBUG(llvm::dbgs() - << "* new inlineHistory entry: " << newInlineHistoryID << ". [" - << getNodeName(call) << ", " << historyToString(inlineHistoryID) - << "]\n"); + LDBG() << "* new inlineHistory entry: " << newInlineHistoryID << ". [" + << getNodeName(call) << ", " << historyToString(inlineHistoryID) + << "]"; for (unsigned k = prevSize; k != calls.size(); ++k) { callHistory.push_back(newInlineHistoryID); - LLVM_DEBUG(llvm::dbgs() << "* new call " << k << " {" << calls[i].call - << "}\n with historyID = " << newInlineHistoryID - << ", added due to inlining of\n call {" << call - << "}\n with historyID = " - << historyToString(inlineHistoryID) << "\n"); + LDBG() << "* new call " << k << " {" << calls[k].call + << "}\n with historyID = " << newInlineHistoryID + << ", added due to inlining of\n call {" << call + << "}\n with historyID = " << historyToString(inlineHistoryID); } // If the inlining was successful, Merge the new uses into the source node. diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp index 353f8a5..a1d975d 100644 --- a/mlir/lib/Transforms/Utils/RegionUtils.cpp +++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp @@ -417,7 +417,7 @@ static LogicalResult deleteDeadness(RewriterBase &rewriter, for (Region ®ion : regions) { if (region.empty()) continue; - bool hasSingleBlock = llvm::hasSingleElement(region); + bool hasSingleBlock = region.hasOneBlock(); // Delete every operation that is not live. Graph regions may have cycles // in the use-def graph, so we must explicitly dropAllUses() from each @@ -850,7 +850,7 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) { /// failure otherwise. static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter, Region ®ion) { - if (region.empty() || llvm::hasSingleElement(region)) + if (region.empty() || region.hasOneBlock()) return failure(); // Identify sets of blocks, other than the entry block, that branch to the |