diff options
author | River Riddle <riddleriver@gmail.com> | 2022-09-09 16:31:24 -0700 |
---|---|---|
committer | River Riddle <riddleriver@gmail.com> | 2022-11-08 01:57:57 -0800 |
commit | ce57789d8e5dc109dc9bd330232b31a22a80ad3a (patch) | |
tree | 74ec15ffb3bb119281a226e3573a100ca1c458ba /mlir/lib/Rewrite/ByteCode.cpp | |
parent | 8c66344ee9f67f76b3cb6b3345a46345a2d3975a (diff) | |
download | llvm-ce57789d8e5dc109dc9bd330232b31a22a80ad3a.zip llvm-ce57789d8e5dc109dc9bd330232b31a22a80ad3a.tar.gz llvm-ce57789d8e5dc109dc9bd330232b31a22a80ad3a.tar.bz2 |
[mlir:PDL] Add support for creating ranges in rewrites
This commit adds support for building a concatenated range from
a given set of elements, either single element or other ranges, within a
rewrite. We could conceptually extend this to support constraining
input ranges, but the logic there is quite a bit more complex so it is
left for later work when a need arises.
Differential Revision: https://reviews.llvm.org/D133719
Diffstat (limited to 'mlir/lib/Rewrite/ByteCode.cpp')
-rw-r--r-- | mlir/lib/Rewrite/ByteCode.cpp | 127 |
1 files changed, 102 insertions, 25 deletions
diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp index 9cc51da..6b1dfb9 100644 --- a/mlir/lib/Rewrite/ByteCode.cpp +++ b/mlir/lib/Rewrite/ByteCode.cpp @@ -99,10 +99,14 @@ enum OpCode : ByteCodeField { CheckTypes, /// Continue to the next iteration of a loop. Continue, + /// Create a type range from a list of constant types. + CreateConstantTypeRange, /// Create an operation. CreateOperation, - /// Create a range of types. - CreateTypes, + /// Create a type range from a list of dynamic types. + CreateDynamicTypeRange, + /// Create a value range. + CreateDynamicValueRange, /// Erase an operation. EraseOp, /// Extract the op from a range at the specified index. @@ -265,6 +269,7 @@ private: void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer); void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer); void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer); + void generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer); void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer); void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer); void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer); @@ -742,9 +747,9 @@ void Generator::generate(Operation *op, ByteCodeWriter &writer) { pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp, pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp, pdl_interp::ContinueOp, pdl_interp::CreateAttributeOp, - pdl_interp::CreateOperationOp, pdl_interp::CreateTypeOp, - pdl_interp::CreateTypesOp, pdl_interp::EraseOp, - pdl_interp::ExtractOp, pdl_interp::FinalizeOp, + pdl_interp::CreateOperationOp, pdl_interp::CreateRangeOp, + pdl_interp::CreateTypeOp, pdl_interp::CreateTypesOp, + pdl_interp::EraseOp, pdl_interp::ExtractOp, pdl_interp::FinalizeOp, pdl_interp::ForEachOp, pdl_interp::GetAttributeOp, pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp, pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp, @@ -863,12 +868,24 @@ void Generator::generate(pdl_interp::CreateOperationOp op, else writer.appendPDLValueList(op.getInputResultTypes()); } +void Generator::generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer) { + // Append the correct opcode for the range type. + TypeSwitch<Type>(op.getType().getElementType()) + .Case( + [&](pdl::TypeType) { writer.append(OpCode::CreateDynamicTypeRange); }) + .Case([&](pdl::ValueType) { + writer.append(OpCode::CreateDynamicValueRange); + }); + + writer.append(op.getResult(), getRangeStorageIndex(op.getResult())); + writer.appendPDLValueList(op->getOperands()); +} void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) { // Simply repoint the memory index of the result to the constant. getMemIndex(op.getResult()) = getMemIndex(op.getValue()); } void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) { - writer.append(OpCode::CreateTypes, op.getResult(), + writer.append(OpCode::CreateConstantTypeRange, op.getResult(), getRangeStorageIndex(op.getResult()), op.getValue()); } void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) { @@ -1103,9 +1120,11 @@ private: void executeCheckResultCount(); void executeCheckTypes(); void executeContinue(); + void executeCreateConstantTypeRange(); void executeCreateOperation(PatternRewriter &rewriter, Location mainRewriteLoc); - void executeCreateTypes(); + template <typename T> + void executeDynamicCreateRange(StringRef type); void executeEraseOp(PatternRewriter &rewriter); template <typename T, typename Range, PDLValue::Kind kind> void executeExtract(); @@ -1172,8 +1191,18 @@ private: } /// Read a list of values from the bytecode buffer. The values may be encoded - /// as either Value or ValueRange elements. - void readValueList(SmallVectorImpl<Value> &list) { + /// either as a single element or a range of elements. + void readList(SmallVectorImpl<Type> &list) { + for (unsigned i = 0, e = read(); i != e; ++i) { + if (read<PDLValue::Kind>() == PDLValue::Kind::Type) { + list.push_back(read<Type>()); + } else { + TypeRange *values = read<TypeRange *>(); + list.append(values->begin(), values->end()); + } + } + } + void readList(SmallVectorImpl<Value> &list) { for (unsigned i = 0, e = read(); i != e; ++i) { if (read<PDLValue::Kind>() == PDLValue::Kind::Value) { list.push_back(read<Value>()); @@ -1292,6 +1321,39 @@ private: return static_cast<PDLValue::Kind>(readImpl<ByteCodeField>()); } + /// Assign the given range to the given memory index. This allocates a new + /// range object if necessary. + template <typename RangeT, typename T = llvm::detail::ValueOfRange<RangeT>> + void assignRangeToMemory(RangeT &&range, unsigned memIndex, + unsigned rangeIndex) { + // Utility functor used to type-erase the assignment. + auto assignRange = [&](auto &allocatedRangeMemory, auto &rangeMemory) { + // If the input range is empty, we don't need to allocate anything. + if (range.empty()) { + rangeMemory[rangeIndex] = {}; + } else { + // Allocate a buffer for this type range. + llvm::OwningArrayRef<T> storage(llvm::size(range)); + llvm::copy(range, storage.begin()); + + // Assign this to the range slot and use the range as the value for the + // memory index. + allocatedRangeMemory.emplace_back(std::move(storage)); + rangeMemory[rangeIndex] = allocatedRangeMemory.back(); + } + memory[memIndex] = &rangeMemory[rangeIndex]; + }; + + // Dispatch based on the concrete range type. + if constexpr (std::is_same_v<T, Type>) { + return assignRange(allocatedTypeRangeMemory, typeRangeMemory); + } else if constexpr (std::is_same_v<T, Value>) { + return assignRange(allocatedValueRangeMemory, valueRangeMemory); + } else { + llvm_unreachable("unhandled range type"); + } + } + /// The underlying bytecode buffer. const ByteCodeField *curCodeIt; @@ -1514,23 +1576,15 @@ void ByteCodeExecutor::executeContinue() { popCodeIt(); } -void ByteCodeExecutor::executeCreateTypes() { - LLVM_DEBUG(llvm::dbgs() << "Executing CreateTypes:\n"); +void ByteCodeExecutor::executeCreateConstantTypeRange() { + LLVM_DEBUG(llvm::dbgs() << "Executing CreateConstantTypeRange:\n"); unsigned memIndex = read(); unsigned rangeIndex = read(); ArrayAttr typesAttr = read<Attribute>().cast<ArrayAttr>(); LLVM_DEBUG(llvm::dbgs() << " * Types: " << typesAttr << "\n\n"); - - // Allocate a buffer for this type range. - llvm::OwningArrayRef<Type> storage(typesAttr.size()); - llvm::copy(typesAttr.getAsValueRange<TypeAttr>(), storage.begin()); - allocatedTypeRangeMemory.emplace_back(std::move(storage)); - - // Assign this to the range slot and use the range as the value for the - // memory index. - typeRangeMemory[rangeIndex] = allocatedTypeRangeMemory.back(); - memory[memIndex] = &typeRangeMemory[rangeIndex]; + assignRangeToMemory(typesAttr.getAsValueRange<TypeAttr>(), memIndex, + rangeIndex); } void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter, @@ -1539,7 +1593,7 @@ void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter, unsigned memIndex = read(); OperationState state(mainRewriteLoc, read<OperationName>()); - readValueList(state.operands); + readList(state.operands); for (unsigned i = 0, e = read(); i != e; ++i) { StringAttr name = read<StringAttr>(); if (Attribute attr = read<Attribute>()) @@ -1587,6 +1641,23 @@ void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter, }); } +template <typename T> +void ByteCodeExecutor::executeDynamicCreateRange(StringRef type) { + LLVM_DEBUG(llvm::dbgs() << "Executing CreateDynamic" << type << "Range:\n"); + unsigned memIndex = read(); + unsigned rangeIndex = read(); + SmallVector<T> values; + readList(values); + + LLVM_DEBUG({ + llvm::dbgs() << "\n * " << type << "s: "; + llvm::interleaveComma(values, llvm::dbgs()); + llvm::dbgs() << "\n"; + }); + + assignRangeToMemory(values, memIndex, rangeIndex); +} + void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) { LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n"); Operation *op = read<Operation *>(); @@ -1949,7 +2020,7 @@ void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) { LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n"); Operation *op = read<Operation *>(); SmallVector<Value, 16> args; - readValueList(args); + readList(args); LLVM_DEBUG({ llvm::dbgs() << " * Operation: " << *op << "\n" @@ -2076,11 +2147,17 @@ ByteCodeExecutor::execute(PatternRewriter &rewriter, case Continue: executeContinue(); break; + case CreateConstantTypeRange: + executeCreateConstantTypeRange(); + break; case CreateOperation: executeCreateOperation(rewriter, *mainRewriteLoc); break; - case CreateTypes: - executeCreateTypes(); + case CreateDynamicTypeRange: + executeDynamicCreateRange<Type>("Type"); + break; + case CreateDynamicValueRange: + executeDynamicCreateRange<Value>("Value"); break; case EraseOp: executeEraseOp(rewriter); |