aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mlir/include/mlir/Dialect/PDL/IR/PDLOps.td9
-rw-r--r--mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td8
-rw-r--r--mlir/include/mlir/IR/PDLPatternMatch.h.inc18
-rw-r--r--mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp41
-rw-r--r--mlir/lib/Conversion/PDLToPDLInterp/Predicate.h58
-rw-r--r--mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp82
-rw-r--r--mlir/lib/Dialect/PDL/IR/PDL.cpp6
-rw-r--r--mlir/lib/Rewrite/ByteCode.cpp144
-rw-r--r--mlir/lib/Tools/PDLL/Parser/Parser.cpp6
-rw-r--r--mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir51
-rw-r--r--mlir/test/Conversion/PDLToPDLInterp/use-constraint-result.mlir77
-rw-r--r--mlir/test/Dialect/PDL/ops.mlir18
-rw-r--r--mlir/test/Rewrite/pdl-bytecode.mlir68
-rw-r--r--mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp2
-rw-r--r--mlir/test/lib/Rewrite/TestPDLByteCode.cpp50
-rw-r--r--mlir/test/mlir-pdll/Parser/constraint-failure.pdll5
-rw-r--r--mlir/test/mlir-pdll/Parser/constraint.pdll8
-rw-r--r--mlir/test/python/dialects/pdl_ops.py2
18 files changed, 555 insertions, 98 deletions
diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
index 4e9ebcc..1e108c3 100644
--- a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
+++ b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td
@@ -35,20 +35,25 @@ def PDL_ApplyNativeConstraintOp
let description = [{
`pdl.apply_native_constraint` operations apply a native C++ constraint, that
has been registered externally with the consumer of PDL, to a given set of
- entities.
+ entities and optionally return a number of values.
Example:
```mlir
// Apply `myConstraint` to the entities defined by `input`, `attr`, and `op`.
pdl.apply_native_constraint "myConstraint"(%input, %attr, %op : !pdl.value, !pdl.attribute, !pdl.operation)
+ // Apply constraint `with_result` to `root`. This constraint returns an attribute.
+ %attr = pdl.apply_native_constraint "with_result"(%root : !pdl.operation) : !pdl.attribute
```
}];
let arguments = (ins StrAttr:$name,
Variadic<PDL_AnyType>:$args,
DefaultValuedAttr<BoolAttr, "false">:$isNegated);
- let assemblyFormat = "$name `(` $args `:` type($args) `)` attr-dict";
+ let results = (outs Variadic<PDL_AnyType>:$results);
+ let assemblyFormat = [{
+ $name `(` $args `:` type($args) `)` (`:` type($results)^ )? attr-dict
+ }];
let hasVerifier = 1;
}
diff --git a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
index 48f625b..901acc0 100644
--- a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
+++ b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
@@ -88,7 +88,9 @@ def PDLInterp_ApplyConstraintOp : PDLInterp_PredicateOp<"apply_constraint"> {
let description = [{
`pdl_interp.apply_constraint` operations apply a generic constraint, that
has been registered with the interpreter, with a given set of positional
- values. On success, this operation branches to the true destination,
+ values.
+ The constraint function may return any number of results.
+ On success, this operation branches to the true destination,
otherwise the false destination is taken. This behavior can be reversed
by setting the attribute `isNegated` to true.
@@ -104,8 +106,10 @@ def PDLInterp_ApplyConstraintOp : PDLInterp_PredicateOp<"apply_constraint"> {
let arguments = (ins StrAttr:$name,
Variadic<PDL_AnyType>:$args,
DefaultValuedAttr<BoolAttr, "false">:$isNegated);
+ let results = (outs Variadic<PDL_AnyType>:$results);
let assemblyFormat = [{
- $name `(` $args `:` type($args) `)` attr-dict `->` successors
+ $name `(` $args `:` type($args) `)` (`:` type($results)^)? attr-dict
+ `->` successors
}];
}
diff --git a/mlir/include/mlir/IR/PDLPatternMatch.h.inc b/mlir/include/mlir/IR/PDLPatternMatch.h.inc
index a215da8..66286ed 100644
--- a/mlir/include/mlir/IR/PDLPatternMatch.h.inc
+++ b/mlir/include/mlir/IR/PDLPatternMatch.h.inc
@@ -318,8 +318,9 @@ protected:
/// A generic PDL pattern constraint function. This function applies a
/// constraint to a given set of opaque PDLValue entities. Returns success if
/// the constraint successfully held, failure otherwise.
-using PDLConstraintFunction =
- std::function<LogicalResult(PatternRewriter &, ArrayRef<PDLValue>)>;
+using PDLConstraintFunction = std::function<LogicalResult(
+ PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
+
/// A native PDL rewrite function. This function performs a rewrite on the
/// given set of values. Any results from this rewrite that should be passed
/// back to PDL should be added to the provided result list. This method is only
@@ -726,7 +727,7 @@ std::enable_if_t<
PDLConstraintFunction>
buildConstraintFn(ConstraintFnT &&constraintFn) {
return [constraintFn = std::forward<ConstraintFnT>(constraintFn)](
- PatternRewriter &rewriter,
+ PatternRewriter &rewriter, PDLResultList &,
ArrayRef<PDLValue> values) -> LogicalResult {
auto argIndices = std::make_index_sequence<
llvm::function_traits<ConstraintFnT>::num_args - 1>();
@@ -842,10 +843,13 @@ public:
/// Register a constraint function with PDL. A constraint function may be
/// specified in one of two ways:
///
- /// * `LogicalResult (PatternRewriter &, ArrayRef<PDLValue>)`
+ /// * `LogicalResult (PatternRewriter &,
+ /// PDLResultList &,
+ /// ArrayRef<PDLValue>)`
///
/// In this overload the arguments of the constraint function are passed via
- /// the low-level PDLValue form.
+ /// the low-level PDLValue form, and the results are manually appended to
+ /// the given result list.
///
/// * `LogicalResult (PatternRewriter &, ValueTs... values)`
///
@@ -960,8 +964,8 @@ public:
}
};
class PDLResultList {};
-using PDLConstraintFunction =
- std::function<LogicalResult(PatternRewriter &, ArrayRef<PDLValue>)>;
+using PDLConstraintFunction = std::function<LogicalResult(
+ PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
using PDLRewriteFunction = std::function<LogicalResult(
PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
index e911631..b00cd0d 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
+++ b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
@@ -50,7 +50,8 @@ private:
/// Generate interpreter operations for the tree rooted at the given matcher
/// node, in the specified region.
- Block *generateMatcher(MatcherNode &node, Region &region);
+ Block *generateMatcher(MatcherNode &node, Region &region,
+ Block *block = nullptr);
/// Get or create an access to the provided positional value in the current
/// block. This operation may mutate the provided block pointer if nested
@@ -148,6 +149,10 @@ private:
/// A mapping between pattern operations and the corresponding configuration
/// set.
DenseMap<Operation *, PDLPatternConfigSet *> *configMap;
+
+ /// A mapping from a constraint question to the ApplyConstraintOp
+ /// that implements it.
+ DenseMap<ConstraintQuestion *, pdl_interp::ApplyConstraintOp> constraintOpMap;
};
} // namespace
@@ -182,9 +187,11 @@ void PatternLowering::lower(ModuleOp module) {
firstMatcherBlock->erase();
}
-Block *PatternLowering::generateMatcher(MatcherNode &node, Region &region) {
+Block *PatternLowering::generateMatcher(MatcherNode &node, Region &region,
+ Block *block) {
// Push a new scope for the values used by this matcher.
- Block *block = &region.emplaceBlock();
+ if (!block)
+ block = &region.emplaceBlock();
ValueMapScope scope(values);
// If this is the return node, simply insert the corresponding interpreter
@@ -364,6 +371,15 @@ Value PatternLowering::getValueAt(Block *&currentBlock, Position *pos) {
loc, cast<ArrayAttr>(rawTypeAttr));
break;
}
+ case Predicates::ConstraintResultPos: {
+ // Due to the order of traversal, the ApplyConstraintOp has already been
+ // created and we can find it in constraintOpMap.
+ auto *constrResPos = cast<ConstraintPosition>(pos);
+ auto i = constraintOpMap.find(constrResPos->getQuestion());
+ assert(i != constraintOpMap.end());
+ value = i->second->getResult(constrResPos->getIndex());
+ break;
+ }
default:
llvm_unreachable("Generating unknown Position getter");
break;
@@ -390,12 +406,11 @@ void PatternLowering::generate(BoolNode *boolNode, Block *&currentBlock,
args.push_back(getValueAt(currentBlock, position));
}
- // Generate the matcher in the current (potentially nested) region
- // and get the failure successor.
- Block *success = generateMatcher(*boolNode->getSuccessNode(), *region);
+ // Generate a new block as success successor and get the failure successor.
+ Block *success = &region->emplaceBlock();
Block *failure = failureBlockStack.back();
- // Finally, create the predicate.
+ // Create the predicate.
builder.setInsertionPointToEnd(currentBlock);
Predicates::Kind kind = question->getKind();
switch (kind) {
@@ -447,14 +462,20 @@ void PatternLowering::generate(BoolNode *boolNode, Block *&currentBlock,
}
case Predicates::ConstraintQuestion: {
auto *cstQuestion = cast<ConstraintQuestion>(question);
- builder.create<pdl_interp::ApplyConstraintOp>(
- loc, cstQuestion->getName(), args, cstQuestion->getIsNegated(), success,
- failure);
+ auto applyConstraintOp = builder.create<pdl_interp::ApplyConstraintOp>(
+ loc, cstQuestion->getResultTypes(), cstQuestion->getName(), args,
+ cstQuestion->getIsNegated(), success, failure);
+
+ constraintOpMap.insert({cstQuestion, applyConstraintOp});
break;
}
default:
llvm_unreachable("Generating unknown Predicate operation");
}
+
+ // Generate the matcher in the current (potentially nested) region.
+ // This might use the results of the current predicate.
+ generateMatcher(*boolNode->getSuccessNode(), *region, success);
}
template <typename OpT, typename PredT, typename ValT = typename PredT::KeyTy>
diff --git a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h
index 2c9b63f..5ad2c47 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h
+++ b/mlir/lib/Conversion/PDLToPDLInterp/Predicate.h
@@ -47,6 +47,7 @@ enum Kind : unsigned {
OperandPos,
OperandGroupPos,
AttributePos,
+ ConstraintResultPos,
ResultPos,
ResultGroupPos,
TypePos,
@@ -280,6 +281,28 @@ struct OperationPosition : public PredicateBase<OperationPosition, Position,
};
//===----------------------------------------------------------------------===//
+// ConstraintPosition
+
+struct ConstraintQuestion;
+
+/// A position describing the result of a native constraint. It saves the
+/// corresponding ConstraintQuestion and result index to enable referring
+/// back to them
+struct ConstraintPosition
+ : public PredicateBase<ConstraintPosition, Position,
+ std::pair<ConstraintQuestion *, unsigned>,
+ Predicates::ConstraintResultPos> {
+ using PredicateBase::PredicateBase;
+
+ /// Returns the ConstraintQuestion to enable keeping track of the native
+ /// constraint this position stems from.
+ ConstraintQuestion *getQuestion() const { return key.first; }
+
+ // Returns the result index of this position
+ unsigned getIndex() const { return key.second; }
+};
+
+//===----------------------------------------------------------------------===//
// ResultPosition
/// A position describing a result of an operation.
@@ -447,11 +470,13 @@ struct AttributeQuestion
: public PredicateBase<AttributeQuestion, Qualifier, void,
Predicates::AttributeQuestion> {};
-/// Apply a parameterized constraint to multiple position values.
+/// Apply a parameterized constraint to multiple position values and possibly
+/// produce results.
struct ConstraintQuestion
- : public PredicateBase<ConstraintQuestion, Qualifier,
- std::tuple<StringRef, ArrayRef<Position *>, bool>,
- Predicates::ConstraintQuestion> {
+ : public PredicateBase<
+ ConstraintQuestion, Qualifier,
+ std::tuple<StringRef, ArrayRef<Position *>, ArrayRef<Type>, bool>,
+ Predicates::ConstraintQuestion> {
using Base::Base;
/// Return the name of the constraint.
@@ -460,15 +485,19 @@ struct ConstraintQuestion
/// Return the arguments of the constraint.
ArrayRef<Position *> getArgs() const { return std::get<1>(key); }
+ /// Return the result types of the constraint.
+ ArrayRef<Type> getResultTypes() const { return std::get<2>(key); }
+
/// Return the negation status of the constraint.
- bool getIsNegated() const { return std::get<2>(key); }
+ bool getIsNegated() const { return std::get<3>(key); }
/// Construct an instance with the given storage allocator.
static ConstraintQuestion *construct(StorageUniquer::StorageAllocator &alloc,
KeyTy key) {
return Base::construct(alloc, KeyTy{alloc.copyInto(std::get<0>(key)),
alloc.copyInto(std::get<1>(key)),
- std::get<2>(key)});
+ alloc.copyInto(std::get<2>(key)),
+ std::get<3>(key)});
}
/// Returns a hash suitable for the given keytype.
@@ -526,6 +555,7 @@ public:
// Register the types of Positions with the uniquer.
registerParametricStorageType<AttributePosition>();
registerParametricStorageType<AttributeLiteralPosition>();
+ registerParametricStorageType<ConstraintPosition>();
registerParametricStorageType<ForEachPosition>();
registerParametricStorageType<OperandPosition>();
registerParametricStorageType<OperandGroupPosition>();
@@ -588,6 +618,12 @@ public:
return OperationPosition::get(uniquer, p);
}
+ // Returns a position for a new value created by a constraint.
+ ConstraintPosition *getConstraintPosition(ConstraintQuestion *q,
+ unsigned index) {
+ return ConstraintPosition::get(uniquer, std::make_pair(q, index));
+ }
+
/// Returns an attribute position for an attribute of the given operation.
Position *getAttribute(OperationPosition *p, StringRef name) {
return AttributePosition::get(uniquer, p, StringAttr::get(ctx, name));
@@ -673,11 +709,11 @@ public:
}
/// Create a predicate that applies a generic constraint.
- Predicate getConstraint(StringRef name, ArrayRef<Position *> pos,
- bool isNegated) {
- return {
- ConstraintQuestion::get(uniquer, std::make_tuple(name, pos, isNegated)),
- TrueAnswer::get(uniquer)};
+ Predicate getConstraint(StringRef name, ArrayRef<Position *> args,
+ ArrayRef<Type> resultTypes, bool isNegated) {
+ return {ConstraintQuestion::get(
+ uniquer, std::make_tuple(name, args, resultTypes, isNegated)),
+ TrueAnswer::get(uniquer)};
}
/// Create a predicate comparing a value with null.
diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
index a9c3b0a..f3d0e08 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
+++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp
@@ -15,6 +15,7 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include <queue>
@@ -49,14 +50,15 @@ static void getTreePredicates(std::vector<PositionalPredicate> &predList,
DenseMap<Value, Position *> &inputs,
AttributePosition *pos) {
assert(isa<pdl::AttributeType>(val.getType()) && "expected attribute type");
- pdl::AttributeOp attr = cast<pdl::AttributeOp>(val.getDefiningOp());
predList.emplace_back(pos, builder.getIsNotNull());
- // If the attribute has a type or value, add a constraint.
- if (Value type = attr.getValueType())
- getTreePredicates(predList, type, builder, inputs, builder.getType(pos));
- else if (Attribute value = attr.getValueAttr())
- predList.emplace_back(pos, builder.getAttributeConstraint(value));
+ if (auto attr = dyn_cast<pdl::AttributeOp>(val.getDefiningOp())) {
+ // If the attribute has a type or value, add a constraint.
+ if (Value type = attr.getValueType())
+ getTreePredicates(predList, type, builder, inputs, builder.getType(pos));
+ else if (Attribute value = attr.getValueAttr())
+ predList.emplace_back(pos, builder.getAttributeConstraint(value));
+ }
}
/// Collect all of the predicates for the given operand position.
@@ -272,8 +274,25 @@ static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op,
// Push the constraint to the furthest position.
Position *pos = *std::max_element(allPositions.begin(), allPositions.end(),
comparePosDepth);
- PredicateBuilder::Predicate pred =
- builder.getConstraint(op.getName(), allPositions, op.getIsNegated());
+ ResultRange results = op.getResults();
+ PredicateBuilder::Predicate pred = builder.getConstraint(
+ op.getName(), allPositions, SmallVector<Type>(results.getTypes()),
+ op.getIsNegated());
+
+ // For each result register a position so it can be used later
+ for (auto [i, result] : llvm::enumerate(results)) {
+ ConstraintQuestion *q = cast<ConstraintQuestion>(pred.first);
+ ConstraintPosition *pos = builder.getConstraintPosition(q, i);
+ auto [it, inserted] = inputs.insert({result, pos});
+ // If this is an input value that has been visited in the tree, add a
+ // constraint to ensure that both instances refer to the same value.
+ if (!inserted) {
+ auto minMaxPositions =
+ std::minmax<Position *>(pos, it->second, comparePosDepth);
+ predList.emplace_back(minMaxPositions.second,
+ builder.getEqualTo(minMaxPositions.first));
+ }
+ }
predList.emplace_back(pos, pred);
}
@@ -875,6 +894,49 @@ static void insertExitNode(std::unique_ptr<MatcherNode> *root) {
*root = std::make_unique<ExitNode>();
}
+/// Sorts the range begin/end with the partial order given by cmp.
+template <typename Iterator, typename Compare>
+static void stableTopologicalSort(Iterator begin, Iterator end, Compare cmp) {
+ while (begin != end) {
+ // Cannot compute sortBeforeOthers in the predicate of stable_partition
+ // because stable_partition will not keep the [begin, end) range intact
+ // while it runs.
+ llvm::SmallPtrSet<typename Iterator::value_type, 16> sortBeforeOthers;
+ for (auto i = begin; i != end; ++i) {
+ if (std::none_of(begin, end, [&](auto const &b) { return cmp(b, *i); }))
+ sortBeforeOthers.insert(*i);
+ }
+
+ auto const next = std::stable_partition(begin, end, [&](auto const &a) {
+ return sortBeforeOthers.contains(a);
+ });
+ assert(next != begin && "not a partial ordering");
+ begin = next;
+ }
+}
+
+/// Returns true if 'b' depends on a result of 'a'.
+static bool dependsOn(OrderedPredicate *a, OrderedPredicate *b) {
+ auto *cqa = dyn_cast<ConstraintQuestion>(a->question);
+ if (!cqa)
+ return false;
+
+ auto positionDependsOnA = [&](Position *p) {
+ auto *cp = dyn_cast<ConstraintPosition>(p);
+ return cp && cp->getQuestion() == cqa;
+ };
+
+ if (auto *cqb = dyn_cast<ConstraintQuestion>(b->question)) {
+ // Does any argument of b use a?
+ return llvm::any_of(cqb->getArgs(), positionDependsOnA);
+ }
+ if (auto *equalTo = dyn_cast<EqualToQuestion>(b->question)) {
+ return positionDependsOnA(b->position) ||
+ positionDependsOnA(equalTo->getValue());
+ }
+ return positionDependsOnA(b->position);
+}
+
/// Given a module containing PDL pattern operations, generate a matcher tree
/// using the patterns within the given module and return the root matcher node.
std::unique_ptr<MatcherNode>
@@ -955,6 +1017,10 @@ MatcherNode::generateMatcherTree(ModuleOp module, PredicateBuilder &builder,
return *lhs < *rhs;
});
+ // Mostly keep the now established order, but also ensure that
+ // ConstraintQuestions come after the results they use.
+ stableTopologicalSort(ordered.begin(), ordered.end(), dependsOn);
+
// Build the matchers for each of the pattern predicate lists.
std::unique_ptr<MatcherNode> root;
for (OrderedPredicateList &list : lists)
diff --git a/mlir/lib/Dialect/PDL/IR/PDL.cpp b/mlir/lib/Dialect/PDL/IR/PDL.cpp
index d5f3467..428b19f 100644
--- a/mlir/lib/Dialect/PDL/IR/PDL.cpp
+++ b/mlir/lib/Dialect/PDL/IR/PDL.cpp
@@ -94,6 +94,12 @@ static void visit(Operation *op, DenseSet<Operation *> &visited) {
LogicalResult ApplyNativeConstraintOp::verify() {
if (getNumOperands() == 0)
return emitOpError("expected at least one argument");
+ if (llvm::any_of(getResults(), [](OpResult result) {
+ return isa<OperationType>(result.getType());
+ })) {
+ return emitOpError(
+ "returning an operation from a constraint is not supported");
+ }
return success();
}
diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp
index 6e6992d..559ce7f 100644
--- a/mlir/lib/Rewrite/ByteCode.cpp
+++ b/mlir/lib/Rewrite/ByteCode.cpp
@@ -769,11 +769,25 @@ void Generator::generate(Operation *op, ByteCodeWriter &writer) {
void Generator::generate(pdl_interp::ApplyConstraintOp op,
ByteCodeWriter &writer) {
- assert(constraintToMemIndex.count(op.getName()) &&
- "expected index for constraint function");
+ // Constraints that should return a value have to be registered as rewrites.
+ // If a constraint and a rewrite of similar name are registered the
+ // constraint takes precedence
writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]);
writer.appendPDLValueList(op.getArgs());
writer.append(ByteCodeField(op.getIsNegated()));
+ ResultRange results = op.getResults();
+ writer.append(ByteCodeField(results.size()));
+ for (Value result : results) {
+ // We record the expected kind of the result, so that we can provide extra
+ // verification of the native rewrite function and handle the failure case
+ // of constraints accordingly.
+ writer.appendPDLValueKind(result);
+
+ // Range results also need to append the range storage index.
+ if (isa<pdl::RangeType>(result.getType()))
+ writer.append(getRangeStorageIndex(result));
+ writer.append(result);
+ }
writer.append(op.getSuccessors());
}
void Generator::generate(pdl_interp::ApplyRewriteOp op,
@@ -786,11 +800,9 @@ void Generator::generate(pdl_interp::ApplyRewriteOp op,
ResultRange results = op.getResults();
writer.append(ByteCodeField(results.size()));
for (Value result : results) {
- // In debug mode we also record the expected kind of the result, so that we
+ // We record the expected kind of the result, so that we
// can provide extra verification of the native rewrite function.
-#ifndef NDEBUG
writer.appendPDLValueKind(result);
-#endif
// Range results also need to append the range storage index.
if (isa<pdl::RangeType>(result.getType()))
@@ -1076,6 +1088,28 @@ void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const {
// ByteCode Execution
namespace {
+/// This class is an instantiation of the PDLResultList that provides access to
+/// the returned results. This API is not on `PDLResultList` to avoid
+/// overexposing access to information specific solely to the ByteCode.
+class ByteCodeRewriteResultList : public PDLResultList {
+public:
+ ByteCodeRewriteResultList(unsigned maxNumResults)
+ : PDLResultList(maxNumResults) {}
+
+ /// Return the list of PDL results.
+ MutableArrayRef<PDLValue> getResults() { return results; }
+
+ /// Return the type ranges allocated by this list.
+ MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() {
+ return allocatedTypeRanges;
+ }
+
+ /// Return the value ranges allocated by this list.
+ MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() {
+ return allocatedValueRanges;
+ }
+};
+
/// This class provides support for executing a bytecode stream.
class ByteCodeExecutor {
public:
@@ -1152,6 +1186,9 @@ private:
void executeSwitchResultCount();
void executeSwitchType();
void executeSwitchTypes();
+ void processNativeFunResults(ByteCodeRewriteResultList &results,
+ unsigned numResults,
+ LogicalResult &rewriteResult);
/// Pushes a code iterator to the stack.
void pushCodeIt(const ByteCodeField *it) { resumeCodeIt.push_back(it); }
@@ -1225,6 +1262,8 @@ private:
return T::getFromOpaquePointer(pointer);
}
+ void skip(size_t skipN) { curCodeIt += skipN; }
+
/// Jump to a specific successor based on a predicate value.
void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); }
/// Jump to a specific successor based on a destination index.
@@ -1381,33 +1420,11 @@ private:
ArrayRef<PDLConstraintFunction> constraintFunctions;
ArrayRef<PDLRewriteFunction> rewriteFunctions;
};
-
-/// This class is an instantiation of the PDLResultList that provides access to
-/// the returned results. This API is not on `PDLResultList` to avoid
-/// overexposing access to information specific solely to the ByteCode.
-class ByteCodeRewriteResultList : public PDLResultList {
-public:
- ByteCodeRewriteResultList(unsigned maxNumResults)
- : PDLResultList(maxNumResults) {}
-
- /// Return the list of PDL results.
- MutableArrayRef<PDLValue> getResults() { return results; }
-
- /// Return the type ranges allocated by this list.
- MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() {
- return allocatedTypeRanges;
- }
-
- /// Return the value ranges allocated by this list.
- MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() {
- return allocatedValueRanges;
- }
-};
} // namespace
void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n");
- const PDLConstraintFunction &constraintFn = constraintFunctions[read()];
+ ByteCodeField fun_idx = read();
SmallVector<PDLValue, 16> args;
readList<PDLValue>(args);
@@ -1422,8 +1439,29 @@ void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
llvm::dbgs() << " * isNegated: " << isNegated << "\n";
llvm::interleaveComma(args, llvm::dbgs());
});
- // Invoke the constraint and jump to the proper destination.
- selectJump(isNegated != succeeded(constraintFn(rewriter, args)));
+
+ ByteCodeField numResults = read();
+ const PDLRewriteFunction &constraintFn = constraintFunctions[fun_idx];
+ ByteCodeRewriteResultList results(numResults);
+ LogicalResult rewriteResult = constraintFn(rewriter, results, args);
+ ArrayRef<PDLValue> constraintResults = results.getResults();
+ LLVM_DEBUG({
+ if (succeeded(rewriteResult)) {
+ llvm::dbgs() << " * Constraint succeeded\n";
+ llvm::dbgs() << " * Results: ";
+ llvm::interleaveComma(constraintResults, llvm::dbgs());
+ llvm::dbgs() << "\n";
+ } else {
+ llvm::dbgs() << " * Constraint failed\n";
+ }
+ });
+ assert((failed(rewriteResult) || constraintResults.size() == numResults) &&
+ "native PDL rewrite function succeeded but returned "
+ "unexpected number of results");
+ processNativeFunResults(results, numResults, rewriteResult);
+
+ // Depending on the constraint jump to the proper destination.
+ selectJump(isNegated != succeeded(rewriteResult));
}
LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
@@ -1445,16 +1483,39 @@ LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
assert(results.getResults().size() == numResults &&
"native PDL rewrite function returned unexpected number of results");
- // Store the results in the bytecode memory.
- for (PDLValue &result : results.getResults()) {
- LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n");
+ processNativeFunResults(results, numResults, rewriteResult);
-// In debug mode we also verify the expected kind of the result.
-#ifndef NDEBUG
- assert(result.getKind() == read<PDLValue::Kind>() &&
- "native PDL rewrite function returned an unexpected type of result");
-#endif
+ if (failed(rewriteResult)) {
+ LLVM_DEBUG(llvm::dbgs() << " - Failed");
+ return failure();
+ }
+ return success();
+}
+void ByteCodeExecutor::processNativeFunResults(
+ ByteCodeRewriteResultList &results, unsigned numResults,
+ LogicalResult &rewriteResult) {
+ // Store the results in the bytecode memory or handle missing results on
+ // failure.
+ for (unsigned resultIdx = 0; resultIdx < numResults; resultIdx++) {
+ PDLValue::Kind resultKind = read<PDLValue::Kind>();
+
+ // Skip the according number of values on the buffer on failure and exit
+ // early as there are no results to process.
+ if (failed(rewriteResult)) {
+ if (resultKind == PDLValue::Kind::TypeRange ||
+ resultKind == PDLValue::Kind::ValueRange) {
+ skip(2);
+ } else {
+ skip(1);
+ }
+ return;
+ }
+ PDLValue result = results.getResults()[resultIdx];
+ LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n");
+ assert(result.getKind() == resultKind &&
+ "native PDL rewrite function returned an unexpected type of "
+ "result");
// If the result is a range, we need to copy it over to the bytecodes
// range memory.
if (std::optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) {
@@ -1476,13 +1537,6 @@ LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
allocatedTypeRangeMemory.push_back(std::move(it));
for (auto &it : results.getAllocatedValueRanges())
allocatedValueRangeMemory.push_back(std::move(it));
-
- // Process the result of the rewrite.
- if (failed(rewriteResult)) {
- LLVM_DEBUG(llvm::dbgs() << " - Failed");
- return failure();
- }
- return success();
}
void ByteCodeExecutor::executeAreEqual() {
diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp
index 97ff8bd..206781e 100644
--- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp
+++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp
@@ -1362,12 +1362,6 @@ FailureOr<T *> Parser::parseUserNativeConstraintOrRewriteDecl(
if (failed(parseToken(Token::semicolon,
"expected `;` after native declaration")))
return failure();
- // TODO: PDL should be able to support constraint results in certain
- // situations, we should revise this.
- if (std::is_same<ast::UserConstraintDecl, T>::value && !results.empty()) {
- return emitError(
- "native Constraints currently do not support returning results");
- }
return T::createNative(ctx, name, arguments, results, optCodeStr, resultType);
}
diff --git a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir
index 02bb831..92afb76 100644
--- a/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir
+++ b/mlir/test/Conversion/PDLToPDLInterp/pdl-to-pdl-interp-matcher.mlir
@@ -79,6 +79,57 @@ module @constraints {
// -----
+// CHECK-LABEL: module @constraint_with_result
+module @constraint_with_result {
+ // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
+ // CHECK: %[[ATTR:.*]] = pdl_interp.apply_constraint "check_op_and_get_attr_constr"(%[[ROOT]]
+ // CHECK: pdl_interp.record_match @rewriters::@pdl_generated_rewriter(%[[ROOT]], %[[ATTR]] : !pdl.operation, !pdl.attribute)
+ pdl.pattern : benefit(1) {
+ %root = operation
+ %attr = pdl.apply_native_constraint "check_op_and_get_attr_constr"(%root : !pdl.operation) : !pdl.attribute
+ rewrite %root with "rewriter"(%attr : !pdl.attribute)
+ }
+}
+
+// -----
+
+// CHECK-LABEL: module @constraint_with_unused_result
+module @constraint_with_unused_result {
+ // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
+ // CHECK: %[[ATTR:.*]] = pdl_interp.apply_constraint "check_op_and_get_attr_constr"(%[[ROOT]]
+ // CHECK: pdl_interp.record_match @rewriters::@pdl_generated_rewriter(%[[ROOT]] : !pdl.operation)
+ pdl.pattern : benefit(1) {
+ %root = operation
+ %attr = pdl.apply_native_constraint "check_op_and_get_attr_constr"(%root : !pdl.operation) : !pdl.attribute
+ rewrite %root with "rewriter"
+ }
+}
+
+// -----
+
+// CHECK-LABEL: module @constraint_with_result_multiple
+module @constraint_with_result_multiple {
+ // check that native constraints work as expected even when multiple identical constraints are fused
+
+ // CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
+ // CHECK: %[[ATTR:.*]] = pdl_interp.apply_constraint "check_op_and_get_attr_constr"(%[[ROOT]]
+ // CHECK-NOT: pdl_interp.apply_constraint "check_op_and_get_attr_constr"
+ // CHECK: pdl_interp.record_match @rewriters::@pdl_generated_rewriter_0(%[[ROOT]], %[[ATTR]] : !pdl.operation, !pdl.attribute)
+ // CHECK: pdl_interp.record_match @rewriters::@pdl_generated_rewriter(%[[ROOT]], %[[ATTR]] : !pdl.operation, !pdl.attribute)
+ pdl.pattern : benefit(1) {
+ %root = operation
+ %attr = pdl.apply_native_constraint "check_op_and_get_attr_constr"(%root : !pdl.operation) : !pdl.attribute
+ rewrite %root with "rewriter"(%attr : !pdl.attribute)
+ }
+ pdl.pattern : benefit(1) {
+ %root = operation
+ %attr = pdl.apply_native_constraint "check_op_and_get_attr_constr"(%root : !pdl.operation) : !pdl.attribute
+ rewrite %root with "rewriter"(%attr : !pdl.attribute)
+ }
+}
+
+// -----
+
// CHECK-LABEL: module @negated_constraint
module @negated_constraint {
// CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
diff --git a/mlir/test/Conversion/PDLToPDLInterp/use-constraint-result.mlir b/mlir/test/Conversion/PDLToPDLInterp/use-constraint-result.mlir
new file mode 100644
index 0000000..4511baf
--- /dev/null
+++ b/mlir/test/Conversion/PDLToPDLInterp/use-constraint-result.mlir
@@ -0,0 +1,77 @@
+// RUN: mlir-opt -split-input-file -convert-pdl-to-pdl-interp %s | FileCheck %s
+
+// Ensuse that the dependency between add & less
+// causes them to be in the correct order.
+// CHECK-LABEL: matcher
+// CHECK: apply_constraint "return_attr_constraint"
+// CHECK: apply_constraint "use_attr_constraint"
+
+module {
+ pdl.pattern : benefit(1) {
+ %0 = attribute
+ %1 = types
+ %2 = operation "tosa.mul" {"shift" = %0} -> (%1 : !pdl.range<type>)
+ %3 = attribute = 0 : i32
+ %4 = attribute = 1 : i32
+ %5 = apply_native_constraint "return_attr_constraint"(%3, %4 : !pdl.attribute, !pdl.attribute) : !pdl.attribute
+ apply_native_constraint "use_attr_constraint"(%0, %5 : !pdl.attribute, !pdl.attribute)
+ rewrite %2 with "rewriter"
+ }
+}
+
+// -----
+
+// CHECK-LABEL: matcher
+// CHECK: %[[ATTR:.*]] = pdl_interp.get_attribute "attr" of
+// CHECK: %[[CONSTRAINT:.*]] = pdl_interp.apply_constraint "return_attr_constraint"
+// CHECK: pdl_interp.are_equal %[[ATTR:.*]], %[[CONSTRAINT:.*]]
+
+pdl.pattern : benefit(1) {
+ %inputOp = operation
+ %result = result 0 of %inputOp
+ %attr = pdl.apply_native_constraint "return_attr_constraint"(%inputOp : !pdl.operation) : !pdl.attribute
+ %root = operation(%result : !pdl.value) {"attr" = %attr}
+ rewrite %root with "rewriter"(%attr : !pdl.attribute)
+}
+
+// -----
+
+// CHECK-LABEL: matcher
+// CHECK: %[[CONSTRAINT:.*]] = pdl_interp.apply_constraint "return_value_constr"
+// CHECK: %[[VALUE:.*]] = pdl_interp.get_operand 0
+// CHECK: pdl_interp.are_equal %[[VALUE:.*]], %[[CONSTRAINT:.*]]
+
+pdl.pattern : benefit(1) {
+ %attr = attribute = 10
+ %value = pdl.apply_native_constraint "return_value_constr"(%attr: !pdl.attribute) : !pdl.value
+ %root = operation(%value : !pdl.value)
+ rewrite %root with "rewriter"
+}
+
+// -----
+
+// CHECK-LABEL: matcher
+// CHECK: %[[CONSTRAINT:.*]] = pdl_interp.apply_constraint "return_type_constr"
+// CHECK: %[[TYPE:.*]] = pdl_interp.get_value_type of
+// CHECK: pdl_interp.are_equal %[[TYPE:.*]], %[[CONSTRAINT:.*]]
+
+pdl.pattern : benefit(1) {
+ %attr = attribute = 10
+ %type = pdl.apply_native_constraint "return_type_constr"(%attr: !pdl.attribute) : !pdl.type
+ %root = operation -> (%type : !pdl.type)
+ rewrite %root with "rewriter"
+}
+
+// -----
+
+// CHECK-LABEL: matcher
+// CHECK: %[[CONSTRAINT:.*]] = pdl_interp.apply_constraint "return_type_range_constr"
+// CHECK: %[[TYPE:.*]] = pdl_interp.get_value_type of
+// CHECK: pdl_interp.are_equal %[[TYPE:.*]], %[[CONSTRAINT:.*]]
+
+pdl.pattern : benefit(1) {
+ %attr = attribute = 10
+ %types = pdl.apply_native_constraint "return_type_range_constr"(%attr: !pdl.attribute) : !pdl.range<type>
+ %root = operation -> (%types : !pdl.range<type>)
+ rewrite %root with "rewriter"
+}
diff --git a/mlir/test/Dialect/PDL/ops.mlir b/mlir/test/Dialect/PDL/ops.mlir
index 6e6da5c..20e40de 100644
--- a/mlir/test/Dialect/PDL/ops.mlir
+++ b/mlir/test/Dialect/PDL/ops.mlir
@@ -134,6 +134,24 @@ pdl.pattern @apply_rewrite_with_no_results : benefit(1) {
// -----
+pdl.pattern @apply_constraint_with_no_results : benefit(1) {
+ %root = operation
+ apply_native_constraint "NativeConstraint"(%root : !pdl.operation)
+ rewrite %root with "rewriter"
+}
+
+// -----
+
+pdl.pattern @apply_constraint_with_results : benefit(1) {
+ %root = operation
+ %attr = apply_native_constraint "NativeConstraint"(%root : !pdl.operation) : !pdl.attribute
+ rewrite %root {
+ apply_native_rewrite "NativeRewrite"(%attr : !pdl.attribute)
+ }
+}
+
+// -----
+
pdl.pattern @attribute_with_dict : benefit(1) {
%root = operation
rewrite %root {
diff --git a/mlir/test/Rewrite/pdl-bytecode.mlir b/mlir/test/Rewrite/pdl-bytecode.mlir
index ae61c1a..f8e4f2e 100644
--- a/mlir/test/Rewrite/pdl-bytecode.mlir
+++ b/mlir/test/Rewrite/pdl-bytecode.mlir
@@ -109,6 +109,74 @@ module @ir attributes { test.apply_constraint_3 } {
// -----
+// Test returning a type from a native constraint.
+module @patterns {
+ pdl_interp.func @matcher(%root : !pdl.operation) {
+ pdl_interp.check_operation_name of %root is "test.success_op" -> ^pat, ^end
+
+ ^pat:
+ %new_type = pdl_interp.apply_constraint "op_constr_return_type"(%root : !pdl.operation) : !pdl.type -> ^pat2, ^end
+
+ ^pat2:
+ pdl_interp.record_match @rewriters::@success(%root, %new_type : !pdl.operation, !pdl.type) : benefit(1), loc([%root]) -> ^end
+
+ ^end:
+ pdl_interp.finalize
+ }
+
+ module @rewriters {
+ pdl_interp.func @success(%root : !pdl.operation, %new_type : !pdl.type) {
+ %op = pdl_interp.create_operation "test.replaced_by_pattern" -> (%new_type : !pdl.type)
+ pdl_interp.erase %root
+ pdl_interp.finalize
+ }
+ }
+}
+
+// CHECK-LABEL: test.apply_constraint_4
+// CHECK-NOT: "test.replaced_by_pattern"
+// CHECK: "test.replaced_by_pattern"() : () -> f32
+module @ir attributes { test.apply_constraint_4 } {
+ "test.failure_op"() : () -> ()
+ "test.success_op"() : () -> ()
+}
+
+// -----
+
+// Test success and failure cases of native constraints with pdl.range results.
+module @patterns {
+ pdl_interp.func @matcher(%root : !pdl.operation) {
+ pdl_interp.check_operation_name of %root is "test.success_op" -> ^pat, ^end
+
+ ^pat:
+ %num_results = pdl_interp.create_attribute 2 : i32
+ %types = pdl_interp.apply_constraint "op_constr_return_type_range"(%root, %num_results : !pdl.operation, !pdl.attribute) : !pdl.range<type> -> ^pat1, ^end
+
+ ^pat1:
+ pdl_interp.record_match @rewriters::@success(%root, %types : !pdl.operation, !pdl.range<type>) : benefit(1), loc([%root]) -> ^end
+
+ ^end:
+ pdl_interp.finalize
+ }
+
+ module @rewriters {
+ pdl_interp.func @success(%root : !pdl.operation, %types : !pdl.range<type>) {
+ %op = pdl_interp.create_operation "test.replaced_by_pattern" -> (%types : !pdl.range<type>)
+ pdl_interp.erase %root
+ pdl_interp.finalize
+ }
+ }
+}
+
+// CHECK-LABEL: test.apply_constraint_5
+// CHECK-NOT: "test.replaced_by_pattern"
+// CHECK: "test.replaced_by_pattern"() : () -> (f32, f32)
+module @ir attributes { test.apply_constraint_5 } {
+ "test.failure_op"() : () -> ()
+ "test.success_op"() : () -> ()
+}
+
+// -----
//===----------------------------------------------------------------------===//
// pdl_interp::ApplyRewriteOp
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index 50caf8f9..b9424e0 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -887,7 +887,7 @@ public:
#include "TestTransformDialectExtensionTypes.cpp.inc"
>();
- auto verboseConstraint = [](PatternRewriter &rewriter,
+ auto verboseConstraint = [](PatternRewriter &rewriter, PDLResultList &,
ArrayRef<PDLValue> pdlValues) {
for (const PDLValue &pdlValue : pdlValues) {
if (Operation *op = pdlValue.dyn_cast<Operation *>()) {
diff --git a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
index daa1c37..56af3c1 100644
--- a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
+++ b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp
@@ -30,6 +30,50 @@ static LogicalResult customMultiEntityVariadicConstraint(
return success();
}
+// Custom constraint that returns a value if the op is named test.success_op
+static LogicalResult customValueResultConstraint(PatternRewriter &rewriter,
+ PDLResultList &results,
+ ArrayRef<PDLValue> args) {
+ auto *op = args[0].cast<Operation *>();
+ if (op->getName().getStringRef() == "test.success_op") {
+ StringAttr customAttr = rewriter.getStringAttr("test.success");
+ results.push_back(customAttr);
+ return success();
+ }
+ return failure();
+}
+
+// Custom constraint that returns a type if the op is named test.success_op
+static LogicalResult customTypeResultConstraint(PatternRewriter &rewriter,
+ PDLResultList &results,
+ ArrayRef<PDLValue> args) {
+ auto *op = args[0].cast<Operation *>();
+ if (op->getName().getStringRef() == "test.success_op") {
+ results.push_back(rewriter.getF32Type());
+ return success();
+ }
+ return failure();
+}
+
+// Custom constraint that returns a type range of variable length if the op is
+// named test.success_op
+static LogicalResult customTypeRangeResultConstraint(PatternRewriter &rewriter,
+ PDLResultList &results,
+ ArrayRef<PDLValue> args) {
+ auto *op = args[0].cast<Operation *>();
+ int numTypes = args[1].cast<Attribute>().cast<IntegerAttr>().getInt();
+
+ if (op->getName().getStringRef() == "test.success_op") {
+ SmallVector<Type> types;
+ for (int i = 0; i < numTypes; i++) {
+ types.push_back(rewriter.getF32Type());
+ }
+ results.push_back(TypeRange(types));
+ return success();
+ }
+ return failure();
+}
+
// Custom creator invoked from PDL.
static Operation *customCreate(PatternRewriter &rewriter, Operation *op) {
return rewriter.create(OperationState(op->getLoc(), "test.success"));
@@ -102,6 +146,12 @@ struct TestPDLByteCodePass
customMultiEntityConstraint);
pdlPattern.registerConstraintFunction("multi_entity_var_constraint",
customMultiEntityVariadicConstraint);
+ pdlPattern.registerConstraintFunction("op_constr_return_attr",
+ customValueResultConstraint);
+ pdlPattern.registerConstraintFunction("op_constr_return_type",
+ customTypeResultConstraint);
+ pdlPattern.registerConstraintFunction("op_constr_return_type_range",
+ customTypeRangeResultConstraint);
pdlPattern.registerRewriteFunction("creator", customCreate);
pdlPattern.registerRewriteFunction("var_creator",
customVariadicResultCreate);
diff --git a/mlir/test/mlir-pdll/Parser/constraint-failure.pdll b/mlir/test/mlir-pdll/Parser/constraint-failure.pdll
index 18877b4..48747d3 100644
--- a/mlir/test/mlir-pdll/Parser/constraint-failure.pdll
+++ b/mlir/test/mlir-pdll/Parser/constraint-failure.pdll
@@ -158,8 +158,3 @@ Pattern {
// CHECK: expected `;` after native declaration
Constraint Foo() [{}]
-
-// -----
-
-// CHECK: native Constraints currently do not support returning results
-Constraint Foo() -> Op;
diff --git a/mlir/test/mlir-pdll/Parser/constraint.pdll b/mlir/test/mlir-pdll/Parser/constraint.pdll
index 1c0a015..e2a52ff 100644
--- a/mlir/test/mlir-pdll/Parser/constraint.pdll
+++ b/mlir/test/mlir-pdll/Parser/constraint.pdll
@@ -12,6 +12,14 @@ Constraint Foo() [{ /* Native Code */ }];
// -----
+// Test that native constraints support returning results.
+
+// CHECK: Module
+// CHECK: `-UserConstraintDecl {{.*}} Name<Foo> ResultType<Attr>
+Constraint Foo() -> Attr;
+
+// -----
+
// CHECK: Module
// CHECK: `-UserConstraintDecl {{.*}} Name<Foo> ResultType<Value>
// CHECK: `Inputs`
diff --git a/mlir/test/python/dialects/pdl_ops.py b/mlir/test/python/dialects/pdl_ops.py
index 0d364f9..95cb25c 100644
--- a/mlir/test/python/dialects/pdl_ops.py
+++ b/mlir/test/python/dialects/pdl_ops.py
@@ -298,6 +298,6 @@ def test_apply_native_constraint():
pattern = PatternOp(1)
with InsertionPoint(pattern.body):
resultType = TypeOp()
- ApplyNativeConstraintOp("typeConstraint", args=[resultType])
+ ApplyNativeConstraintOp([], "typeConstraint", args=[resultType])
root = OperationOp(types=[resultType])
RewriteOp(root, name="rewrite")