diff options
Diffstat (limited to 'mlir/lib/Rewrite/ByteCode.cpp')
-rw-r--r-- | mlir/lib/Rewrite/ByteCode.cpp | 105 |
1 files changed, 73 insertions, 32 deletions
diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp index 388d6dc..9cc51da 100644 --- a/mlir/lib/Rewrite/ByteCode.cpp +++ b/mlir/lib/Rewrite/ByteCode.cpp @@ -34,21 +34,23 @@ using namespace mlir::detail; //===----------------------------------------------------------------------===// PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp, + PDLPatternConfigSet *configSet, ByteCodeAddr rewriterAddr) { + PatternBenefit benefit = matchOp.getBenefit(); + MLIRContext *ctx = matchOp.getContext(); + + // Collect the set of generated operations. SmallVector<StringRef, 8> generatedOps; if (ArrayAttr generatedOpsAttr = matchOp.getGeneratedOpsAttr()) generatedOps = llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>()); - PatternBenefit benefit = matchOp.getBenefit(); - MLIRContext *ctx = matchOp.getContext(); - // Check to see if this is pattern matches a specific operation type. if (Optional<StringRef> rootKind = matchOp.getRootKind()) - return PDLByteCodePattern(rewriterAddr, *rootKind, benefit, ctx, + return PDLByteCodePattern(rewriterAddr, configSet, *rootKind, benefit, ctx, generatedOps); - return PDLByteCodePattern(rewriterAddr, MatchAnyOpTypeTag(), benefit, ctx, - generatedOps); + return PDLByteCodePattern(rewriterAddr, configSet, MatchAnyOpTypeTag(), + benefit, ctx, generatedOps); } //===----------------------------------------------------------------------===// @@ -194,14 +196,15 @@ public: ByteCodeField &maxValueRangeMemoryIndex, ByteCodeField &maxLoopLevel, llvm::StringMap<PDLConstraintFunction> &constraintFns, - llvm::StringMap<PDLRewriteFunction> &rewriteFns) + llvm::StringMap<PDLRewriteFunction> &rewriteFns, + const DenseMap<Operation *, PDLPatternConfigSet *> &configMap) : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode), rewriterByteCode(rewriterByteCode), patterns(patterns), maxValueMemoryIndex(maxValueMemoryIndex), maxOpRangeMemoryIndex(maxOpRangeMemoryIndex), maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex), maxValueRangeMemoryIndex(maxValueRangeMemoryIndex), - maxLoopLevel(maxLoopLevel) { + maxLoopLevel(maxLoopLevel), configMap(configMap) { for (const auto &it : llvm::enumerate(constraintFns)) constraintToMemIndex.try_emplace(it.value().first(), it.index()); for (const auto &it : llvm::enumerate(rewriteFns)) @@ -328,6 +331,9 @@ private: ByteCodeField &maxTypeRangeMemoryIndex; ByteCodeField &maxValueRangeMemoryIndex; ByteCodeField &maxLoopLevel; + + /// A map of pattern configurations. + const DenseMap<Operation *, PDLPatternConfigSet *> &configMap; }; /// This class provides utilities for writing a bytecode stream. @@ -969,7 +975,8 @@ void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) { void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) { ByteCodeField patternIndex = patterns.size(); patterns.emplace_back(PDLByteCodePattern::create( - op, rewriterToAddr[op.getRewriter().getLeafReference().getValue()])); + op, configMap.lookup(op), + rewriterToAddr[op.getRewriter().getLeafReference().getValue()])); writer.append(OpCode::RecordMatch, patternIndex, SuccessorRange(op.getOperation()), op.getMatchedOps()); writer.appendPDLValueList(op.getInputs()); @@ -1014,13 +1021,16 @@ void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) { // PDLByteCode //===----------------------------------------------------------------------===// -PDLByteCode::PDLByteCode(ModuleOp module, - llvm::StringMap<PDLConstraintFunction> constraintFns, - llvm::StringMap<PDLRewriteFunction> rewriteFns) { +PDLByteCode::PDLByteCode( + ModuleOp module, SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs, + const DenseMap<Operation *, PDLPatternConfigSet *> &configMap, + llvm::StringMap<PDLConstraintFunction> constraintFns, + llvm::StringMap<PDLRewriteFunction> rewriteFns) + : configs(std::move(configs)) { Generator generator(module.getContext(), uniquedData, matcherByteCode, rewriterByteCode, patterns, maxValueMemoryIndex, maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount, - maxLoopLevel, constraintFns, rewriteFns); + maxLoopLevel, constraintFns, rewriteFns, configMap); generator.generate(module); // Initialize the external functions. @@ -1076,14 +1086,15 @@ public: /// Start executing the code at the current bytecode index. `matches` is an /// optional field provided when this function is executed in a matching /// context. - void execute(PatternRewriter &rewriter, - SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr, - Optional<Location> mainRewriteLoc = {}); + LogicalResult + execute(PatternRewriter &rewriter, + SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr, + Optional<Location> mainRewriteLoc = {}); private: /// Internal implementation of executing each of the bytecode commands. void executeApplyConstraint(PatternRewriter &rewriter); - void executeApplyRewrite(PatternRewriter &rewriter); + LogicalResult executeApplyRewrite(PatternRewriter &rewriter); void executeAreEqual(); void executeAreRangesEqual(); void executeBranch(); @@ -1345,7 +1356,7 @@ void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) { selectJump(succeeded(constraintFn(rewriter, args))); } -void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) { +LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) { LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n"); const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()]; SmallVector<PDLValue, 16> args; @@ -1359,7 +1370,7 @@ void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) { // Execute the rewrite function. ByteCodeField numResults = read(); ByteCodeRewriteResultList results(numResults); - rewriteFn(rewriter, results, args); + LogicalResult rewriteResult = rewriteFn(rewriter, results, args); assert(results.getResults().size() == numResults && "native PDL rewrite function returned unexpected number of results"); @@ -1395,6 +1406,13 @@ void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) { allocatedTypeRangeMemory.push_back(std::move(it)); for (auto &it : results.getAllocatedValueRanges()) allocatedValueRangeMemory.push_back(std::move(it)); + + // Process the result of the rewrite. + if (failed(rewriteResult)) { + LLVM_DEBUG(llvm::dbgs() << " - Failed"); + return failure(); + } + return success(); } void ByteCodeExecutor::executeAreEqual() { @@ -2017,10 +2035,10 @@ void ByteCodeExecutor::executeSwitchTypes() { }); } -void ByteCodeExecutor::execute( - PatternRewriter &rewriter, - SmallVectorImpl<PDLByteCode::MatchResult> *matches, - Optional<Location> mainRewriteLoc) { +LogicalResult +ByteCodeExecutor::execute(PatternRewriter &rewriter, + SmallVectorImpl<PDLByteCode::MatchResult> *matches, + Optional<Location> mainRewriteLoc) { while (true) { // Print the location of the operation being executed. LLVM_DEBUG(llvm::dbgs() << readInline<Location>() << "\n"); @@ -2031,7 +2049,8 @@ void ByteCodeExecutor::execute( executeApplyConstraint(rewriter); break; case ApplyRewrite: - executeApplyRewrite(rewriter); + if (failed(executeApplyRewrite(rewriter))) + return failure(); break; case AreEqual: executeAreEqual(); @@ -2078,7 +2097,7 @@ void ByteCodeExecutor::execute( case Finalize: executeFinalize(); LLVM_DEBUG(llvm::dbgs() << "\n"); - return; + return success(); case ForEach: executeForEach(); break; @@ -2166,8 +2185,6 @@ void ByteCodeExecutor::execute( } } -/// Run the pattern matcher on the given root operation, collecting the matched -/// patterns in `matches`. void PDLByteCode::match(Operation *op, PatternRewriter &rewriter, SmallVectorImpl<MatchResult> &matches, PDLByteCodeMutableState &state) const { @@ -2181,7 +2198,8 @@ void PDLByteCode::match(Operation *op, PatternRewriter &rewriter, state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex, uniquedData, matcherByteCode, state.currentPatternBenefits, patterns, constraintFunctions, rewriteFunctions); - executor.execute(rewriter, &matches); + LogicalResult executeResult = executor.execute(rewriter, &matches); + assert(succeeded(executeResult) && "unexpected matcher execution failure"); // Order the found matches by benefit. std::stable_sort(matches.begin(), matches.end(), @@ -2190,9 +2208,13 @@ void PDLByteCode::match(Operation *op, PatternRewriter &rewriter, }); } -/// Run the rewriter of the given pattern on the root operation `op`. -void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match, - PDLByteCodeMutableState &state) const { +LogicalResult PDLByteCode::rewrite(PatternRewriter &rewriter, + const MatchResult &match, + PDLByteCodeMutableState &state) const { + auto *configSet = match.pattern->getConfigSet(); + if (configSet) + configSet->notifyRewriteBegin(rewriter); + // The arguments of the rewrite function are stored at the start of the // memory buffer. llvm::copy(match.values, state.memory.begin()); @@ -2204,5 +2226,24 @@ void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match, state.allocatedValueRangeMemory, state.loopIndex, uniquedData, rewriterByteCode, state.currentPatternBenefits, patterns, constraintFunctions, rewriteFunctions); - executor.execute(rewriter, /*matches=*/nullptr, match.location); + LogicalResult result = + executor.execute(rewriter, /*matches=*/nullptr, match.location); + + if (configSet) + configSet->notifyRewriteEnd(rewriter); + + // If the rewrite failed, check if the pattern rewriter can recover. If it + // can, we can signal to the pattern applicator to keep trying patterns. If it + // doesn't, we need to bail. Bailing here should be fine, given that we have + // no means to propagate such a failure to the user, and it also indicates a + // bug in the user code (i.e. failable rewrites should not be used with + // pattern rewriters that don't support it). + if (failed(result) && !rewriter.canRecoverFromRewriteFailure()) { + LLVM_DEBUG(llvm::dbgs() << " and rollback is not supported - aborting"); + llvm::report_fatal_error( + "Native PDL Rewrite failed, but the pattern " + "rewriter doesn't support recovery. Failable pattern rewrites should " + "not be used with pattern rewriters that do not support them."); + } + return result; } |