aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Transforms
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Transforms')
-rw-r--r--mlir/lib/Transforms/CMakeLists.txt1
-rw-r--r--mlir/lib/Transforms/CSE.cpp6
-rw-r--r--mlir/lib/Transforms/Canonicalizer.cpp1
-rw-r--r--mlir/lib/Transforms/OpStats.cpp2
-rw-r--r--mlir/lib/Transforms/RemoveDeadValues.cpp65
-rw-r--r--mlir/lib/Transforms/Utils/DialectConversion.cpp229
-rw-r--r--mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp1
-rw-r--r--mlir/lib/Transforms/Utils/Inliner.cpp33
-rw-r--r--mlir/lib/Transforms/Utils/RegionUtils.cpp4
9 files changed, 226 insertions, 116 deletions
diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index 3a8088b..058039e 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -37,5 +37,4 @@ add_mlir_library(MLIRTransforms
MLIRSideEffectInterfaces
MLIRSupport
MLIRTransformUtils
- MLIRUBDialect
)
diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp
index 8e03f62..09e5a02 100644
--- a/mlir/lib/Transforms/CSE.cpp
+++ b/mlir/lib/Transforms/CSE.cpp
@@ -19,7 +19,6 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/DenseMapInfo.h"
-#include "llvm/ADT/Hashing.h"
#include "llvm/ADT/ScopedHashTable.h"
#include "llvm/Support/Allocator.h"
#include "llvm/Support/RecyclingAllocator.h"
@@ -239,9 +238,8 @@ LogicalResult CSEDriver::simplifyOperation(ScopedMapTy &knownValues,
// Don't simplify operations with regions that have multiple blocks.
// TODO: We need additional tests to verify that we handle such IR correctly.
- if (!llvm::all_of(op->getRegions(), [](Region &r) {
- return r.getBlocks().empty() || llvm::hasSingleElement(r.getBlocks());
- }))
+ if (!llvm::all_of(op->getRegions(),
+ [](Region &r) { return r.empty() || r.hasOneBlock(); }))
return failure();
// Some simple use case of operation with memory side-effect are dealt with
diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp
index 4b0ac28..7a99fe8 100644
--- a/mlir/lib/Transforms/Canonicalizer.cpp
+++ b/mlir/lib/Transforms/Canonicalizer.cpp
@@ -13,7 +13,6 @@
#include "mlir/Transforms/Passes.h"
-#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
diff --git a/mlir/lib/Transforms/OpStats.cpp b/mlir/lib/Transforms/OpStats.cpp
index 0dc3fe9..9ebf310 100644
--- a/mlir/lib/Transforms/OpStats.cpp
+++ b/mlir/lib/Transforms/OpStats.cpp
@@ -8,10 +8,8 @@
#include "mlir/Transforms/Passes.h"
-#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/OperationSupport.h"
-#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/Format.h"
#include "llvm/Support/raw_ostream.h"
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 608bdcb..4ccb83f 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -36,6 +36,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Dialect.h"
+#include "mlir/IR/Operation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/Value.h"
@@ -51,6 +52,7 @@
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include <cassert>
#include <cstddef>
#include <memory>
@@ -58,8 +60,6 @@
#include <vector>
#define DEBUG_TYPE "remove-dead-values"
-#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
-#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
namespace mlir {
#define GEN_PASS_DEF_REMOVEDEADVALUES
@@ -119,21 +119,21 @@ static bool hasLive(ValueRange values, const DenseSet<Value> &nonLiveSet,
RunLivenessAnalysis &la) {
for (Value value : values) {
if (nonLiveSet.contains(value)) {
- LDBG("Value " << value << " is already marked non-live (dead)");
+ LDBG() << "Value " << value << " is already marked non-live (dead)";
continue;
}
const Liveness *liveness = la.getLiveness(value);
if (!liveness) {
- LDBG("Value " << value
- << " has no liveness info, conservatively considered live");
+ LDBG() << "Value " << value
+ << " has no liveness info, conservatively considered live";
return true;
}
if (liveness->isLive) {
- LDBG("Value " << value << " is live according to liveness analysis");
+ LDBG() << "Value " << value << " is live according to liveness analysis";
return true;
} else {
- LDBG("Value " << value << " is dead according to liveness analysis");
+ LDBG() << "Value " << value << " is dead according to liveness analysis";
}
}
return false;
@@ -148,8 +148,8 @@ static BitVector markLives(ValueRange values, const DenseSet<Value> &nonLiveSet,
for (auto [index, value] : llvm::enumerate(values)) {
if (nonLiveSet.contains(value)) {
lives.reset(index);
- LDBG("Value " << value << " is already marked non-live (dead) at index "
- << index);
+ LDBG() << "Value " << value
+ << " is already marked non-live (dead) at index " << index;
continue;
}
@@ -161,17 +161,17 @@ static BitVector markLives(ValueRange values, const DenseSet<Value> &nonLiveSet,
// (because they weren't erased) and also their liveness is null because
// liveness analysis ran before their creation.
if (!liveness) {
- LDBG("Value " << value << " at index " << index
- << " has no liveness info, conservatively considered live");
+ LDBG() << "Value " << value << " at index " << index
+ << " has no liveness info, conservatively considered live";
continue;
}
if (!liveness->isLive) {
lives.reset(index);
- LDBG("Value " << value << " at index " << index
- << " is dead according to liveness analysis");
+ LDBG() << "Value " << value << " at index " << index
+ << " is dead according to liveness analysis";
} else {
- LDBG("Value " << value << " at index " << index
- << " is live according to liveness analysis");
+ LDBG() << "Value " << value << " at index " << index
+ << " is live according to liveness analysis";
}
}
@@ -187,8 +187,8 @@ static void collectNonLiveValues(DenseSet<Value> &nonLiveSet, ValueRange range,
if (!nonLive[index])
continue;
nonLiveSet.insert(result);
- LDBG("Marking value " << result << " as non-live (dead) at index "
- << index);
+ LDBG() << "Marking value " << result << " as non-live (dead) at index "
+ << index;
}
}
@@ -258,16 +258,18 @@ static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) {
static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
DenseSet<Value> &nonLiveSet,
RDVFinalCleanupList &cl) {
- LDBG("Processing simple op: " << *op);
+ LDBG() << "Processing simple op: " << *op;
if (!isMemoryEffectFree(op) || hasLive(op->getResults(), nonLiveSet, la)) {
- LDBG("Simple op is not memory effect free or has live results, skipping: "
- << *op);
+ LDBG()
+ << "Simple op is not memory effect free or has live results, skipping: "
+ << *op;
return;
}
- LDBG("Simple op has all dead results and is memory effect free, scheduling "
- "for removal: "
- << *op);
+ LDBG()
+ << "Simple op has all dead results and is memory effect free, scheduling "
+ "for removal: "
+ << *op;
cl.operations.push_back(op);
collectNonLiveValues(nonLiveSet, op->getResults(),
BitVector(op->getNumResults(), true));
@@ -286,10 +288,10 @@ static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet,
RDVFinalCleanupList &cl) {
- LDBG("Processing function op: " << funcOp.getOperation()->getName());
+ LDBG() << "Processing function op: " << funcOp.getOperation()->getName();
if (funcOp.isPublic() || funcOp.isExternal()) {
- LDBG("Function is public or external, skipping: "
- << funcOp.getOperation()->getName());
+ LDBG() << "Function is public or external, skipping: "
+ << funcOp.getOperation()->getName();
return;
}
@@ -345,8 +347,6 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
// since it forwards only to non-live value(s) (%1#1).
Operation *lastReturnOp = funcOp.back().getTerminator();
size_t numReturns = lastReturnOp->getNumOperands();
- if (numReturns == 0)
- return;
BitVector nonLiveRets(numReturns, true);
for (SymbolTable::SymbolUse use : uses) {
Operation *callOp = use.getUser();
@@ -368,6 +368,8 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
cl.functions.push_back({funcOp, nonLiveArgs, nonLiveRets});
// Do (5) and (6).
+ if (numReturns == 0)
+ return;
for (SymbolTable::SymbolUse use : uses) {
Operation *callOp = use.getUser();
assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
@@ -409,9 +411,8 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
RunLivenessAnalysis &la,
DenseSet<Value> &nonLiveSet,
RDVFinalCleanupList &cl) {
- LLVM_DEBUG(DBGS() << "Processing region branch op: "; regionBranchOp->print(
- llvm::dbgs(), OpPrintingFlags().skipRegions());
- llvm::dbgs() << "\n");
+ LDBG() << "Processing region branch op: "
+ << OpWithFlags(regionBranchOp, OpPrintingFlags().skipRegions());
// Mark live results of `regionBranchOp` in `liveResults`.
auto markLiveResults = [&](BitVector &liveResults) {
liveResults = markLives(regionBranchOp->getResults(), nonLiveSet, la);
@@ -697,7 +698,7 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,
DenseSet<Value> &nonLiveSet,
RDVFinalCleanupList &cl) {
- LDBG("Processing branch op: " << *branchOp);
+ LDBG() << "Processing branch op: " << *branchOp;
unsigned numSuccessors = branchOp->getNumSuccessors();
for (unsigned succIdx = 0; succIdx < numSuccessors; ++succIdx) {
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 4c4ce3c..08803e0 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -14,8 +14,10 @@
#include "mlir/IR/Dominance.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Iterators.h"
+#include "mlir/IR/Operation.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Rewrite/PatternApplicator.h"
+#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/FormatVariadic.h"
@@ -130,11 +132,6 @@ struct ConversionValueMapping {
/// value.
ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {}) const;
- /// Lookup the given value within the map, or return an empty vector if the
- /// value is not mapped. If it is mapped, this follows the same behavior
- /// as `lookupOrDefault`.
- ValueVector lookupOrNull(Value from, TypeRange desiredTypes = {}) const;
-
template <typename T>
struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {};
@@ -237,15 +234,6 @@ ConversionValueMapping::lookupOrDefault(Value from,
return !desiredValue.empty() ? std::move(desiredValue) : std::move(current);
}
-ValueVector ConversionValueMapping::lookupOrNull(Value from,
- TypeRange desiredTypes) const {
- ValueVector result = lookupOrDefault(from, desiredTypes);
- if (result == ValueVector{from} ||
- (!desiredTypes.empty() && TypeRange(ValueRange(result)) != desiredTypes))
- return {};
- return result;
-}
-
//===----------------------------------------------------------------------===//
// Rewriter and Translation State
//===----------------------------------------------------------------------===//
@@ -521,9 +509,11 @@ private:
class MoveBlockRewrite : public BlockRewrite {
public:
MoveBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block,
- Region *region, Block *insertBeforeBlock)
- : BlockRewrite(Kind::MoveBlock, rewriterImpl, block), region(region),
- insertBeforeBlock(insertBeforeBlock) {}
+ Region *previousRegion, Region::iterator previousIt)
+ : BlockRewrite(Kind::MoveBlock, rewriterImpl, block),
+ region(previousRegion),
+ insertBeforeBlock(previousIt == previousRegion->end() ? nullptr
+ : &*previousIt) {}
static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() == Kind::MoveBlock;
@@ -630,9 +620,12 @@ protected:
class MoveOperationRewrite : public OperationRewrite {
public:
MoveOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
- Operation *op, Block *block, Operation *insertBeforeOp)
- : OperationRewrite(Kind::MoveOperation, rewriterImpl, op), block(block),
- insertBeforeOp(insertBeforeOp) {}
+ Operation *op, OpBuilder::InsertPoint previous)
+ : OperationRewrite(Kind::MoveOperation, rewriterImpl, op),
+ block(previous.getBlock()),
+ insertBeforeOp(previous.getPoint() == previous.getBlock()->end()
+ ? nullptr
+ : &*previous.getPoint()) {}
static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() == Kind::MoveOperation;
@@ -926,6 +919,23 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// Return "true" if the given operation was replaced or erased.
bool wasOpReplaced(Operation *op) const;
+ /// Lookup the most recently mapped values with the desired types in the
+ /// mapping.
+ ///
+ /// Special cases:
+ /// - If the desired type range is empty, simply return the most recently
+ /// mapped values.
+ /// - If there is no mapping to the desired types, also return the most
+ /// recently mapped values.
+ /// - If there is no mapping for the given values at all, return the given
+ /// value.
+ ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {}) const;
+
+ /// Lookup the given value within the map, or return an empty vector if the
+ /// value is not mapped. If it is mapped, this follows the same behavior
+ /// as `lookupOrDefault`.
+ ValueVector lookupOrNull(Value from, TypeRange desiredTypes = {}) const;
+
//===--------------------------------------------------------------------===//
// IR Rewrites / Type Conversion
//===--------------------------------------------------------------------===//
@@ -1248,6 +1258,22 @@ void ConversionPatternRewriterImpl::applyRewrites() {
// State Management
//===----------------------------------------------------------------------===//
+ValueVector
+ConversionPatternRewriterImpl::lookupOrDefault(Value from,
+ TypeRange desiredTypes) const {
+ return mapping.lookupOrDefault(from, desiredTypes);
+}
+
+ValueVector
+ConversionPatternRewriterImpl::lookupOrNull(Value from,
+ TypeRange desiredTypes) const {
+ ValueVector result = lookupOrDefault(from, desiredTypes);
+ if (result == ValueVector{from} ||
+ (!desiredTypes.empty() && TypeRange(ValueRange(result)) != desiredTypes))
+ return {};
+ return result;
+}
+
RewriterState ConversionPatternRewriterImpl::getCurrentState() {
return RewriterState(rewrites.size(), ignoredOps.size(), replacedOps.size());
}
@@ -1295,7 +1321,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
// The current pattern does not have a type converter. I.e., it does not
// distinguish between legal and illegal types. For each operand, simply
// pass through the most recently mapped values.
- remapped.push_back(mapping.lookupOrDefault(operand));
+ remapped.push_back(lookupOrDefault(operand));
continue;
}
@@ -1314,7 +1340,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
continue;
}
- ValueVector repl = mapping.lookupOrDefault(operand, legalTypes);
+ ValueVector repl = lookupOrDefault(operand, legalTypes);
if (!repl.empty() && TypeRange(ValueRange(repl)) == legalTypes) {
// Mapped values have the correct type or there is an existing
// materialization. Or the operand is not mapped at all and has the
@@ -1324,7 +1350,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
}
// Create a materialization for the most recently mapped values.
- repl = mapping.lookupOrDefault(operand);
+ repl = lookupOrDefault(operand);
ValueRange castValues = buildUnresolvedMaterialization(
MaterializationKind::Target, computeInsertPoint(repl), operandLoc,
/*valuesToMap=*/repl, /*inputs=*/repl, /*outputTypes=*/legalTypes,
@@ -1502,7 +1528,7 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
OpBuilder builder(outputTypes.front().getContext());
builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
auto convertOp =
- builder.create<UnrealizedConversionCastOp>(loc, outputTypes, inputs);
+ UnrealizedConversionCastOp::create(builder, loc, outputTypes, inputs);
if (!valuesToMap.empty())
mapping.map(valuesToMap, convertOp.getResults());
if (castOp)
@@ -1519,7 +1545,7 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
// Try to find a replacement value with the same type in the conversion value
// mapping. This includes cached materializations. We try to reuse those
// instead of generating duplicate IR.
- ValueVector repl = mapping.lookupOrNull(value, value.getType());
+ ValueVector repl = lookupOrNull(value, value.getType());
if (!repl.empty())
return repl.front();
@@ -1535,7 +1561,7 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
// No replacement value was found. Get the latest replacement value
// (regardless of the type) and build a source materialization to the
// original type.
- repl = mapping.lookupOrNull(value);
+ repl = lookupOrNull(value);
if (repl.empty()) {
// No replacement value is registered in the mapping. This means that the
// value is dropped and no longer needed. (If the value were still needed,
@@ -1568,23 +1594,30 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
void ConversionPatternRewriterImpl::notifyOperationInserted(
Operation *op, OpBuilder::InsertPoint previous) {
+ // If no previous insertion point is provided, the op used to be detached.
+ bool wasDetached = !previous.isSet();
LLVM_DEBUG({
- logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
- << ")\n";
+ logger.startLine() << "** Insert : '" << op->getName() << "' (" << op
+ << ")";
+ if (wasDetached)
+ logger.getOStream() << " (was detached)";
+ logger.getOStream() << "\n";
});
assert(!wasOpReplaced(op->getParentOp()) &&
"attempting to insert into a block within a replaced/erased op");
- if (!previous.isSet()) {
- // This is a newly created op.
+ if (wasDetached) {
+ // If the op was detached, it is most likely a newly created op.
+ // TODO: If the same op is inserted multiple times from a detached state,
+ // the rollback mechanism may erase the same op multiple times. This is a
+ // bug in the rollback-based dialect conversion driver.
appendRewrite<CreateOperationRewrite>(op);
patternNewOps.insert(op);
return;
}
- Operation *prevOp = previous.getPoint() == previous.getBlock()->end()
- ? nullptr
- : &*previous.getPoint();
- appendRewrite<MoveOperationRewrite>(op, previous.getBlock(), prevOp);
+
+ // The op was moved from one place to another.
+ appendRewrite<MoveOperationRewrite>(op, previous);
}
void ConversionPatternRewriterImpl::replaceOp(
@@ -1649,29 +1682,40 @@ void ConversionPatternRewriterImpl::eraseBlock(Block *block) {
void ConversionPatternRewriterImpl::notifyBlockInserted(
Block *block, Region *previous, Region::iterator previousIt) {
- assert(!wasOpReplaced(block->getParentOp()) &&
- "attempting to insert into a region within a replaced/erased op");
+ // If no previous insertion point is provided, the block used to be detached.
+ bool wasDetached = !previous;
+ Operation *newParentOp = block->getParentOp();
LLVM_DEBUG(
{
- Operation *parent = block->getParentOp();
+ Operation *parent = newParentOp;
if (parent) {
logger.startLine() << "** Insert Block into : '" << parent->getName()
- << "'(" << parent << ")\n";
+ << "' (" << parent << ")";
} else {
logger.startLine()
- << "** Insert Block into detached Region (nullptr parent op)'\n";
+ << "** Insert Block into detached Region (nullptr parent op)";
}
+ if (wasDetached)
+ logger.getOStream() << " (was detached)";
+ logger.getOStream() << "\n";
});
+ assert(!wasOpReplaced(newParentOp) &&
+ "attempting to insert into a region within a replaced/erased op");
+ (void)newParentOp;
patternInsertedBlocks.insert(block);
- if (!previous) {
- // This is a newly created block.
+ if (wasDetached) {
+ // If the block was detached, it is most likely a newly created block.
+ // TODO: If the same block is inserted multiple times from a detached state,
+ // the rollback mechanism may erase the same block multiple times. This is a
+ // bug in the rollback-based dialect conversion driver.
appendRewrite<CreateBlockRewrite>(block);
return;
}
- Block *prevBlock = previousIt == previous->end() ? nullptr : &*previousIt;
- appendRewrite<MoveBlockRewrite>(block, previous, prevBlock);
+
+ // The block was moved from one place to another.
+ appendRewrite<MoveBlockRewrite>(block, previous, previousIt);
}
void ConversionPatternRewriterImpl::inlineBlockBefore(Block *source,
@@ -1716,6 +1760,12 @@ void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
impl->logger.startLine()
<< "** Replace : '" << op->getName() << "'(" << op << ")\n";
});
+
+ // If the current insertion point is before the erased operation, we adjust
+ // the insertion point to be after the operation.
+ if (getInsertionPoint() == op->getIterator())
+ setInsertionPointAfter(op);
+
SmallVector<SmallVector<Value>> newVals =
llvm::map_to_vector(newValues, [](Value v) -> SmallVector<Value> {
return v ? SmallVector<Value>{v} : SmallVector<Value>();
@@ -1731,6 +1781,12 @@ void ConversionPatternRewriter::replaceOpWithMultiple(
impl->logger.startLine()
<< "** Replace : '" << op->getName() << "'(" << op << ")\n";
});
+
+ // If the current insertion point is before the erased operation, we adjust
+ // the insertion point to be after the operation.
+ if (getInsertionPoint() == op->getIterator())
+ setInsertionPointAfter(op);
+
impl->replaceOp(op, std::move(newValues));
}
@@ -1739,6 +1795,12 @@ void ConversionPatternRewriter::eraseOp(Operation *op) {
impl->logger.startLine()
<< "** Erase : '" << op->getName() << "'(" << op << ")\n";
});
+
+ // If the current insertion point is before the erased operation, we adjust
+ // the insertion point to be after the operation.
+ if (getInsertionPoint() == op->getIterator())
+ setInsertionPointAfter(op);
+
SmallVector<SmallVector<Value>> nullRepls(op->getNumResults(), {});
impl->replaceOp(op, std::move(nullRepls));
}
@@ -1845,6 +1907,11 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
moveOpBefore(&source->front(), dest, before);
}
+ // If the current insertion point is within the source block, adjust the
+ // insertion point to the destination block.
+ if (getInsertionBlock() == source)
+ setInsertionPoint(dest, getInsertionPoint());
+
// Erase the source block.
eraseBlock(source);
}
@@ -1976,6 +2043,7 @@ private:
/// Legalize the resultant IR after successfully applying the given pattern.
LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern,
ConversionPatternRewriter &rewriter,
+ const RewriterState &curState,
const SetVector<Operation *> &newOps,
const SetVector<Operation *> &modifiedOps,
const SetVector<Block *> &insertedBlocks);
@@ -2092,8 +2160,9 @@ OperationLegalizer::legalize(Operation *op,
// If the operation has no regions, just print it here.
if (!isIgnored && op->getNumRegions() == 0) {
- op->print(logger.startLine(), OpPrintingFlags().printGenericOpForm());
- logger.getOStream() << "\n\n";
+ logger.startLine() << OpWithFlags(op,
+ OpPrintingFlags().printGenericOpForm())
+ << "\n";
}
});
@@ -2172,23 +2241,39 @@ OperationLegalizer::legalizeWithFold(Operation *op,
rewriterImpl.logger.startLine() << "* Fold {\n";
rewriterImpl.logger.indent();
});
- (void)rewriterImpl;
+
+ // Clear pattern state, so that the next pattern application starts with a
+ // clean slate. (The op/block sets are populated by listener notifications.)
+ auto cleanup = llvm::make_scope_exit([&]() {
+ rewriterImpl.patternNewOps.clear();
+ rewriterImpl.patternModifiedOps.clear();
+ rewriterImpl.patternInsertedBlocks.clear();
+ });
+
+ // Upon failure, undo all changes made by the folder.
+ RewriterState curState = rewriterImpl.getCurrentState();
// Try to fold the operation.
StringRef opName = op->getName().getStringRef();
SmallVector<Value, 2> replacementValues;
SmallVector<Operation *, 2> newOps;
rewriter.setInsertionPoint(op);
+ rewriter.startOpModification(op);
if (failed(rewriter.tryFold(op, replacementValues, &newOps))) {
LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
+ rewriter.cancelOpModification(op);
return failure();
}
+ rewriter.finalizeOpModification(op);
// An empty list of replacement values indicates that the fold was in-place.
// As the operation changed, a new legalization needs to be attempted.
if (replacementValues.empty())
return legalize(op, rewriter);
+ // Insert a replacement for 'op' with the folded replacement values.
+ rewriter.replaceOp(op, replacementValues);
+
// Recursively legalize any new constant operations.
for (Operation *newOp : newOps) {
if (failed(legalize(newOp, rewriter))) {
@@ -2201,16 +2286,12 @@ OperationLegalizer::legalizeWithFold(Operation *op,
"op '" + opName +
"' folder rollback of IR modifications requested");
}
- // Legalization failed: erase all materialized constants.
- for (Operation *op : newOps)
- rewriter.eraseOp(op);
+ rewriterImpl.resetState(
+ curState, std::string(op->getName().getStringRef()) + " folder");
return failure();
}
}
- // Insert a replacement for 'op' with the folded replacement values.
- rewriter.replaceOp(op, replacementValues);
-
LLVM_DEBUG(logSuccess(rewriterImpl.logger, ""));
return success();
}
@@ -2220,6 +2301,32 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
ConversionPatternRewriter &rewriter) {
auto &rewriterImpl = rewriter.getImpl();
+#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+ Operation *checkOp;
+ std::optional<OperationFingerPrint> topLevelFingerPrint;
+ if (!rewriterImpl.config.allowPatternRollback) {
+ // The op may be getting erased, so we have to check the parent op.
+ // (In rare cases, a pattern may even erase the parent op, which will cause
+ // a crash here. Expensive checks are "best effort".) Skip the check if the
+ // op does not have a parent op.
+ if ((checkOp = op->getParentOp())) {
+ if (!op->getContext()->isMultithreadingEnabled()) {
+ topLevelFingerPrint = OperationFingerPrint(checkOp);
+ } else {
+ // Another thread may be modifying a sibling operation. Therefore, the
+ // fingerprinting mechanism of the parent op works only in
+ // single-threaded mode.
+ LLVM_DEBUG({
+ rewriterImpl.logger.startLine()
+ << "WARNING: Multi-threadeding is enabled. Some dialect "
+ "conversion expensive checks are skipped in multithreading "
+ "mode!\n";
+ });
+ }
+ }
+ }
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+
// Functor that returns if the given pattern may be applied.
auto canApply = [&](const Pattern &pattern) {
bool canApply = canApplyPattern(op, pattern, rewriter);
@@ -2232,6 +2339,17 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
RewriterState curState = rewriterImpl.getCurrentState();
auto onFailure = [&](const Pattern &pattern) {
assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
+#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+ if (!rewriterImpl.config.allowPatternRollback) {
+ // Returning "failure" after modifying IR is not allowed.
+ if (checkOp) {
+ OperationFingerPrint fingerPrintAfterPattern(checkOp);
+ if (fingerPrintAfterPattern != *topLevelFingerPrint)
+ llvm::report_fatal_error("pattern '" + pattern.getDebugName() +
+ "' returned failure but IR did change");
+ }
+ }
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
rewriterImpl.patternNewOps.clear();
rewriterImpl.patternModifiedOps.clear();
rewriterImpl.patternInsertedBlocks.clear();
@@ -2260,7 +2378,7 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
moveAndReset(rewriterImpl.patternModifiedOps);
SetVector<Block *> insertedBlocks =
moveAndReset(rewriterImpl.patternInsertedBlocks);
- auto result = legalizePatternResult(op, pattern, rewriter, newOps,
+ auto result = legalizePatternResult(op, pattern, rewriter, curState, newOps,
modifiedOps, insertedBlocks);
appliedPatterns.erase(&pattern);
if (failed(result)) {
@@ -2303,7 +2421,7 @@ bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern,
LogicalResult OperationLegalizer::legalizePatternResult(
Operation *op, const Pattern &pattern, ConversionPatternRewriter &rewriter,
- const SetVector<Operation *> &newOps,
+ const RewriterState &curState, const SetVector<Operation *> &newOps,
const SetVector<Operation *> &modifiedOps,
const SetVector<Block *> &insertedBlocks) {
auto &impl = rewriter.getImpl();
@@ -2319,7 +2437,8 @@ LogicalResult OperationLegalizer::legalizePatternResult(
return hasRewrite<ModifyOperationRewrite>(newRewrites, op);
};
if (!replacedRoot() && !updatedRootInPlace())
- llvm::report_fatal_error("expected pattern to replace the root operation");
+ llvm::report_fatal_error(
+ "expected pattern to replace the root operation or modify it in place");
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
// Legalize each of the actions registered during application.
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index b82d850..607b86c 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -22,6 +22,7 @@
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ScopedPrinter.h"
#include "llvm/Support/raw_ostream.h"
diff --git a/mlir/lib/Transforms/Utils/Inliner.cpp b/mlir/lib/Transforms/Utils/Inliner.cpp
index b639e87f..26c965c 100644
--- a/mlir/lib/Transforms/Utils/Inliner.cpp
+++ b/mlir/lib/Transforms/Utils/Inliner.cpp
@@ -21,7 +21,7 @@
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/SCCIterator.h"
#include "llvm/ADT/STLExtras.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#define DEBUG_TYPE "inlining"
@@ -348,13 +348,11 @@ static void collectCallOps(iterator_range<Region::iterator> blocks,
// InlinerInterfaceImpl
//===----------------------------------------------------------------------===//
-#ifndef NDEBUG
static std::string getNodeName(CallOpInterface op) {
if (llvm::dyn_cast_if_present<SymbolRefAttr>(op.getCallableForCallee()))
return debugString(op);
return "_unnamed_callee_";
}
-#endif
/// Return true if the specified `inlineHistoryID` indicates an inline history
/// that already includes `node`.
@@ -614,10 +612,10 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
std::vector<InlineHistoryT> callHistory(calls.size(), InlineHistoryT{});
LLVM_DEBUG({
- llvm::dbgs() << "* Inliner: Initial calls in SCC are: {\n";
+ LDBG() << "* Inliner: Initial calls in SCC are: {";
for (unsigned i = 0, e = calls.size(); i < e; ++i)
- llvm::dbgs() << " " << i << ". " << calls[i].call << ",\n";
- llvm::dbgs() << "}\n";
+ LDBG() << " " << i << ". " << calls[i].call << ",";
+ LDBG() << "}";
});
// Try to inline each of the call operations. Don't cache the end iterator
@@ -635,9 +633,9 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
CallOpInterface call = it.call;
LLVM_DEBUG({
if (doInline)
- llvm::dbgs() << "* Inlining call: " << i << ". " << call << "\n";
+ LDBG() << "* Inlining call: " << i << ". " << call;
else
- llvm::dbgs() << "* Not inlining call: " << i << ". " << call << "\n";
+ LDBG() << "* Not inlining call: " << i << ". " << call;
});
if (!doInline)
continue;
@@ -654,7 +652,7 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
cast<CallableOpInterface>(targetRegion->getParentOp()),
targetRegion, /*shouldCloneInlinedRegion=*/!inlineInPlace);
if (failed(inlineResult)) {
- LLVM_DEBUG(llvm::dbgs() << "** Failed to inline\n");
+ LDBG() << "** Failed to inline";
continue;
}
inlinedAnyCalls = true;
@@ -667,19 +665,16 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
auto historyToString = [](InlineHistoryT h) {
return h.has_value() ? std::to_string(*h) : "root";
};
- (void)historyToString;
- LLVM_DEBUG(llvm::dbgs()
- << "* new inlineHistory entry: " << newInlineHistoryID << ". ["
- << getNodeName(call) << ", " << historyToString(inlineHistoryID)
- << "]\n");
+ LDBG() << "* new inlineHistory entry: " << newInlineHistoryID << ". ["
+ << getNodeName(call) << ", " << historyToString(inlineHistoryID)
+ << "]";
for (unsigned k = prevSize; k != calls.size(); ++k) {
callHistory.push_back(newInlineHistoryID);
- LLVM_DEBUG(llvm::dbgs() << "* new call " << k << " {" << calls[i].call
- << "}\n with historyID = " << newInlineHistoryID
- << ", added due to inlining of\n call {" << call
- << "}\n with historyID = "
- << historyToString(inlineHistoryID) << "\n");
+ LDBG() << "* new call " << k << " {" << calls[k].call
+ << "}\n with historyID = " << newInlineHistoryID
+ << ", added due to inlining of\n call {" << call
+ << "}\n with historyID = " << historyToString(inlineHistoryID);
}
// If the inlining was successful, Merge the new uses into the source node.
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index 353f8a5..a1d975d 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -417,7 +417,7 @@ static LogicalResult deleteDeadness(RewriterBase &rewriter,
for (Region &region : regions) {
if (region.empty())
continue;
- bool hasSingleBlock = llvm::hasSingleElement(region);
+ bool hasSingleBlock = region.hasOneBlock();
// Delete every operation that is not live. Graph regions may have cycles
// in the use-def graph, so we must explicitly dropAllUses() from each
@@ -850,7 +850,7 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
/// failure otherwise.
static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter,
Region &region) {
- if (region.empty() || llvm::hasSingleElement(region))
+ if (region.empty() || region.hasOneBlock())
return failure();
// Identify sets of blocks, other than the entry block, that branch to the