aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Rewrite/ByteCode.cpp
diff options
context:
space:
mode:
authorRiver Riddle <riddleriver@gmail.com>2022-09-09 16:31:24 -0700
committerRiver Riddle <riddleriver@gmail.com>2022-11-08 01:57:57 -0800
commitce57789d8e5dc109dc9bd330232b31a22a80ad3a (patch)
tree74ec15ffb3bb119281a226e3573a100ca1c458ba /mlir/lib/Rewrite/ByteCode.cpp
parent8c66344ee9f67f76b3cb6b3345a46345a2d3975a (diff)
downloadllvm-ce57789d8e5dc109dc9bd330232b31a22a80ad3a.zip
llvm-ce57789d8e5dc109dc9bd330232b31a22a80ad3a.tar.gz
llvm-ce57789d8e5dc109dc9bd330232b31a22a80ad3a.tar.bz2
[mlir:PDL] Add support for creating ranges in rewrites
This commit adds support for building a concatenated range from a given set of elements, either single element or other ranges, within a rewrite. We could conceptually extend this to support constraining input ranges, but the logic there is quite a bit more complex so it is left for later work when a need arises. Differential Revision: https://reviews.llvm.org/D133719
Diffstat (limited to 'mlir/lib/Rewrite/ByteCode.cpp')
-rw-r--r--mlir/lib/Rewrite/ByteCode.cpp127
1 files changed, 102 insertions, 25 deletions
diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp
index 9cc51da..6b1dfb9 100644
--- a/mlir/lib/Rewrite/ByteCode.cpp
+++ b/mlir/lib/Rewrite/ByteCode.cpp
@@ -99,10 +99,14 @@ enum OpCode : ByteCodeField {
CheckTypes,
/// Continue to the next iteration of a loop.
Continue,
+ /// Create a type range from a list of constant types.
+ CreateConstantTypeRange,
/// Create an operation.
CreateOperation,
- /// Create a range of types.
- CreateTypes,
+ /// Create a type range from a list of dynamic types.
+ CreateDynamicTypeRange,
+ /// Create a value range.
+ CreateDynamicValueRange,
/// Erase an operation.
EraseOp,
/// Extract the op from a range at the specified index.
@@ -265,6 +269,7 @@ private:
void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer);
void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
+ void generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer);
void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer);
void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
@@ -742,9 +747,9 @@ void Generator::generate(Operation *op, ByteCodeWriter &writer) {
pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp,
pdl_interp::ContinueOp, pdl_interp::CreateAttributeOp,
- pdl_interp::CreateOperationOp, pdl_interp::CreateTypeOp,
- pdl_interp::CreateTypesOp, pdl_interp::EraseOp,
- pdl_interp::ExtractOp, pdl_interp::FinalizeOp,
+ pdl_interp::CreateOperationOp, pdl_interp::CreateRangeOp,
+ pdl_interp::CreateTypeOp, pdl_interp::CreateTypesOp,
+ pdl_interp::EraseOp, pdl_interp::ExtractOp, pdl_interp::FinalizeOp,
pdl_interp::ForEachOp, pdl_interp::GetAttributeOp,
pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp,
pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp,
@@ -863,12 +868,24 @@ void Generator::generate(pdl_interp::CreateOperationOp op,
else
writer.appendPDLValueList(op.getInputResultTypes());
}
+void Generator::generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer) {
+ // Append the correct opcode for the range type.
+ TypeSwitch<Type>(op.getType().getElementType())
+ .Case(
+ [&](pdl::TypeType) { writer.append(OpCode::CreateDynamicTypeRange); })
+ .Case([&](pdl::ValueType) {
+ writer.append(OpCode::CreateDynamicValueRange);
+ });
+
+ writer.append(op.getResult(), getRangeStorageIndex(op.getResult()));
+ writer.appendPDLValueList(op->getOperands());
+}
void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
// Simply repoint the memory index of the result to the constant.
getMemIndex(op.getResult()) = getMemIndex(op.getValue());
}
void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) {
- writer.append(OpCode::CreateTypes, op.getResult(),
+ writer.append(OpCode::CreateConstantTypeRange, op.getResult(),
getRangeStorageIndex(op.getResult()), op.getValue());
}
void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
@@ -1103,9 +1120,11 @@ private:
void executeCheckResultCount();
void executeCheckTypes();
void executeContinue();
+ void executeCreateConstantTypeRange();
void executeCreateOperation(PatternRewriter &rewriter,
Location mainRewriteLoc);
- void executeCreateTypes();
+ template <typename T>
+ void executeDynamicCreateRange(StringRef type);
void executeEraseOp(PatternRewriter &rewriter);
template <typename T, typename Range, PDLValue::Kind kind>
void executeExtract();
@@ -1172,8 +1191,18 @@ private:
}
/// Read a list of values from the bytecode buffer. The values may be encoded
- /// as either Value or ValueRange elements.
- void readValueList(SmallVectorImpl<Value> &list) {
+ /// either as a single element or a range of elements.
+ void readList(SmallVectorImpl<Type> &list) {
+ for (unsigned i = 0, e = read(); i != e; ++i) {
+ if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
+ list.push_back(read<Type>());
+ } else {
+ TypeRange *values = read<TypeRange *>();
+ list.append(values->begin(), values->end());
+ }
+ }
+ }
+ void readList(SmallVectorImpl<Value> &list) {
for (unsigned i = 0, e = read(); i != e; ++i) {
if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
list.push_back(read<Value>());
@@ -1292,6 +1321,39 @@ private:
return static_cast<PDLValue::Kind>(readImpl<ByteCodeField>());
}
+ /// Assign the given range to the given memory index. This allocates a new
+ /// range object if necessary.
+ template <typename RangeT, typename T = llvm::detail::ValueOfRange<RangeT>>
+ void assignRangeToMemory(RangeT &&range, unsigned memIndex,
+ unsigned rangeIndex) {
+ // Utility functor used to type-erase the assignment.
+ auto assignRange = [&](auto &allocatedRangeMemory, auto &rangeMemory) {
+ // If the input range is empty, we don't need to allocate anything.
+ if (range.empty()) {
+ rangeMemory[rangeIndex] = {};
+ } else {
+ // Allocate a buffer for this type range.
+ llvm::OwningArrayRef<T> storage(llvm::size(range));
+ llvm::copy(range, storage.begin());
+
+ // Assign this to the range slot and use the range as the value for the
+ // memory index.
+ allocatedRangeMemory.emplace_back(std::move(storage));
+ rangeMemory[rangeIndex] = allocatedRangeMemory.back();
+ }
+ memory[memIndex] = &rangeMemory[rangeIndex];
+ };
+
+ // Dispatch based on the concrete range type.
+ if constexpr (std::is_same_v<T, Type>) {
+ return assignRange(allocatedTypeRangeMemory, typeRangeMemory);
+ } else if constexpr (std::is_same_v<T, Value>) {
+ return assignRange(allocatedValueRangeMemory, valueRangeMemory);
+ } else {
+ llvm_unreachable("unhandled range type");
+ }
+ }
+
/// The underlying bytecode buffer.
const ByteCodeField *curCodeIt;
@@ -1514,23 +1576,15 @@ void ByteCodeExecutor::executeContinue() {
popCodeIt();
}
-void ByteCodeExecutor::executeCreateTypes() {
- LLVM_DEBUG(llvm::dbgs() << "Executing CreateTypes:\n");
+void ByteCodeExecutor::executeCreateConstantTypeRange() {
+ LLVM_DEBUG(llvm::dbgs() << "Executing CreateConstantTypeRange:\n");
unsigned memIndex = read();
unsigned rangeIndex = read();
ArrayAttr typesAttr = read<Attribute>().cast<ArrayAttr>();
LLVM_DEBUG(llvm::dbgs() << " * Types: " << typesAttr << "\n\n");
-
- // Allocate a buffer for this type range.
- llvm::OwningArrayRef<Type> storage(typesAttr.size());
- llvm::copy(typesAttr.getAsValueRange<TypeAttr>(), storage.begin());
- allocatedTypeRangeMemory.emplace_back(std::move(storage));
-
- // Assign this to the range slot and use the range as the value for the
- // memory index.
- typeRangeMemory[rangeIndex] = allocatedTypeRangeMemory.back();
- memory[memIndex] = &typeRangeMemory[rangeIndex];
+ assignRangeToMemory(typesAttr.getAsValueRange<TypeAttr>(), memIndex,
+ rangeIndex);
}
void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
@@ -1539,7 +1593,7 @@ void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
unsigned memIndex = read();
OperationState state(mainRewriteLoc, read<OperationName>());
- readValueList(state.operands);
+ readList(state.operands);
for (unsigned i = 0, e = read(); i != e; ++i) {
StringAttr name = read<StringAttr>();
if (Attribute attr = read<Attribute>())
@@ -1587,6 +1641,23 @@ void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
});
}
+template <typename T>
+void ByteCodeExecutor::executeDynamicCreateRange(StringRef type) {
+ LLVM_DEBUG(llvm::dbgs() << "Executing CreateDynamic" << type << "Range:\n");
+ unsigned memIndex = read();
+ unsigned rangeIndex = read();
+ SmallVector<T> values;
+ readList(values);
+
+ LLVM_DEBUG({
+ llvm::dbgs() << "\n * " << type << "s: ";
+ llvm::interleaveComma(values, llvm::dbgs());
+ llvm::dbgs() << "\n";
+ });
+
+ assignRangeToMemory(values, memIndex, rangeIndex);
+}
+
void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) {
LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n");
Operation *op = read<Operation *>();
@@ -1949,7 +2020,7 @@ void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) {
LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n");
Operation *op = read<Operation *>();
SmallVector<Value, 16> args;
- readValueList(args);
+ readList(args);
LLVM_DEBUG({
llvm::dbgs() << " * Operation: " << *op << "\n"
@@ -2076,11 +2147,17 @@ ByteCodeExecutor::execute(PatternRewriter &rewriter,
case Continue:
executeContinue();
break;
+ case CreateConstantTypeRange:
+ executeCreateConstantTypeRange();
+ break;
case CreateOperation:
executeCreateOperation(rewriter, *mainRewriteLoc);
break;
- case CreateTypes:
- executeCreateTypes();
+ case CreateDynamicTypeRange:
+ executeDynamicCreateRange<Type>("Type");
+ break;
+ case CreateDynamicValueRange:
+ executeDynamicCreateRange<Value>("Value");
break;
case EraseOp:
executeEraseOp(rewriter);