aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Transforms
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Transforms')
-rw-r--r--mlir/lib/Transforms/CSE.cpp12
-rw-r--r--mlir/lib/Transforms/InlinerPass.cpp7
-rw-r--r--mlir/lib/Transforms/Mem2Reg.cpp4
-rw-r--r--mlir/lib/Transforms/RemoveDeadValues.cpp49
-rw-r--r--mlir/lib/Transforms/SROA.cpp4
-rw-r--r--mlir/lib/Transforms/SymbolDCE.cpp30
-rw-r--r--mlir/lib/Transforms/Utils/ControlFlowSinkUtils.cpp9
-rw-r--r--mlir/lib/Transforms/Utils/DialectConversion.cpp926
-rw-r--r--mlir/lib/Transforms/Utils/FoldUtils.cpp7
-rw-r--r--mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp42
-rw-r--r--mlir/lib/Transforms/Utils/InliningUtils.cpp13
-rw-r--r--mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp12
-rw-r--r--mlir/lib/Transforms/Utils/RegionUtils.cpp33
-rw-r--r--mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp126
14 files changed, 938 insertions, 336 deletions
diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp
index 09e5a02..8eaac30 100644
--- a/mlir/lib/Transforms/CSE.cpp
+++ b/mlir/lib/Transforms/CSE.cpp
@@ -177,11 +177,10 @@ void CSEDriver::replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op,
bool CSEDriver::hasOtherSideEffectingOpInBetween(Operation *fromOp,
Operation *toOp) {
assert(fromOp->getBlock() == toOp->getBlock());
- assert(
- isa<MemoryEffectOpInterface>(fromOp) &&
- cast<MemoryEffectOpInterface>(fromOp).hasEffect<MemoryEffects::Read>() &&
- isa<MemoryEffectOpInterface>(toOp) &&
- cast<MemoryEffectOpInterface>(toOp).hasEffect<MemoryEffects::Read>());
+ assert(hasEffect<MemoryEffects::Read>(fromOp) &&
+ "expected read effect on fromOp");
+ assert(hasEffect<MemoryEffects::Read>(toOp) &&
+ "expected read effect on toOp");
Operation *nextOp = fromOp->getNextNode();
auto result =
memEffectsCache.try_emplace(fromOp, std::make_pair(fromOp, nullptr));
@@ -245,11 +244,10 @@ LogicalResult CSEDriver::simplifyOperation(ScopedMapTy &knownValues,
// Some simple use case of operation with memory side-effect are dealt with
// here. Operations with no side-effect are done after.
if (!isMemoryEffectFree(op)) {
- auto memEffects = dyn_cast<MemoryEffectOpInterface>(op);
// TODO: Only basic use case for operations with MemoryEffects::Read can be
// eleminated now. More work needs to be done for more complicated patterns
// and other side-effects.
- if (!memEffects || !memEffects.onlyHasEffect<MemoryEffects::Read>())
+ if (!hasSingleEffect<MemoryEffects::Read>(op))
return failure();
// Look for an existing definition for the operation.
diff --git a/mlir/lib/Transforms/InlinerPass.cpp b/mlir/lib/Transforms/InlinerPass.cpp
index 703e517..77a9e6c 100644
--- a/mlir/lib/Transforms/InlinerPass.cpp
+++ b/mlir/lib/Transforms/InlinerPass.cpp
@@ -18,6 +18,7 @@
#include "mlir/Analysis/CallGraph.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Inliner.h"
+#include "llvm/Support/DebugLog.h"
namespace mlir {
#define GEN_PASS_DEF_INLINER
@@ -120,8 +121,8 @@ static bool isProfitableToInline(const Inliner::ResolvedCall &resolvedCall,
return true;
unsigned ratio = countOps(calleeRegion) * 100 / callerOps;
- LLVM_DEBUG(llvm::dbgs() << "Callee / caller operation ratio (max: "
- << inliningThreshold << "%): " << ratio << "%\n");
+ LDBG() << "Callee / caller operation ratio (max: " << inliningThreshold
+ << "%): " << ratio << "%";
return ratio <= inliningThreshold;
}
@@ -138,7 +139,7 @@ void InlinerPass::runOnOperation() {
}
// By default, assume that any inlining is profitable.
- auto profitabilityCb = [=](const Inliner::ResolvedCall &call) {
+ auto profitabilityCb = [this](const Inliner::ResolvedCall &call) {
return isProfitableToInline(call, inliningThreshold);
};
diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index cf039c3..d36a3c1 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -19,6 +19,7 @@
#include "mlir/Interfaces/MemorySlotInterfaces.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/GenericIteratedDominanceFrontier.h"
namespace mlir {
@@ -632,8 +633,7 @@ MemorySlotPromoter::promoteSlot() {
}
}
- LLVM_DEBUG(llvm::dbgs() << "[mem2reg] Promoted memory slot: " << slot.ptr
- << "\n");
+ LDBG() << "Promoted memory slot: " << slot.ptr;
if (statistics.promotedAmount)
(*statistics.promotedAmount)++;
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 4ccb83f..0e84b6d 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -258,18 +258,17 @@ static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) {
static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
DenseSet<Value> &nonLiveSet,
RDVFinalCleanupList &cl) {
- LDBG() << "Processing simple op: " << *op;
if (!isMemoryEffectFree(op) || hasLive(op->getResults(), nonLiveSet, la)) {
- LDBG()
- << "Simple op is not memory effect free or has live results, skipping: "
- << *op;
+ LDBG() << "Simple op is not memory effect free or has live results, "
+ "preserving it: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
return;
}
LDBG()
<< "Simple op has all dead results and is memory effect free, scheduling "
"for removal: "
- << *op;
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
cl.operations.push_back(op);
collectNonLiveValues(nonLiveSet, op->getResults(),
BitVector(op->getNumResults(), true));
@@ -345,8 +344,7 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
// being returned, in order to optimize our IR. So, this demonstrates how we
// can make our optimization strong by even removing a live return value (%0),
// since it forwards only to non-live value(s) (%1#1).
- Operation *lastReturnOp = funcOp.back().getTerminator();
- size_t numReturns = lastReturnOp->getNumOperands();
+ size_t numReturns = funcOp.getNumResults();
BitVector nonLiveRets(numReturns, true);
for (SymbolTable::SymbolUse use : uses) {
Operation *callOp = use.getUser();
@@ -728,19 +726,31 @@ static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,
/// Removes dead values collected in RDVFinalCleanupList.
/// To be run once when all dead values have been collected.
static void cleanUpDeadVals(RDVFinalCleanupList &list) {
+ LDBG() << "Starting cleanup of dead values...";
+
// 1. Operations
+ LDBG() << "Cleaning up " << list.operations.size() << " operations";
for (auto &op : list.operations) {
+ LDBG() << "Erasing operation: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
op->dropAllUses();
op->erase();
}
// 2. Values
+ LDBG() << "Cleaning up " << list.values.size() << " values";
for (auto &v : list.values) {
+ LDBG() << "Dropping all uses of value: " << v;
v.dropAllUses();
}
// 3. Functions
+ LDBG() << "Cleaning up " << list.functions.size() << " functions";
for (auto &f : list.functions) {
+ LDBG() << "Cleaning up function: " << f.funcOp.getOperation()->getName();
+ LDBG() << " Erasing " << f.nonLiveArgs.count() << " non-live arguments";
+ LDBG() << " Erasing " << f.nonLiveRets.count()
+ << " non-live return values";
// Some functions may not allow erasing arguments or results. These calls
// return failure in such cases without modifying the function, so it's okay
// to proceed.
@@ -749,44 +759,67 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
}
// 4. Operands
+ LDBG() << "Cleaning up " << list.operands.size() << " operand lists";
for (OperationToCleanup &o : list.operands) {
- if (o.op->getNumOperands() > 0)
+ if (o.op->getNumOperands() > 0) {
+ LDBG() << "Erasing " << o.nonLive.count()
+ << " non-live operands from operation: "
+ << OpWithFlags(o.op, OpPrintingFlags().skipRegions());
o.op->eraseOperands(o.nonLive);
+ }
}
// 5. Results
+ LDBG() << "Cleaning up " << list.results.size() << " result lists";
for (auto &r : list.results) {
+ LDBG() << "Erasing " << r.nonLive.count()
+ << " non-live results from operation: "
+ << OpWithFlags(r.op, OpPrintingFlags().skipRegions());
dropUsesAndEraseResults(r.op, r.nonLive);
}
// 6. Blocks
+ LDBG() << "Cleaning up " << list.blocks.size() << " block argument lists";
for (auto &b : list.blocks) {
// blocks that are accessed via multiple codepaths processed once
if (b.b->getNumArguments() != b.nonLiveArgs.size())
continue;
+ LDBG() << "Erasing " << b.nonLiveArgs.count()
+ << " non-live arguments from block: " << b.b;
// it iterates backwards because erase invalidates all successor indexes
for (int i = b.nonLiveArgs.size() - 1; i >= 0; --i) {
if (!b.nonLiveArgs[i])
continue;
+ LDBG() << " Erasing block argument " << i << ": " << b.b->getArgument(i);
b.b->getArgument(i).dropAllUses();
b.b->eraseArgument(i);
}
}
// 7. Successor Operands
+ LDBG() << "Cleaning up " << list.successorOperands.size()
+ << " successor operand lists";
for (auto &op : list.successorOperands) {
SuccessorOperands successorOperands =
op.branch.getSuccessorOperands(op.successorIndex);
// blocks that are accessed via multiple codepaths processed once
if (successorOperands.size() != op.nonLiveOperands.size())
continue;
+ LDBG() << "Erasing " << op.nonLiveOperands.count()
+ << " non-live successor operands from successor "
+ << op.successorIndex << " of branch: "
+ << OpWithFlags(op.branch, OpPrintingFlags().skipRegions());
// it iterates backwards because erase invalidates all successor indexes
for (int i = successorOperands.size() - 1; i >= 0; --i) {
if (!op.nonLiveOperands[i])
continue;
+ LDBG() << " Erasing successor operand " << i << ": "
+ << successorOperands[i];
successorOperands.erase(i);
}
}
+
+ LDBG() << "Finished cleanup of dead values";
}
struct RemoveDeadValues : public impl::RemoveDeadValuesBase<RemoveDeadValues> {
diff --git a/mlir/lib/Transforms/SROA.cpp b/mlir/lib/Transforms/SROA.cpp
index 67f536a..859c030 100644
--- a/mlir/lib/Transforms/SROA.cpp
+++ b/mlir/lib/Transforms/SROA.cpp
@@ -12,6 +12,7 @@
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Interfaces/MemorySlotInterfaces.h"
#include "mlir/Transforms/Passes.h"
+#include "llvm/Support/DebugLog.h"
namespace mlir {
#define GEN_PASS_DEF_SROA
@@ -180,8 +181,7 @@ static void destructureSlot(
assert(slot.ptr.use_empty() && "after destructuring, the original slot "
"pointer should no longer be used");
- LLVM_DEBUG(llvm::dbgs() << "[sroa] Destructured memory slot: " << slot.ptr
- << "\n");
+ LDBG() << "Destructured memory slot: " << slot.ptr;
if (statistics.destructuredAmount)
(*statistics.destructuredAmount)++;
diff --git a/mlir/lib/Transforms/SymbolDCE.cpp b/mlir/lib/Transforms/SymbolDCE.cpp
index 0a925c4..87885be 100644
--- a/mlir/lib/Transforms/SymbolDCE.cpp
+++ b/mlir/lib/Transforms/SymbolDCE.cpp
@@ -13,8 +13,11 @@
#include "mlir/Transforms/Passes.h"
+#include "mlir/IR/Operation.h"
#include "mlir/IR/SymbolTable.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
+#include "llvm/Support/InterleavedRange.h"
namespace mlir {
#define GEN_PASS_DEF_SYMBOLDCE
@@ -87,8 +90,8 @@ LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp,
SymbolTableCollection &symbolTable,
bool symbolTableIsHidden,
DenseSet<Operation *> &liveSymbols) {
- LLVM_DEBUG(llvm::dbgs() << "computeLiveness: " << symbolTableOp->getName()
- << "\n");
+ LDBG() << "computeLiveness: "
+ << OpWithFlags(symbolTableOp, OpPrintingFlags().skipRegions());
// A worklist of live operations to propagate uses from.
SmallVector<Operation *, 16> worklist;
@@ -116,7 +119,8 @@ LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp,
// consideration.
while (!worklist.empty()) {
Operation *op = worklist.pop_back_val();
- LLVM_DEBUG(llvm::dbgs() << "processing: " << op->getName() << "\n");
+ LDBG() << "processing: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
// If this is a symbol table, recursively compute its liveness.
if (op->hasTrait<OpTrait::SymbolTable>()) {
@@ -124,13 +128,14 @@ LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp,
// symbol, or if it is a private symbol.
SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op);
bool symIsHidden = symbolTableIsHidden || !symbol || symbol.isPrivate();
- LLVM_DEBUG(llvm::dbgs() << "\tsymbol table: " << op->getName()
- << " is hidden: " << symIsHidden << "\n");
+ LDBG() << "\tsymbol table: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions())
+ << " is hidden: " << symIsHidden;
if (failed(computeLiveness(op, symbolTable, symIsHidden, liveSymbols)))
return failure();
} else {
- LLVM_DEBUG(llvm::dbgs()
- << "\tnon-symbol table: " << op->getName() << "\n");
+ LDBG() << "\tnon-symbol table: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
// If the op is not a symbol table, then, unless op itself is dead which
// would be handled by DCE, we need to check all the regions and blocks
// within the op to find the uses (e.g., consider visibility within op as
@@ -160,20 +165,17 @@ LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp,
}
SmallVector<Operation *, 4> resolvedSymbols;
- LLVM_DEBUG(llvm::dbgs() << "uses of " << op->getName() << "\n");
+ LDBG() << "uses of " << OpWithFlags(op, OpPrintingFlags().skipRegions());
for (const SymbolTable::SymbolUse &use : *uses) {
- LLVM_DEBUG(llvm::dbgs() << "\tuse: " << use.getUser() << "\n");
+ LDBG() << "\tuse: " << use.getUser();
// Lookup the symbols referenced by this use.
resolvedSymbols.clear();
if (failed(symbolTable.lookupSymbolIn(parentOp, use.getSymbolRef(),
resolvedSymbols)))
// Ignore references to unknown symbols.
continue;
- LLVM_DEBUG({
- llvm::dbgs() << "\t\tresolved symbols: ";
- llvm::interleaveComma(resolvedSymbols, llvm::dbgs());
- llvm::dbgs() << "\n";
- });
+ LDBG() << "\t\tresolved symbols: "
+ << llvm::interleaved(resolvedSymbols, ", ");
// Mark each of the resolved symbols as live.
for (Operation *resolvedSymbol : resolvedSymbols)
diff --git a/mlir/lib/Transforms/Utils/ControlFlowSinkUtils.cpp b/mlir/lib/Transforms/Utils/ControlFlowSinkUtils.cpp
index cfd568f..19cf464 100644
--- a/mlir/lib/Transforms/Utils/ControlFlowSinkUtils.cpp
+++ b/mlir/lib/Transforms/Utils/ControlFlowSinkUtils.cpp
@@ -21,7 +21,10 @@
#include "mlir/Transforms/ControlFlowSinkUtils.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/Matchers.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/OperationSupport.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "llvm/Support/DebugLog.h"
#include <vector>
#define DEBUG_TYPE "cf-sink"
@@ -84,13 +87,15 @@ bool Sinker::allUsersDominatedBy(Operation *op, Region *region) {
void Sinker::tryToSinkPredecessors(Operation *user, Region *region,
std::vector<Operation *> &stack) {
- LLVM_DEBUG(user->print(llvm::dbgs() << "\nContained op:\n"));
+ LDBG() << "Contained op: "
+ << OpWithFlags(user, OpPrintingFlags().skipRegions());
for (Value value : user->getOperands()) {
Operation *op = value.getDefiningOp();
// Ignore block arguments and ops that are already inside the region.
if (!op || op->getParentRegion() == region)
continue;
- LLVM_DEBUG(op->print(llvm::dbgs() << "\nTry to sink:\n"));
+ LDBG() << "Try to sink:\n"
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
// If the op's users are all in the region and it can be moved, then do so.
if (allUsersDominatedBy(op, region) && shouldMoveIntoRegion(op, region)) {
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index f23c619..5ba109d 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -121,17 +121,8 @@ struct ConversionValueMapping {
/// false positives.
bool isMappedTo(Value value) const { return mappedTo.contains(value); }
- /// Lookup the most recently mapped values with the desired types in the
- /// mapping.
- ///
- /// Special cases:
- /// - If the desired type range is empty, simply return the most recently
- /// mapped values.
- /// - If there is no mapping to the desired types, also return the most
- /// recently mapped values.
- /// - If there is no mapping for the given values at all, return the given
- /// value.
- ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {}) const;
+ /// Lookup a value in the mapping.
+ ValueVector lookup(const ValueVector &from) const;
template <typename T>
struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {};
@@ -185,54 +176,40 @@ private:
};
} // namespace
-ValueVector
-ConversionValueMapping::lookupOrDefault(Value from,
- TypeRange desiredTypes) const {
- // Try to find the deepest values that have the desired types. If there is no
- // such mapping, simply return the deepest values.
- ValueVector desiredValue;
- ValueVector current{from};
- do {
- // Store the current value if the types match.
- if (TypeRange(ValueRange(current)) == desiredTypes)
- desiredValue = current;
-
- // If possible, Replace each value with (one or multiple) mapped values.
- ValueVector next;
- for (Value v : current) {
- auto it = mapping.find({v});
- if (it != mapping.end()) {
- llvm::append_range(next, it->second);
- } else {
- next.push_back(v);
- }
- }
- if (next != current) {
- // If at least one value was replaced, continue the lookup from there.
- current = std::move(next);
- continue;
- }
+/// Marker attribute for pure type conversions. I.e., mappings whose only
+/// purpose is to resolve a type mismatch. (In contrast, mappings that point to
+/// the replacement values of a "replaceOp" call, etc., are not pure type
+/// conversions.)
+static const StringRef kPureTypeConversionMarker = "__pure_type_conversion__";
+
+/// Return the operation that defines all values in the vector. Return nullptr
+/// if the values are not defined by the same operation.
+static Operation *getCommonDefiningOp(const ValueVector &values) {
+ assert(!values.empty() && "expected non-empty value vector");
+ Operation *op = values.front().getDefiningOp();
+ for (Value v : llvm::drop_begin(values)) {
+ if (v.getDefiningOp() != op)
+ return nullptr;
+ }
+ return op;
+}
- // Otherwise: Check if there is a mapping for the entire vector. Such
- // mappings are materializations. (N:M mapping are not supported for value
- // replacements.)
- //
- // Note: From a correctness point of view, materializations do not have to
- // be stored (and looked up) in the mapping. But for performance reasons,
- // we choose to reuse existing IR (when possible) instead of creating it
- // multiple times.
- auto it = mapping.find(current);
- if (it == mapping.end()) {
- // No mapping found: The lookup stops here.
- break;
- }
- current = it->second;
- } while (true);
+/// A vector of values is a pure type conversion if all values are defined by
+/// the same operation and the operation has the `kPureTypeConversionMarker`
+/// attribute.
+static bool isPureTypeConversion(const ValueVector &values) {
+ assert(!values.empty() && "expected non-empty value vector");
+ Operation *op = getCommonDefiningOp(values);
+ return op && op->hasAttr(kPureTypeConversionMarker);
+}
- // If the desired values were found use them, otherwise default to the leaf
- // values.
- // Note: If `desiredTypes` is empty, this function always returns `current`.
- return !desiredValue.empty() ? std::move(desiredValue) : std::move(current);
+ValueVector ConversionValueMapping::lookup(const ValueVector &from) const {
+ auto it = mapping.find(from);
+ if (it == mapping.end()) {
+ // No mapping found: The lookup stops here.
+ return {};
+ }
+ return it->second;
}
//===----------------------------------------------------------------------===//
@@ -871,9 +848,10 @@ static bool hasRewrite(R &&rewrites, Block *block) {
namespace mlir {
namespace detail {
struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
- explicit ConversionPatternRewriterImpl(MLIRContext *ctx,
+ explicit ConversionPatternRewriterImpl(ConversionPatternRewriter &rewriter,
const ConversionConfig &config)
- : context(ctx), config(config) {}
+ : rewriter(rewriter), config(config),
+ notifyingRewriter(rewriter.getContext(), config.listener) {}
//===--------------------------------------------------------------------===//
// State Management
@@ -895,6 +873,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// failure.
template <typename RewriteTy, typename... Args>
void appendRewrite(Args &&...args) {
+ assert(config.allowPatternRollback && "appending rewrites is not allowed");
rewrites.push_back(
std::make_unique<RewriteTy>(*this, std::forward<Args>(args)...));
}
@@ -909,8 +888,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// is the tag used when describing a value within a diagnostic, e.g.
/// "operand".
LogicalResult remapValues(StringRef valueDiagTag,
- std::optional<Location> inputLoc,
- PatternRewriter &rewriter, ValueRange values,
+ std::optional<Location> inputLoc, ValueRange values,
SmallVector<ValueVector> &remapped);
/// Return "true" if the given operation is ignored, and does not need to be
@@ -921,16 +899,13 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
bool wasOpReplaced(Operation *op) const;
/// Lookup the most recently mapped values with the desired types in the
- /// mapping.
+ /// mapping, taking into account only replacements. Perform a best-effort
+ /// search for existing materializations with the desired types.
///
- /// Special cases:
- /// - If the desired type range is empty, simply return the most recently
- /// mapped values.
- /// - If there is no mapping to the desired types, also return the most
- /// recently mapped values.
- /// - If there is no mapping for the given values at all, return the given
- /// value.
- ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {}) const;
+ /// If `skipPureTypeConversions` is "true", materializations that are pure
+ /// type conversions are not considered.
+ ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {},
+ bool skipPureTypeConversions = false) const;
/// Lookup the given value within the map, or return an empty vector if the
/// value is not mapped. If it is mapped, this follows the same behavior
@@ -943,8 +918,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// Convert the types of block arguments within the given region.
FailureOr<Block *>
- convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region,
- const TypeConverter &converter,
+ convertRegionTypes(Region *region, const TypeConverter &converter,
TypeConverter::SignatureConversion *entryConversion);
/// Apply the given signature conversion on the given block. The new block
@@ -954,8 +928,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// translate between the origin argument types and those specified in the
/// signature conversion.
Block *applySignatureConversion(
- ConversionPatternRewriter &rewriter, Block *block,
- const TypeConverter *converter,
+ Block *block, const TypeConverter *converter,
TypeConverter::SignatureConversion &signatureConversion);
/// Replace the results of the given operation with the given values and
@@ -993,11 +966,18 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// If `valuesToMap` is set to a non-null Value, then that value is mapped to
/// the results of the unresolved materialization in the conversion value
/// mapping.
+ ///
+ /// If `isPureTypeConversion` is "true", the materialization is created only
+ /// to resolve a type mismatch. That means it is not a regular value
+ /// replacement issued by the user. (Replacement values that are created
+ /// "out of thin air" appear like unresolved materializations because they are
+ /// unrealized_conversion_cast ops. However, they must be treated like
+ /// regular value replacements.)
ValueRange buildUnresolvedMaterialization(
MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
Type originalType, const TypeConverter *converter,
- UnrealizedConversionCastOp *castOp = nullptr);
+ bool isPureTypeConversion = true);
/// Find a replacement value for the given SSA value in the conversion value
/// mapping. The replacement value must have the same type as the given SSA
@@ -1078,14 +1058,17 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
// State
//===--------------------------------------------------------------------===//
- /// MLIR context.
- MLIRContext *context;
+ /// The rewriter that is used to perform the conversion.
+ ConversionPatternRewriter &rewriter;
// Mapping between replaced values that differ in type. This happens when
// replacing a value with one of a different type.
ConversionValueMapping mapping;
/// Ordered list of block operations (creations, splits, motions).
+ /// This vector is maintained only if `allowPatternRollback` is set to
+ /// "true". Otherwise, all IR rewrites are materialized immediately and no
+ /// bookkeeping is needed.
SmallVector<std::unique_ptr<IRRewrite>> rewrites;
/// A set of operations that should no longer be considered for legalization.
@@ -1109,6 +1092,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// by the current pattern.
SetVector<Block *> patternInsertedBlocks;
+ /// A list of unresolved materializations that were created by the current
+ /// pattern.
+ DenseSet<UnrealizedConversionCastOp> patternMaterializations;
+
/// A mapping for looking up metadata of unresolved materializations.
DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
unresolvedMaterializations;
@@ -1124,15 +1111,37 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// Dialect conversion configuration.
const ConversionConfig &config;
+ /// A set of erased operations. This set is utilized only if
+ /// `allowPatternRollback` is set to "false". Conceptually, this set is
+ /// similar to `replacedOps` (which is maintained when the flag is set to
+ /// "true"). However, erasing from a DenseSet is more efficient than erasing
+ /// from a SetVector.
+ DenseSet<Operation *> erasedOps;
+
+ /// A set of erased blocks. This set is utilized only if
+ /// `allowPatternRollback` is set to "false".
+ DenseSet<Block *> erasedBlocks;
+
+ /// A rewriter that notifies the listener (if any) about all IR
+ /// modifications. This rewriter is utilized only if `allowPatternRollback`
+ /// is set to "false". If the flag is set to "true", the listener is notified
+ /// with a separate mechanism (e.g., in `IRRewrite::commit`).
+ IRRewriter notifyingRewriter;
+
#ifndef NDEBUG
+ /// A set of replaced block arguments. This set is for debugging purposes
+ /// only and it is maintained only if `allowPatternRollback` is set to
+ /// "true".
+ DenseSet<BlockArgument> replacedArgs;
+
/// A set of operations that have pending updates. This tracking isn't
/// strictly necessary, and is thus only active during debug builds for extra
/// verification.
SmallPtrSet<Operation *, 1> pendingRootUpdates;
/// A raw output stream used to prefix the debug log.
- llvm::impl::raw_ldbg_ostream os{(Twine("[") + DEBUG_TYPE + "] ").str(),
- llvm::dbgs(), /*HasPendingNewline=*/false};
+ llvm::impl::raw_ldbg_ostream os{(Twine("[") + DEBUG_TYPE + ":1] ").str(),
+ llvm::dbgs()};
/// A logger used to emit diagnostics during the conversion process.
llvm::ScopedPrinter logger{os};
@@ -1160,11 +1169,8 @@ void BlockTypeConversionRewrite::rollback() {
getNewBlock()->replaceAllUsesWith(getOrigBlock());
}
-void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
- Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter);
- if (!repl)
- return;
-
+static void performReplaceBlockArg(RewriterBase &rewriter, BlockArgument arg,
+ Value repl) {
if (isa<BlockArgument>(repl)) {
rewriter.replaceAllUsesWith(arg, repl);
return;
@@ -1181,6 +1187,13 @@ void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
});
}
+void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
+ Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter);
+ if (!repl)
+ return;
+ performReplaceBlockArg(rewriter, arg, repl);
+}
+
void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase({arg}); }
void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
@@ -1243,16 +1256,17 @@ void UnresolvedMaterializationRewrite::rollback() {
}
void ConversionPatternRewriterImpl::applyRewrites() {
- // Commit all rewrites.
- IRRewriter rewriter(context, config.listener);
+ // Commit all rewrites. Use a new rewriter, so the modifications are not
+ // tracked for rollback purposes etc.
+ IRRewriter irRewriter(rewriter.getContext(), config.listener);
// Note: New rewrites may be added during the "commit" phase and the
// `rewrites` vector may reallocate.
for (size_t i = 0; i < rewrites.size(); ++i)
- rewrites[i]->commit(rewriter);
+ rewrites[i]->commit(irRewriter);
// Clean up all rewrites.
SingleEraseRewriter eraseRewriter(
- context, /*opErasedCallback=*/[&](Operation *op) {
+ rewriter.getContext(), /*opErasedCallback=*/[&](Operation *op) {
if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
unresolvedMaterializations.erase(castOp);
});
@@ -1264,10 +1278,101 @@ void ConversionPatternRewriterImpl::applyRewrites() {
// State Management
//===----------------------------------------------------------------------===//
-ValueVector
-ConversionPatternRewriterImpl::lookupOrDefault(Value from,
- TypeRange desiredTypes) const {
- return mapping.lookupOrDefault(from, desiredTypes);
+ValueVector ConversionPatternRewriterImpl::lookupOrDefault(
+ Value from, TypeRange desiredTypes, bool skipPureTypeConversions) const {
+ // Helper function that looks up a single value.
+ auto lookup = [&](const ValueVector &values) -> ValueVector {
+ assert(!values.empty() && "expected non-empty value vector");
+
+ // If the pattern rollback is enabled, use the mapping to look up the
+ // values.
+ if (config.allowPatternRollback)
+ return mapping.lookup(values);
+
+ // Otherwise, look up values by examining the IR. All replacements have
+ // already been materialized in IR.
+ Operation *op = getCommonDefiningOp(values);
+ if (!op)
+ return {};
+ auto castOp = dyn_cast<UnrealizedConversionCastOp>(op);
+ if (!castOp)
+ return {};
+ if (!this->unresolvedMaterializations.contains(castOp))
+ return {};
+ if (castOp.getOutputs() != values)
+ return {};
+ return castOp.getInputs();
+ };
+
+ // Helper function that looks up each value in `values` individually and then
+ // composes the results. If that fails, it tries to look up the entire vector
+ // at once.
+ auto composedLookup = [&](const ValueVector &values) -> ValueVector {
+ // If possible, replace each value with (one or multiple) mapped values.
+ ValueVector next;
+ for (Value v : values) {
+ ValueVector r = lookup({v});
+ if (!r.empty()) {
+ llvm::append_range(next, r);
+ } else {
+ next.push_back(v);
+ }
+ }
+ if (next != values) {
+ // At least one value was replaced.
+ return next;
+ }
+
+ // Otherwise: Check if there is a mapping for the entire vector. Such
+ // mappings are materializations. (N:M mapping are not supported for value
+ // replacements.)
+ //
+ // Note: From a correctness point of view, materializations do not have to
+ // be stored (and looked up) in the mapping. But for performance reasons,
+ // we choose to reuse existing IR (when possible) instead of creating it
+ // multiple times.
+ ValueVector r = lookup(values);
+ if (r.empty()) {
+ // No mapping found: The lookup stops here.
+ return {};
+ }
+ return r;
+ };
+
+ // Try to find the deepest values that have the desired types. If there is no
+ // such mapping, simply return the deepest values.
+ ValueVector desiredValue;
+ ValueVector current{from};
+ ValueVector lastNonMaterialization{from};
+ do {
+ // Store the current value if the types match.
+ bool match = TypeRange(ValueRange(current)) == desiredTypes;
+ if (skipPureTypeConversions) {
+ // Skip pure type conversions, if requested.
+ bool pureConversion = isPureTypeConversion(current);
+ match &= !pureConversion;
+ // Keep track of the last mapped value that was not a pure type
+ // conversion.
+ if (!pureConversion)
+ lastNonMaterialization = current;
+ }
+ if (match)
+ desiredValue = current;
+
+ // Lookup next value in the mapping.
+ ValueVector next = composedLookup(current);
+ if (next.empty())
+ break;
+ current = std::move(next);
+ } while (true);
+
+ // If the desired values were found use them, otherwise default to the leaf
+ // values. (Skip pure type conversions, if requested.)
+ if (!desiredTypes.empty())
+ return desiredValue;
+ if (skipPureTypeConversions)
+ return lastNonMaterialization;
+ return current;
}
ValueVector
@@ -1300,21 +1405,13 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state,
void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep,
StringRef patternName) {
for (auto &rewrite :
- llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep))) {
- if (!config.allowPatternRollback &&
- !isa<UnresolvedMaterializationRewrite>(rewrite)) {
- // Unresolved materializations can always be rolled back (erased).
- llvm::report_fatal_error("pattern '" + patternName +
- "' rollback of IR modifications requested");
- }
+ llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep)))
rewrite->rollback();
- }
rewrites.resize(numRewritesToKeep);
}
LogicalResult ConversionPatternRewriterImpl::remapValues(
- StringRef valueDiagTag, std::optional<Location> inputLoc,
- PatternRewriter &rewriter, ValueRange values,
+ StringRef valueDiagTag, std::optional<Location> inputLoc, ValueRange values,
SmallVector<ValueVector> &remapped) {
remapped.reserve(llvm::size(values));
@@ -1324,16 +1421,19 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
if (!currentTypeConverter) {
- // The current pattern does not have a type converter. I.e., it does not
- // distinguish between legal and illegal types. For each operand, simply
- // pass through the most recently mapped values.
- remapped.push_back(lookupOrDefault(operand));
+ // The current pattern does not have a type converter. Pass the most
+ // recently mapped values, excluding materializations. Materializations
+ // are intentionally excluded because their presence may depend on other
+ // patterns. Including materializations would make the lookup fragile
+ // and unpredictable.
+ remapped.push_back(lookupOrDefault(operand, /*desiredTypes=*/{},
+ /*skipPureTypeConversions=*/true));
continue;
}
// If there is no legal conversion, fail to match this pattern.
SmallVector<Type, 1> legalTypes;
- if (failed(currentTypeConverter->convertType(origType, legalTypes))) {
+ if (failed(currentTypeConverter->convertType(operand, legalTypes))) {
notifyMatchFailure(operandLoc, [=](Diagnostic &diag) {
diag << "unable to convert type for " << valueDiagTag << " #"
<< it.index() << ", type was " << origType;
@@ -1356,7 +1456,8 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
}
// Create a materialization for the most recently mapped values.
- repl = lookupOrDefault(operand);
+ repl = lookupOrDefault(operand, /*desiredTypes=*/{},
+ /*skipPureTypeConversions=*/true);
ValueRange castValues = buildUnresolvedMaterialization(
MaterializationKind::Target, computeInsertPoint(repl), operandLoc,
/*valuesToMap=*/repl, /*inputs=*/repl, /*outputTypes=*/legalTypes,
@@ -1368,12 +1469,12 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const {
// Check to see if this operation is ignored or was replaced.
- return replacedOps.count(op) || ignoredOps.count(op);
+ return wasOpReplaced(op) || ignoredOps.count(op);
}
bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const {
// Check to see if this operation was replaced.
- return replacedOps.count(op);
+ return replacedOps.count(op) || erasedOps.count(op);
}
//===----------------------------------------------------------------------===//
@@ -1381,8 +1482,7 @@ bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const {
//===----------------------------------------------------------------------===//
FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
- ConversionPatternRewriter &rewriter, Region *region,
- const TypeConverter &converter,
+ Region *region, const TypeConverter &converter,
TypeConverter::SignatureConversion *entryConversion) {
regionToConverter[region] = &converter;
if (region->empty())
@@ -1397,25 +1497,23 @@ FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
if (!conversion)
return failure();
// Convert the block with the computed signature.
- applySignatureConversion(rewriter, &block, &converter, *conversion);
+ applySignatureConversion(&block, &converter, *conversion);
}
// Convert the entry block. If an entry signature conversion was provided,
// use that one. Otherwise, compute the signature with the type converter.
if (entryConversion)
- return applySignatureConversion(rewriter, &region->front(), &converter,
+ return applySignatureConversion(&region->front(), &converter,
*entryConversion);
std::optional<TypeConverter::SignatureConversion> conversion =
converter.convertBlockSignature(&region->front());
if (!conversion)
return failure();
- return applySignatureConversion(rewriter, &region->front(), &converter,
- *conversion);
+ return applySignatureConversion(&region->front(), &converter, *conversion);
}
Block *ConversionPatternRewriterImpl::applySignatureConversion(
- ConversionPatternRewriter &rewriter, Block *block,
- const TypeConverter *converter,
+ Block *block, const TypeConverter *converter,
TypeConverter::SignatureConversion &signatureConversion) {
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
// A block cannot be converted multiple times.
@@ -1457,7 +1555,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
// a bit more efficient, so we try to do that when possible.
bool fastPath = !config.listener;
if (fastPath) {
- appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->end());
+ if (config.allowPatternRollback)
+ appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->end());
newBlock->getOperations().splice(newBlock->end(), block->getOperations());
} else {
while (!block->empty())
@@ -1482,7 +1581,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
OpBuilder::InsertPoint(newBlock, newBlock->begin()),
origArg.getLoc(),
/*valuesToMap=*/{}, /*inputs=*/ValueRange(),
- /*outputTypes=*/origArgType, /*originalType=*/Type(), converter)
+ /*outputTypes=*/origArgType, /*originalType=*/Type(), converter,
+ /*isPureTypeConversion=*/false)
.front();
replaceUsesOfBlockArgument(origArg, mat, converter);
continue;
@@ -1504,7 +1604,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
replaceUsesOfBlockArgument(origArg, replArgs, converter);
}
- appendRewrite<BlockTypeConversionRewrite>(/*origBlock=*/block, newBlock);
+ if (config.allowPatternRollback)
+ appendRewrite<BlockTypeConversionRewrite>(/*origBlock=*/block, newBlock);
// Erase the old block. (It is just unlinked for now and will be erased during
// cleanup.)
@@ -1523,7 +1624,7 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
Type originalType, const TypeConverter *converter,
- UnrealizedConversionCastOp *castOp) {
+ bool isPureTypeConversion) {
assert((!originalType || kind == MaterializationKind::Target) &&
"original type is valid only for target materializations");
assert(TypeRange(inputs) != outputTypes &&
@@ -1533,21 +1634,35 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
// tracking the materialization like we do for other operations.
OpBuilder builder(outputTypes.front().getContext());
builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
- auto convertOp =
+ UnrealizedConversionCastOp convertOp =
UnrealizedConversionCastOp::create(builder, loc, outputTypes, inputs);
- if (!valuesToMap.empty())
- mapping.map(valuesToMap, convertOp.getResults());
- if (castOp)
- *castOp = convertOp;
+ if (config.attachDebugMaterializationKind) {
+ StringRef kindStr =
+ kind == MaterializationKind::Source ? "source" : "target";
+ convertOp->setAttr("__kind__", builder.getStringAttr(kindStr));
+ }
+ if (isPureTypeConversion)
+ convertOp->setAttr(kPureTypeConversionMarker, builder.getUnitAttr());
+
+ // Register the materialization.
unresolvedMaterializations[convertOp] =
UnresolvedMaterializationInfo(converter, kind, originalType);
- appendRewrite<UnresolvedMaterializationRewrite>(convertOp,
- std::move(valuesToMap));
+ if (config.allowPatternRollback) {
+ if (!valuesToMap.empty())
+ mapping.map(valuesToMap, convertOp.getResults());
+ appendRewrite<UnresolvedMaterializationRewrite>(convertOp,
+ std::move(valuesToMap));
+ } else {
+ patternMaterializations.insert(convertOp);
+ }
return convertOp.getResults();
}
Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
Value value, const TypeConverter *converter) {
+ assert(config.allowPatternRollback &&
+ "this code path is valid only in rollback mode");
+
// Try to find a replacement value with the same type in the conversion value
// mapping. This includes cached materializations. We try to reuse those
// instead of generating duplicate IR.
@@ -1609,26 +1724,119 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
logger.getOStream() << " (was detached)";
logger.getOStream() << "\n";
});
- assert(!wasOpReplaced(op->getParentOp()) &&
+
+ // In rollback mode, it is easier to misuse the API, so perform extra error
+ // checking.
+ assert(!(config.allowPatternRollback && wasOpReplaced(op->getParentOp())) &&
"attempting to insert into a block within a replaced/erased op");
+ // In "no rollback" mode, the listener is always notified immediately.
+ if (!config.allowPatternRollback && config.listener)
+ config.listener->notifyOperationInserted(op, previous);
+
if (wasDetached) {
- // If the op was detached, it is most likely a newly created op.
- // TODO: If the same op is inserted multiple times from a detached state,
- // the rollback mechanism may erase the same op multiple times. This is a
- // bug in the rollback-based dialect conversion driver.
- appendRewrite<CreateOperationRewrite>(op);
+ // If the op was detached, it is most likely a newly created op. Add it the
+ // set of newly created ops, so that it will be legalized. If this op is
+ // not a newly created op, it will be legalized a second time, which is
+ // inefficient but harmless.
patternNewOps.insert(op);
+
+ if (config.allowPatternRollback) {
+ // TODO: If the same op is inserted multiple times from a detached
+ // state, the rollback mechanism may erase the same op multiple times.
+ // This is a bug in the rollback-based dialect conversion driver.
+ appendRewrite<CreateOperationRewrite>(op);
+ } else {
+ // In "no rollback" mode, there is an extra data structure for tracking
+ // erased operations that must be kept up to date.
+ erasedOps.erase(op);
+ }
return;
}
// The op was moved from one place to another.
- appendRewrite<MoveOperationRewrite>(op, previous);
+ if (config.allowPatternRollback)
+ appendRewrite<MoveOperationRewrite>(op, previous);
+}
+
+/// Given that `fromRange` is about to be replaced with `toRange`, compute
+/// replacement values with the types of `fromRange`.
+static SmallVector<Value>
+getReplacementValues(ConversionPatternRewriterImpl &impl, ValueRange fromRange,
+ const SmallVector<SmallVector<Value>> &toRange,
+ const TypeConverter *converter) {
+ assert(!impl.config.allowPatternRollback &&
+ "this code path is valid only in 'no rollback' mode");
+ SmallVector<Value> repls;
+ for (auto [from, to] : llvm::zip_equal(fromRange, toRange)) {
+ if (from.use_empty()) {
+ // The replaced value is dead. No replacement value is needed.
+ repls.push_back(Value());
+ continue;
+ }
+
+ if (to.empty()) {
+ // The replaced value is dropped. Materialize a replacement value "out of
+ // thin air".
+ Value srcMat = impl.buildUnresolvedMaterialization(
+ MaterializationKind::Source, computeInsertPoint(from), from.getLoc(),
+ /*valuesToMap=*/{}, /*inputs=*/ValueRange(),
+ /*outputTypes=*/from.getType(), /*originalType=*/Type(),
+ converter)[0];
+ repls.push_back(srcMat);
+ continue;
+ }
+
+ if (TypeRange(ValueRange(to)) == TypeRange(from.getType())) {
+ // The replacement value already has the correct type. Use it directly.
+ repls.push_back(to[0]);
+ continue;
+ }
+
+ // The replacement value has the wrong type. Build a source materialization
+ // to the original type.
+ // TODO: This is a bit inefficient. We should try to reuse existing
+ // materializations if possible. This would require an extension of the
+ // `lookupOrDefault` API.
+ Value srcMat = impl.buildUnresolvedMaterialization(
+ MaterializationKind::Source, computeInsertPoint(to), from.getLoc(),
+ /*valuesToMap=*/{}, /*inputs=*/to, /*outputTypes=*/from.getType(),
+ /*originalType=*/Type(), converter)[0];
+ repls.push_back(srcMat);
+ }
+
+ return repls;
}
void ConversionPatternRewriterImpl::replaceOp(
Operation *op, SmallVector<SmallVector<Value>> &&newValues) {
- assert(newValues.size() == op->getNumResults());
+ assert(newValues.size() == op->getNumResults() &&
+ "incorrect number of replacement values");
+
+ if (!config.allowPatternRollback) {
+ // Pattern rollback is not allowed: materialize all IR changes immediately.
+ SmallVector<Value> repls = getReplacementValues(
+ *this, op->getResults(), newValues, currentTypeConverter);
+ // Update internal data structures, so that there are no dangling pointers
+ // to erased IR.
+ op->walk([&](Operation *op) {
+ erasedOps.insert(op);
+ ignoredOps.remove(op);
+ if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
+ unresolvedMaterializations.erase(castOp);
+ patternMaterializations.erase(castOp);
+ }
+ // The original op will be erased, so remove it from the set of
+ // unlegalized ops.
+ if (config.unlegalizedOps)
+ config.unlegalizedOps->erase(op);
+ });
+ op->walk([&](Block *block) { erasedBlocks.insert(block); });
+ // Replace the op with the replacement values and notify the listener.
+ notifyingRewriter.replaceOp(op, repls);
+ return;
+ }
+
assert(!ignoredOps.contains(op) && "operation was already replaced");
// Check if replaced op is an unresolved materialization, i.e., an
@@ -1650,7 +1858,7 @@ void ConversionPatternRewriterImpl::replaceOp(
MaterializationKind::Source, computeInsertPoint(result),
result.getLoc(), /*valuesToMap=*/{result}, /*inputs=*/ValueRange(),
/*outputTypes=*/result.getType(), /*originalType=*/Type(),
- currentTypeConverter);
+ currentTypeConverter, /*isPureTypeConversion=*/false);
continue;
}
@@ -1667,11 +1875,59 @@ void ConversionPatternRewriterImpl::replaceOp(
void ConversionPatternRewriterImpl::replaceUsesOfBlockArgument(
BlockArgument from, ValueRange to, const TypeConverter *converter) {
+ if (!config.allowPatternRollback) {
+ SmallVector<Value> toConv = llvm::to_vector(to);
+ SmallVector<Value> repls =
+ getReplacementValues(*this, from, {toConv}, converter);
+ IRRewriter r(from.getContext());
+ Value repl = repls.front();
+ if (!repl)
+ return;
+
+ performReplaceBlockArg(r, from, repl);
+ return;
+ }
+
+#ifndef NDEBUG
+ // Make sure that a block argument is not replaced multiple times. In
+ // rollback mode, `replaceUsesOfBlockArgument` replaces not only all current
+ // uses of the given block argument, but also all future uses that may be
+ // introduced by future pattern applications. Therefore, it does not make
+ // sense to call `replaceUsesOfBlockArgument` multiple times with the same
+ // block argument. Doing so would overwrite the mapping and mess with the
+ // internal state of the dialect conversion driver.
+ assert(!replacedArgs.contains(from) &&
+ "attempting to replace a block argument that was already replaced");
+ replacedArgs.insert(from);
+#endif // NDEBUG
+
appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from, converter);
mapping.map(from, to);
}
void ConversionPatternRewriterImpl::eraseBlock(Block *block) {
+ if (!config.allowPatternRollback) {
+ // Pattern rollback is not allowed: materialize all IR changes immediately.
+ // Update internal data structures, so that there are no dangling pointers
+ // to erased IR.
+ block->walk([&](Operation *op) {
+ erasedOps.insert(op);
+ ignoredOps.remove(op);
+ if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
+ unresolvedMaterializations.erase(castOp);
+ patternMaterializations.erase(castOp);
+ }
+ // The original op will be erased, so remove it from the set of
+ // unlegalized ops.
+ if (config.unlegalizedOps)
+ config.unlegalizedOps->erase(op);
+ });
+ block->walk([&](Block *block) { erasedBlocks.insert(block); });
+ // Erase the block and notify the listener.
+ notifyingRewriter.eraseBlock(block);
+ return;
+ }
+
assert(!wasOpReplaced(block->getParentOp()) &&
"attempting to erase a block within a replaced/erased op");
appendRewrite<EraseBlockRewrite>(block);
@@ -1705,23 +1961,37 @@ void ConversionPatternRewriterImpl::notifyBlockInserted(
logger.getOStream() << " (was detached)";
logger.getOStream() << "\n";
});
- assert(!wasOpReplaced(newParentOp) &&
+
+ // In rollback mode, it is easier to misuse the API, so perform extra error
+ // checking.
+ assert(!(config.allowPatternRollback && wasOpReplaced(newParentOp)) &&
"attempting to insert into a region within a replaced/erased op");
(void)newParentOp;
+ // In "no rollback" mode, the listener is always notified immediately.
+ if (!config.allowPatternRollback && config.listener)
+ config.listener->notifyBlockInserted(block, previous, previousIt);
+
patternInsertedBlocks.insert(block);
if (wasDetached) {
// If the block was detached, it is most likely a newly created block.
- // TODO: If the same block is inserted multiple times from a detached state,
- // the rollback mechanism may erase the same block multiple times. This is a
- // bug in the rollback-based dialect conversion driver.
- appendRewrite<CreateBlockRewrite>(block);
+ if (config.allowPatternRollback) {
+ // TODO: If the same block is inserted multiple times from a detached
+ // state, the rollback mechanism may erase the same block multiple times.
+ // This is a bug in the rollback-based dialect conversion driver.
+ appendRewrite<CreateBlockRewrite>(block);
+ } else {
+ // In "no rollback" mode, there is an extra data structure for tracking
+ // erased blocks that must be kept up to date.
+ erasedBlocks.erase(block);
+ }
return;
}
// The block was moved from one place to another.
- appendRewrite<MoveBlockRewrite>(block, previous, previousIt);
+ if (config.allowPatternRollback)
+ appendRewrite<MoveBlockRewrite>(block, previous, previousIt);
}
void ConversionPatternRewriterImpl::inlineBlockBefore(Block *source,
@@ -1748,12 +2018,16 @@ void ConversionPatternRewriterImpl::notifyMatchFailure(
ConversionPatternRewriter::ConversionPatternRewriter(
MLIRContext *ctx, const ConversionConfig &config)
: PatternRewriter(ctx),
- impl(new detail::ConversionPatternRewriterImpl(ctx, config)) {
+ impl(new detail::ConversionPatternRewriterImpl(*this, config)) {
setListener(impl.get());
}
ConversionPatternRewriter::~ConversionPatternRewriter() = default;
+const ConversionConfig &ConversionPatternRewriter::getConfig() const {
+ return impl->config;
+}
+
void ConversionPatternRewriter::replaceOp(Operation *op, Operation *newOp) {
assert(op && newOp && "expected non-null op");
replaceOp(op, newOp->getResults());
@@ -1821,7 +2095,7 @@ Block *ConversionPatternRewriter::applySignatureConversion(
assert(!impl->wasOpReplaced(block->getParentOp()) &&
"attempting to apply a signature conversion to a block within a "
"replaced/erased op");
- return impl->applySignatureConversion(*this, block, converter, conversion);
+ return impl->applySignatureConversion(block, converter, conversion);
}
FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
@@ -1830,7 +2104,7 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
assert(!impl->wasOpReplaced(region->getParentOp()) &&
"attempting to apply a signature conversion to a block within a "
"replaced/erased op");
- return impl->convertRegionTypes(*this, region, converter, entryConversion);
+ return impl->convertRegionTypes(region, converter, entryConversion);
}
void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
@@ -1849,7 +2123,7 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
Value ConversionPatternRewriter::getRemappedValue(Value key) {
SmallVector<ValueVector> remappedValues;
- if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, key,
+ if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, key,
remappedValues)))
return nullptr;
assert(remappedValues.front().size() == 1 && "1:N conversion not supported");
@@ -1862,7 +2136,7 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys,
if (keys.empty())
return success();
SmallVector<ValueVector> remapped;
- if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys,
+ if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, keys,
remapped)))
return failure();
for (const auto &values : remapped) {
@@ -1895,9 +2169,9 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
// ops should be moved one-by-one ("slow path"), so that a separate
// `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
// a bit more efficient, so we try to do that when possible.
- bool fastPath = !impl->config.listener;
+ bool fastPath = !getConfig().listener;
- if (fastPath)
+ if (fastPath && impl->config.allowPatternRollback)
impl->inlineBlockBefore(source, dest, before);
// Replace all uses of block arguments.
@@ -1923,6 +2197,11 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
}
void ConversionPatternRewriter::startOpModification(Operation *op) {
+ if (!impl->config.allowPatternRollback) {
+ // Pattern rollback is not allowed: no extra bookkeeping is needed.
+ PatternRewriter::startOpModification(op);
+ return;
+ }
assert(!impl->wasOpReplaced(op) &&
"attempting to modify a replaced/erased op");
#ifndef NDEBUG
@@ -1932,20 +2211,29 @@ void ConversionPatternRewriter::startOpModification(Operation *op) {
}
void ConversionPatternRewriter::finalizeOpModification(Operation *op) {
- assert(!impl->wasOpReplaced(op) &&
- "attempting to modify a replaced/erased op");
- PatternRewriter::finalizeOpModification(op);
impl->patternModifiedOps.insert(op);
+ if (!impl->config.allowPatternRollback) {
+ PatternRewriter::finalizeOpModification(op);
+ if (getConfig().listener)
+ getConfig().listener->notifyOperationModified(op);
+ return;
+ }
// There is nothing to do here, we only need to track the operation at the
// start of the update.
#ifndef NDEBUG
+ assert(!impl->wasOpReplaced(op) &&
+ "attempting to modify a replaced/erased op");
assert(impl->pendingRootUpdates.erase(op) &&
"operation did not have a pending in-place update");
#endif
}
void ConversionPatternRewriter::cancelOpModification(Operation *op) {
+ if (!impl->config.allowPatternRollback) {
+ PatternRewriter::cancelOpModification(op);
+ return;
+ }
#ifndef NDEBUG
assert(impl->pendingRootUpdates.erase(op) &&
"operation did not have a pending in-place update");
@@ -1970,17 +2258,17 @@ detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
// ConversionPattern
//===----------------------------------------------------------------------===//
-SmallVector<Value> ConversionPattern::getOneToOneAdaptorOperands(
+FailureOr<SmallVector<Value>> ConversionPattern::getOneToOneAdaptorOperands(
ArrayRef<ValueRange> operands) const {
SmallVector<Value> oneToOneOperands;
oneToOneOperands.reserve(operands.size());
for (ValueRange operand : operands) {
if (operand.size() != 1)
- llvm::report_fatal_error("pattern '" + getDebugName() +
- "' does not support 1:N conversion");
+ return failure();
+
oneToOneOperands.push_back(operand.front());
}
- return oneToOneOperands;
+ return std::move(oneToOneOperands);
}
LogicalResult
@@ -1995,7 +2283,7 @@ ConversionPattern::matchAndRewrite(Operation *op,
// Remap the operands of the operation.
SmallVector<ValueVector> remapped;
- if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter,
+ if (failed(rewriterImpl.remapValues("operand", op->getLoc(),
op->getOperands(), remapped))) {
return failure();
}
@@ -2017,38 +2305,34 @@ class OperationLegalizer {
public:
using LegalizationAction = ConversionTarget::LegalizationAction;
- OperationLegalizer(const ConversionTarget &targetInfo,
- const FrozenRewritePatternSet &patterns,
- const ConversionConfig &config);
+ OperationLegalizer(ConversionPatternRewriter &rewriter,
+ const ConversionTarget &targetInfo,
+ const FrozenRewritePatternSet &patterns);
/// Returns true if the given operation is known to be illegal on the target.
bool isIllegal(Operation *op) const;
/// Attempt to legalize the given operation. Returns success if the operation
/// was legalized, failure otherwise.
- LogicalResult legalize(Operation *op, ConversionPatternRewriter &rewriter);
+ LogicalResult legalize(Operation *op);
/// Returns the conversion target in use by the legalizer.
const ConversionTarget &getTarget() { return target; }
private:
/// Attempt to legalize the given operation by folding it.
- LogicalResult legalizeWithFold(Operation *op,
- ConversionPatternRewriter &rewriter);
+ LogicalResult legalizeWithFold(Operation *op);
/// Attempt to legalize the given operation by applying a pattern. Returns
/// success if the operation was legalized, failure otherwise.
- LogicalResult legalizeWithPattern(Operation *op,
- ConversionPatternRewriter &rewriter);
+ LogicalResult legalizeWithPattern(Operation *op);
/// Return true if the given pattern may be applied to the given operation,
/// false otherwise.
- bool canApplyPattern(Operation *op, const Pattern &pattern,
- ConversionPatternRewriter &rewriter);
+ bool canApplyPattern(Operation *op, const Pattern &pattern);
/// Legalize the resultant IR after successfully applying the given pattern.
LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern,
- ConversionPatternRewriter &rewriter,
const RewriterState &curState,
const SetVector<Operation *> &newOps,
const SetVector<Operation *> &modifiedOps,
@@ -2057,18 +2341,12 @@ private:
/// Legalizes the actions registered during the execution of a pattern.
LogicalResult
legalizePatternBlockRewrites(Operation *op,
- ConversionPatternRewriter &rewriter,
- ConversionPatternRewriterImpl &impl,
const SetVector<Block *> &insertedBlocks,
const SetVector<Operation *> &newOps);
LogicalResult
- legalizePatternCreatedOperations(ConversionPatternRewriter &rewriter,
- ConversionPatternRewriterImpl &impl,
- const SetVector<Operation *> &newOps);
+ legalizePatternCreatedOperations(const SetVector<Operation *> &newOps);
LogicalResult
- legalizePatternRootUpdates(ConversionPatternRewriter &rewriter,
- ConversionPatternRewriterImpl &impl,
- const SetVector<Operation *> &modifiedOps);
+ legalizePatternRootUpdates(const SetVector<Operation *> &modifiedOps);
//===--------------------------------------------------------------------===//
// Cost Model
@@ -2111,21 +2389,21 @@ private:
/// The current set of patterns that have been applied.
SmallPtrSet<const Pattern *, 8> appliedPatterns;
+ /// The rewriter to use when converting operations.
+ ConversionPatternRewriter &rewriter;
+
/// The legalization information provided by the target.
const ConversionTarget &target;
/// The pattern applicator to use for conversions.
PatternApplicator applicator;
-
- /// Dialect conversion configuration.
- const ConversionConfig &config;
};
} // namespace
-OperationLegalizer::OperationLegalizer(const ConversionTarget &targetInfo,
- const FrozenRewritePatternSet &patterns,
- const ConversionConfig &config)
- : target(targetInfo), applicator(patterns), config(config) {
+OperationLegalizer::OperationLegalizer(ConversionPatternRewriter &rewriter,
+ const ConversionTarget &targetInfo,
+ const FrozenRewritePatternSet &patterns)
+ : rewriter(rewriter), target(targetInfo), applicator(patterns) {
// The set of patterns that can be applied to illegal operations to transform
// them into legal ones.
DenseMap<OperationName, LegalizationPatterns> legalizerPatterns;
@@ -2139,9 +2417,7 @@ bool OperationLegalizer::isIllegal(Operation *op) const {
return target.isIllegal(op);
}
-LogicalResult
-OperationLegalizer::legalize(Operation *op,
- ConversionPatternRewriter &rewriter) {
+LogicalResult OperationLegalizer::legalize(Operation *op) {
#ifndef NDEBUG
const char *logLineComment =
"//===-------------------------------------------===//\n";
@@ -2203,19 +2479,21 @@ OperationLegalizer::legalize(Operation *op,
return success();
}
- // If the operation isn't legal, try to fold it in-place.
- // TODO: Should we always try to do this, even if the op is
- // already legal?
- if (succeeded(legalizeWithFold(op, rewriter))) {
- LLVM_DEBUG({
- logSuccess(logger, "operation was folded");
- logger.startLine() << logLineComment;
- });
- return success();
+ // If the operation is not legal, try to fold it in-place if the folding mode
+ // is 'BeforePatterns'. 'Never' will skip this.
+ const ConversionConfig &config = rewriter.getConfig();
+ if (config.foldingMode == DialectConversionFoldingMode::BeforePatterns) {
+ if (succeeded(legalizeWithFold(op))) {
+ LLVM_DEBUG({
+ logSuccess(logger, "operation was folded");
+ logger.startLine() << logLineComment;
+ });
+ return success();
+ }
}
// Otherwise, we need to apply a legalization pattern to this operation.
- if (succeeded(legalizeWithPattern(op, rewriter))) {
+ if (succeeded(legalizeWithPattern(op))) {
LLVM_DEBUG({
logSuccess(logger, "");
logger.startLine() << logLineComment;
@@ -2223,6 +2501,18 @@ OperationLegalizer::legalize(Operation *op,
return success();
}
+ // If the operation can't be legalized via patterns, try to fold it in-place
+ // if the folding mode is 'AfterPatterns'.
+ if (config.foldingMode == DialectConversionFoldingMode::AfterPatterns) {
+ if (succeeded(legalizeWithFold(op))) {
+ LLVM_DEBUG({
+ logSuccess(logger, "operation was folded");
+ logger.startLine() << logLineComment;
+ });
+ return success();
+ }
+ }
+
LLVM_DEBUG({
logFailure(logger, "no matched legalization pattern");
logger.startLine() << logLineComment;
@@ -2239,9 +2529,7 @@ static T moveAndReset(T &obj) {
return result;
}
-LogicalResult
-OperationLegalizer::legalizeWithFold(Operation *op,
- ConversionPatternRewriter &rewriter) {
+LogicalResult OperationLegalizer::legalizeWithFold(Operation *op) {
auto &rewriterImpl = rewriter.getImpl();
LLVM_DEBUG({
rewriterImpl.logger.startLine() << "* Fold {\n";
@@ -2275,18 +2563,18 @@ OperationLegalizer::legalizeWithFold(Operation *op,
// An empty list of replacement values indicates that the fold was in-place.
// As the operation changed, a new legalization needs to be attempted.
if (replacementValues.empty())
- return legalize(op, rewriter);
+ return legalize(op);
// Insert a replacement for 'op' with the folded replacement values.
rewriter.replaceOp(op, replacementValues);
// Recursively legalize any new constant operations.
for (Operation *newOp : newOps) {
- if (failed(legalize(newOp, rewriter))) {
+ if (failed(legalize(newOp))) {
LLVM_DEBUG(logFailure(rewriterImpl.logger,
"failed to legalize generated constant '{0}'",
newOp->getName()));
- if (!config.allowPatternRollback) {
+ if (!rewriter.getConfig().allowPatternRollback) {
// Rolling back a folder is like rolling back a pattern.
llvm::report_fatal_error(
"op '" + opName +
@@ -2302,10 +2590,34 @@ OperationLegalizer::legalizeWithFold(Operation *op,
return success();
}
-LogicalResult
-OperationLegalizer::legalizeWithPattern(Operation *op,
- ConversionPatternRewriter &rewriter) {
+/// Report a fatal error indicating that newly produced or modified IR could
+/// not be legalized.
+static void
+reportNewIrLegalizationFatalError(const Pattern &pattern,
+ const SetVector<Operation *> &newOps,
+ const SetVector<Operation *> &modifiedOps,
+ const SetVector<Block *> &insertedBlocks) {
+ auto newOpNames = llvm::map_range(
+ newOps, [](Operation *op) { return op->getName().getStringRef(); });
+ auto modifiedOpNames = llvm::map_range(
+ modifiedOps, [](Operation *op) { return op->getName().getStringRef(); });
+ StringRef detachedBlockStr = "(detached block)";
+ auto insertedBlockNames = llvm::map_range(insertedBlocks, [&](Block *block) {
+ if (block->getParentOp())
+ return block->getParentOp()->getName().getStringRef();
+ return detachedBlockStr;
+ });
+ llvm::report_fatal_error(
+ "pattern '" + pattern.getDebugName() +
+ "' produced IR that could not be legalized. " + "new ops: {" +
+ llvm::join(newOpNames, ", ") + "}, " + "modified ops: {" +
+ llvm::join(modifiedOpNames, ", ") + "}, " + "inserted block into ops: {" +
+ llvm::join(insertedBlockNames, ", ") + "}");
+}
+
+LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) {
auto &rewriterImpl = rewriter.getImpl();
+ const ConversionConfig &config = rewriter.getConfig();
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
Operation *checkOp;
@@ -2335,7 +2647,7 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
// Functor that returns if the given pattern may be applied.
auto canApply = [&](const Pattern &pattern) {
- bool canApply = canApplyPattern(op, pattern, rewriter);
+ bool canApply = canApplyPattern(op, pattern);
if (canApply && config.listener)
config.listener->notifyPatternBegin(pattern, op);
return canApply;
@@ -2345,17 +2657,23 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
RewriterState curState = rewriterImpl.getCurrentState();
auto onFailure = [&](const Pattern &pattern) {
assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
-#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
if (!rewriterImpl.config.allowPatternRollback) {
- // Returning "failure" after modifying IR is not allowed.
+ // Erase all unresolved materializations.
+ for (auto op : rewriterImpl.patternMaterializations) {
+ rewriterImpl.unresolvedMaterializations.erase(op);
+ op.erase();
+ }
+ rewriterImpl.patternMaterializations.clear();
+#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+ // Expensive pattern check that can detect API violations.
if (checkOp) {
OperationFingerPrint fingerPrintAfterPattern(checkOp);
if (fingerPrintAfterPattern != *topLevelFingerPrint)
llvm::report_fatal_error("pattern '" + pattern.getDebugName() +
"' returned failure but IR did change");
}
- }
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+ }
rewriterImpl.patternNewOps.clear();
rewriterImpl.patternModifiedOps.clear();
rewriterImpl.patternInsertedBlocks.clear();
@@ -2379,18 +2697,28 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
// successfully applied.
auto onSuccess = [&](const Pattern &pattern) {
assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
+ if (!rewriterImpl.config.allowPatternRollback) {
+ // Eagerly erase unused materializations.
+ for (auto op : rewriterImpl.patternMaterializations) {
+ if (op->use_empty()) {
+ rewriterImpl.unresolvedMaterializations.erase(op);
+ op.erase();
+ }
+ }
+ rewriterImpl.patternMaterializations.clear();
+ }
SetVector<Operation *> newOps = moveAndReset(rewriterImpl.patternNewOps);
SetVector<Operation *> modifiedOps =
moveAndReset(rewriterImpl.patternModifiedOps);
SetVector<Block *> insertedBlocks =
moveAndReset(rewriterImpl.patternInsertedBlocks);
- auto result = legalizePatternResult(op, pattern, rewriter, curState, newOps,
+ auto result = legalizePatternResult(op, pattern, curState, newOps,
modifiedOps, insertedBlocks);
appliedPatterns.erase(&pattern);
if (failed(result)) {
if (!rewriterImpl.config.allowPatternRollback)
- llvm::report_fatal_error("pattern '" + pattern.getDebugName() +
- "' produced IR that could not be legalized");
+ reportNewIrLegalizationFatalError(pattern, newOps, modifiedOps,
+ insertedBlocks);
rewriterImpl.resetState(curState, pattern.getDebugName());
}
if (config.listener)
@@ -2403,8 +2731,8 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
onSuccess);
}
-bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern,
- ConversionPatternRewriter &rewriter) {
+bool OperationLegalizer::canApplyPattern(Operation *op,
+ const Pattern &pattern) {
LLVM_DEBUG({
auto &os = rewriter.getImpl().logger;
os.getOStream() << "\n";
@@ -2426,11 +2754,11 @@ bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern,
}
LogicalResult OperationLegalizer::legalizePatternResult(
- Operation *op, const Pattern &pattern, ConversionPatternRewriter &rewriter,
- const RewriterState &curState, const SetVector<Operation *> &newOps,
+ Operation *op, const Pattern &pattern, const RewriterState &curState,
+ const SetVector<Operation *> &newOps,
const SetVector<Operation *> &modifiedOps,
const SetVector<Block *> &insertedBlocks) {
- auto &impl = rewriter.getImpl();
+ [[maybe_unused]] auto &impl = rewriter.getImpl();
assert(impl.pendingRootUpdates.empty() && "dangling root updates");
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
@@ -2448,10 +2776,9 @@ LogicalResult OperationLegalizer::legalizePatternResult(
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
// Legalize each of the actions registered during application.
- if (failed(legalizePatternBlockRewrites(op, rewriter, impl, insertedBlocks,
- newOps)) ||
- failed(legalizePatternRootUpdates(rewriter, impl, modifiedOps)) ||
- failed(legalizePatternCreatedOperations(rewriter, impl, newOps))) {
+ if (failed(legalizePatternBlockRewrites(op, insertedBlocks, newOps)) ||
+ failed(legalizePatternRootUpdates(modifiedOps)) ||
+ failed(legalizePatternCreatedOperations(newOps))) {
return failure();
}
@@ -2460,15 +2787,17 @@ LogicalResult OperationLegalizer::legalizePatternResult(
}
LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
- Operation *op, ConversionPatternRewriter &rewriter,
- ConversionPatternRewriterImpl &impl,
- const SetVector<Block *> &insertedBlocks,
+ Operation *op, const SetVector<Block *> &insertedBlocks,
const SetVector<Operation *> &newOps) {
+ ConversionPatternRewriterImpl &impl = rewriter.getImpl();
SmallPtrSet<Operation *, 16> alreadyLegalized;
// If the pattern moved or created any blocks, make sure the types of block
// arguments get legalized.
for (Block *block : insertedBlocks) {
+ if (impl.erasedBlocks.contains(block))
+ continue;
+
// Only check blocks outside of the current operation.
Operation *parentOp = block->getParentOp();
if (!parentOp || parentOp == op || block->getNumArguments() == 0)
@@ -2484,7 +2813,7 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
"block"));
return failure();
}
- impl.applySignatureConversion(rewriter, block, converter, *conversion);
+ impl.applySignatureConversion(block, converter, *conversion);
continue;
}
@@ -2493,7 +2822,7 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
// operation, and blocks in regions created by this pattern will already be
// legalized later on.
if (!newOps.count(parentOp) && alreadyLegalized.insert(parentOp).second) {
- if (failed(legalize(parentOp, rewriter))) {
+ if (failed(legalize(parentOp))) {
LLVM_DEBUG(logFailure(
impl.logger, "operation '{0}'({1}) became illegal after rewrite",
parentOp->getName(), parentOp));
@@ -2505,11 +2834,10 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
}
LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
- ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
const SetVector<Operation *> &newOps) {
for (Operation *op : newOps) {
- if (failed(legalize(op, rewriter))) {
- LLVM_DEBUG(logFailure(impl.logger,
+ if (failed(legalize(op))) {
+ LLVM_DEBUG(logFailure(rewriter.getImpl().logger,
"failed to legalize generated operation '{0}'({1})",
op->getName(), op));
return failure();
@@ -2519,13 +2847,13 @@ LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
}
LogicalResult OperationLegalizer::legalizePatternRootUpdates(
- ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
const SetVector<Operation *> &modifiedOps) {
for (Operation *op : modifiedOps) {
- if (failed(legalize(op, rewriter))) {
- LLVM_DEBUG(logFailure(
- impl.logger, "failed to legalize operation updated in-place '{0}'",
- op->getName()));
+ if (failed(legalize(op))) {
+ LLVM_DEBUG(
+ logFailure(rewriter.getImpl().logger,
+ "failed to legalize operation updated in-place '{0}'",
+ op->getName()));
return failure();
}
}
@@ -2745,11 +3073,11 @@ namespace mlir {
// rewrite patterns. The conversion behaves differently depending on the
// conversion mode.
struct OperationConverter {
- explicit OperationConverter(const ConversionTarget &target,
+ explicit OperationConverter(MLIRContext *ctx, const ConversionTarget &target,
const FrozenRewritePatternSet &patterns,
const ConversionConfig &config,
OpConversionMode mode)
- : config(config), opLegalizer(target, patterns, this->config),
+ : rewriter(ctx, config), opLegalizer(rewriter, target, patterns),
mode(mode) {}
/// Converts the given operations to the conversion target.
@@ -2757,10 +3085,10 @@ struct OperationConverter {
private:
/// Converts an operation with the given rewriter.
- LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op);
+ LogicalResult convert(Operation *op);
- /// Dialect conversion configuration.
- ConversionConfig config;
+ /// The rewriter to use when converting operations.
+ ConversionPatternRewriter rewriter;
/// The legalizer to use when converting operations.
OperationLegalizer opLegalizer;
@@ -2770,10 +3098,11 @@ private:
};
} // namespace mlir
-LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
- Operation *op) {
+LogicalResult OperationConverter::convert(Operation *op) {
+ const ConversionConfig &config = rewriter.getConfig();
+
// Legalize the given operation.
- if (failed(opLegalizer.legalize(op, rewriter))) {
+ if (failed(opLegalizer.legalize(op))) {
// Handle the case of a failed conversion for each of the different modes.
// Full conversions expect all operations to be converted.
if (mode == OpConversionMode::Full)
@@ -2849,7 +3178,6 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter,
}
LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
- assert(!ops.empty() && "expected at least one operation");
const ConversionTarget &target = opLegalizer.getTarget();
// Compute the set of operations and blocks to convert.
@@ -2868,11 +3196,10 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
}
// Convert each operation and discard rewrites on failure.
- ConversionPatternRewriter rewriter(ops.front()->getContext(), config);
ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
for (auto *op : toConvert) {
- if (failed(convert(rewriter, op))) {
+ if (failed(convert(op))) {
// Dialect conversion failed.
if (rewriterImpl.config.allowPatternRollback) {
// Rollback is allowed: restore the original IR.
@@ -2902,14 +3229,21 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
SmallVector<UnrealizedConversionCastOp> remainingCastOps;
reconcileUnrealizedCasts(allCastOps, &remainingCastOps);
+ // Drop markers.
+ for (UnrealizedConversionCastOp castOp : remainingCastOps)
+ castOp->removeAttr(kPureTypeConversionMarker);
+
// Try to legalize all unresolved materializations.
- if (config.buildMaterializations) {
- IRRewriter rewriter(rewriterImpl.context, config.listener);
+ if (rewriter.getConfig().buildMaterializations) {
+ // Use a new rewriter, so the modifications are not tracked for rollback
+ // purposes etc.
+ IRRewriter irRewriter(rewriterImpl.rewriter.getContext(),
+ rewriter.getConfig().listener);
for (UnrealizedConversionCastOp castOp : remainingCastOps) {
auto it = materializations.find(castOp);
assert(it != materializations.end() && "inconsistent state");
- if (failed(
- legalizeUnresolvedMaterialization(rewriter, castOp, it->second)))
+ if (failed(legalizeUnresolvedMaterialization(irRewriter, castOp,
+ it->second)))
return failure();
}
}
@@ -3076,6 +3410,27 @@ LogicalResult TypeConverter::convertType(Type t,
return failure();
}
+LogicalResult TypeConverter::convertType(Value v,
+ SmallVectorImpl<Type> &results) const {
+ assert(v && "expected non-null value");
+
+ // If this type converter does not have context-aware type conversions, call
+ // the type-based overload, which has caching.
+ if (!hasContextAwareTypeConversions)
+ return convertType(v.getType(), results);
+
+ // Walk the added converters in reverse order to apply the most recently
+ // registered first.
+ for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) {
+ if (std::optional<LogicalResult> result = converter(v, results)) {
+ if (!succeeded(*result))
+ return failure();
+ return success();
+ }
+ }
+ return failure();
+}
+
Type TypeConverter::convertType(Type t) const {
// Use the multi-type result version to convert the type.
SmallVector<Type, 1> results;
@@ -3086,6 +3441,16 @@ Type TypeConverter::convertType(Type t) const {
return results.size() == 1 ? results.front() : nullptr;
}
+Type TypeConverter::convertType(Value v) const {
+ // Use the multi-type result version to convert the type.
+ SmallVector<Type, 1> results;
+ if (failed(convertType(v, results)))
+ return nullptr;
+
+ // Check to ensure that only one type was produced.
+ return results.size() == 1 ? results.front() : nullptr;
+}
+
LogicalResult
TypeConverter::convertTypes(TypeRange types,
SmallVectorImpl<Type> &results) const {
@@ -3095,21 +3460,38 @@ TypeConverter::convertTypes(TypeRange types,
return success();
}
+LogicalResult
+TypeConverter::convertTypes(ValueRange values,
+ SmallVectorImpl<Type> &results) const {
+ for (Value value : values)
+ if (failed(convertType(value, results)))
+ return failure();
+ return success();
+}
+
bool TypeConverter::isLegal(Type type) const {
return convertType(type) == type;
}
+
+bool TypeConverter::isLegal(Value value) const {
+ return convertType(value) == value.getType();
+}
+
bool TypeConverter::isLegal(Operation *op) const {
- return isLegal(op->getOperandTypes()) && isLegal(op->getResultTypes());
+ return isLegal(op->getOperands()) && isLegal(op->getResults());
}
bool TypeConverter::isLegal(Region *region) const {
- return llvm::all_of(*region, [this](Block &block) {
- return isLegal(block.getArgumentTypes());
- });
+ return llvm::all_of(
+ *region, [this](Block &block) { return isLegal(block.getArguments()); });
}
bool TypeConverter::isSignatureLegal(FunctionType ty) const {
- return isLegal(llvm::concat<const Type>(ty.getInputs(), ty.getResults()));
+ if (!isLegal(ty.getInputs()))
+ return false;
+ if (!isLegal(ty.getResults()))
+ return false;
+ return true;
}
LogicalResult
@@ -3137,6 +3519,31 @@ TypeConverter::convertSignatureArgs(TypeRange types,
return failure();
return success();
}
+LogicalResult
+TypeConverter::convertSignatureArg(unsigned inputNo, Value value,
+ SignatureConversion &result) const {
+ // Try to convert the given input type.
+ SmallVector<Type, 1> convertedTypes;
+ if (failed(convertType(value, convertedTypes)))
+ return failure();
+
+ // If this argument is being dropped, there is nothing left to do.
+ if (convertedTypes.empty())
+ return success();
+
+ // Otherwise, add the new inputs.
+ result.addInputs(inputNo, convertedTypes);
+ return success();
+}
+LogicalResult
+TypeConverter::convertSignatureArgs(ValueRange values,
+ SignatureConversion &result,
+ unsigned origInputOffset) const {
+ for (unsigned i = 0, e = values.size(); i != e; ++i)
+ if (failed(convertSignatureArg(origInputOffset + i, values[i], result)))
+ return failure();
+ return success();
+}
Value TypeConverter::materializeSourceConversion(OpBuilder &builder,
Location loc, Type resultType,
@@ -3180,7 +3587,7 @@ SmallVector<Value> TypeConverter::materializeTargetConversion(
std::optional<TypeConverter::SignatureConversion>
TypeConverter::convertBlockSignature(Block *block) const {
SignatureConversion conversion(block->getNumArguments());
- if (failed(convertSignatureArgs(block->getArgumentTypes(), conversion)))
+ if (failed(convertSignatureArgs(block->getArguments(), conversion)))
return std::nullopt;
return conversion;
}
@@ -3305,7 +3712,7 @@ mlir::convertOpResultTypes(Operation *op, ValueRange operands,
newOp.addOperands(operands);
SmallVector<Type> newResultTypes;
- if (failed(converter.convertTypes(op->getResultTypes(), newResultTypes)))
+ if (failed(converter.convertTypes(op->getResults(), newResultTypes)))
return rewriter.notifyMatchFailure(loc, "couldn't convert return types");
newOp.addTypes(newResultTypes);
@@ -3578,7 +3985,8 @@ static LogicalResult applyConversion(ArrayRef<Operation *> ops,
SmallVector<IRUnit> irUnits(ops.begin(), ops.end());
ctx->executeAction<ApplyConversionAction>(
[&] {
- OperationConverter opConverter(target, patterns, config, mode);
+ OperationConverter opConverter(ops.front()->getContext(), target,
+ patterns, config, mode);
status = opConverter.convertOperations(ops);
},
irUnits);
diff --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp
index e9adda0..5e07509 100644
--- a/mlir/lib/Transforms/Utils/FoldUtils.cpp
+++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp
@@ -154,11 +154,14 @@ bool OperationFolder::insertKnownConstant(Operation *op, Attribute constValue) {
// already at the front of the block, or the previous operation is already a
// constant we unique'd (i.e. one we inserted), then we don't need to do
// anything. Otherwise, we move the constant to the insertion block.
+ // The location info is erased if the constant is moved to a different block.
Block *insertBlock = &insertRegion->front();
- if (opBlock != insertBlock || (&insertBlock->front() != op &&
- !isFolderOwnedConstant(op->getPrevNode()))) {
+ if (opBlock != insertBlock) {
op->moveBefore(&insertBlock->front());
op->setLoc(erasedFoldedLocation);
+ } else if (&insertBlock->front() != op &&
+ !isFolderOwnedConstant(op->getPrevNode())) {
+ op->moveBefore(&insertBlock->front());
}
folderConstOp = op;
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 607b86c..0324588 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -15,6 +15,8 @@
#include "mlir/Config/mlir-config.h"
#include "mlir/IR/Action.h"
#include "mlir/IR/Matchers.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Rewrite/PatternApplicator.h"
@@ -23,7 +25,7 @@
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/ScopeExit.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/ScopedPrinter.h"
#include "llvm/Support/raw_ostream.h"
@@ -178,9 +180,8 @@ static Operation *getDumpRootOp(Operation *op) {
return op;
}
static void logSuccessfulFolding(Operation *op) {
- llvm::dbgs() << "// *** IR Dump After Successful Folding ***\n";
- op->dump();
- llvm::dbgs() << "\n\n";
+ LDBG() << "// *** IR Dump After Successful Folding ***\n"
+ << OpWithFlags(op, OpPrintingFlags().elideLargeElementsAttrs());
}
#endif // NDEBUG
@@ -394,8 +395,12 @@ private:
function_ref<void(Diagnostic &)> reasonCallback) override;
#ifndef NDEBUG
+ /// A raw output stream used to prefix the debug log.
+
+ llvm::impl::raw_ldbg_ostream os{(Twine("[") + DEBUG_TYPE + ":1] ").str(),
+ llvm::dbgs()};
/// A logger used to emit information during the application process.
- llvm::ScopedPrinter logger{llvm::dbgs()};
+ llvm::ScopedPrinter logger{os};
#endif
/// The low-level pattern applicator.
@@ -871,7 +876,18 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
ctx->executeAction<GreedyPatternRewriteIteration>(
[&] {
- continueRewrites = processWorklist();
+ continueRewrites = false;
+
+ // Erase unreachable blocks
+ // Operations like:
+ // %add = arith.addi %add, %add : i64
+ // are legal in unreachable code. Unfortunately many patterns would be
+ // unsafe to apply on such IR and can lead to crashes or infinite
+ // loops.
+ continueRewrites |=
+ succeeded(eraseUnreachableBlocks(rewriter, region));
+
+ continueRewrites |= processWorklist();
// After applying patterns, make sure that the CFG of each of the
// regions is kept up to date.
@@ -917,10 +933,9 @@ mlir::applyPatternsGreedily(Region &region,
RegionPatternRewriteDriver driver(region.getContext(), patterns, config,
region);
LogicalResult converged = std::move(driver).simplify(changed);
- LLVM_DEBUG(if (failed(converged)) {
- llvm::dbgs() << "The pattern rewrite did not converge after scanning "
- << config.getMaxIterations() << " times\n";
- });
+ if (failed(converged))
+ LDBG() << "The pattern rewrite did not converge after scanning "
+ << config.getMaxIterations() << " times";
return converged;
}
@@ -1052,9 +1067,8 @@ LogicalResult mlir::applyOpPatternsGreedily(
LogicalResult converged = std::move(driver).simplify(ops, changed);
if (allErased)
*allErased = surviving.empty();
- LLVM_DEBUG(if (failed(converged)) {
- llvm::dbgs() << "The pattern rewrite did not converge after "
- << config.getMaxNumRewrites() << " rewrites";
- });
+ if (failed(converged))
+ LDBG() << "The pattern rewrite did not converge after "
+ << config.getMaxNumRewrites() << " rewrites";
return converged;
}
diff --git a/mlir/lib/Transforms/Utils/InliningUtils.cpp b/mlir/lib/Transforms/Utils/InliningUtils.cpp
index eeb4052..73107cf 100644
--- a/mlir/lib/Transforms/Utils/InliningUtils.cpp
+++ b/mlir/lib/Transforms/Utils/InliningUtils.cpp
@@ -13,10 +13,12 @@
#include "mlir/Transforms/InliningUtils.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Operation.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/raw_ostream.h"
#include <optional>
@@ -182,13 +184,16 @@ static bool isLegalToInline(InlinerInterface &interface, Region *src,
IRMapping &valueMapping) {
for (auto &block : *src) {
for (auto &op : block) {
+ // UnrealizedConversionCastOp is inlineable but cannot implement the
+ // inliner interface due to layering constraints.
+ if (isa<UnrealizedConversionCastOp>(op))
+ continue;
+
// Check this operation.
if (!interface.isLegalToInline(&op, insertRegion,
shouldCloneInlinedRegion, valueMapping)) {
- LLVM_DEBUG({
- llvm::dbgs() << "* Illegal to inline because of op: ";
- op.dump();
- });
+ LDBG() << "* Illegal to inline because of op: "
+ << OpWithFlags(&op, OpPrintingFlags().skipRegions());
return false;
}
// Check any nested regions.
diff --git a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
index cb3f2c5..111f58e 100644
--- a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
@@ -13,11 +13,13 @@
#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
#include "mlir/IR/Operation.h"
+#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/SubsetOpInterface.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include <queue>
#define DEBUG_TYPE "licm"
@@ -64,8 +66,7 @@ size_t mlir::moveLoopInvariantCode(
size_t numMoved = 0;
for (Region *region : regions) {
- LLVM_DEBUG(llvm::dbgs() << "Original loop:\n"
- << *region->getParentOp() << "\n");
+ LDBG() << "Original loop:\n" << *region->getParentOp();
std::queue<Operation *> worklist;
// Add top-level operations in the loop body to the worklist.
@@ -83,12 +84,13 @@ size_t mlir::moveLoopInvariantCode(
if (op->getParentRegion() != region)
continue;
- LLVM_DEBUG(llvm::dbgs() << "Checking op: " << *op << "\n");
+ LDBG() << "Checking op: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
if (!shouldMoveOutOfRegion(op, region) ||
!canBeHoisted(op, definedOutside))
continue;
- LLVM_DEBUG(llvm::dbgs() << "Moving loop-invariant op: " << *op << "\n");
+ LDBG() << "Moving loop-invariant op: " << *op;
moveOutOfRegion(op, region);
++numMoved;
@@ -322,7 +324,7 @@ static LoopLikeOpInterface hoistSubsetAtIterArg(RewriterBase &rewriter,
LoopLikeOpInterface loopLike,
BlockArgument iterArg) {
assert(iterArg.getOwner()->getParentOp() == loopLike && "invalid iter_arg");
- auto it = llvm::find(loopLike.getRegionIterArgs(), iterArg);
+ BlockArgument *it = llvm::find(loopLike.getRegionIterArgs(), iterArg);
int64_t iterArgIdx = std::distance(loopLike.getRegionIterArgs().begin(), it);
MatchingSubsets subsets;
if (failed(subsets.populateSubsetOpsAtIterArg(loopLike, iterArg)))
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index a1d975d..31ae1d1 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -23,12 +23,15 @@
#include "llvm/ADT/DepthFirstIterator.h"
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/DebugLog.h"
#include <deque>
#include <iterator>
using namespace mlir;
+#define DEBUG_TYPE "region-utils"
+
void mlir::replaceAllUsesInRegionWith(Value orig, Value replacement,
Region &region) {
for (auto &use : llvm::make_early_inc_range(orig.getUses())) {
@@ -182,19 +185,34 @@ SmallVector<Value> mlir::makeRegionIsolatedFromAbove(
// TODO: We could likely merge this with the DCE algorithm below.
LogicalResult mlir::eraseUnreachableBlocks(RewriterBase &rewriter,
MutableArrayRef<Region> regions) {
+ LDBG() << "Starting eraseUnreachableBlocks with " << regions.size()
+ << " regions";
+
// Set of blocks found to be reachable within a given region.
llvm::df_iterator_default_set<Block *, 16> reachable;
// If any blocks were found to be dead.
- bool erasedDeadBlocks = false;
+ int erasedDeadBlocks = 0;
SmallVector<Region *, 1> worklist;
worklist.reserve(regions.size());
for (Region &region : regions)
worklist.push_back(&region);
+
+ LDBG(2) << "Initial worklist size: " << worklist.size();
+
while (!worklist.empty()) {
Region *region = worklist.pop_back_val();
- if (region->empty())
+ if (region->empty()) {
+ LDBG(2) << "Skipping empty region";
continue;
+ }
+
+ LDBG(2) << "Processing region with " << region->getBlocks().size()
+ << " blocks";
+ if (region->getParentOp())
+ LDBG(2) << " -> for operation: "
+ << OpWithFlags(region->getParentOp(),
+ OpPrintingFlags().skipRegions());
// If this is a single block region, just collect the nested regions.
if (region->hasOneBlock()) {
@@ -209,13 +227,17 @@ LogicalResult mlir::eraseUnreachableBlocks(RewriterBase &rewriter,
for (Block *block : depth_first_ext(&region->front(), reachable))
(void)block /* Mark all reachable blocks */;
+ LDBG(2) << "Found " << reachable.size() << " reachable blocks out of "
+ << region->getBlocks().size() << " total blocks";
+
// Collect all of the dead blocks and push the live regions onto the
// worklist.
for (Block &block : llvm::make_early_inc_range(*region)) {
if (!reachable.count(&block)) {
+ LDBG() << "Erasing unreachable block: " << &block;
block.dropAllDefinedValueUses();
rewriter.eraseBlock(&block);
- erasedDeadBlocks = true;
+ ++erasedDeadBlocks;
continue;
}
@@ -226,7 +248,10 @@ LogicalResult mlir::eraseUnreachableBlocks(RewriterBase &rewriter,
}
}
- return success(erasedDeadBlocks);
+ LDBG() << "Finished eraseUnreachableBlocks, erased " << erasedDeadBlocks
+ << " dead blocks";
+
+ return success(erasedDeadBlocks > 0);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
index ee5c642..1382550 100644
--- a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
@@ -13,18 +13,40 @@
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Operation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Verifier.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Rewrite/PatternApplicator.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/ErrorHandling.h"
#define DEBUG_TYPE "walk-rewriter"
namespace mlir {
+// Find all reachable blocks in the region and add them to the visitedBlocks
+// set.
+static void findReachableBlocks(Region &region,
+ DenseSet<Block *> &reachableBlocks) {
+ Block *entryBlock = &region.front();
+ reachableBlocks.insert(entryBlock);
+ // Traverse the CFG and add all reachable blocks to the blockList.
+ SmallVector<Block *> worklist({entryBlock});
+ while (!worklist.empty()) {
+ Block *block = worklist.pop_back_val();
+ Operation *terminator = &block->back();
+ for (Block *successor : terminator->getSuccessors()) {
+ if (reachableBlocks.contains(successor))
+ continue;
+ worklist.push_back(successor);
+ reachableBlocks.insert(successor);
+ }
+ }
+}
+
namespace {
struct WalkAndApplyPatternsAction final
: tracing::ActionImpl<WalkAndApplyPatternsAction> {
@@ -88,20 +110,104 @@ void walkAndApplyPatterns(Operation *op,
PatternApplicator applicator(patterns);
applicator.applyDefaultCostModel();
+ // Iterator on all reachable operations in the region.
+ // Also keep track if we visited the nested regions of the current op
+ // already to drive the post-order traversal.
+ struct RegionReachableOpIterator {
+ RegionReachableOpIterator(Region *region) : region(region) {
+ regionIt = region->begin();
+ if (regionIt != region->end())
+ blockIt = regionIt->begin();
+ if (!llvm::hasSingleElement(*region))
+ findReachableBlocks(*region, reachableBlocks);
+ }
+ // Advance the iterator to the next reachable operation.
+ void advance() {
+ assert(regionIt != region->end());
+ hasVisitedRegions = false;
+ if (blockIt == regionIt->end()) {
+ ++regionIt;
+ while (regionIt != region->end() &&
+ !reachableBlocks.contains(&*regionIt))
+ ++regionIt;
+ if (regionIt != region->end())
+ blockIt = regionIt->begin();
+ return;
+ }
+ ++blockIt;
+ if (blockIt != regionIt->end()) {
+ LDBG() << "Incrementing block iterator, next op: "
+ << OpWithFlags(&*blockIt, OpPrintingFlags().skipRegions());
+ }
+ }
+ // The region we're iterating over.
+ Region *region;
+ // The Block currently being iterated over.
+ Region::iterator regionIt;
+ // The Operation currently being iterated over.
+ Block::iterator blockIt;
+ // The set of blocks that are reachable in the current region.
+ DenseSet<Block *> reachableBlocks;
+ // Whether we've visited the nested regions of the current op already.
+ bool hasVisitedRegions = false;
+ };
+
+ // Worklist of regions to visit to drive the post-order traversal.
+ SmallVector<RegionReachableOpIterator> worklist;
+
+ LDBG() << "Starting walk-based pattern rewrite driver";
ctx->executeAction<WalkAndApplyPatternsAction>(
[&] {
+ // Perform a post-order traversal of the regions, visiting each
+ // reachable operation.
for (Region &region : op->getRegions()) {
- region.walk([&](Operation *visitedOp) {
- LLVM_DEBUG(llvm::dbgs() << "Visiting op: "; visitedOp->print(
- llvm::dbgs(), OpPrintingFlags().skipRegions());
- llvm::dbgs() << "\n";);
+ assert(worklist.empty());
+ if (region.empty())
+ continue;
+
+ // Prime the worklist with the entry block of this region.
+ worklist.push_back({&region});
+ while (!worklist.empty()) {
+ RegionReachableOpIterator &it = worklist.back();
+ if (it.regionIt == it.region->end()) {
+ // We're done with this region.
+ worklist.pop_back();
+ continue;
+ }
+ if (it.blockIt == it.regionIt->end()) {
+ // We're done with this block.
+ it.advance();
+ continue;
+ }
+ Operation *op = &*it.blockIt;
+ // If we haven't visited the nested regions of this op yet,
+ // enqueue them.
+ if (!it.hasVisitedRegions) {
+ it.hasVisitedRegions = true;
+ for (Region &nestedRegion : llvm::reverse(op->getRegions())) {
+ if (nestedRegion.empty())
+ continue;
+ worklist.push_back({&nestedRegion});
+ }
+ }
+ // If we're not at the back of the worklist, we've enqueued some
+ // nested region for processing. We'll come back to this op later
+ // (post-order)
+ if (&it != &worklist.back())
+ continue;
+
+ // Preemptively increment the iterator, in case the current op
+ // would be erased.
+ it.advance();
+
+ LDBG() << "Visiting op: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
- erasedListener.visitedOp = visitedOp;
+ erasedListener.visitedOp = op;
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
- if (succeeded(applicator.matchAndRewrite(visitedOp, rewriter))) {
- LLVM_DEBUG(llvm::dbgs() << "\tOp matched and rewritten\n";);
- }
- });
+ if (succeeded(applicator.matchAndRewrite(op, rewriter)))
+ LDBG() << "\tOp matched and rewritten";
+ }
}
},
{op});