diff options
18 files changed, 555 insertions, 98 deletions
diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td index 4e9ebcc..1e108c3 100644 --- a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td +++ b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td @@ -35,20 +35,25 @@ def PDL_ApplyNativeConstraintOp let description = [{ `pdl.apply_native_constraint` operations apply a native C++ constraint, that has been registered externally with the consumer of PDL, to a given set of - entities. + entities and optionally return a number of values. Example: ```mlir // Apply `myConstraint` to the entities defined by `input`, `attr`, and `op`. pdl.apply_native_constraint "myConstraint"(%input, %attr, %op : !pdl.value, !pdl.attribute, !pdl.operation) + // Apply constraint `with_result` to `root`. This constraint returns an attribute. + %attr = pdl.apply_native_constraint "with_result"(%root : !pdl.operation) : !pdl.attribute ``` }]; let arguments = (ins StrAttr:$name, Variadic<PDL_AnyType>:$args, DefaultValuedAttr<BoolAttr, "false">:$isNegated); - let assemblyFormat = "$name `(` $args `:` type($args) `)` attr-dict"; + let results = (outs Variadic<PDL_AnyType>:$results); + let assemblyFormat = [{ + $name `(` $args `:` type($args) `)` (`:` type($results)^ )? attr-dict + }]; let hasVerifier = 1; } diff --git a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td index 48f625b..901acc0 100644 --- a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td +++ b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td @@ -88,7 +88,9 @@ def PDLInterp_ApplyConstraintOp : PDLInterp_PredicateOp<"apply_constraint"> { let description = [{ `pdl_interp.apply_constraint` operations apply a generic constraint, that has been registered with the interpreter, with a given set of positional - values. On success, this operation branches to the true destination, + values. + The constraint function may return any number of results. + On success, this operation branches to the true destination, otherwise the false destination is taken. This behavior can be reversed by setting the attribute `isNegated` to true. @@ -104,8 +106,10 @@ def PDLInterp_ApplyConstraintOp : PDLInterp_PredicateOp<"apply_constraint"> { let arguments = (ins StrAttr:$name, Variadic<PDL_AnyType>:$args, DefaultValuedAttr<BoolAttr, "false">:$isNegated); + let results = (outs Variadic<PDL_AnyType>:$results); let assemblyFormat = [{ - $name `(` $args `:` type($args) `)` attr-dict `->` successors + $name `(` $args `:` type($args) `)` (`:` type($results)^)? attr-dict + `->` successors }]; } diff --git a/mlir/include/mlir/IR/PDLPatternMatch.h.inc b/mlir/include/mlir/IR/PDLPatternMatch.h.inc index a215da8..66286ed 100644 --- a/mlir/include/mlir/IR/PDLPatternMatch.h.inc +++ b/mlir/include/mlir/IR/PDLPatternMatch.h.inc @@ -318,8 +318,9 @@ protected: /// A generic PDL pattern constraint function. This function applies a /// constraint to a given set of opaque PDLValue entities. Returns success if /// the constraint successfully held, failure otherwise. -using PDLConstraintFunction = - std::function<LogicalResult(PatternRewriter &, ArrayRef<PDLValue>)>; +using PDLConstraintFunction = std::function<LogicalResult( + PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>; + /// A native PDL rewrite function. This function performs a rewrite on the /// given set of values. Any results from this rewrite that should be passed /// back to PDL should be added to the provided result list. This method is only @@ -726,7 +727,7 @@ std::enable_if_t< PDLConstraintFunction> buildConstraintFn(ConstraintFnT &&constraintFn) { return [constraintFn = std::forward<ConstraintFnT>(constraintFn)]( - PatternRewriter &rewriter, + PatternRewriter &rewriter, PDLResultList &, ArrayRef<PDLValue> values) -> LogicalResult { auto argIndices = std::make_index_sequence< llvm::function_traits<ConstraintFnT>::num_args - 1>(); @@ -842,10 +843,13 @@ public: /// Register a constraint function with PDL. A constraint function may be /// specified in one of two ways: /// - /// * `LogicalResult (PatternRewriter &, ArrayRef<PDLValue>)` + /// * `LogicalResult (PatternRewriter &, + /// PDLResultList &, + /// ArrayRef<PDLValue>)` /// /// In this overload the arguments of the constraint function are passed via - /// the low-level PDLValue form. + /// the low-level PDLValue form, and the results are manually appended to + /// the given result list. /// /// * `LogicalResult (PatternRewriter &, ValueTs... values)` /// @@ -960,8 +964,8 @@ public: } }; class PDLResultList {}; -using PDLConstraintFunction = - std::function<LogicalResult(PatternRewriter &, ArrayRef<PDLValue>)>; +using PDLConstraintFunction = std::function<LogicalResult( + PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>; using PDLRewriteFunction = std::function<LogicalResult( PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>; diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp index e911631..b00cd0d 100644 --- a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp @@ -50,7 +50,8 @@ private: /// Generate interpreter operations for the tree rooted at the given matcher /// node, in the specified region. - Block *generateMatcher(MatcherNode &node, Region ®ion); + Block *generateMatcher(MatcherNode &node, Region ®ion, + Block *block = nullptr); /// Get or create an access to the provided positional value in the current /// block. This operation may mutate the provided block pointer if nested @@ -148,6 +149,10 @@ private: /// A mapping between pattern operations and the corresponding configuration /// set. DenseMap<Operation *, PDLPatternConfigSet *> *configMap; + + /// A mapping from a constraint question to the ApplyConstraintOp + /// that implements it. + DenseMap<ConstraintQuestion *, pdl_interp::ApplyConstraintOp> constraintOpMap; }; } // namespace @@ -182,9 +187,11 @@ void PatternLowering::lower(ModuleOp module) { firstMatcherBlock->erase(); } -Block *PatternLowering::generateMatcher(MatcherNode &node, Region ®ion) { +Block *PatternLowering::generateMatcher(MatcherNode &node, Region ®ion, + Block *block) { // Push a new scope for the values used by this matcher. - Block *block = ®ion.emplaceBlock(); + if (!block) + block = ®ion.emplaceBlock(); ValueMapScope scope(values); // If this is the return node, simply insert the corresponding interpreter @@ -364,6 +371,15 @@ Value PatternLowering::getValueAt(Block *¤tBlock, Position *pos) { loc, cast<ArrayAttr>(rawTypeAttr)); break; } + case Predicates::ConstraintResultPos: { + // Due to the order of traversal, the ApplyConstraintOp has already been + // created and we can find it in constraintOpMap. + auto *constrResPos = cast<ConstraintPosition>(pos); + auto i = constraintOpMap.find(constrResPos->getQuestion()); + assert(i != constraintOpMap.end()); + value = i->second->getResult(constrResPos->getIndex()); + break; + } default: llvm_unreachable("Generating unknown Position getter"); break; @@ -390,12 +406,11 @@ void PatternLowering::generate(BoolNode *boolNode, Block *¤tBlock, args.push_back(getValueAt(currentBlock, position)); } - // Generate the matcher in the current (potentially nested) region - // and get the failure successor. - Block *success = generateMatcher(*boolNode->getSuccessNode(), *region); + // Generate a new block as success successor and get the failure successor. + Block *success = ®ion->emplaceBlock(); Block *failure = failureBlockStack.back(); - // Finally, create the predicate. + // Create the predicate. builder.setInsertionPointToEnd(currentBlock); Predicates::Kind kind = question->getKind(); switch (kind) { @@ -447,14 +462,20 @@ void PatternLowering::generate(BoolNode *boolNode, Block *¤tBlock, } case Predicates::ConstraintQuestion: { auto *cstQuestion = cast<ConstraintQuestion>(question); - builder.create<pdl_interp::ApplyConstraintOp>( - loc, cstQuestion->getName(), args, cstQuestion->getIsNegated(), success, - failure); + auto applyConstraintOp = builder.create<pdl_interp::ApplyConstraintOp>( + loc, cstQuestion->getResultTypes(), cstQuestion->getName(), args, + cstQuestion->getIsNegated(), success, failure); + + constraintOpMap.insert({cstQuestion, applyConstraintOp}); break; } default: llvm_unreachable("Generating unknown Predicate operation"); } + + // Generate the matcher in the current (potentially nested) region. + // This might use the results of the current predicate. + generateMatcher(*boolNode->getSuccessNode(), *region, success); } template <typename OpT, typename PredT, typename ValT = typename PredT::KeyTy> diff --git a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h index 2c9b63f..5ad2c47 100644 --- a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h +++ b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h @@ -47,6 +47,7 @@ enum Kind : unsigned { OperandPos, OperandGroupPos, AttributePos, + ConstraintResultPos, ResultPos, ResultGroupPos, TypePos, @@ -280,6 +281,28 @@ struct OperationPosition : public PredicateBase<OperationPosition, Position, }; //===----------------------------------------------------------------------===// +// ConstraintPosition + +struct ConstraintQuestion; + +/// A position describing the result of a native constraint. It saves the +/// corresponding ConstraintQuestion and result index to enable referring +/// back to them +struct ConstraintPosition + : public PredicateBase<ConstraintPosition, Position, + std::pair<ConstraintQuestion *, unsigned>, + Predicates::ConstraintResultPos> { + using PredicateBase::PredicateBase; + + /// Returns the ConstraintQuestion to enable keeping track of the native + /// constraint this position stems from. + ConstraintQuestion *getQuestion() const { return key.first; } + + // Returns the result index of this position + unsigned getIndex() const { return key.second; } +}; + +//===----------------------------------------------------------------------===// // ResultPosition /// A position describing a result of an operation. @@ -447,11 +470,13 @@ struct AttributeQuestion : public PredicateBase<AttributeQuestion, Qualifier, void, Predicates::AttributeQuestion> {}; -/// Apply a parameterized constraint to multiple position values. +/// Apply a parameterized constraint to multiple position values and possibly +/// produce results. struct ConstraintQuestion - : public PredicateBase<ConstraintQuestion, Qualifier, - std::tuple<StringRef, ArrayRef<Position *>, bool>, - Predicates::ConstraintQuestion> { + : public PredicateBase< + ConstraintQuestion, Qualifier, + std::tuple<StringRef, ArrayRef<Position *>, ArrayRef<Type>, bool>, + Predicates::ConstraintQuestion> { using Base::Base; /// Return the name of the constraint. @@ -460,15 +485,19 @@ struct ConstraintQuestion /// Return the arguments of the constraint. ArrayRef<Position *> getArgs() const { return std::get<1>(key); } + /// Return the result types of the constraint. + ArrayRef<Type> getResultTypes() const { return std::get<2>(key); } + /// Return the negation status of the constraint. - bool getIsNegated() const { return std::get<2>(key); } + bool getIsNegated() const { return std::get<3>(key); } /// Construct an instance with the given storage allocator. static ConstraintQuestion *construct(StorageUniquer::StorageAllocator &alloc, KeyTy key) { return Base::construct(alloc, KeyTy{alloc.copyInto(std::get<0>(key)), alloc.copyInto(std::get<1>(key)), - std::get<2>(key)}); + alloc.copyInto(std::get<2>(key)), + std::get<3>(key)}); } /// Returns a hash suitable for the given keytype. @@ -526,6 +555,7 @@ public: // Register the types of Positions with the uniquer. registerParametricStorageType<AttributePosition>(); registerParametricStorageType<AttributeLiteralPosition>(); + registerParametricStorageType<ConstraintPosition>(); registerParametricStorageType<ForEachPosition>(); registerParametricStorageType<OperandPosition>(); registerParametricStorageType<OperandGroupPosition>(); @@ -588,6 +618,12 @@ public: return OperationPosition::get(uniquer, p); } + // Returns a position for a new value created by a constraint. + ConstraintPosition *getConstraintPosition(ConstraintQuestion *q, + unsigned index) { + return ConstraintPosition::get(uniquer, std::make_pair(q, index)); + } + /// Returns an attribute position for an attribute of the given operation. Position *getAttribute(OperationPosition *p, StringRef name) { return AttributePosition::get(uniquer, p, StringAttr::get(ctx, name)); @@ -673,11 +709,11 @@ public: } /// Create a predicate that applies a generic constraint. - Predicate getConstraint(StringRef name, ArrayRef<Position *> pos, - bool isNegated) { - return { - ConstraintQuestion::get(uniquer, std::make_tuple(name, pos, isNegated)), - TrueAnswer::get(uniquer)}; + Predicate getConstraint(StringRef name, ArrayRef<Position *> args, + ArrayRef<Type> resultTypes, bool isNegated) { + return {ConstraintQuestion::get( + uniquer, std::make_tuple(name, args, resultTypes, isNegated)), + TrueAnswer::get(uniquer)}; } /// Create a predicate comparing a value with null. diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp index a9c3b0a..f3d0e08 100644 --- a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp @@ -15,6 +15,7 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "llvm/ADT/MapVector.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" #include <queue> @@ -49,14 +50,15 @@ static void getTreePredicates(std::vector<PositionalPredicate> &predList, DenseMap<Value, Position *> &inputs, AttributePosition *pos) { assert(isa<pdl::AttributeType>(val.getType()) && "expected attribute type"); - pdl::AttributeOp attr = cast<pdl::AttributeOp>(val.getDefiningOp()); predList.emplace_back(pos, builder.getIsNotNull()); - // If the attribute has a type or value, add a constraint. - if (Value type = attr.getValueType()) - getTreePredicates(predList, type, builder, inputs, builder.getType(pos)); - else if (Attribute value = attr.getValueAttr()) - predList.emplace_back(pos, builder.getAttributeConstraint(value)); + if (auto attr = dyn_cast<pdl::AttributeOp>(val.getDefiningOp())) { + // If the attribute has a type or value, add a constraint. + if (Value type = attr.getValueType()) + getTreePredicates(predList, type, builder, inputs, builder.getType(pos)); + else if (Attribute value = attr.getValueAttr()) + predList.emplace_back(pos, builder.getAttributeConstraint(value)); + } } /// Collect all of the predicates for the given operand position. @@ -272,8 +274,25 @@ static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op, // Push the constraint to the furthest position. Position *pos = *std::max_element(allPositions.begin(), allPositions.end(), comparePosDepth); - PredicateBuilder::Predicate pred = - builder.getConstraint(op.getName(), allPositions, op.getIsNegated()); + ResultRange results = op.getResults(); + PredicateBuilder::Predicate pred = builder.getConstraint( + op.getName(), allPositions, SmallVector<Type>(results.getTypes()), + op.getIsNegated()); + + // For each result register a position so it can be used later + for (auto [i, result] : llvm::enumerate(results)) { + ConstraintQuestion *q = cast<ConstraintQuestion>(pred.first); + ConstraintPosition *pos = builder.getConstraintPosition(q, i); + auto [it, inserted] = inputs.insert({result, pos}); + // If this is an input value that has been visited in the tree, add a + // constraint to ensure that both instances refer to the same value. + if (!inserted) { + auto minMaxPositions = + std::minmax<Position *>(pos, it->second, comparePosDepth); + predList.emplace_back(minMaxPositions.second, + builder.getEqualTo(minMaxPositions.first)); + } + } predList.emplace_back(pos, pred); } @@ -875,6 +894,49 @@ static void insertExitNode(std::unique_ptr<MatcherNode> *root) { *root = std::make_unique<ExitNode>(); } +/// Sorts the range begin/end with the partial order given by cmp. +template <typename Iterator, typename Compare> +static void stableTopologicalSort(Iterator begin, Iterator end, Compare cmp) { + while (begin != end) { + // Cannot compute sortBeforeOthers in the predicate of stable_partition + // because stable_partition will not keep the [begin, end) range intact + // while it runs. + llvm::SmallPtrSet<typename Iterator::value_type, 16> sortBeforeOthers; + for (auto i = begin; i != end; ++i) { + if (std::none_of(begin, end, [&](auto const &b) { return cmp(b, *i); })) + sortBeforeOthers.insert(*i); + } + + auto const next = std::stable_partition(begin, end, [&](auto const &a) { + return sortBeforeOthers.contains(a); + }); + assert(next != begin && "not a partial ordering"); + begin = next; + } +} + +/// Returns true if 'b' depends on a result of 'a'. +static bool dependsOn(OrderedPredicate *a, OrderedPredicate *b) { + auto *cqa = dyn_cast<ConstraintQuestion>(a->question); + if (!cqa) + return false; + + auto positionDependsOnA = [&](Position *p) { + auto *cp = dyn_cast<ConstraintPosition>(p); + return cp && cp->getQuestion() == cqa; + }; + + if (auto *cqb = dyn_cast<ConstraintQuestion>(b->question)) { + // Does any argument of b use a? + return llvm::any_of(cqb->getArgs(), positionDependsOnA); + } + if (auto *equalTo = dyn_cast<EqualToQuestion>(b->question)) { + return positionDependsOnA(b->position) || + positionDependsOnA(equalTo->getValue()); + } + return positionDependsOnA(b->position); +} + /// Given a module containing PDL pattern operations, generate a matcher tree /// using the patterns within the given module and return the root matcher node. std::unique_ptr<MatcherNode> @@ -955,6 +1017,10 @@ MatcherNode::generateMatcherTree(ModuleOp module, PredicateBuilder &builder, return *lhs < *rhs; }); + // Mostly keep the now established order, but also ensure that + // ConstraintQuestions come after the results they use. + stableTopologicalSort(ordered.begin(), ordered.end(), dependsOn); + // Build the matchers for each of the pattern predicate lists. std::unique_ptr<MatcherNode> root; for (OrderedPredicateList &list : lists) diff --git a/mlir/lib/Dialect/PDL/IR/PDL.cpp b/mlir/lib/Dialect/PDL/IR/PDL.cpp index d5f3467..428b19f 100644 --- a/mlir/lib/Dialect/PDL/IR/PDL.cpp +++ b/mlir/lib/Dialect/PDL/IR/PDL.cpp @@ -94,6 +94,12 @@ static void visit(Operation *op, DenseSet<Operation *> &visited) { LogicalResult ApplyNativeConstraintOp::verify() { if (getNumOperands() == 0) return emitOpError("expected at least one argument"); + if (llvm::any_of(getResults(), [](OpResult result) { + return isa<OperationType>(result.getType()); + })) { + return emitOpError( + "returning an operation from a constraint is not supported"); + } return success(); } diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp index 6e6992d..559ce7f 100644 --- a/mlir/lib/Rewrite/ByteCode.cpp +++ b/mlir/lib/Rewrite/ByteCode.cpp @@ -769,11 +769,25 @@ void Generator::generate(Operation *op, ByteCodeWriter &writer) { void Generator::generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer) { - assert(constraintToMemIndex.count(op.getName()) && - "expected index for constraint function"); + // Constraints that should return a value have to be registered as rewrites. + // If a constraint and a rewrite of similar name are registered the + // constraint takes precedence writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]); writer.appendPDLValueList(op.getArgs()); writer.append(ByteCodeField(op.getIsNegated())); + ResultRange results = op.getResults(); + writer.append(ByteCodeField(results.size())); + for (Value result : results) { + // We record the expected kind of the result, so that we can provide extra + // verification of the native rewrite function and handle the failure case + // of constraints accordingly. + writer.appendPDLValueKind(result); + + // Range results also need to append the range storage index. + if (isa<pdl::RangeType>(result.getType())) + writer.append(getRangeStorageIndex(result)); + writer.append(result); + } writer.append(op.getSuccessors()); } void Generator::generate(pdl_interp::ApplyRewriteOp op, @@ -786,11 +800,9 @@ void Generator::generate(pdl_interp::ApplyRewriteOp op, ResultRange results = op.getResults(); writer.append(ByteCodeField(results.size())); for (Value result : results) { - // In debug mode we also record the expected kind of the result, so that we + // We record the expected kind of the result, so that we // can provide extra verification of the native rewrite function. -#ifndef NDEBUG writer.appendPDLValueKind(result); -#endif // Range results also need to append the range storage index. if (isa<pdl::RangeType>(result.getType())) @@ -1076,6 +1088,28 @@ void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const { // ByteCode Execution namespace { +/// This class is an instantiation of the PDLResultList that provides access to +/// the returned results. This API is not on `PDLResultList` to avoid +/// overexposing access to information specific solely to the ByteCode. +class ByteCodeRewriteResultList : public PDLResultList { +public: + ByteCodeRewriteResultList(unsigned maxNumResults) + : PDLResultList(maxNumResults) {} + + /// Return the list of PDL results. + MutableArrayRef<PDLValue> getResults() { return results; } + + /// Return the type ranges allocated by this list. + MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() { + return allocatedTypeRanges; + } + + /// Return the value ranges allocated by this list. + MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() { + return allocatedValueRanges; + } +}; + /// This class provides support for executing a bytecode stream. class ByteCodeExecutor { public: @@ -1152,6 +1186,9 @@ private: void executeSwitchResultCount(); void executeSwitchType(); void executeSwitchTypes(); + void processNativeFunResults(ByteCodeRewriteResultList &results, + unsigned numResults, + LogicalResult &rewriteResult); /// Pushes a code iterator to the stack. void pushCodeIt(const ByteCodeField *it) { resumeCodeIt.push_back(it); } @@ -1225,6 +1262,8 @@ private: return T::getFromOpaquePointer(pointer); } + void skip(size_t skipN) { curCodeIt += skipN; } + /// Jump to a specific successor based on a predicate value. void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); } /// Jump to a specific successor based on a destination index. @@ -1381,33 +1420,11 @@ private: ArrayRef<PDLConstraintFunction> constraintFunctions; ArrayRef<PDLRewriteFunction> rewriteFunctions; }; - -/// This class is an instantiation of the PDLResultList that provides access to -/// the returned results. This API is not on `PDLResultList` to avoid -/// overexposing access to information specific solely to the ByteCode. -class ByteCodeRewriteResultList : public PDLResultList { -public: - ByteCodeRewriteResultList(unsigned maxNumResults) - : PDLResultList(maxNumResults) {} - - /// Return the list of PDL results. - MutableArrayRef<PDLValue> getResults() { return results; } - - /// Return the type ranges allocated by this list. - MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() { - return allocatedTypeRanges; - } - - /// Return the value ranges allocated by this list. - MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() { - return allocatedValueRanges; - } -}; } // namespace void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) { LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n"); - const PDLConstraintFunction &constraintFn = constraintFunctions[read()]; + ByteCodeField fun_idx = read(); SmallVector<PDLValue, 16> args; readList<PDLValue>(args); @@ -1422,8 +1439,29 @@ void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) { llvm::dbgs() << " * isNegated: " << isNegated << "\n"; llvm::interleaveComma(args, llvm::dbgs()); }); - // Invoke the constraint and jump to the proper destination. - selectJump(isNegated != succeeded(constraintFn(rewriter, args))); + + ByteCodeField numResults = read(); + const PDLRewriteFunction &constraintFn = constraintFunctions[fun_idx]; + ByteCodeRewriteResultList results(numResults); + LogicalResult rewriteResult = constraintFn(rewriter, results, args); + ArrayRef<PDLValue> constraintResults = results.getResults(); + LLVM_DEBUG({ + if (succeeded(rewriteResult)) { + llvm::dbgs() << " * Constraint succeeded\n"; + llvm::dbgs() << " * Results: "; + llvm::interleaveComma(constraintResults, llvm::dbgs()); + llvm::dbgs() << "\n"; + } else { + llvm::dbgs() << " * Constraint failed\n"; + } + }); + assert((failed(rewriteResult) || constraintResults.size() == numResults) && + "native PDL rewrite function succeeded but returned " + "unexpected number of results"); + processNativeFunResults(results, numResults, rewriteResult); + + // Depending on the constraint jump to the proper destination. + selectJump(isNegated != succeeded(rewriteResult)); } LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) { @@ -1445,16 +1483,39 @@ LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) { assert(results.getResults().size() == numResults && "native PDL rewrite function returned unexpected number of results"); - // Store the results in the bytecode memory. - for (PDLValue &result : results.getResults()) { - LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n"); + processNativeFunResults(results, numResults, rewriteResult); -// In debug mode we also verify the expected kind of the result. -#ifndef NDEBUG - assert(result.getKind() == read<PDLValue::Kind>() && - "native PDL rewrite function returned an unexpected type of result"); -#endif + if (failed(rewriteResult)) { + LLVM_DEBUG(llvm::dbgs() << " - Failed"); + return failure(); + } + return success(); +} +void ByteCodeExecutor::processNativeFunResults( + ByteCodeRewriteResultList &results, unsigned numResults, + LogicalResult &rewriteResult) { + // Store the results in the bytecode memory or handle missing results on + // failure. + for (unsigned resultIdx = 0; resultIdx < numResults; resultIdx++) { + PDLValue::Kind resultKind = read<PDLValue::Kind>(); + + // Skip the according number of values on the buffer on failure and exit + // early as there are no results to process. + if (failed(rewriteResult)) { + if (resultKind == PDLValue::Kind::TypeRange || + resultKind == PDLValue::Kind::ValueRange) { + skip(2); + } else { + skip(1); + } + return; + } + PDLValue result = results.getResults()[resultIdx]; + LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n"); + assert(result.getKind() == resultKind && + "native PDL rewrite function returned an unexpected type of " + "result"); // If the result is a range, we need to copy it over to the bytecodes // range memory. if (std::optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) { @@ -1476,13 +1537,6 @@ LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) { allocatedTypeRangeMemory.push_back(std::move(it)); for (auto &it : results.getAllocatedValueRanges()) allocatedValueRangeMemory.push_back(std::move(it)); - - // Process the result of the rewrite. - if (failed(rewriteResult)) { - LLVM_DEBUG(llvm::dbgs() << " - Failed"); - return failure(); - } - return success(); } void ByteCodeExecutor::executeAreEqual() { diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp index 97ff8bd..206781e 100644 --- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp +++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp @@ -1362,12 +1362,6 @@ FailureOr<T *> Parser::parseUserNativeConstraintOrRewriteDecl( if (failed(parseToken(Token::semicolon, "expected `;` after native declaration"))) return failure(); - // TODO: PDL should be able to support constraint results in certain - // situations, we should revise this. - if (std::is_same<ast::UserConstraintDecl, T>::value && !results.empty()) { - return emitError( - "native Constraints currently do not support returning results"); - } return T::createNative(ctx, name, arguments, results, optCodeStr, resultType); } diff --git a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir index 02bb831..92afb76 100644 --- a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir +++ b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir @@ -79,6 +79,57 @@ module @constraints { // ----- +// CHECK-LABEL: module @constraint_with_result +module @constraint_with_result { + // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation) + // CHECK: %[[ATTR:.*]] = pdl_interp.apply_constraint "check_op_and_get_attr_constr"(%[[ROOT]] + // CHECK: pdl_interp.record_match @rewriters::@pdl_generated_rewriter(%[[ROOT]], %[[ATTR]] : !pdl.operation, !pdl.attribute) + pdl.pattern : benefit(1) { + %root = operation + %attr = pdl.apply_native_constraint "check_op_and_get_attr_constr"(%root : !pdl.operation) : !pdl.attribute + rewrite %root with "rewriter"(%attr : !pdl.attribute) + } +} + +// ----- + +// CHECK-LABEL: module @constraint_with_unused_result +module @constraint_with_unused_result { + // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation) + // CHECK: %[[ATTR:.*]] = pdl_interp.apply_constraint "check_op_and_get_attr_constr"(%[[ROOT]] + // CHECK: pdl_interp.record_match @rewriters::@pdl_generated_rewriter(%[[ROOT]] : !pdl.operation) + pdl.pattern : benefit(1) { + %root = operation + %attr = pdl.apply_native_constraint "check_op_and_get_attr_constr"(%root : !pdl.operation) : !pdl.attribute + rewrite %root with "rewriter" + } +} + +// ----- + +// CHECK-LABEL: module @constraint_with_result_multiple +module @constraint_with_result_multiple { + // check that native constraints work as expected even when multiple identical constraints are fused + + // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation) + // CHECK: %[[ATTR:.*]] = pdl_interp.apply_constraint "check_op_and_get_attr_constr"(%[[ROOT]] + // CHECK-NOT: pdl_interp.apply_constraint "check_op_and_get_attr_constr" + // CHECK: pdl_interp.record_match @rewriters::@pdl_generated_rewriter_0(%[[ROOT]], %[[ATTR]] : !pdl.operation, !pdl.attribute) + // CHECK: pdl_interp.record_match @rewriters::@pdl_generated_rewriter(%[[ROOT]], %[[ATTR]] : !pdl.operation, !pdl.attribute) + pdl.pattern : benefit(1) { + %root = operation + %attr = pdl.apply_native_constraint "check_op_and_get_attr_constr"(%root : !pdl.operation) : !pdl.attribute + rewrite %root with "rewriter"(%attr : !pdl.attribute) + } + pdl.pattern : benefit(1) { + %root = operation + %attr = pdl.apply_native_constraint "check_op_and_get_attr_constr"(%root : !pdl.operation) : !pdl.attribute + rewrite %root with "rewriter"(%attr : !pdl.attribute) + } +} + +// ----- + // CHECK-LABEL: module @negated_constraint module @negated_constraint { // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation) diff --git a/mlir/test/Conversion/PDLToPDLInterp/use-constraint-result.mlir b/mlir/test/Conversion/PDLToPDLInterp/use-constraint-result.mlir new file mode 100644 index 0000000..4511baf --- /dev/null +++ b/mlir/test/Conversion/PDLToPDLInterp/use-constraint-result.mlir @@ -0,0 +1,77 @@ +// RUN: mlir-opt -split-input-file -convert-pdl-to-pdl-interp %s | FileCheck %s + +// Ensuse that the dependency between add & less +// causes them to be in the correct order. +// CHECK-LABEL: matcher +// CHECK: apply_constraint "return_attr_constraint" +// CHECK: apply_constraint "use_attr_constraint" + +module { + pdl.pattern : benefit(1) { + %0 = attribute + %1 = types + %2 = operation "tosa.mul" {"shift" = %0} -> (%1 : !pdl.range<type>) + %3 = attribute = 0 : i32 + %4 = attribute = 1 : i32 + %5 = apply_native_constraint "return_attr_constraint"(%3, %4 : !pdl.attribute, !pdl.attribute) : !pdl.attribute + apply_native_constraint "use_attr_constraint"(%0, %5 : !pdl.attribute, !pdl.attribute) + rewrite %2 with "rewriter" + } +} + +// ----- + +// CHECK-LABEL: matcher +// CHECK: %[[ATTR:.*]] = pdl_interp.get_attribute "attr" of +// CHECK: %[[CONSTRAINT:.*]] = pdl_interp.apply_constraint "return_attr_constraint" +// CHECK: pdl_interp.are_equal %[[ATTR:.*]], %[[CONSTRAINT:.*]] + +pdl.pattern : benefit(1) { + %inputOp = operation + %result = result 0 of %inputOp + %attr = pdl.apply_native_constraint "return_attr_constraint"(%inputOp : !pdl.operation) : !pdl.attribute + %root = operation(%result : !pdl.value) {"attr" = %attr} + rewrite %root with "rewriter"(%attr : !pdl.attribute) +} + +// ----- + +// CHECK-LABEL: matcher +// CHECK: %[[CONSTRAINT:.*]] = pdl_interp.apply_constraint "return_value_constr" +// CHECK: %[[VALUE:.*]] = pdl_interp.get_operand 0 +// CHECK: pdl_interp.are_equal %[[VALUE:.*]], %[[CONSTRAINT:.*]] + +pdl.pattern : benefit(1) { + %attr = attribute = 10 + %value = pdl.apply_native_constraint "return_value_constr"(%attr: !pdl.attribute) : !pdl.value + %root = operation(%value : !pdl.value) + rewrite %root with "rewriter" +} + +// ----- + +// CHECK-LABEL: matcher +// CHECK: %[[CONSTRAINT:.*]] = pdl_interp.apply_constraint "return_type_constr" +// CHECK: %[[TYPE:.*]] = pdl_interp.get_value_type of +// CHECK: pdl_interp.are_equal %[[TYPE:.*]], %[[CONSTRAINT:.*]] + +pdl.pattern : benefit(1) { + %attr = attribute = 10 + %type = pdl.apply_native_constraint "return_type_constr"(%attr: !pdl.attribute) : !pdl.type + %root = operation -> (%type : !pdl.type) + rewrite %root with "rewriter" +} + +// ----- + +// CHECK-LABEL: matcher +// CHECK: %[[CONSTRAINT:.*]] = pdl_interp.apply_constraint "return_type_range_constr" +// CHECK: %[[TYPE:.*]] = pdl_interp.get_value_type of +// CHECK: pdl_interp.are_equal %[[TYPE:.*]], %[[CONSTRAINT:.*]] + +pdl.pattern : benefit(1) { + %attr = attribute = 10 + %types = pdl.apply_native_constraint "return_type_range_constr"(%attr: !pdl.attribute) : !pdl.range<type> + %root = operation -> (%types : !pdl.range<type>) + rewrite %root with "rewriter" +} diff --git a/mlir/test/Dialect/PDL/ops.mlir b/mlir/test/Dialect/PDL/ops.mlir index 6e6da5c..20e40de 100644 --- a/mlir/test/Dialect/PDL/ops.mlir +++ b/mlir/test/Dialect/PDL/ops.mlir @@ -134,6 +134,24 @@ pdl.pattern @apply_rewrite_with_no_results : benefit(1) { // ----- +pdl.pattern @apply_constraint_with_no_results : benefit(1) { + %root = operation + apply_native_constraint "NativeConstraint"(%root : !pdl.operation) + rewrite %root with "rewriter" +} + +// ----- + +pdl.pattern @apply_constraint_with_results : benefit(1) { + %root = operation + %attr = apply_native_constraint "NativeConstraint"(%root : !pdl.operation) : !pdl.attribute + rewrite %root { + apply_native_rewrite "NativeRewrite"(%attr : !pdl.attribute) + } +} + +// ----- + pdl.pattern @attribute_with_dict : benefit(1) { %root = operation rewrite %root { diff --git a/mlir/test/Rewrite/pdl-bytecode.mlir b/mlir/test/Rewrite/pdl-bytecode.mlir index ae61c1a..f8e4f2e 100644 --- a/mlir/test/Rewrite/pdl-bytecode.mlir +++ b/mlir/test/Rewrite/pdl-bytecode.mlir @@ -109,6 +109,74 @@ module @ir attributes { test.apply_constraint_3 } { // ----- +// Test returning a type from a native constraint. +module @patterns { + pdl_interp.func @matcher(%root : !pdl.operation) { + pdl_interp.check_operation_name of %root is "test.success_op" -> ^pat, ^end + + ^pat: + %new_type = pdl_interp.apply_constraint "op_constr_return_type"(%root : !pdl.operation) : !pdl.type -> ^pat2, ^end + + ^pat2: + pdl_interp.record_match @rewriters::@success(%root, %new_type : !pdl.operation, !pdl.type) : benefit(1), loc([%root]) -> ^end + + ^end: + pdl_interp.finalize + } + + module @rewriters { + pdl_interp.func @success(%root : !pdl.operation, %new_type : !pdl.type) { + %op = pdl_interp.create_operation "test.replaced_by_pattern" -> (%new_type : !pdl.type) + pdl_interp.erase %root + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.apply_constraint_4 +// CHECK-NOT: "test.replaced_by_pattern" +// CHECK: "test.replaced_by_pattern"() : () -> f32 +module @ir attributes { test.apply_constraint_4 } { + "test.failure_op"() : () -> () + "test.success_op"() : () -> () +} + +// ----- + +// Test success and failure cases of native constraints with pdl.range results. +module @patterns { + pdl_interp.func @matcher(%root : !pdl.operation) { + pdl_interp.check_operation_name of %root is "test.success_op" -> ^pat, ^end + + ^pat: + %num_results = pdl_interp.create_attribute 2 : i32 + %types = pdl_interp.apply_constraint "op_constr_return_type_range"(%root, %num_results : !pdl.operation, !pdl.attribute) : !pdl.range<type> -> ^pat1, ^end + + ^pat1: + pdl_interp.record_match @rewriters::@success(%root, %types : !pdl.operation, !pdl.range<type>) : benefit(1), loc([%root]) -> ^end + + ^end: + pdl_interp.finalize + } + + module @rewriters { + pdl_interp.func @success(%root : !pdl.operation, %types : !pdl.range<type>) { + %op = pdl_interp.create_operation "test.replaced_by_pattern" -> (%types : !pdl.range<type>) + pdl_interp.erase %root + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.apply_constraint_5 +// CHECK-NOT: "test.replaced_by_pattern" +// CHECK: "test.replaced_by_pattern"() : () -> (f32, f32) +module @ir attributes { test.apply_constraint_5 } { + "test.failure_op"() : () -> () + "test.success_op"() : () -> () +} + +// ----- //===----------------------------------------------------------------------===// // pdl_interp::ApplyRewriteOp diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp index 50caf8f9..b9424e0 100644 --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -887,7 +887,7 @@ public: #include "TestTransformDialectExtensionTypes.cpp.inc" >(); - auto verboseConstraint = [](PatternRewriter &rewriter, + auto verboseConstraint = [](PatternRewriter &rewriter, PDLResultList &, ArrayRef<PDLValue> pdlValues) { for (const PDLValue &pdlValue : pdlValues) { if (Operation *op = pdlValue.dyn_cast<Operation *>()) { diff --git a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp index daa1c37..56af3c1 100644 --- a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp +++ b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp @@ -30,6 +30,50 @@ static LogicalResult customMultiEntityVariadicConstraint( return success(); } +// Custom constraint that returns a value if the op is named test.success_op +static LogicalResult customValueResultConstraint(PatternRewriter &rewriter, + PDLResultList &results, + ArrayRef<PDLValue> args) { + auto *op = args[0].cast<Operation *>(); + if (op->getName().getStringRef() == "test.success_op") { + StringAttr customAttr = rewriter.getStringAttr("test.success"); + results.push_back(customAttr); + return success(); + } + return failure(); +} + +// Custom constraint that returns a type if the op is named test.success_op +static LogicalResult customTypeResultConstraint(PatternRewriter &rewriter, + PDLResultList &results, + ArrayRef<PDLValue> args) { + auto *op = args[0].cast<Operation *>(); + if (op->getName().getStringRef() == "test.success_op") { + results.push_back(rewriter.getF32Type()); + return success(); + } + return failure(); +} + +// Custom constraint that returns a type range of variable length if the op is +// named test.success_op +static LogicalResult customTypeRangeResultConstraint(PatternRewriter &rewriter, + PDLResultList &results, + ArrayRef<PDLValue> args) { + auto *op = args[0].cast<Operation *>(); + int numTypes = args[1].cast<Attribute>().cast<IntegerAttr>().getInt(); + + if (op->getName().getStringRef() == "test.success_op") { + SmallVector<Type> types; + for (int i = 0; i < numTypes; i++) { + types.push_back(rewriter.getF32Type()); + } + results.push_back(TypeRange(types)); + return success(); + } + return failure(); +} + // Custom creator invoked from PDL. static Operation *customCreate(PatternRewriter &rewriter, Operation *op) { return rewriter.create(OperationState(op->getLoc(), "test.success")); @@ -102,6 +146,12 @@ struct TestPDLByteCodePass customMultiEntityConstraint); pdlPattern.registerConstraintFunction("multi_entity_var_constraint", customMultiEntityVariadicConstraint); + pdlPattern.registerConstraintFunction("op_constr_return_attr", + customValueResultConstraint); + pdlPattern.registerConstraintFunction("op_constr_return_type", + customTypeResultConstraint); + pdlPattern.registerConstraintFunction("op_constr_return_type_range", + customTypeRangeResultConstraint); pdlPattern.registerRewriteFunction("creator", customCreate); pdlPattern.registerRewriteFunction("var_creator", customVariadicResultCreate); diff --git a/mlir/test/mlir-pdll/Parser/constraint-failure.pdll b/mlir/test/mlir-pdll/Parser/constraint-failure.pdll index 18877b4..48747d3 100644 --- a/mlir/test/mlir-pdll/Parser/constraint-failure.pdll +++ b/mlir/test/mlir-pdll/Parser/constraint-failure.pdll @@ -158,8 +158,3 @@ Pattern { // CHECK: expected `;` after native declaration Constraint Foo() [{}] - -// ----- - -// CHECK: native Constraints currently do not support returning results -Constraint Foo() -> Op; diff --git a/mlir/test/mlir-pdll/Parser/constraint.pdll b/mlir/test/mlir-pdll/Parser/constraint.pdll index 1c0a015..e2a52ff 100644 --- a/mlir/test/mlir-pdll/Parser/constraint.pdll +++ b/mlir/test/mlir-pdll/Parser/constraint.pdll @@ -12,6 +12,14 @@ Constraint Foo() [{ /* Native Code */ }]; // ----- +// Test that native constraints support returning results. + +// CHECK: Module +// CHECK: `-UserConstraintDecl {{.*}} Name<Foo> ResultType<Attr> +Constraint Foo() -> Attr; + +// ----- + // CHECK: Module // CHECK: `-UserConstraintDecl {{.*}} Name<Foo> ResultType<Value> // CHECK: `Inputs` diff --git a/mlir/test/python/dialects/pdl_ops.py b/mlir/test/python/dialects/pdl_ops.py index 0d364f9..95cb25c 100644 --- a/mlir/test/python/dialects/pdl_ops.py +++ b/mlir/test/python/dialects/pdl_ops.py @@ -298,6 +298,6 @@ def test_apply_native_constraint(): pattern = PatternOp(1) with InsertionPoint(pattern.body): resultType = TypeOp() - ApplyNativeConstraintOp("typeConstraint", args=[resultType]) + ApplyNativeConstraintOp([], "typeConstraint", args=[resultType]) root = OperationOp(types=[resultType]) RewriteOp(root, name="rewrite") |