diff options
Diffstat (limited to 'mlir/test/lib')
| -rw-r--r-- | mlir/test/lib/Dialect/Test/TestPatterns.cpp | 28 | ||||
| -rw-r--r-- | mlir/test/lib/Dialect/Test/TestTypeDefs.td | 7 | ||||
| -rw-r--r-- | mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp | 3 |
3 files changed, 32 insertions, 6 deletions
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index fd2b943..9b64bc6 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -1418,6 +1418,22 @@ public: } }; +class TestPostOrderLegalization : public ConversionPattern { +public: + TestPostOrderLegalization(MLIRContext *ctx, const TypeConverter &converter) + : ConversionPattern(converter, "test.post_order_legalization", 1, ctx) {} + LogicalResult + matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands, + ConversionPatternRewriter &rewriter) const final { + for (Region &r : op->getRegions()) + if (failed(rewriter.legalize(&r))) + return failure(); + rewriter.modifyOpInPlace( + op, [&]() { op->setAttr("is_legal", rewriter.getUnitAttr()); }); + return success(); + } +}; + /// Test unambiguous overload resolution of replaceOpWithMultiple. This /// function is just to trigger compiler errors. It is never executed. [[maybe_unused]] void testReplaceOpWithMultipleOverloads( @@ -1532,7 +1548,8 @@ struct TestLegalizePatternDriver patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp, TestPassthroughInvalidOp, TestMultiple1ToNReplacement, TestValueReplace, TestReplaceWithValidConsumer, - TestTypeConsumerOpPattern>(&getContext(), converter); + TestTypeConsumerOpPattern, TestPostOrderLegalization>( + &getContext(), converter); patterns.add<TestConvertBlockArgs>(converter, &getContext()); mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns, converter); @@ -1553,14 +1570,16 @@ struct TestLegalizePatternDriver [](Type type) { return type.isF32(); }); }); target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { - return converter.isSignatureLegal(op.getFunctionType()) && - converter.isLegal(&op.getBody()); + return converter.isSignatureLegal(op.getFunctionType()); }); target.addDynamicallyLegalOp<func::CallOp>( [&](func::CallOp op) { return converter.isLegal(op); }); target.addDynamicallyLegalOp( OperationName("test.value_replace", &getContext()), [](Operation *op) { return op->hasAttr("is_legal"); }); + target.addDynamicallyLegalOp( + OperationName("test.post_order_legalization", &getContext()), + [](Operation *op) { return op->hasAttr("is_legal"); }); // TestCreateUnregisteredOp creates `arith.constant` operation, // which was not added to target intentionally to test @@ -2156,8 +2175,7 @@ struct TestTypeConversionDriver recursiveType.getName() == "outer_converted_type"); }); target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { - return converter.isSignatureLegal(op.getFunctionType()) && - converter.isLegal(&op.getBody()); + return converter.isSignatureLegal(op.getFunctionType()); }); target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) { // Allow casts from F64 to F32. diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td index ea20597..9859bd0 100644 --- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td +++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td @@ -470,4 +470,11 @@ def TestMemrefType : Test_Type<"TestMemref", }]; } +// Test implementation of an interface with methods specifying a +// method body +def TestBaseBody : Test_Type<"TestBaseBody", + [DeclareTypeInterfaceMethods<TestBaseTypeInterfacePrintTypeA>]> { + let mnemonic = "test_base_body"; +} + #endif // TEST_TYPEDEFS diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp index 76d4611..93d5144 100644 --- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp +++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp @@ -200,7 +200,8 @@ class TestStepOpPattern : public OpConversionPattern<vector::StepOp> { Value sgId = gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr); - auto maybeOffsets = sliceAttr.getOffsets(rewriter, loc, sgId, wgShape); + auto maybeOffsets = + sliceAttr.computeDistributedCoords(rewriter, loc, sgId, wgShape); if (failed(maybeOffsets)) return failure(); |
