diff options
author | River Riddle <riddleriver@gmail.com> | 2022-09-09 16:31:24 -0700 |
---|---|---|
committer | River Riddle <riddleriver@gmail.com> | 2022-11-08 01:57:57 -0800 |
commit | ce57789d8e5dc109dc9bd330232b31a22a80ad3a (patch) | |
tree | 74ec15ffb3bb119281a226e3573a100ca1c458ba | |
parent | 8c66344ee9f67f76b3cb6b3345a46345a2d3975a (diff) | |
download | llvm-ce57789d8e5dc109dc9bd330232b31a22a80ad3a.zip llvm-ce57789d8e5dc109dc9bd330232b31a22a80ad3a.tar.gz llvm-ce57789d8e5dc109dc9bd330232b31a22a80ad3a.tar.bz2 |
[mlir:PDL] Add support for creating ranges in rewrites
This commit adds support for building a concatenated range from
a given set of elements, either single element or other ranges, within a
rewrite. We could conceptually extend this to support constraining
input ranges, but the logic there is quite a bit more complex so it is
left for later work when a need arises.
Differential Revision: https://reviews.llvm.org/D133719
-rw-r--r-- | mlir/include/mlir/Dialect/PDL/IR/PDLOps.td | 42 | ||||
-rw-r--r-- | mlir/include/mlir/Dialect/PDL/IR/PDLTypes.h | 5 | ||||
-rw-r--r-- | mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td | 37 | ||||
-rw-r--r-- | mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp | 17 | ||||
-rw-r--r-- | mlir/lib/Dialect/PDL/IR/PDL.cpp | 33 | ||||
-rw-r--r-- | mlir/lib/Dialect/PDL/IR/PDLTypes.cpp | 6 | ||||
-rw-r--r-- | mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp | 34 | ||||
-rw-r--r-- | mlir/lib/Rewrite/ByteCode.cpp | 127 | ||||
-rw-r--r-- | mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir | 17 | ||||
-rw-r--r-- | mlir/test/Dialect/PDL/invalid.mlir | 17 | ||||
-rw-r--r-- | mlir/test/Dialect/PDLInterp/invalid.mlir | 14 | ||||
-rw-r--r-- | mlir/test/Rewrite/pdl-bytecode.mlir | 42 |
12 files changed, 363 insertions, 28 deletions
diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td index fbe991a..c85687e 100644 --- a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td +++ b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td @@ -437,6 +437,48 @@ def PDL_PatternOp : PDL_Op<"pattern", [ } //===----------------------------------------------------------------------===// +// pdl::RangeOp +//===----------------------------------------------------------------------===// + +def PDL_RangeOp : PDL_Op<"range", [Pure, HasParent<"pdl::RewriteOp">]> { + let summary = "Construct a range of pdl entities"; + let description = [{ + `pdl.range` operations construct a range from a given set of PDL entities, + which all share the same underlying element type. For example, a + `!pdl.range<value>` may be constructed from a list of `!pdl.value` + or `!pdl.range<value>` entities. + + Example: + + ```mlir + // Construct a range of values. + %valueRange = pdl.range %inputValue, %inputRange : !pdl.value, !pdl.range<value> + + // Construct a range of types. + %typeRange = pdl.range %inputType, %inputRange : !pdl.type, !pdl.range<type> + + // Construct an empty range of types. + %valueRange = pdl.range : !pdl.range<type> + ``` + + TODO: Range construction is currently limited to rewrites, but it could + be extended to constraints under certain circustances; i.e., if we can + determine how to extract the underlying elements. If we can't, e.g. if + there are multiple sub ranges used for construction, we won't be able + to determine their sizes during constraint time. + }]; + + let arguments = (ins Variadic<PDL_AnyType>:$arguments); + let results = (outs PDL_RangeOf<AnyTypeOf<[PDL_Type, PDL_Value]>>:$result); + let assemblyFormat = [{ + ($arguments^ `:` type($arguments))? + custom<RangeType>(ref(type($arguments)), type($result)) + attr-dict + }]; + let hasVerifier = 1; +} + +//===----------------------------------------------------------------------===// // pdl::ReplaceOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.h b/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.h index 8cbe31f..a342dcc 100644 --- a/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.h +++ b/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.h @@ -28,6 +28,11 @@ public: static bool classof(Type type); }; + +/// If the given type is a range, return its element type, otherwise return +/// the type itself. +Type getRangeElementTypeOrSelf(Type type); + } // namespace pdl } // namespace mlir diff --git a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td index 659bfbc..96d631b 100644 --- a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td +++ b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td @@ -992,6 +992,43 @@ def PDLInterp_IsNotNullOp let assemblyFormat = "$value `:` type($value) attr-dict `->` successors"; } + +//===----------------------------------------------------------------------===// +// pdl_interp::CreateRangeOp +//===----------------------------------------------------------------------===// + +def PDLInterp_CreateRangeOp : PDLInterp_Op<"create_range", [Pure]> { + let summary = "Construct a range of PDL entities"; + let description = [{ + `pdl_interp.create_range` operations construct a range from a given set of PDL + entities, which all share the same underlying element type. For example, a + `!pdl.range<value>` may be constructed from a list of `!pdl.value` + or `!pdl.range<value>` entities. + + Example: + + ```mlir + // Construct a range of values. + %valueRange = pdl_interp.create_range %inputValue, %inputRange : !pdl.value, !pdl.range<value> + + // Construct a range of types. + %typeRange = pdl_interp.create_range %inputType, %inputRange : !pdl.type, !pdl.range<type> + + // Construct an empty range of types. + %valueRange = pdl_interp.create_range : !pdl.range<type> + ``` + }]; + + let arguments = (ins Variadic<PDL_AnyType>:$arguments); + let results = (outs PDL_RangeOf<AnyTypeOf<[PDL_Type, PDL_Value]>>:$result); + let assemblyFormat = [{ + ($arguments^ `:` type($arguments))? + custom<RangeType>(ref(type($arguments)), type($result)) + attr-dict + }]; + let hasVerifier = 1; +} + //===----------------------------------------------------------------------===// // pdl_interp::RecordMatchOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp index 987e7a3..fdc95ab 100644 --- a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp @@ -89,6 +89,9 @@ private: void generateRewriter(pdl::OperationOp operationOp, DenseMap<Value, Value> &rewriteValues, function_ref<Value(Value)> mapRewriteValue); + void generateRewriter(pdl::RangeOp rangeOp, + DenseMap<Value, Value> &rewriteValues, + function_ref<Value(Value)> mapRewriteValue); void generateRewriter(pdl::ReplaceOp replaceOp, DenseMap<Value, Value> &rewriteValues, function_ref<Value(Value)> mapRewriteValue); @@ -668,8 +671,8 @@ SymbolRefAttr PatternLowering::generateRewriter( for (Operation &rewriteOp : *rewriter.getBody()) { llvm::TypeSwitch<Operation *>(&rewriteOp) .Case<pdl::ApplyNativeRewriteOp, pdl::AttributeOp, pdl::EraseOp, - pdl::OperationOp, pdl::ReplaceOp, pdl::ResultOp, pdl::ResultsOp, - pdl::TypeOp, pdl::TypesOp>([&](auto op) { + pdl::OperationOp, pdl::RangeOp, pdl::ReplaceOp, pdl::ResultOp, + pdl::ResultsOp, pdl::TypeOp, pdl::TypesOp>([&](auto op) { this->generateRewriter(op, rewriteValues, mapRewriteValue); }); } @@ -776,6 +779,16 @@ void PatternLowering::generateRewriter( } void PatternLowering::generateRewriter( + pdl::RangeOp rangeOp, DenseMap<Value, Value> &rewriteValues, + function_ref<Value(Value)> mapRewriteValue) { + SmallVector<Value, 4> replOperands; + for (Value operand : rangeOp.getArguments()) + replOperands.push_back(mapRewriteValue(operand)); + rewriteValues[rangeOp] = builder.create<pdl_interp::CreateRangeOp>( + rangeOp.getLoc(), rangeOp.getType(), replOperands); +} + +void PatternLowering::generateRewriter( pdl::ReplaceOp replaceOp, DenseMap<Value, Value> &rewriteValues, function_ref<Value(Value)> mapRewriteValue) { SmallVector<Value, 4> replOperands; diff --git a/mlir/lib/Dialect/PDL/IR/PDL.cpp b/mlir/lib/Dialect/PDL/IR/PDL.cpp index b96f34b..e33ba71 100644 --- a/mlir/lib/Dialect/PDL/IR/PDL.cpp +++ b/mlir/lib/Dialect/PDL/IR/PDL.cpp @@ -398,6 +398,39 @@ StringRef PatternOp::getDefaultDialect() { } //===----------------------------------------------------------------------===// +// pdl::RangeOp +//===----------------------------------------------------------------------===// + +static ParseResult parseRangeType(OpAsmParser &p, TypeRange argumentTypes, + Type &resultType) { + // If arguments were provided, infer the result type from the argument list. + if (!argumentTypes.empty()) { + resultType = RangeType::get(getRangeElementTypeOrSelf(argumentTypes[0])); + return success(); + } + // Otherwise, parse the type as a trailing type. + return p.parseColonType(resultType); +} + +static void printRangeType(OpAsmPrinter &p, RangeOp op, TypeRange argumentTypes, + Type resultType) { + if (argumentTypes.empty()) + p << ": " << resultType; +} + +LogicalResult RangeOp::verify() { + Type elementType = getType().getElementType(); + for (Type operandType : getOperandTypes()) { + Type operandElementType = getRangeElementTypeOrSelf(operandType); + if (operandElementType != elementType) { + return emitOpError("expected operand to have element type ") + << elementType << ", but got " << operandElementType; + } + } + return success(); +} + +//===----------------------------------------------------------------------===// // pdl::ReplaceOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/PDL/IR/PDLTypes.cpp b/mlir/lib/Dialect/PDL/IR/PDLTypes.cpp index f4dbbfb..49eee1a 100644 --- a/mlir/lib/Dialect/PDL/IR/PDLTypes.cpp +++ b/mlir/lib/Dialect/PDL/IR/PDLTypes.cpp @@ -59,6 +59,12 @@ bool PDLType::classof(Type type) { return llvm::isa<PDLDialect>(type.getDialect()); } +Type pdl::getRangeElementTypeOrSelf(Type type) { + if (auto rangeType = type.dyn_cast<RangeType>()) + return rangeType.getElementType(); + return type; +} + //===----------------------------------------------------------------------===// // RangeType //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp b/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp index 01670e3..e8a61ef 100644 --- a/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp +++ b/mlir/lib/Dialect/PDLInterp/IR/PDLInterp.cpp @@ -238,6 +238,40 @@ static Type getGetValueTypeOpValueType(Type type) { } //===----------------------------------------------------------------------===// +// pdl::CreateRangeOp +//===----------------------------------------------------------------------===// + +static ParseResult parseRangeType(OpAsmParser &p, TypeRange argumentTypes, + Type &resultType) { + // If arguments were provided, infer the result type from the argument list. + if (!argumentTypes.empty()) { + resultType = + pdl::RangeType::get(pdl::getRangeElementTypeOrSelf(argumentTypes[0])); + return success(); + } + // Otherwise, parse the type as a trailing type. + return p.parseColonType(resultType); +} + +static void printRangeType(OpAsmPrinter &p, CreateRangeOp op, + TypeRange argumentTypes, Type resultType) { + if (argumentTypes.empty()) + p << ": " << resultType; +} + +LogicalResult CreateRangeOp::verify() { + Type elementType = getType().getElementType(); + for (Type operandType : getOperandTypes()) { + Type operandElementType = pdl::getRangeElementTypeOrSelf(operandType); + if (operandElementType != elementType) { + return emitOpError("expected operand to have element type ") + << elementType << ", but got " << operandElementType; + } + } + return success(); +} + +//===----------------------------------------------------------------------===// // pdl_interp::SwitchAttributeOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp index 9cc51da..6b1dfb9 100644 --- a/mlir/lib/Rewrite/ByteCode.cpp +++ b/mlir/lib/Rewrite/ByteCode.cpp @@ -99,10 +99,14 @@ enum OpCode : ByteCodeField { CheckTypes, /// Continue to the next iteration of a loop. Continue, + /// Create a type range from a list of constant types. + CreateConstantTypeRange, /// Create an operation. CreateOperation, - /// Create a range of types. - CreateTypes, + /// Create a type range from a list of dynamic types. + CreateDynamicTypeRange, + /// Create a value range. + CreateDynamicValueRange, /// Erase an operation. EraseOp, /// Extract the op from a range at the specified index. @@ -265,6 +269,7 @@ private: void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer); void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer); void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer); + void generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer); void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer); void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer); void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer); @@ -742,9 +747,9 @@ void Generator::generate(Operation *op, ByteCodeWriter &writer) { pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp, pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp, pdl_interp::ContinueOp, pdl_interp::CreateAttributeOp, - pdl_interp::CreateOperationOp, pdl_interp::CreateTypeOp, - pdl_interp::CreateTypesOp, pdl_interp::EraseOp, - pdl_interp::ExtractOp, pdl_interp::FinalizeOp, + pdl_interp::CreateOperationOp, pdl_interp::CreateRangeOp, + pdl_interp::CreateTypeOp, pdl_interp::CreateTypesOp, + pdl_interp::EraseOp, pdl_interp::ExtractOp, pdl_interp::FinalizeOp, pdl_interp::ForEachOp, pdl_interp::GetAttributeOp, pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp, pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp, @@ -863,12 +868,24 @@ void Generator::generate(pdl_interp::CreateOperationOp op, else writer.appendPDLValueList(op.getInputResultTypes()); } +void Generator::generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer) { + // Append the correct opcode for the range type. + TypeSwitch<Type>(op.getType().getElementType()) + .Case( + [&](pdl::TypeType) { writer.append(OpCode::CreateDynamicTypeRange); }) + .Case([&](pdl::ValueType) { + writer.append(OpCode::CreateDynamicValueRange); + }); + + writer.append(op.getResult(), getRangeStorageIndex(op.getResult())); + writer.appendPDLValueList(op->getOperands()); +} void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) { // Simply repoint the memory index of the result to the constant. getMemIndex(op.getResult()) = getMemIndex(op.getValue()); } void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) { - writer.append(OpCode::CreateTypes, op.getResult(), + writer.append(OpCode::CreateConstantTypeRange, op.getResult(), getRangeStorageIndex(op.getResult()), op.getValue()); } void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) { @@ -1103,9 +1120,11 @@ private: void executeCheckResultCount(); void executeCheckTypes(); void executeContinue(); + void executeCreateConstantTypeRange(); void executeCreateOperation(PatternRewriter &rewriter, Location mainRewriteLoc); - void executeCreateTypes(); + template <typename T> + void executeDynamicCreateRange(StringRef type); void executeEraseOp(PatternRewriter &rewriter); template <typename T, typename Range, PDLValue::Kind kind> void executeExtract(); @@ -1172,8 +1191,18 @@ private: } /// Read a list of values from the bytecode buffer. The values may be encoded - /// as either Value or ValueRange elements. - void readValueList(SmallVectorImpl<Value> &list) { + /// either as a single element or a range of elements. + void readList(SmallVectorImpl<Type> &list) { + for (unsigned i = 0, e = read(); i != e; ++i) { + if (read<PDLValue::Kind>() == PDLValue::Kind::Type) { + list.push_back(read<Type>()); + } else { + TypeRange *values = read<TypeRange *>(); + list.append(values->begin(), values->end()); + } + } + } + void readList(SmallVectorImpl<Value> &list) { for (unsigned i = 0, e = read(); i != e; ++i) { if (read<PDLValue::Kind>() == PDLValue::Kind::Value) { list.push_back(read<Value>()); @@ -1292,6 +1321,39 @@ private: return static_cast<PDLValue::Kind>(readImpl<ByteCodeField>()); } + /// Assign the given range to the given memory index. This allocates a new + /// range object if necessary. + template <typename RangeT, typename T = llvm::detail::ValueOfRange<RangeT>> + void assignRangeToMemory(RangeT &&range, unsigned memIndex, + unsigned rangeIndex) { + // Utility functor used to type-erase the assignment. + auto assignRange = [&](auto &allocatedRangeMemory, auto &rangeMemory) { + // If the input range is empty, we don't need to allocate anything. + if (range.empty()) { + rangeMemory[rangeIndex] = {}; + } else { + // Allocate a buffer for this type range. + llvm::OwningArrayRef<T> storage(llvm::size(range)); + llvm::copy(range, storage.begin()); + + // Assign this to the range slot and use the range as the value for the + // memory index. + allocatedRangeMemory.emplace_back(std::move(storage)); + rangeMemory[rangeIndex] = allocatedRangeMemory.back(); + } + memory[memIndex] = &rangeMemory[rangeIndex]; + }; + + // Dispatch based on the concrete range type. + if constexpr (std::is_same_v<T, Type>) { + return assignRange(allocatedTypeRangeMemory, typeRangeMemory); + } else if constexpr (std::is_same_v<T, Value>) { + return assignRange(allocatedValueRangeMemory, valueRangeMemory); + } else { + llvm_unreachable("unhandled range type"); + } + } + /// The underlying bytecode buffer. const ByteCodeField *curCodeIt; @@ -1514,23 +1576,15 @@ void ByteCodeExecutor::executeContinue() { popCodeIt(); } -void ByteCodeExecutor::executeCreateTypes() { - LLVM_DEBUG(llvm::dbgs() << "Executing CreateTypes:\n"); +void ByteCodeExecutor::executeCreateConstantTypeRange() { + LLVM_DEBUG(llvm::dbgs() << "Executing CreateConstantTypeRange:\n"); unsigned memIndex = read(); unsigned rangeIndex = read(); ArrayAttr typesAttr = read<Attribute>().cast<ArrayAttr>(); LLVM_DEBUG(llvm::dbgs() << " * Types: " << typesAttr << "\n\n"); - - // Allocate a buffer for this type range. - llvm::OwningArrayRef<Type> storage(typesAttr.size()); - llvm::copy(typesAttr.getAsValueRange<TypeAttr>(), storage.begin()); - allocatedTypeRangeMemory.emplace_back(std::move(storage)); - - // Assign this to the range slot and use the range as the value for the - // memory index. - typeRangeMemory[rangeIndex] = allocatedTypeRangeMemory.back(); - memory[memIndex] = &typeRangeMemory[rangeIndex]; + assignRangeToMemory(typesAttr.getAsValueRange<TypeAttr>(), memIndex, + rangeIndex); } void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter, @@ -1539,7 +1593,7 @@ void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter, unsigned memIndex = read(); OperationState state(mainRewriteLoc, read<OperationName>()); - readValueList(state.operands); + readList(state.operands); for (unsigned i = 0, e = read(); i != e; ++i) { StringAttr name = read<StringAttr>(); if (Attribute attr = read<Attribute>()) @@ -1587,6 +1641,23 @@ void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter, }); } +template <typename T> +void ByteCodeExecutor::executeDynamicCreateRange(StringRef type) { + LLVM_DEBUG(llvm::dbgs() << "Executing CreateDynamic" << type << "Range:\n"); + unsigned memIndex = read(); + unsigned rangeIndex = read(); + SmallVector<T> values; + readList(values); + + LLVM_DEBUG({ + llvm::dbgs() << "\n * " << type << "s: "; + llvm::interleaveComma(values, llvm::dbgs()); + llvm::dbgs() << "\n"; + }); + + assignRangeToMemory(values, memIndex, rangeIndex); +} + void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) { LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n"); Operation *op = read<Operation *>(); @@ -1949,7 +2020,7 @@ void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) { LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n"); Operation *op = read<Operation *>(); SmallVector<Value, 16> args; - readValueList(args); + readList(args); LLVM_DEBUG({ llvm::dbgs() << " * Operation: " << *op << "\n" @@ -2076,11 +2147,17 @@ ByteCodeExecutor::execute(PatternRewriter &rewriter, case Continue: executeContinue(); break; + case CreateConstantTypeRange: + executeCreateConstantTypeRange(); + break; case CreateOperation: executeCreateOperation(rewriter, *mainRewriteLoc); break; - case CreateTypes: - executeCreateTypes(); + case CreateDynamicTypeRange: + executeDynamicCreateRange<Type>("Type"); + break; + case CreateDynamicValueRange: + executeDynamicCreateRange<Value>("Value"); break; case EraseOp: executeEraseOp(rewriter); diff --git a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir index d6e8f4a..e5a84d6 100644 --- a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir +++ b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-rewriter.mlir @@ -243,3 +243,20 @@ module @unbound_rewrite_op { } // ----- + +// CHECK-LABEL: module @range_op +module @range_op { + // CHECK: module @rewriters + // CHECK: func @pdl_generated_rewriter(%[[OPERAND:.*]]: !pdl.value) + // CHECK: %[[RANGE1:.*]] = pdl_interp.create_range : !pdl.range<value> + // CHECK: %[[RANGE2:.*]] = pdl_interp.create_range %[[OPERAND]], %[[RANGE1]] : !pdl.value, !pdl.range<value> + // CHECK: pdl_interp.finalize + pdl.pattern : benefit(1) { + %operand = pdl.operand + %root = operation "foo.op"(%operand : !pdl.value) + rewrite %root { + %emptyRange = pdl.range : !pdl.range<value> + %range = pdl.range %operand, %emptyRange : !pdl.value, !pdl.range<value> + } + } +} diff --git a/mlir/test/Dialect/PDL/invalid.mlir b/mlir/test/Dialect/PDL/invalid.mlir index 61c0aae..522e9fb 100644 --- a/mlir/test/Dialect/PDL/invalid.mlir +++ b/mlir/test/Dialect/PDL/invalid.mlir @@ -238,6 +238,23 @@ pdl.pattern : benefit(1) { // ----- //===----------------------------------------------------------------------===// +// pdl::RangeOp +//===----------------------------------------------------------------------===// + +pdl.pattern : benefit(1) { + %operand = pdl.operand + %resultType = pdl.type + %root = pdl.operation "baz.op"(%operand : !pdl.value) -> (%resultType : !pdl.type) + + rewrite %root { + // expected-error @below {{expected operand to have element type '!pdl.value', but got '!pdl.type'}} + %range = pdl.range %operand, %resultType : !pdl.value, !pdl.type + } +} + +// ----- + +//===----------------------------------------------------------------------===// // pdl::ResultsOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/PDLInterp/invalid.mlir b/mlir/test/Dialect/PDLInterp/invalid.mlir index f194d32..0457a15 100644 --- a/mlir/test/Dialect/PDLInterp/invalid.mlir +++ b/mlir/test/Dialect/PDLInterp/invalid.mlir @@ -1,7 +1,7 @@ // RUN: mlir-opt %s -split-input-file -verify-diagnostics //===----------------------------------------------------------------------===// -// pdl::CreateOperationOp +// pdl_interp::CreateOperationOp //===----------------------------------------------------------------------===// pdl_interp.func @rewriter() { @@ -23,3 +23,15 @@ pdl_interp.func @rewriter() { } : (!pdl.type) -> (!pdl.operation) pdl_interp.finalize } + +// ----- + +//===----------------------------------------------------------------------===// +// pdl_interp::CreateRangeOp +//===----------------------------------------------------------------------===// + +pdl_interp.func @rewriter(%value: !pdl.value, %type: !pdl.type) { + // expected-error @below {{expected operand to have element type '!pdl.value', but got '!pdl.type'}} + %range = pdl_interp.create_range %value, %type : !pdl.value, !pdl.type + pdl_interp.finalize +} diff --git a/mlir/test/Rewrite/pdl-bytecode.mlir b/mlir/test/Rewrite/pdl-bytecode.mlir index 20e2490..565874f 100644 --- a/mlir/test/Rewrite/pdl-bytecode.mlir +++ b/mlir/test/Rewrite/pdl-bytecode.mlir @@ -569,6 +569,48 @@ module @ir attributes { test.create_op_infer_results } { // ----- //===----------------------------------------------------------------------===// +// pdl_interp::CreateRangeOp +//===----------------------------------------------------------------------===// + +module @patterns { + pdl_interp.func @matcher(%root : !pdl.operation) { + pdl_interp.check_operand_count of %root is 2 -> ^pat1, ^end + + ^pat1: + pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end + + ^end: + pdl_interp.finalize + } + + module @rewriters { + pdl_interp.func @success(%root: !pdl.operation) { + %rootOperand = pdl_interp.get_operand 0 of %root + %rootOperands = pdl_interp.get_operands of %root : !pdl.range<value> + %operandRange = pdl_interp.create_range %rootOperand, %rootOperands : !pdl.value, !pdl.range<value> + + %operandType = pdl_interp.get_value_type of %rootOperand : !pdl.type + %operandTypes = pdl_interp.get_value_type of %rootOperands : !pdl.range<type> + %typeRange = pdl_interp.create_range %operandType, %operandTypes : !pdl.type, !pdl.range<type> + + %op = pdl_interp.create_operation "test.success"(%operandRange : !pdl.range<value>) -> (%typeRange : !pdl.range<type>) + pdl_interp.erase %root + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.create_range_1 +// CHECK: %[[INPUTS:.*]]:2 = "test.input"() +// CHECK: "test.success"(%[[INPUTS]]#0, %[[INPUTS]]#0, %[[INPUTS]]#1) : (i32, i32, i32) -> (i32, i32, i32) +module @ir attributes { test.create_range_1 } { + %values:2 = "test.input"() : () -> (i32, i32) + "test.op"(%values#0, %values#1) : (i32, i32) -> () +} + +// ----- + +//===----------------------------------------------------------------------===// // pdl_interp::CreateTypeOp //===----------------------------------------------------------------------===// |