aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Rewrite/ByteCode.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Rewrite/ByteCode.cpp')
-rw-r--r--mlir/lib/Rewrite/ByteCode.cpp105
1 files changed, 73 insertions, 32 deletions
diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp
index 388d6dc..9cc51da 100644
--- a/mlir/lib/Rewrite/ByteCode.cpp
+++ b/mlir/lib/Rewrite/ByteCode.cpp
@@ -34,21 +34,23 @@ using namespace mlir::detail;
//===----------------------------------------------------------------------===//
PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp,
+ PDLPatternConfigSet *configSet,
ByteCodeAddr rewriterAddr) {
+ PatternBenefit benefit = matchOp.getBenefit();
+ MLIRContext *ctx = matchOp.getContext();
+
+ // Collect the set of generated operations.
SmallVector<StringRef, 8> generatedOps;
if (ArrayAttr generatedOpsAttr = matchOp.getGeneratedOpsAttr())
generatedOps =
llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>());
- PatternBenefit benefit = matchOp.getBenefit();
- MLIRContext *ctx = matchOp.getContext();
-
// Check to see if this is pattern matches a specific operation type.
if (Optional<StringRef> rootKind = matchOp.getRootKind())
- return PDLByteCodePattern(rewriterAddr, *rootKind, benefit, ctx,
+ return PDLByteCodePattern(rewriterAddr, configSet, *rootKind, benefit, ctx,
generatedOps);
- return PDLByteCodePattern(rewriterAddr, MatchAnyOpTypeTag(), benefit, ctx,
- generatedOps);
+ return PDLByteCodePattern(rewriterAddr, configSet, MatchAnyOpTypeTag(),
+ benefit, ctx, generatedOps);
}
//===----------------------------------------------------------------------===//
@@ -194,14 +196,15 @@ public:
ByteCodeField &maxValueRangeMemoryIndex,
ByteCodeField &maxLoopLevel,
llvm::StringMap<PDLConstraintFunction> &constraintFns,
- llvm::StringMap<PDLRewriteFunction> &rewriteFns)
+ llvm::StringMap<PDLRewriteFunction> &rewriteFns,
+ const DenseMap<Operation *, PDLPatternConfigSet *> &configMap)
: ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
rewriterByteCode(rewriterByteCode), patterns(patterns),
maxValueMemoryIndex(maxValueMemoryIndex),
maxOpRangeMemoryIndex(maxOpRangeMemoryIndex),
maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex),
maxValueRangeMemoryIndex(maxValueRangeMemoryIndex),
- maxLoopLevel(maxLoopLevel) {
+ maxLoopLevel(maxLoopLevel), configMap(configMap) {
for (const auto &it : llvm::enumerate(constraintFns))
constraintToMemIndex.try_emplace(it.value().first(), it.index());
for (const auto &it : llvm::enumerate(rewriteFns))
@@ -328,6 +331,9 @@ private:
ByteCodeField &maxTypeRangeMemoryIndex;
ByteCodeField &maxValueRangeMemoryIndex;
ByteCodeField &maxLoopLevel;
+
+ /// A map of pattern configurations.
+ const DenseMap<Operation *, PDLPatternConfigSet *> &configMap;
};
/// This class provides utilities for writing a bytecode stream.
@@ -969,7 +975,8 @@ void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
ByteCodeField patternIndex = patterns.size();
patterns.emplace_back(PDLByteCodePattern::create(
- op, rewriterToAddr[op.getRewriter().getLeafReference().getValue()]));
+ op, configMap.lookup(op),
+ rewriterToAddr[op.getRewriter().getLeafReference().getValue()]));
writer.append(OpCode::RecordMatch, patternIndex,
SuccessorRange(op.getOperation()), op.getMatchedOps());
writer.appendPDLValueList(op.getInputs());
@@ -1014,13 +1021,16 @@ void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) {
// PDLByteCode
//===----------------------------------------------------------------------===//
-PDLByteCode::PDLByteCode(ModuleOp module,
- llvm::StringMap<PDLConstraintFunction> constraintFns,
- llvm::StringMap<PDLRewriteFunction> rewriteFns) {
+PDLByteCode::PDLByteCode(
+ ModuleOp module, SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs,
+ const DenseMap<Operation *, PDLPatternConfigSet *> &configMap,
+ llvm::StringMap<PDLConstraintFunction> constraintFns,
+ llvm::StringMap<PDLRewriteFunction> rewriteFns)
+ : configs(std::move(configs)) {
Generator generator(module.getContext(), uniquedData, matcherByteCode,
rewriterByteCode, patterns, maxValueMemoryIndex,
maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount,
- maxLoopLevel, constraintFns, rewriteFns);
+ maxLoopLevel, constraintFns, rewriteFns, configMap);
generator.generate(module);
// Initialize the external functions.
@@ -1076,14 +1086,15 @@ public:
/// Start executing the code at the current bytecode index. `matches` is an
/// optional field provided when this function is executed in a matching
/// context.
- void execute(PatternRewriter &rewriter,
- SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr,
- Optional<Location> mainRewriteLoc = {});
+ LogicalResult
+ execute(PatternRewriter &rewriter,
+ SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr,
+ Optional<Location> mainRewriteLoc = {});
private:
/// Internal implementation of executing each of the bytecode commands.
void executeApplyConstraint(PatternRewriter &rewriter);
- void executeApplyRewrite(PatternRewriter &rewriter);
+ LogicalResult executeApplyRewrite(PatternRewriter &rewriter);
void executeAreEqual();
void executeAreRangesEqual();
void executeBranch();
@@ -1345,7 +1356,7 @@ void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
selectJump(succeeded(constraintFn(rewriter, args)));
}
-void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
+LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n");
const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
SmallVector<PDLValue, 16> args;
@@ -1359,7 +1370,7 @@ void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
// Execute the rewrite function.
ByteCodeField numResults = read();
ByteCodeRewriteResultList results(numResults);
- rewriteFn(rewriter, results, args);
+ LogicalResult rewriteResult = rewriteFn(rewriter, results, args);
assert(results.getResults().size() == numResults &&
"native PDL rewrite function returned unexpected number of results");
@@ -1395,6 +1406,13 @@ void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
allocatedTypeRangeMemory.push_back(std::move(it));
for (auto &it : results.getAllocatedValueRanges())
allocatedValueRangeMemory.push_back(std::move(it));
+
+ // Process the result of the rewrite.
+ if (failed(rewriteResult)) {
+ LLVM_DEBUG(llvm::dbgs() << " - Failed");
+ return failure();
+ }
+ return success();
}
void ByteCodeExecutor::executeAreEqual() {
@@ -2017,10 +2035,10 @@ void ByteCodeExecutor::executeSwitchTypes() {
});
}
-void ByteCodeExecutor::execute(
- PatternRewriter &rewriter,
- SmallVectorImpl<PDLByteCode::MatchResult> *matches,
- Optional<Location> mainRewriteLoc) {
+LogicalResult
+ByteCodeExecutor::execute(PatternRewriter &rewriter,
+ SmallVectorImpl<PDLByteCode::MatchResult> *matches,
+ Optional<Location> mainRewriteLoc) {
while (true) {
// Print the location of the operation being executed.
LLVM_DEBUG(llvm::dbgs() << readInline<Location>() << "\n");
@@ -2031,7 +2049,8 @@ void ByteCodeExecutor::execute(
executeApplyConstraint(rewriter);
break;
case ApplyRewrite:
- executeApplyRewrite(rewriter);
+ if (failed(executeApplyRewrite(rewriter)))
+ return failure();
break;
case AreEqual:
executeAreEqual();
@@ -2078,7 +2097,7 @@ void ByteCodeExecutor::execute(
case Finalize:
executeFinalize();
LLVM_DEBUG(llvm::dbgs() << "\n");
- return;
+ return success();
case ForEach:
executeForEach();
break;
@@ -2166,8 +2185,6 @@ void ByteCodeExecutor::execute(
}
}
-/// Run the pattern matcher on the given root operation, collecting the matched
-/// patterns in `matches`.
void PDLByteCode::match(Operation *op, PatternRewriter &rewriter,
SmallVectorImpl<MatchResult> &matches,
PDLByteCodeMutableState &state) const {
@@ -2181,7 +2198,8 @@ void PDLByteCode::match(Operation *op, PatternRewriter &rewriter,
state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex,
uniquedData, matcherByteCode, state.currentPatternBenefits, patterns,
constraintFunctions, rewriteFunctions);
- executor.execute(rewriter, &matches);
+ LogicalResult executeResult = executor.execute(rewriter, &matches);
+ assert(succeeded(executeResult) && "unexpected matcher execution failure");
// Order the found matches by benefit.
std::stable_sort(matches.begin(), matches.end(),
@@ -2190,9 +2208,13 @@ void PDLByteCode::match(Operation *op, PatternRewriter &rewriter,
});
}
-/// Run the rewriter of the given pattern on the root operation `op`.
-void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match,
- PDLByteCodeMutableState &state) const {
+LogicalResult PDLByteCode::rewrite(PatternRewriter &rewriter,
+ const MatchResult &match,
+ PDLByteCodeMutableState &state) const {
+ auto *configSet = match.pattern->getConfigSet();
+ if (configSet)
+ configSet->notifyRewriteBegin(rewriter);
+
// The arguments of the rewrite function are stored at the start of the
// memory buffer.
llvm::copy(match.values, state.memory.begin());
@@ -2204,5 +2226,24 @@ void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match,
state.allocatedValueRangeMemory, state.loopIndex, uniquedData,
rewriterByteCode, state.currentPatternBenefits, patterns,
constraintFunctions, rewriteFunctions);
- executor.execute(rewriter, /*matches=*/nullptr, match.location);
+ LogicalResult result =
+ executor.execute(rewriter, /*matches=*/nullptr, match.location);
+
+ if (configSet)
+ configSet->notifyRewriteEnd(rewriter);
+
+ // If the rewrite failed, check if the pattern rewriter can recover. If it
+ // can, we can signal to the pattern applicator to keep trying patterns. If it
+ // doesn't, we need to bail. Bailing here should be fine, given that we have
+ // no means to propagate such a failure to the user, and it also indicates a
+ // bug in the user code (i.e. failable rewrites should not be used with
+ // pattern rewriters that don't support it).
+ if (failed(result) && !rewriter.canRecoverFromRewriteFailure()) {
+ LLVM_DEBUG(llvm::dbgs() << " and rollback is not supported - aborting");
+ llvm::report_fatal_error(
+ "Native PDL Rewrite failed, but the pattern "
+ "rewriter doesn't support recovery. Failable pattern rewrites should "
+ "not be used with pattern rewriters that do not support them.");
+ }
+ return result;
}