aboutsummaryrefslogtreecommitdiff
path: root/mlir/test/lib
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/test/lib')
-rw-r--r--mlir/test/lib/Dialect/Test/TestPatterns.cpp28
-rw-r--r--mlir/test/lib/Dialect/Test/TestTypeDefs.td7
-rw-r--r--mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp3
3 files changed, 32 insertions, 6 deletions
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index fd2b943..9b64bc6 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1418,6 +1418,22 @@ public:
}
};
+class TestPostOrderLegalization : public ConversionPattern {
+public:
+ TestPostOrderLegalization(MLIRContext *ctx, const TypeConverter &converter)
+ : ConversionPattern(converter, "test.post_order_legalization", 1, ctx) {}
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ for (Region &r : op->getRegions())
+ if (failed(rewriter.legalize(&r)))
+ return failure();
+ rewriter.modifyOpInPlace(
+ op, [&]() { op->setAttr("is_legal", rewriter.getUnitAttr()); });
+ return success();
+ }
+};
+
/// Test unambiguous overload resolution of replaceOpWithMultiple. This
/// function is just to trigger compiler errors. It is never executed.
[[maybe_unused]] void testReplaceOpWithMultipleOverloads(
@@ -1532,7 +1548,8 @@ struct TestLegalizePatternDriver
patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
TestPassthroughInvalidOp, TestMultiple1ToNReplacement,
TestValueReplace, TestReplaceWithValidConsumer,
- TestTypeConsumerOpPattern>(&getContext(), converter);
+ TestTypeConsumerOpPattern, TestPostOrderLegalization>(
+ &getContext(), converter);
patterns.add<TestConvertBlockArgs>(converter, &getContext());
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
converter);
@@ -1553,14 +1570,16 @@ struct TestLegalizePatternDriver
[](Type type) { return type.isF32(); });
});
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
- return converter.isSignatureLegal(op.getFunctionType()) &&
- converter.isLegal(&op.getBody());
+ return converter.isSignatureLegal(op.getFunctionType());
});
target.addDynamicallyLegalOp<func::CallOp>(
[&](func::CallOp op) { return converter.isLegal(op); });
target.addDynamicallyLegalOp(
OperationName("test.value_replace", &getContext()),
[](Operation *op) { return op->hasAttr("is_legal"); });
+ target.addDynamicallyLegalOp(
+ OperationName("test.post_order_legalization", &getContext()),
+ [](Operation *op) { return op->hasAttr("is_legal"); });
// TestCreateUnregisteredOp creates `arith.constant` operation,
// which was not added to target intentionally to test
@@ -2156,8 +2175,7 @@ struct TestTypeConversionDriver
recursiveType.getName() == "outer_converted_type");
});
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
- return converter.isSignatureLegal(op.getFunctionType()) &&
- converter.isLegal(&op.getBody());
+ return converter.isSignatureLegal(op.getFunctionType());
});
target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) {
// Allow casts from F64 to F32.
diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index ea20597..9859bd0 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -470,4 +470,11 @@ def TestMemrefType : Test_Type<"TestMemref",
}];
}
+// Test implementation of an interface with methods specifying a
+// method body
+def TestBaseBody : Test_Type<"TestBaseBody",
+ [DeclareTypeInterfaceMethods<TestBaseTypeInterfacePrintTypeA>]> {
+ let mnemonic = "test_base_body";
+}
+
#endif // TEST_TYPEDEFS
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index 76d4611..93d5144 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -200,7 +200,8 @@ class TestStepOpPattern : public OpConversionPattern<vector::StepOp> {
Value sgId =
gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
- auto maybeOffsets = sliceAttr.getOffsets(rewriter, loc, sgId, wgShape);
+ auto maybeOffsets =
+ sliceAttr.computeDistributedCoords(rewriter, loc, sgId, wgShape);
if (failed(maybeOffsets))
return failure();