diff options
author | River Riddle <riddleriver@gmail.com> | 2022-09-08 16:59:39 -0700 |
---|---|---|
committer | River Riddle <riddleriver@gmail.com> | 2022-11-08 01:57:57 -0800 |
commit | 8c66344ee9f67f76b3cb6b3345a46345a2d3975a (patch) | |
tree | 65727605f79bacc7595e6dd21fe05625fc02a0ec /mlir/lib/Rewrite/ByteCode.cpp | |
parent | f3a86a23c10db44cbcb432d56201475721c64bdc (diff) | |
download | llvm-8c66344ee9f67f76b3cb6b3345a46345a2d3975a.zip llvm-8c66344ee9f67f76b3cb6b3345a46345a2d3975a.tar.gz llvm-8c66344ee9f67f76b3cb6b3345a46345a2d3975a.tar.bz2 |
[mlir:PDL] Add support for DialectConversion with pattern configurations
Up until now PDL(L) has not supported dialect conversion because we had no
way of remapping values or integrating with type conversions. This commit
rectifies that by adding a new "pattern configuration" concept to PDL. This
essentially allows for attaching external configurations to patterns, which
can hook into pattern events (for now just the scope of a rewrite, but we
could also pass configs to native rewrites as well). This allows for injecting
the type converter into the conversion pattern rewriter.
Differential Revision: https://reviews.llvm.org/D133142
Diffstat (limited to 'mlir/lib/Rewrite/ByteCode.cpp')
-rw-r--r-- | mlir/lib/Rewrite/ByteCode.cpp | 105 |
1 files changed, 73 insertions, 32 deletions
diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp index 388d6dc..9cc51da 100644 --- a/mlir/lib/Rewrite/ByteCode.cpp +++ b/mlir/lib/Rewrite/ByteCode.cpp @@ -34,21 +34,23 @@ using namespace mlir::detail; //===----------------------------------------------------------------------===// PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp, + PDLPatternConfigSet *configSet, ByteCodeAddr rewriterAddr) { + PatternBenefit benefit = matchOp.getBenefit(); + MLIRContext *ctx = matchOp.getContext(); + + // Collect the set of generated operations. SmallVector<StringRef, 8> generatedOps; if (ArrayAttr generatedOpsAttr = matchOp.getGeneratedOpsAttr()) generatedOps = llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>()); - PatternBenefit benefit = matchOp.getBenefit(); - MLIRContext *ctx = matchOp.getContext(); - // Check to see if this is pattern matches a specific operation type. if (Optional<StringRef> rootKind = matchOp.getRootKind()) - return PDLByteCodePattern(rewriterAddr, *rootKind, benefit, ctx, + return PDLByteCodePattern(rewriterAddr, configSet, *rootKind, benefit, ctx, generatedOps); - return PDLByteCodePattern(rewriterAddr, MatchAnyOpTypeTag(), benefit, ctx, - generatedOps); + return PDLByteCodePattern(rewriterAddr, configSet, MatchAnyOpTypeTag(), + benefit, ctx, generatedOps); } //===----------------------------------------------------------------------===// @@ -194,14 +196,15 @@ public: ByteCodeField &maxValueRangeMemoryIndex, ByteCodeField &maxLoopLevel, llvm::StringMap<PDLConstraintFunction> &constraintFns, - llvm::StringMap<PDLRewriteFunction> &rewriteFns) + llvm::StringMap<PDLRewriteFunction> &rewriteFns, + const DenseMap<Operation *, PDLPatternConfigSet *> &configMap) : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode), rewriterByteCode(rewriterByteCode), patterns(patterns), maxValueMemoryIndex(maxValueMemoryIndex), maxOpRangeMemoryIndex(maxOpRangeMemoryIndex), maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex), maxValueRangeMemoryIndex(maxValueRangeMemoryIndex), - maxLoopLevel(maxLoopLevel) { + maxLoopLevel(maxLoopLevel), configMap(configMap) { for (const auto &it : llvm::enumerate(constraintFns)) constraintToMemIndex.try_emplace(it.value().first(), it.index()); for (const auto &it : llvm::enumerate(rewriteFns)) @@ -328,6 +331,9 @@ private: ByteCodeField &maxTypeRangeMemoryIndex; ByteCodeField &maxValueRangeMemoryIndex; ByteCodeField &maxLoopLevel; + + /// A map of pattern configurations. + const DenseMap<Operation *, PDLPatternConfigSet *> &configMap; }; /// This class provides utilities for writing a bytecode stream. @@ -969,7 +975,8 @@ void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) { void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) { ByteCodeField patternIndex = patterns.size(); patterns.emplace_back(PDLByteCodePattern::create( - op, rewriterToAddr[op.getRewriter().getLeafReference().getValue()])); + op, configMap.lookup(op), + rewriterToAddr[op.getRewriter().getLeafReference().getValue()])); writer.append(OpCode::RecordMatch, patternIndex, SuccessorRange(op.getOperation()), op.getMatchedOps()); writer.appendPDLValueList(op.getInputs()); @@ -1014,13 +1021,16 @@ void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) { // PDLByteCode //===----------------------------------------------------------------------===// -PDLByteCode::PDLByteCode(ModuleOp module, - llvm::StringMap<PDLConstraintFunction> constraintFns, - llvm::StringMap<PDLRewriteFunction> rewriteFns) { +PDLByteCode::PDLByteCode( + ModuleOp module, SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs, + const DenseMap<Operation *, PDLPatternConfigSet *> &configMap, + llvm::StringMap<PDLConstraintFunction> constraintFns, + llvm::StringMap<PDLRewriteFunction> rewriteFns) + : configs(std::move(configs)) { Generator generator(module.getContext(), uniquedData, matcherByteCode, rewriterByteCode, patterns, maxValueMemoryIndex, maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount, - maxLoopLevel, constraintFns, rewriteFns); + maxLoopLevel, constraintFns, rewriteFns, configMap); generator.generate(module); // Initialize the external functions. @@ -1076,14 +1086,15 @@ public: /// Start executing the code at the current bytecode index. `matches` is an /// optional field provided when this function is executed in a matching /// context. - void execute(PatternRewriter &rewriter, - SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr, - Optional<Location> mainRewriteLoc = {}); + LogicalResult + execute(PatternRewriter &rewriter, + SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr, + Optional<Location> mainRewriteLoc = {}); private: /// Internal implementation of executing each of the bytecode commands. void executeApplyConstraint(PatternRewriter &rewriter); - void executeApplyRewrite(PatternRewriter &rewriter); + LogicalResult executeApplyRewrite(PatternRewriter &rewriter); void executeAreEqual(); void executeAreRangesEqual(); void executeBranch(); @@ -1345,7 +1356,7 @@ void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) { selectJump(succeeded(constraintFn(rewriter, args))); } -void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) { +LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) { LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n"); const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()]; SmallVector<PDLValue, 16> args; @@ -1359,7 +1370,7 @@ void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) { // Execute the rewrite function. ByteCodeField numResults = read(); ByteCodeRewriteResultList results(numResults); - rewriteFn(rewriter, results, args); + LogicalResult rewriteResult = rewriteFn(rewriter, results, args); assert(results.getResults().size() == numResults && "native PDL rewrite function returned unexpected number of results"); @@ -1395,6 +1406,13 @@ void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) { allocatedTypeRangeMemory.push_back(std::move(it)); for (auto &it : results.getAllocatedValueRanges()) allocatedValueRangeMemory.push_back(std::move(it)); + + // Process the result of the rewrite. + if (failed(rewriteResult)) { + LLVM_DEBUG(llvm::dbgs() << " - Failed"); + return failure(); + } + return success(); } void ByteCodeExecutor::executeAreEqual() { @@ -2017,10 +2035,10 @@ void ByteCodeExecutor::executeSwitchTypes() { }); } -void ByteCodeExecutor::execute( - PatternRewriter &rewriter, - SmallVectorImpl<PDLByteCode::MatchResult> *matches, - Optional<Location> mainRewriteLoc) { +LogicalResult +ByteCodeExecutor::execute(PatternRewriter &rewriter, + SmallVectorImpl<PDLByteCode::MatchResult> *matches, + Optional<Location> mainRewriteLoc) { while (true) { // Print the location of the operation being executed. LLVM_DEBUG(llvm::dbgs() << readInline<Location>() << "\n"); @@ -2031,7 +2049,8 @@ void ByteCodeExecutor::execute( executeApplyConstraint(rewriter); break; case ApplyRewrite: - executeApplyRewrite(rewriter); + if (failed(executeApplyRewrite(rewriter))) + return failure(); break; case AreEqual: executeAreEqual(); @@ -2078,7 +2097,7 @@ void ByteCodeExecutor::execute( case Finalize: executeFinalize(); LLVM_DEBUG(llvm::dbgs() << "\n"); - return; + return success(); case ForEach: executeForEach(); break; @@ -2166,8 +2185,6 @@ void ByteCodeExecutor::execute( } } -/// Run the pattern matcher on the given root operation, collecting the matched -/// patterns in `matches`. void PDLByteCode::match(Operation *op, PatternRewriter &rewriter, SmallVectorImpl<MatchResult> &matches, PDLByteCodeMutableState &state) const { @@ -2181,7 +2198,8 @@ void PDLByteCode::match(Operation *op, PatternRewriter &rewriter, state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex, uniquedData, matcherByteCode, state.currentPatternBenefits, patterns, constraintFunctions, rewriteFunctions); - executor.execute(rewriter, &matches); + LogicalResult executeResult = executor.execute(rewriter, &matches); + assert(succeeded(executeResult) && "unexpected matcher execution failure"); // Order the found matches by benefit. std::stable_sort(matches.begin(), matches.end(), @@ -2190,9 +2208,13 @@ void PDLByteCode::match(Operation *op, PatternRewriter &rewriter, }); } -/// Run the rewriter of the given pattern on the root operation `op`. -void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match, - PDLByteCodeMutableState &state) const { +LogicalResult PDLByteCode::rewrite(PatternRewriter &rewriter, + const MatchResult &match, + PDLByteCodeMutableState &state) const { + auto *configSet = match.pattern->getConfigSet(); + if (configSet) + configSet->notifyRewriteBegin(rewriter); + // The arguments of the rewrite function are stored at the start of the // memory buffer. llvm::copy(match.values, state.memory.begin()); @@ -2204,5 +2226,24 @@ void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match, state.allocatedValueRangeMemory, state.loopIndex, uniquedData, rewriterByteCode, state.currentPatternBenefits, patterns, constraintFunctions, rewriteFunctions); - executor.execute(rewriter, /*matches=*/nullptr, match.location); + LogicalResult result = + executor.execute(rewriter, /*matches=*/nullptr, match.location); + + if (configSet) + configSet->notifyRewriteEnd(rewriter); + + // If the rewrite failed, check if the pattern rewriter can recover. If it + // can, we can signal to the pattern applicator to keep trying patterns. If it + // doesn't, we need to bail. Bailing here should be fine, given that we have + // no means to propagate such a failure to the user, and it also indicates a + // bug in the user code (i.e. failable rewrites should not be used with + // pattern rewriters that don't support it). + if (failed(result) && !rewriter.canRecoverFromRewriteFailure()) { + LLVM_DEBUG(llvm::dbgs() << " and rollback is not supported - aborting"); + llvm::report_fatal_error( + "Native PDL Rewrite failed, but the pattern " + "rewriter doesn't support recovery. Failable pattern rewrites should " + "not be used with pattern rewriters that do not support them."); + } + return result; } |