aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Rewrite/ByteCode.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Rewrite/ByteCode.cpp')
-rw-r--r--mlir/lib/Rewrite/ByteCode.cpp144
1 files changed, 45 insertions, 99 deletions
diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp
index 559ce7f..6e6992d 100644
--- a/mlir/lib/Rewrite/ByteCode.cpp
+++ b/mlir/lib/Rewrite/ByteCode.cpp
@@ -769,25 +769,11 @@ void Generator::generate(Operation *op, ByteCodeWriter &writer) {
void Generator::generate(pdl_interp::ApplyConstraintOp op,
ByteCodeWriter &writer) {
- // Constraints that should return a value have to be registered as rewrites.
- // If a constraint and a rewrite of similar name are registered the
- // constraint takes precedence
+ assert(constraintToMemIndex.count(op.getName()) &&
+ "expected index for constraint function");
writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]);
writer.appendPDLValueList(op.getArgs());
writer.append(ByteCodeField(op.getIsNegated()));
- ResultRange results = op.getResults();
- writer.append(ByteCodeField(results.size()));
- for (Value result : results) {
- // We record the expected kind of the result, so that we can provide extra
- // verification of the native rewrite function and handle the failure case
- // of constraints accordingly.
- writer.appendPDLValueKind(result);
-
- // Range results also need to append the range storage index.
- if (isa<pdl::RangeType>(result.getType()))
- writer.append(getRangeStorageIndex(result));
- writer.append(result);
- }
writer.append(op.getSuccessors());
}
void Generator::generate(pdl_interp::ApplyRewriteOp op,
@@ -800,9 +786,11 @@ void Generator::generate(pdl_interp::ApplyRewriteOp op,
ResultRange results = op.getResults();
writer.append(ByteCodeField(results.size()));
for (Value result : results) {
- // We record the expected kind of the result, so that we
+ // In debug mode we also record the expected kind of the result, so that we
// can provide extra verification of the native rewrite function.
+#ifndef NDEBUG
writer.appendPDLValueKind(result);
+#endif
// Range results also need to append the range storage index.
if (isa<pdl::RangeType>(result.getType()))
@@ -1088,28 +1076,6 @@ void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const {
// ByteCode Execution
namespace {
-/// This class is an instantiation of the PDLResultList that provides access to
-/// the returned results. This API is not on `PDLResultList` to avoid
-/// overexposing access to information specific solely to the ByteCode.
-class ByteCodeRewriteResultList : public PDLResultList {
-public:
- ByteCodeRewriteResultList(unsigned maxNumResults)
- : PDLResultList(maxNumResults) {}
-
- /// Return the list of PDL results.
- MutableArrayRef<PDLValue> getResults() { return results; }
-
- /// Return the type ranges allocated by this list.
- MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() {
- return allocatedTypeRanges;
- }
-
- /// Return the value ranges allocated by this list.
- MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() {
- return allocatedValueRanges;
- }
-};
-
/// This class provides support for executing a bytecode stream.
class ByteCodeExecutor {
public:
@@ -1186,9 +1152,6 @@ private:
void executeSwitchResultCount();
void executeSwitchType();
void executeSwitchTypes();
- void processNativeFunResults(ByteCodeRewriteResultList &results,
- unsigned numResults,
- LogicalResult &rewriteResult);
/// Pushes a code iterator to the stack.
void pushCodeIt(const ByteCodeField *it) { resumeCodeIt.push_back(it); }
@@ -1262,8 +1225,6 @@ private:
return T::getFromOpaquePointer(pointer);
}
- void skip(size_t skipN) { curCodeIt += skipN; }
-
/// Jump to a specific successor based on a predicate value.
void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); }
/// Jump to a specific successor based on a destination index.
@@ -1420,11 +1381,33 @@ private:
ArrayRef<PDLConstraintFunction> constraintFunctions;
ArrayRef<PDLRewriteFunction> rewriteFunctions;
};
+
+/// This class is an instantiation of the PDLResultList that provides access to
+/// the returned results. This API is not on `PDLResultList` to avoid
+/// overexposing access to information specific solely to the ByteCode.
+class ByteCodeRewriteResultList : public PDLResultList {
+public:
+ ByteCodeRewriteResultList(unsigned maxNumResults)
+ : PDLResultList(maxNumResults) {}
+
+ /// Return the list of PDL results.
+ MutableArrayRef<PDLValue> getResults() { return results; }
+
+ /// Return the type ranges allocated by this list.
+ MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() {
+ return allocatedTypeRanges;
+ }
+
+ /// Return the value ranges allocated by this list.
+ MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() {
+ return allocatedValueRanges;
+ }
+};
} // namespace
void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n");
- ByteCodeField fun_idx = read();
+ const PDLConstraintFunction &constraintFn = constraintFunctions[read()];
SmallVector<PDLValue, 16> args;
readList<PDLValue>(args);
@@ -1439,29 +1422,8 @@ void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
llvm::dbgs() << " * isNegated: " << isNegated << "\n";
llvm::interleaveComma(args, llvm::dbgs());
});
-
- ByteCodeField numResults = read();
- const PDLRewriteFunction &constraintFn = constraintFunctions[fun_idx];
- ByteCodeRewriteResultList results(numResults);
- LogicalResult rewriteResult = constraintFn(rewriter, results, args);
- ArrayRef<PDLValue> constraintResults = results.getResults();
- LLVM_DEBUG({
- if (succeeded(rewriteResult)) {
- llvm::dbgs() << " * Constraint succeeded\n";
- llvm::dbgs() << " * Results: ";
- llvm::interleaveComma(constraintResults, llvm::dbgs());
- llvm::dbgs() << "\n";
- } else {
- llvm::dbgs() << " * Constraint failed\n";
- }
- });
- assert((failed(rewriteResult) || constraintResults.size() == numResults) &&
- "native PDL rewrite function succeeded but returned "
- "unexpected number of results");
- processNativeFunResults(results, numResults, rewriteResult);
-
- // Depending on the constraint jump to the proper destination.
- selectJump(isNegated != succeeded(rewriteResult));
+ // Invoke the constraint and jump to the proper destination.
+ selectJump(isNegated != succeeded(constraintFn(rewriter, args)));
}
LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
@@ -1483,39 +1445,16 @@ LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
assert(results.getResults().size() == numResults &&
"native PDL rewrite function returned unexpected number of results");
- processNativeFunResults(results, numResults, rewriteResult);
+ // Store the results in the bytecode memory.
+ for (PDLValue &result : results.getResults()) {
+ LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n");
- if (failed(rewriteResult)) {
- LLVM_DEBUG(llvm::dbgs() << " - Failed");
- return failure();
- }
- return success();
-}
+// In debug mode we also verify the expected kind of the result.
+#ifndef NDEBUG
+ assert(result.getKind() == read<PDLValue::Kind>() &&
+ "native PDL rewrite function returned an unexpected type of result");
+#endif
-void ByteCodeExecutor::processNativeFunResults(
- ByteCodeRewriteResultList &results, unsigned numResults,
- LogicalResult &rewriteResult) {
- // Store the results in the bytecode memory or handle missing results on
- // failure.
- for (unsigned resultIdx = 0; resultIdx < numResults; resultIdx++) {
- PDLValue::Kind resultKind = read<PDLValue::Kind>();
-
- // Skip the according number of values on the buffer on failure and exit
- // early as there are no results to process.
- if (failed(rewriteResult)) {
- if (resultKind == PDLValue::Kind::TypeRange ||
- resultKind == PDLValue::Kind::ValueRange) {
- skip(2);
- } else {
- skip(1);
- }
- return;
- }
- PDLValue result = results.getResults()[resultIdx];
- LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n");
- assert(result.getKind() == resultKind &&
- "native PDL rewrite function returned an unexpected type of "
- "result");
// If the result is a range, we need to copy it over to the bytecodes
// range memory.
if (std::optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) {
@@ -1537,6 +1476,13 @@ void ByteCodeExecutor::processNativeFunResults(
allocatedTypeRangeMemory.push_back(std::move(it));
for (auto &it : results.getAllocatedValueRanges())
allocatedValueRangeMemory.push_back(std::move(it));
+
+ // Process the result of the rewrite.
+ if (failed(rewriteResult)) {
+ LLVM_DEBUG(llvm::dbgs() << " - Failed");
+ return failure();
+ }
+ return success();
}
void ByteCodeExecutor::executeAreEqual() {