aboutsummaryrefslogtreecommitdiff
path: root/mlir/test/lib
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/test/lib')
-rw-r--r--mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp72
-rw-r--r--mlir/test/lib/Dialect/CMakeLists.txt1
-rw-r--r--mlir/test/lib/Dialect/OpenACC/CMakeLists.txt16
-rw-r--r--mlir/test/lib/Dialect/OpenACC/TestOpenACC.cpp23
-rw-r--r--mlir/test/lib/Dialect/OpenACC/TestPointerLikeTypeInterface.cpp305
-rw-r--r--mlir/test/lib/Dialect/Test/TestOps.td13
-rw-r--r--mlir/test/lib/Dialect/Test/TestPatterns.cpp10
7 files changed, 437 insertions, 3 deletions
diff --git a/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp
index d57b41c..eb0d980 100644
--- a/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp
+++ b/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp
@@ -11,7 +11,6 @@
//===----------------------------------------------------------------------===//
#include "TestDenseDataFlowAnalysis.h"
-#include "TestDialect.h"
#include "TestOps.h"
#include "mlir/Analysis/DataFlow/DenseAnalysis.h"
#include "mlir/Analysis/DataFlow/Utils.h"
@@ -23,12 +22,15 @@
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/TypeID.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
using namespace mlir::dataflow;
using namespace mlir::dataflow::test;
+#define DEBUG_TYPE "test-next-access"
+
namespace {
class NextAccess : public AbstractDenseLattice, public AccessLatticeBase {
@@ -72,6 +74,7 @@ public:
// means "we don't know what the next access is" rather than "there is no next
// access". But it's unclear how to differentiate the two cases...
void setToExitState(NextAccess *lattice) override {
+ LDBG() << "setToExitState: setting lattice to unknown state";
propagateIfChanged(lattice, lattice->setKnownToUnknown());
}
@@ -87,16 +90,23 @@ public:
LogicalResult NextAccessAnalysis::visitOperation(Operation *op,
const NextAccess &after,
NextAccess *before) {
+ LDBG() << "visitOperation: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
+ LDBG() << " after state: " << after;
+ LDBG() << " before state: " << *before;
+
auto memory = dyn_cast<MemoryEffectOpInterface>(op);
// If we can't reason about the memory effects, conservatively assume we can't
// say anything about the next access.
if (!memory) {
+ LDBG() << " No memory effect interface, setting to exit state";
setToExitState(before);
return success();
}
SmallVector<MemoryEffects::EffectInstance> effects;
memory.getEffects(effects);
+ LDBG() << " Found " << effects.size() << " memory effects";
// First, check if all underlying values are already known. Otherwise, avoid
// propagating and stay in the "undefined" state to avoid incorrectly
@@ -110,6 +120,7 @@ LogicalResult NextAccessAnalysis::visitOperation(Operation *op,
// Effects with unspecified value are treated conservatively and we cannot
// assume anything about the next access.
if (!value) {
+ LDBG() << " Effect has unspecified value, setting to exit state";
setToExitState(before);
return success();
}
@@ -124,38 +135,63 @@ LogicalResult NextAccessAnalysis::visitOperation(Operation *op,
});
// If the underlying value is not known yet, don't propagate.
- if (!underlyingValue)
+ if (!underlyingValue) {
+ LDBG() << " Underlying value not known for " << value
+ << ", skipping propagation";
return success();
+ }
+ LDBG() << " Found underlying value " << *underlyingValue << " for "
+ << value;
underlyingValues.push_back(*underlyingValue);
}
// Update the state if all underlying values are known.
+ LDBG() << " All underlying values known, updating state";
ChangeResult result = before->meet(after);
for (const auto &[effect, value] : llvm::zip(effects, underlyingValues)) {
// If the underlying value is known to be unknown, set to fixpoint.
if (!value) {
+ LDBG() << " Underlying value is unknown, setting to exit state";
setToExitState(before);
return success();
}
+ LDBG() << " Setting next access for value " << value << " to operation "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
result |= before->set(value, op);
}
+ LDBG() << " Final result: "
+ << (result == ChangeResult::Change ? "changed" : "no change");
propagateIfChanged(before, result);
return success();
}
void NextAccessAnalysis::buildOperationEquivalentLatticeAnchor(Operation *op) {
+ LDBG() << "buildOperationEquivalentLatticeAnchor: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
if (isMemoryEffectFree(op)) {
+ LDBG() << " Operation is memory effect free, unioning lattice anchors";
unionLatticeAnchors<NextAccess>(getProgramPointBefore(op),
getProgramPointAfter(op));
+ } else {
+ LDBG() << " Operation has memory effects, not unioning lattice anchors";
}
}
void NextAccessAnalysis::visitCallControlFlowTransfer(
CallOpInterface call, CallControlFlowAction action, const NextAccess &after,
NextAccess *before) {
+ LDBG() << "visitCallControlFlowTransfer: "
+ << OpWithFlags(call.getOperation(), OpPrintingFlags().skipRegions());
+ LDBG() << " action: "
+ << (action == CallControlFlowAction::ExternalCallee ? "ExternalCallee"
+ : action == CallControlFlowAction::EnterCallee ? "EnterCallee"
+ : "ExitCallee");
+ LDBG() << " assumeFuncReads: " << assumeFuncReads;
+
if (action == CallControlFlowAction::ExternalCallee && assumeFuncReads) {
+ LDBG() << " Handling external callee with assumed function reads";
SmallVector<Value> underlyingValues;
underlyingValues.reserve(call->getNumOperands());
for (Value operand : call.getArgOperands()) {
@@ -165,15 +201,26 @@ void NextAccessAnalysis::visitCallControlFlowTransfer(
return getOrCreateFor<UnderlyingValueLattice>(
getProgramPointBefore(call.getOperation()), value);
});
- if (!underlyingValue)
+ if (!underlyingValue) {
+ LDBG() << " Underlying value not known for operand " << operand
+ << ", returning";
return;
+ }
+ LDBG() << " Found underlying value " << *underlyingValue
+ << " for operand " << operand;
underlyingValues.push_back(*underlyingValue);
}
+ LDBG() << " Setting next access for " << underlyingValues.size()
+ << " operands";
ChangeResult result = before->meet(after);
for (Value operand : underlyingValues) {
+ LDBG() << " Setting next access for operand " << operand << " to call "
+ << call;
result |= before->set(operand, call);
}
+ LDBG() << " Call control flow result: "
+ << (result == ChangeResult::Change ? "changed" : "no change");
return propagateIfChanged(before, result);
}
auto testCallAndStore =
@@ -182,8 +229,10 @@ void NextAccessAnalysis::visitCallControlFlowTransfer(
testCallAndStore.getStoreBeforeCall()) ||
(action == CallControlFlowAction::ExitCallee &&
!testCallAndStore.getStoreBeforeCall()))) {
+ LDBG() << " Handling TestCallAndStoreOp with special logic";
(void)visitOperation(call, after, before);
} else {
+ LDBG() << " Using default call control flow transfer logic";
AbstractDenseBackwardDataFlowAnalysis::visitCallControlFlowTransfer(
call, action, after, before);
}
@@ -192,6 +241,11 @@ void NextAccessAnalysis::visitCallControlFlowTransfer(
void NextAccessAnalysis::visitRegionBranchControlFlowTransfer(
RegionBranchOpInterface branch, RegionBranchPoint regionFrom,
RegionBranchPoint regionTo, const NextAccess &after, NextAccess *before) {
+ LDBG() << "visitRegionBranchControlFlowTransfer: "
+ << OpWithFlags(branch.getOperation(), OpPrintingFlags().skipRegions());
+ LDBG() << " regionFrom: " << (regionFrom.isParent() ? "parent" : "region");
+ LDBG() << " regionTo: " << (regionTo.isParent() ? "parent" : "region");
+
auto testStoreWithARegion =
dyn_cast<::test::TestStoreWithARegion>(branch.getOperation());
@@ -199,9 +253,11 @@ void NextAccessAnalysis::visitRegionBranchControlFlowTransfer(
((regionTo.isParent() && !testStoreWithARegion.getStoreBeforeRegion()) ||
(regionFrom.isParent() &&
testStoreWithARegion.getStoreBeforeRegion()))) {
+ LDBG() << " Handling TestStoreWithARegion with special logic";
(void)visitOperation(branch, static_cast<const NextAccess &>(after),
static_cast<NextAccess *>(before));
} else {
+ LDBG() << " Using default region branch control flow transfer logic";
propagateIfChanged(before, before->meet(after));
}
}
@@ -278,6 +334,11 @@ struct TestNextAccessPass
void runOnOperation() override {
Operation *op = getOperation();
+ LDBG() << "runOnOperation: Starting test-next-access pass on "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
+ LDBG() << " interprocedural: " << interprocedural;
+ LDBG() << " assumeFuncReads: " << assumeFuncReads;
+
SymbolTableCollection symbolTable;
auto config = DataFlowConfig().setInterprocedural(interprocedural);
@@ -285,15 +346,20 @@ struct TestNextAccessPass
loadBaselineAnalyses(solver);
solver.load<NextAccessAnalysis>(symbolTable, assumeFuncReads);
solver.load<UnderlyingValueAnalysis>();
+ LDBG() << " Initializing and running dataflow solver";
if (failed(solver.initializeAndRun(op))) {
emitError(op->getLoc(), "dataflow solver failed");
return signalPassFailure();
}
+ LDBG() << " Dataflow solver completed successfully";
+ LDBG() << " Walking operations to set next access attributes";
op->walk([&](Operation *op) {
auto tag = op->getAttrOfType<StringAttr>(kTagAttrName);
if (!tag)
return;
+ LDBG() << " Processing tagged operation: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
const NextAccess *nextAccess =
solver.lookupState<NextAccess>(solver.getProgramPointAfter(op));
op->setAttr(kNextAccessAttrName,
diff --git a/mlir/test/lib/Dialect/CMakeLists.txt b/mlir/test/lib/Dialect/CMakeLists.txt
index 3b7bd9b..e31140a 100644
--- a/mlir/test/lib/Dialect/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/CMakeLists.txt
@@ -12,6 +12,7 @@ add_subdirectory(Math)
add_subdirectory(MemRef)
add_subdirectory(Shard)
add_subdirectory(NVGPU)
+add_subdirectory(OpenACC)
add_subdirectory(SCF)
add_subdirectory(Shape)
add_subdirectory(SPIRV)
diff --git a/mlir/test/lib/Dialect/OpenACC/CMakeLists.txt b/mlir/test/lib/Dialect/OpenACC/CMakeLists.txt
new file mode 100644
index 0000000..f84055d
--- /dev/null
+++ b/mlir/test/lib/Dialect/OpenACC/CMakeLists.txt
@@ -0,0 +1,16 @@
+add_mlir_library(MLIROpenACCTestPasses
+ TestOpenACC.cpp
+ TestPointerLikeTypeInterface.cpp
+
+ EXCLUDE_FROM_LIBMLIR
+)
+mlir_target_link_libraries(MLIROpenACCTestPasses PUBLIC
+ MLIRIR
+ MLIRArithDialect
+ MLIRFuncDialect
+ MLIRMemRefDialect
+ MLIROpenACCDialect
+ MLIRPass
+ MLIRSupport
+)
+
diff --git a/mlir/test/lib/Dialect/OpenACC/TestOpenACC.cpp b/mlir/test/lib/Dialect/OpenACC/TestOpenACC.cpp
new file mode 100644
index 0000000..9886240
--- /dev/null
+++ b/mlir/test/lib/Dialect/OpenACC/TestOpenACC.cpp
@@ -0,0 +1,23 @@
+//===- TestOpenACC.cpp - OpenACC Test Registration ------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains unified registration for all OpenACC test passes.
+//
+//===----------------------------------------------------------------------===//
+
+namespace mlir {
+namespace test {
+
+// Forward declarations of individual test pass registration functions
+void registerTestPointerLikeTypeInterfacePass();
+
+// Unified registration function for all OpenACC tests
+void registerTestOpenACC() { registerTestPointerLikeTypeInterfacePass(); }
+
+} // namespace test
+} // namespace mlir
diff --git a/mlir/test/lib/Dialect/OpenACC/TestPointerLikeTypeInterface.cpp b/mlir/test/lib/Dialect/OpenACC/TestPointerLikeTypeInterface.cpp
new file mode 100644
index 0000000..85f9283
--- /dev/null
+++ b/mlir/test/lib/Dialect/OpenACC/TestPointerLikeTypeInterface.cpp
@@ -0,0 +1,305 @@
+//===- TestPointerLikeTypeInterface.cpp - Test PointerLikeType interface -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains test passes for testing the OpenACC PointerLikeType
+// interface methods.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Pass/Pass.h"
+#include "llvm/Support/CommandLine.h"
+
+using namespace mlir;
+using namespace mlir::acc;
+
+namespace {
+
+struct OperationTracker : public OpBuilder::Listener {
+ SmallVector<Operation *> insertedOps;
+
+ void notifyOperationInserted(Operation *op,
+ OpBuilder::InsertPoint previous) override {
+ insertedOps.push_back(op);
+ }
+};
+
+struct TestPointerLikeTypeInterfacePass
+ : public PassWrapper<TestPointerLikeTypeInterfacePass,
+ OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPointerLikeTypeInterfacePass)
+
+ TestPointerLikeTypeInterfacePass() = default;
+ TestPointerLikeTypeInterfacePass(const TestPointerLikeTypeInterfacePass &pass)
+ : PassWrapper(pass) {
+ testMode = pass.testMode;
+ }
+
+ Pass::Option<std::string> testMode{
+ *this, "test-mode",
+ llvm::cl::desc("Test mode: walk, alloc, copy, or free"),
+ llvm::cl::init("walk")};
+
+ StringRef getArgument() const override {
+ return "test-acc-pointer-like-interface";
+ }
+
+ StringRef getDescription() const override {
+ return "Test OpenACC PointerLikeType interface methods on any implementing "
+ "type";
+ }
+
+ void runOnOperation() override;
+
+ void getDependentDialects(DialectRegistry &registry) const override {
+ registry.insert<acc::OpenACCDialect>();
+ registry.insert<arith::ArithDialect>();
+ registry.insert<memref::MemRefDialect>();
+ }
+
+private:
+ void walkAndPrint();
+ void testGenAllocate(Operation *op, Value result, PointerLikeType pointerType,
+ OpBuilder &builder);
+ void testGenFree(Operation *op, Value result, PointerLikeType pointerType,
+ OpBuilder &builder);
+ void testGenCopy(Operation *srcOp, Operation *destOp, Value srcResult,
+ Value destResult, PointerLikeType pointerType,
+ OpBuilder &builder);
+
+ struct PointerCandidate {
+ Operation *op;
+ Value result;
+ PointerLikeType pointerType;
+ };
+};
+
+void TestPointerLikeTypeInterfacePass::runOnOperation() {
+ if (testMode == "walk") {
+ walkAndPrint();
+ return;
+ }
+
+ auto func = getOperation();
+ OpBuilder builder(&getContext());
+
+ if (testMode == "alloc" || testMode == "free") {
+ // Collect all candidates first
+ SmallVector<PointerCandidate> candidates;
+ func.walk([&](Operation *op) {
+ if (op->hasAttr("test.ptr")) {
+ for (auto result : op->getResults()) {
+ if (isa<PointerLikeType>(result.getType())) {
+ candidates.push_back(
+ {op, result, cast<PointerLikeType>(result.getType())});
+ break; // Only take the first PointerLikeType result
+ }
+ }
+ }
+ });
+
+ // Now test all candidates
+ for (const auto &candidate : candidates) {
+ if (testMode == "alloc")
+ testGenAllocate(candidate.op, candidate.result, candidate.pointerType,
+ builder);
+ else if (testMode == "free")
+ testGenFree(candidate.op, candidate.result, candidate.pointerType,
+ builder);
+ }
+ } else if (testMode == "copy") {
+ // Collect all source and destination candidates
+ SmallVector<PointerCandidate> sources, destinations;
+
+ func.walk([&](Operation *op) {
+ if (op->hasAttr("test.src_ptr")) {
+ for (auto result : op->getResults()) {
+ if (isa<PointerLikeType>(result.getType())) {
+ sources.push_back(
+ {op, result, cast<PointerLikeType>(result.getType())});
+ break;
+ }
+ }
+ }
+ if (op->hasAttr("test.dest_ptr")) {
+ for (auto result : op->getResults()) {
+ if (isa<PointerLikeType>(result.getType())) {
+ destinations.push_back(
+ {op, result, cast<PointerLikeType>(result.getType())});
+ break;
+ }
+ }
+ }
+ });
+
+ // Try copying from each source to each destination
+ for (const auto &src : sources)
+ for (const auto &dest : destinations)
+ testGenCopy(src.op, dest.op, src.result, dest.result, src.pointerType,
+ builder);
+ }
+}
+
+void TestPointerLikeTypeInterfacePass::walkAndPrint() {
+ auto func = getOperation();
+
+ func.walk([&](Operation *op) {
+ // Look for operations marked with "test.ptr", "test.src_ptr", or
+ // "test.dest_ptr"
+ if (op->hasAttr("test.ptr") || op->hasAttr("test.src_ptr") ||
+ op->hasAttr("test.dest_ptr")) {
+ llvm::errs() << "Operation: ";
+ op->print(llvm::errs());
+ llvm::errs() << "\n";
+
+ // Check each result to see if it's a PointerLikeType
+ for (auto result : op->getResults()) {
+ if (isa<PointerLikeType>(result.getType())) {
+ llvm::errs() << " Result " << result.getResultNumber()
+ << " is PointerLikeType: ";
+ result.getType().print(llvm::errs());
+ llvm::errs() << "\n";
+ } else {
+ llvm::errs() << " Result " << result.getResultNumber()
+ << " is NOT PointerLikeType: ";
+ result.getType().print(llvm::errs());
+ llvm::errs() << "\n";
+ }
+ }
+
+ if (op->getNumResults() == 0)
+ llvm::errs() << " Operation has no results\n";
+
+ llvm::errs() << "\n";
+ }
+ });
+}
+
+void TestPointerLikeTypeInterfacePass::testGenAllocate(
+ Operation *op, Value result, PointerLikeType pointerType,
+ OpBuilder &builder) {
+ Location loc = op->getLoc();
+
+ // Create a new builder with the listener and set insertion point
+ OperationTracker tracker;
+ OpBuilder newBuilder(op->getContext());
+ newBuilder.setListener(&tracker);
+ newBuilder.setInsertionPointAfter(op);
+
+ // Call the genAllocate API
+ Value allocRes = pointerType.genAllocate(newBuilder, loc, "test_alloc",
+ result.getType(), result);
+
+ if (allocRes) {
+ llvm::errs() << "Successfully generated alloc for operation: ";
+ op->print(llvm::errs());
+ llvm::errs() << "\n";
+
+ // Print all operations that were inserted
+ for (Operation *insertedOp : tracker.insertedOps) {
+ llvm::errs() << "\tGenerated: ";
+ insertedOp->print(llvm::errs());
+ llvm::errs() << "\n";
+ }
+ } else {
+ llvm::errs() << "Failed to generate alloc for operation: ";
+ op->print(llvm::errs());
+ llvm::errs() << "\n";
+ }
+}
+
+void TestPointerLikeTypeInterfacePass::testGenFree(Operation *op, Value result,
+ PointerLikeType pointerType,
+ OpBuilder &builder) {
+ Location loc = op->getLoc();
+
+ // Create a new builder with the listener and set insertion point
+ OperationTracker tracker;
+ OpBuilder newBuilder(op->getContext());
+ newBuilder.setListener(&tracker);
+ newBuilder.setInsertionPointAfter(op);
+
+ // Call the genFree API
+ auto typedResult = cast<TypedValue<PointerLikeType>>(result);
+ bool success =
+ pointerType.genFree(newBuilder, loc, typedResult, result.getType());
+
+ if (success) {
+ llvm::errs() << "Successfully generated free for operation: ";
+ op->print(llvm::errs());
+ llvm::errs() << "\n";
+
+ // Print all operations that were inserted
+ for (Operation *insertedOp : tracker.insertedOps) {
+ llvm::errs() << "\tGenerated: ";
+ insertedOp->print(llvm::errs());
+ llvm::errs() << "\n";
+ }
+ } else {
+ llvm::errs() << "Failed to generate free for operation: ";
+ op->print(llvm::errs());
+ llvm::errs() << "\n";
+ }
+}
+
+void TestPointerLikeTypeInterfacePass::testGenCopy(
+ Operation *srcOp, Operation *destOp, Value srcResult, Value destResult,
+ PointerLikeType pointerType, OpBuilder &builder) {
+ Location loc = destOp->getLoc();
+
+ // Create a new builder with the listener and set insertion point
+ OperationTracker tracker;
+ OpBuilder newBuilder(destOp->getContext());
+ newBuilder.setListener(&tracker);
+ newBuilder.setInsertionPointAfter(destOp);
+
+ // Call the genCopy API with the provided source and destination
+ auto typedSrc = cast<TypedValue<PointerLikeType>>(srcResult);
+ auto typedDest = cast<TypedValue<PointerLikeType>>(destResult);
+ bool success = pointerType.genCopy(newBuilder, loc, typedDest, typedSrc,
+ srcResult.getType());
+
+ if (success) {
+ llvm::errs() << "Successfully generated copy from source: ";
+ srcOp->print(llvm::errs());
+ llvm::errs() << " to destination: ";
+ destOp->print(llvm::errs());
+ llvm::errs() << "\n";
+
+ // Print all operations that were inserted
+ for (Operation *insertedOp : tracker.insertedOps) {
+ llvm::errs() << "\tGenerated: ";
+ insertedOp->print(llvm::errs());
+ llvm::errs() << "\n";
+ }
+ } else {
+ llvm::errs() << "Failed to generate copy from source: ";
+ srcOp->print(llvm::errs());
+ llvm::errs() << " to destination: ";
+ destOp->print(llvm::errs());
+ llvm::errs() << "\n";
+ }
+}
+
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Pass Registration
+//===----------------------------------------------------------------------===//
+
+namespace mlir {
+namespace test {
+void registerTestPointerLikeTypeInterfacePass() {
+ PassRegistration<TestPointerLikeTypeInterfacePass>();
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 6ea27187..6329d61 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1169,6 +1169,11 @@ def OpP : TEST_Op<"op_p"> {
let results = (outs I32);
}
+def OpQ : TEST_Op<"op_q"> {
+ let arguments = (ins AnyType, AnyType);
+ let results = (outs AnyType);
+}
+
// Test constant-folding a pattern that maps `(F32) -> SI32`.
def SignOp : TEST_Op<"sign", [SameOperandsAndResultShape]> {
let arguments = (ins RankedTensorOf<[F32]>:$operand);
@@ -1207,6 +1212,14 @@ def TestNestedSameOpAndSameArgEqualityPattern :
def TestMultipleEqualArgsPattern :
Pat<(OpP $a, $b, $a, $a, $b, $c), (OpN $c, $b)>;
+// Test equal arguments checks are applied before user provided constraints.
+def AssertBinOpEqualArgsAndReturnTrue : Constraint<
+ CPred<"assertBinOpEqualArgsAndReturnTrue($0)">>;
+def TestEqualArgsCheckBeforeUserConstraintsPattern :
+ Pat<(OpQ:$op $x, $x),
+ (replaceWithValue $x),
+ [(AssertBinOpEqualArgsAndReturnTrue $op)]>;
+
// Test for memrefs normalization of an op with normalizable memrefs.
def OpNorm : TEST_Op<"op_norm", [MemRefsNormalizable]> {
let arguments = (ins AnyMemRef:$X, AnyMemRef:$Y);
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index f8b5144..ee4fa39 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -70,6 +70,16 @@ static Attribute opMTest(PatternRewriter &rewriter, Value val) {
return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i);
}
+static bool assertBinOpEqualArgsAndReturnTrue(Value v) {
+ Operation *operation = v.getDefiningOp();
+ if (operation->getOperand(0) != operation->getOperand(1)) {
+ // Name binding equality check must happen before user-defined constraints,
+ // thus this must not be triggered.
+ llvm::report_fatal_error("Arguments are not equal");
+ }
+ return true;
+}
+
namespace {
#include "TestPatterns.inc"
} // namespace