aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Rewrite/ByteCode.cpp
diff options
context:
space:
mode:
authorRiver Riddle <riddleriver@gmail.com>2022-09-08 16:59:39 -0700
committerRiver Riddle <riddleriver@gmail.com>2022-11-08 01:57:57 -0800
commit8c66344ee9f67f76b3cb6b3345a46345a2d3975a (patch)
tree65727605f79bacc7595e6dd21fe05625fc02a0ec /mlir/lib/Rewrite/ByteCode.cpp
parentf3a86a23c10db44cbcb432d56201475721c64bdc (diff)
downloadllvm-8c66344ee9f67f76b3cb6b3345a46345a2d3975a.zip
llvm-8c66344ee9f67f76b3cb6b3345a46345a2d3975a.tar.gz
llvm-8c66344ee9f67f76b3cb6b3345a46345a2d3975a.tar.bz2
[mlir:PDL] Add support for DialectConversion with pattern configurations
Up until now PDL(L) has not supported dialect conversion because we had no way of remapping values or integrating with type conversions. This commit rectifies that by adding a new "pattern configuration" concept to PDL. This essentially allows for attaching external configurations to patterns, which can hook into pattern events (for now just the scope of a rewrite, but we could also pass configs to native rewrites as well). This allows for injecting the type converter into the conversion pattern rewriter. Differential Revision: https://reviews.llvm.org/D133142
Diffstat (limited to 'mlir/lib/Rewrite/ByteCode.cpp')
-rw-r--r--mlir/lib/Rewrite/ByteCode.cpp105
1 files changed, 73 insertions, 32 deletions
diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp
index 388d6dc..9cc51da 100644
--- a/mlir/lib/Rewrite/ByteCode.cpp
+++ b/mlir/lib/Rewrite/ByteCode.cpp
@@ -34,21 +34,23 @@ using namespace mlir::detail;
//===----------------------------------------------------------------------===//
PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp,
+ PDLPatternConfigSet *configSet,
ByteCodeAddr rewriterAddr) {
+ PatternBenefit benefit = matchOp.getBenefit();
+ MLIRContext *ctx = matchOp.getContext();
+
+ // Collect the set of generated operations.
SmallVector<StringRef, 8> generatedOps;
if (ArrayAttr generatedOpsAttr = matchOp.getGeneratedOpsAttr())
generatedOps =
llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>());
- PatternBenefit benefit = matchOp.getBenefit();
- MLIRContext *ctx = matchOp.getContext();
-
// Check to see if this is pattern matches a specific operation type.
if (Optional<StringRef> rootKind = matchOp.getRootKind())
- return PDLByteCodePattern(rewriterAddr, *rootKind, benefit, ctx,
+ return PDLByteCodePattern(rewriterAddr, configSet, *rootKind, benefit, ctx,
generatedOps);
- return PDLByteCodePattern(rewriterAddr, MatchAnyOpTypeTag(), benefit, ctx,
- generatedOps);
+ return PDLByteCodePattern(rewriterAddr, configSet, MatchAnyOpTypeTag(),
+ benefit, ctx, generatedOps);
}
//===----------------------------------------------------------------------===//
@@ -194,14 +196,15 @@ public:
ByteCodeField &maxValueRangeMemoryIndex,
ByteCodeField &maxLoopLevel,
llvm::StringMap<PDLConstraintFunction> &constraintFns,
- llvm::StringMap<PDLRewriteFunction> &rewriteFns)
+ llvm::StringMap<PDLRewriteFunction> &rewriteFns,
+ const DenseMap<Operation *, PDLPatternConfigSet *> &configMap)
: ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
rewriterByteCode(rewriterByteCode), patterns(patterns),
maxValueMemoryIndex(maxValueMemoryIndex),
maxOpRangeMemoryIndex(maxOpRangeMemoryIndex),
maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex),
maxValueRangeMemoryIndex(maxValueRangeMemoryIndex),
- maxLoopLevel(maxLoopLevel) {
+ maxLoopLevel(maxLoopLevel), configMap(configMap) {
for (const auto &it : llvm::enumerate(constraintFns))
constraintToMemIndex.try_emplace(it.value().first(), it.index());
for (const auto &it : llvm::enumerate(rewriteFns))
@@ -328,6 +331,9 @@ private:
ByteCodeField &maxTypeRangeMemoryIndex;
ByteCodeField &maxValueRangeMemoryIndex;
ByteCodeField &maxLoopLevel;
+
+ /// A map of pattern configurations.
+ const DenseMap<Operation *, PDLPatternConfigSet *> &configMap;
};
/// This class provides utilities for writing a bytecode stream.
@@ -969,7 +975,8 @@ void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
ByteCodeField patternIndex = patterns.size();
patterns.emplace_back(PDLByteCodePattern::create(
- op, rewriterToAddr[op.getRewriter().getLeafReference().getValue()]));
+ op, configMap.lookup(op),
+ rewriterToAddr[op.getRewriter().getLeafReference().getValue()]));
writer.append(OpCode::RecordMatch, patternIndex,
SuccessorRange(op.getOperation()), op.getMatchedOps());
writer.appendPDLValueList(op.getInputs());
@@ -1014,13 +1021,16 @@ void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) {
// PDLByteCode
//===----------------------------------------------------------------------===//
-PDLByteCode::PDLByteCode(ModuleOp module,
- llvm::StringMap<PDLConstraintFunction> constraintFns,
- llvm::StringMap<PDLRewriteFunction> rewriteFns) {
+PDLByteCode::PDLByteCode(
+ ModuleOp module, SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs,
+ const DenseMap<Operation *, PDLPatternConfigSet *> &configMap,
+ llvm::StringMap<PDLConstraintFunction> constraintFns,
+ llvm::StringMap<PDLRewriteFunction> rewriteFns)
+ : configs(std::move(configs)) {
Generator generator(module.getContext(), uniquedData, matcherByteCode,
rewriterByteCode, patterns, maxValueMemoryIndex,
maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount,
- maxLoopLevel, constraintFns, rewriteFns);
+ maxLoopLevel, constraintFns, rewriteFns, configMap);
generator.generate(module);
// Initialize the external functions.
@@ -1076,14 +1086,15 @@ public:
/// Start executing the code at the current bytecode index. `matches` is an
/// optional field provided when this function is executed in a matching
/// context.
- void execute(PatternRewriter &rewriter,
- SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr,
- Optional<Location> mainRewriteLoc = {});
+ LogicalResult
+ execute(PatternRewriter &rewriter,
+ SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr,
+ Optional<Location> mainRewriteLoc = {});
private:
/// Internal implementation of executing each of the bytecode commands.
void executeApplyConstraint(PatternRewriter &rewriter);
- void executeApplyRewrite(PatternRewriter &rewriter);
+ LogicalResult executeApplyRewrite(PatternRewriter &rewriter);
void executeAreEqual();
void executeAreRangesEqual();
void executeBranch();
@@ -1345,7 +1356,7 @@ void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
selectJump(succeeded(constraintFn(rewriter, args)));
}
-void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
+LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n");
const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
SmallVector<PDLValue, 16> args;
@@ -1359,7 +1370,7 @@ void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
// Execute the rewrite function.
ByteCodeField numResults = read();
ByteCodeRewriteResultList results(numResults);
- rewriteFn(rewriter, results, args);
+ LogicalResult rewriteResult = rewriteFn(rewriter, results, args);
assert(results.getResults().size() == numResults &&
"native PDL rewrite function returned unexpected number of results");
@@ -1395,6 +1406,13 @@ void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
allocatedTypeRangeMemory.push_back(std::move(it));
for (auto &it : results.getAllocatedValueRanges())
allocatedValueRangeMemory.push_back(std::move(it));
+
+ // Process the result of the rewrite.
+ if (failed(rewriteResult)) {
+ LLVM_DEBUG(llvm::dbgs() << " - Failed");
+ return failure();
+ }
+ return success();
}
void ByteCodeExecutor::executeAreEqual() {
@@ -2017,10 +2035,10 @@ void ByteCodeExecutor::executeSwitchTypes() {
});
}
-void ByteCodeExecutor::execute(
- PatternRewriter &rewriter,
- SmallVectorImpl<PDLByteCode::MatchResult> *matches,
- Optional<Location> mainRewriteLoc) {
+LogicalResult
+ByteCodeExecutor::execute(PatternRewriter &rewriter,
+ SmallVectorImpl<PDLByteCode::MatchResult> *matches,
+ Optional<Location> mainRewriteLoc) {
while (true) {
// Print the location of the operation being executed.
LLVM_DEBUG(llvm::dbgs() << readInline<Location>() << "\n");
@@ -2031,7 +2049,8 @@ void ByteCodeExecutor::execute(
executeApplyConstraint(rewriter);
break;
case ApplyRewrite:
- executeApplyRewrite(rewriter);
+ if (failed(executeApplyRewrite(rewriter)))
+ return failure();
break;
case AreEqual:
executeAreEqual();
@@ -2078,7 +2097,7 @@ void ByteCodeExecutor::execute(
case Finalize:
executeFinalize();
LLVM_DEBUG(llvm::dbgs() << "\n");
- return;
+ return success();
case ForEach:
executeForEach();
break;
@@ -2166,8 +2185,6 @@ void ByteCodeExecutor::execute(
}
}
-/// Run the pattern matcher on the given root operation, collecting the matched
-/// patterns in `matches`.
void PDLByteCode::match(Operation *op, PatternRewriter &rewriter,
SmallVectorImpl<MatchResult> &matches,
PDLByteCodeMutableState &state) const {
@@ -2181,7 +2198,8 @@ void PDLByteCode::match(Operation *op, PatternRewriter &rewriter,
state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex,
uniquedData, matcherByteCode, state.currentPatternBenefits, patterns,
constraintFunctions, rewriteFunctions);
- executor.execute(rewriter, &matches);
+ LogicalResult executeResult = executor.execute(rewriter, &matches);
+ assert(succeeded(executeResult) && "unexpected matcher execution failure");
// Order the found matches by benefit.
std::stable_sort(matches.begin(), matches.end(),
@@ -2190,9 +2208,13 @@ void PDLByteCode::match(Operation *op, PatternRewriter &rewriter,
});
}
-/// Run the rewriter of the given pattern on the root operation `op`.
-void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match,
- PDLByteCodeMutableState &state) const {
+LogicalResult PDLByteCode::rewrite(PatternRewriter &rewriter,
+ const MatchResult &match,
+ PDLByteCodeMutableState &state) const {
+ auto *configSet = match.pattern->getConfigSet();
+ if (configSet)
+ configSet->notifyRewriteBegin(rewriter);
+
// The arguments of the rewrite function are stored at the start of the
// memory buffer.
llvm::copy(match.values, state.memory.begin());
@@ -2204,5 +2226,24 @@ void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match,
state.allocatedValueRangeMemory, state.loopIndex, uniquedData,
rewriterByteCode, state.currentPatternBenefits, patterns,
constraintFunctions, rewriteFunctions);
- executor.execute(rewriter, /*matches=*/nullptr, match.location);
+ LogicalResult result =
+ executor.execute(rewriter, /*matches=*/nullptr, match.location);
+
+ if (configSet)
+ configSet->notifyRewriteEnd(rewriter);
+
+ // If the rewrite failed, check if the pattern rewriter can recover. If it
+ // can, we can signal to the pattern applicator to keep trying patterns. If it
+ // doesn't, we need to bail. Bailing here should be fine, given that we have
+ // no means to propagate such a failure to the user, and it also indicates a
+ // bug in the user code (i.e. failable rewrites should not be used with
+ // pattern rewriters that don't support it).
+ if (failed(result) && !rewriter.canRecoverFromRewriteFailure()) {
+ LLVM_DEBUG(llvm::dbgs() << " and rollback is not supported - aborting");
+ llvm::report_fatal_error(
+ "Native PDL Rewrite failed, but the pattern "
+ "rewriter doesn't support recovery. Failable pattern rewrites should "
+ "not be used with pattern rewriters that do not support them.");
+ }
+ return result;
}