aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Rewrite/ByteCode.cpp
diff options
context:
space:
mode:
authorMichael Kruse <llvm-project@meinersbur.de>2022-05-06 02:43:49 -0500
committerMichael Kruse <llvm-project@meinersbur.de>2022-05-06 02:43:49 -0500
commitd3460d2a890ccb4ee84ffd05d4a722ff82b5170d (patch)
tree19c511c63d14da385bc907b4aa99276b8636dc1a /mlir/lib/Rewrite/ByteCode.cpp
parentfdb6ddcfeb62be7dbc502a4a4ed7c5be11c1c9b6 (diff)
parent9c1085c7e20bdd7c4a487f50313ebeeb2b6683b8 (diff)
downloadllvm-users/meinersbur/irbuilder-ompregion.zip
llvm-users/meinersbur/irbuilder-ompregion.tar.gz
llvm-users/meinersbur/irbuilder-ompregion.tar.bz2
Merge branch 'main' into irbuilder-ompregionusers/meinersbur/irbuilder-ompregion
Diffstat (limited to 'mlir/lib/Rewrite/ByteCode.cpp')
-rw-r--r--mlir/lib/Rewrite/ByteCode.cpp63
1 files changed, 33 insertions, 30 deletions
diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp
index c2dc41a..ad4c078 100644
--- a/mlir/lib/Rewrite/ByteCode.cpp
+++ b/mlir/lib/Rewrite/ByteCode.cpp
@@ -162,6 +162,10 @@ enum OpCode : ByteCodeField {
};
} // namespace
+/// A marker used to indicate if an operation should infer types.
+static constexpr ByteCodeField kInferTypesMarker =
+ std::numeric_limits<ByteCodeField>::max();
+
//===----------------------------------------------------------------------===//
// ByteCode Generation
//===----------------------------------------------------------------------===//
@@ -273,7 +277,6 @@ private:
void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer);
void generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer);
void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
- void generate(pdl_interp::InferredTypesOp op, ByteCodeWriter &writer);
void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer);
void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer);
@@ -723,8 +726,7 @@ void Generator::generate(Operation *op, ByteCodeWriter &writer) {
LLVM_DEBUG({
// The following list must contain all the operations that do not
// produce any bytecode.
- if (!isa<pdl_interp::CreateAttributeOp, pdl_interp::CreateTypeOp,
- pdl_interp::InferredTypesOp>(op))
+ if (!isa<pdl_interp::CreateAttributeOp, pdl_interp::CreateTypeOp>(op))
writer.appendInline(op->getLoc());
});
TypeSwitch<Operation *>(op)
@@ -742,11 +744,11 @@ void Generator::generate(Operation *op, ByteCodeWriter &writer) {
pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp,
pdl_interp::GetResultOp, pdl_interp::GetResultsOp,
pdl_interp::GetUsersOp, pdl_interp::GetValueTypeOp,
- pdl_interp::InferredTypesOp, pdl_interp::IsNotNullOp,
- pdl_interp::RecordMatchOp, pdl_interp::ReplaceOp,
- pdl_interp::SwitchAttributeOp, pdl_interp::SwitchTypeOp,
- pdl_interp::SwitchTypesOp, pdl_interp::SwitchOperandCountOp,
- pdl_interp::SwitchOperationNameOp, pdl_interp::SwitchResultCountOp>(
+ pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp,
+ pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp,
+ pdl_interp::SwitchTypeOp, pdl_interp::SwitchTypesOp,
+ pdl_interp::SwitchOperandCountOp, pdl_interp::SwitchOperationNameOp,
+ pdl_interp::SwitchResultCountOp>(
[&](auto interpOp) { this->generate(interpOp, writer); })
.Default([](Operation *) {
llvm_unreachable("unknown `pdl_interp` operation");
@@ -847,7 +849,13 @@ void Generator::generate(pdl_interp::CreateOperationOp op,
writer.append(static_cast<ByteCodeField>(attributes.size()));
for (auto it : llvm::zip(op.getInputAttributeNames(), attributes))
writer.append(std::get<0>(it), std::get<1>(it));
- writer.appendPDLValueList(op.getInputResultTypes());
+
+ // Add the result types. If the operation has inferred results, we use a
+ // marker "size" value. Otherwise, we add the list of explicit result types.
+ if (op.getInferredResultTypes())
+ writer.append(kInferTypesMarker);
+ else
+ writer.appendPDLValueList(op.getInputResultTypes());
}
void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
// Simply repoint the memory index of the result to the constant.
@@ -955,12 +963,6 @@ void Generator::generate(pdl_interp::GetValueTypeOp op,
writer.append(OpCode::GetValueType, op.getResult(), op.getValue());
}
}
-
-void Generator::generate(pdl_interp::InferredTypesOp op,
- ByteCodeWriter &writer) {
- // InferType maps to a null type as a marker for inferring result types.
- getMemIndex(op.getResult()) = getMemIndex(Type());
-}
void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
writer.append(OpCode::IsNotNull, op.getValue(), op.getSuccessors());
}
@@ -1526,30 +1528,31 @@ void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
state.addAttribute(name, attr);
}
- for (unsigned i = 0, e = read(); i != e; ++i) {
- if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
- state.types.push_back(read<Type>());
- continue;
- }
-
- // If we find a null range, this signals that the types are infered.
- if (TypeRange *resultTypes = read<TypeRange *>()) {
- state.types.append(resultTypes->begin(), resultTypes->end());
- continue;
- }
-
- // Handle the case where the operation has inferred types.
+ // Read in the result types. If the "size" is the sentinel value, this
+ // indicates that the result types should be inferred.
+ unsigned numResults = read();
+ if (numResults == kInferTypesMarker) {
InferTypeOpInterface::Concept *inferInterface =
state.name.getRegisteredInfo()->getInterface<InferTypeOpInterface>();
+ assert(inferInterface &&
+ "expected operation to provide InferTypeOpInterface");
// TODO: Handle failure.
- state.types.clear();
if (failed(inferInterface->inferReturnTypes(
state.getContext(), state.location, state.operands,
state.attributes.getDictionary(state.getContext()), state.regions,
state.types)))
return;
- break;
+ } else {
+ // Otherwise, this is a fixed number of results.
+ for (unsigned i = 0; i != numResults; ++i) {
+ if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
+ state.types.push_back(read<Type>());
+ } else {
+ TypeRange *resultTypes = read<TypeRange *>();
+ state.types.append(resultTypes->begin(), resultTypes->end());
+ }
+ }
}
Operation *resultOp = rewriter.create(state);