aboutsummaryrefslogtreecommitdiff
path: root/mlir/test/lib/Dialect/Test
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/test/lib/Dialect/Test')
-rw-r--r--mlir/test/lib/Dialect/Test/CMakeLists.txt1
-rw-r--r--mlir/test/lib/Dialect/Test/TestDialect.cpp1
-rw-r--r--mlir/test/lib/Dialect/Test/TestOpDefs.cpp138
-rw-r--r--mlir/test/lib/Dialect/Test/TestOps.h1
-rw-r--r--mlir/test/lib/Dialect/Test/TestOps.td136
-rw-r--r--mlir/test/lib/Dialect/Test/TestPatterns.cpp43
-rw-r--r--mlir/test/lib/Dialect/Test/TestTypeDefs.td7
-rw-r--r--mlir/test/lib/Dialect/Test/TestTypes.cpp18
8 files changed, 327 insertions, 18 deletions
diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt
index f099d01..9354a85 100644
--- a/mlir/test/lib/Dialect/Test/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt
@@ -71,6 +71,7 @@ add_mlir_library(MLIRTestDialect
)
mlir_target_link_libraries(MLIRTestDialect PUBLIC
MLIRControlFlowInterfaces
+ MLIRControlFlowTransforms
MLIRDataLayoutInterfaces
MLIRDerivedAttributeOpInterface
MLIRDestinationStyleOpInterface
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 21d75f5..43392d7 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -37,7 +37,6 @@
#include "llvm/Support/Base64.h"
#include "llvm/Support/Casting.h"
-#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Interfaces/FoldInterfaces.h"
#include "mlir/Reducer/ReductionPatternInterface.h"
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index 4d4ec02..8689265 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -320,10 +320,10 @@ LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes(
}
//===----------------------------------------------------------------------===//
-// OpWithResultShapePerDimInterfaceOp
+// ReifyShapedTypeUsingReifyResultShapesOp
//===----------------------------------------------------------------------===//
-LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes(
+LogicalResult ReifyShapedTypeUsingReifyResultShapesOp::reifyResultShapes(
OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) {
Location loc = getLoc();
shapes.reserve(getNumOperands());
@@ -345,6 +345,103 @@ LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes(
}
//===----------------------------------------------------------------------===//
+// ReifyShapedTypeUsingReifyShapeOfResultOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult ReifyShapedTypeUsingReifyShapeOfResultOp::reifyResultShapes(
+ OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) {
+ return failure();
+}
+
+FailureOr<SmallVector<OpFoldResult>>
+ReifyShapedTypeUsingReifyShapeOfResultOp::reifyShapeOfResult(OpBuilder &builder,
+ int resultIndex) {
+ Location loc = getLoc();
+ Value sourceOperand = getOperand(getNumOperands() - 1 - resultIndex);
+ SmallVector<OpFoldResult> shape =
+ tensor::getMixedSizes(builder, loc, sourceOperand);
+ return shape;
+}
+
+//===----------------------------------------------------------------------===//
+// ReifyShapedTypeUsingReifyDimOfResultOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult ReifyShapedTypeUsingReifyDimOfResultOp::reifyResultShapes(
+ OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) {
+ return failure();
+}
+
+FailureOr<SmallVector<OpFoldResult>>
+ReifyShapedTypeUsingReifyDimOfResultOp::reifyShapeOfResult(OpBuilder &builder,
+ int resultIndex) {
+ return failure();
+}
+
+FailureOr<OpFoldResult>
+ReifyShapedTypeUsingReifyDimOfResultOp::reifyDimOfResult(OpBuilder &builder,
+ int resultIndex,
+ int dim) {
+ Location loc = getLoc();
+ Value sourceOperand = getOperand(getNumOperands() - 1 - resultIndex);
+ OpFoldResult shape = tensor::getMixedSize(builder, loc, sourceOperand, dim);
+ return shape;
+}
+
+//===----------------------------------------------------------------------===//
+// UnreifableResultShapesOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult UnreifiableResultShapesOp::reifyResultShapes(
+ OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) {
+ Location loc = getLoc();
+ shapes.resize(1);
+ shapes[0] = {tensor::getMixedSize(builder, loc, getOperand(), 0),
+ OpFoldResult()};
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// UnreifableResultShapeOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult UnreifiableResultShapeOp::reifyResultShapes(
+ OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) {
+ return failure();
+}
+
+FailureOr<SmallVector<OpFoldResult>>
+UnreifiableResultShapeOp::reifyShapeOfResult(OpBuilder &builder,
+ int resultIndex) {
+ SmallVector<OpFoldResult> shape = {
+ tensor::getMixedSize(builder, getLoc(), getOperand(), 0), OpFoldResult()};
+ return shape;
+}
+
+//===----------------------------------------------------------------------===//
+// UnreifableResultShapeOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult UnreifiableDimOfResultShapeOp::reifyResultShapes(
+ OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) {
+ return failure();
+}
+
+FailureOr<SmallVector<OpFoldResult>>
+UnreifiableDimOfResultShapeOp::reifyShapeOfResult(OpBuilder &builder,
+ int resultIndex) {
+ return failure();
+}
+
+FailureOr<OpFoldResult>
+UnreifiableDimOfResultShapeOp::reifyDimOfResult(OpBuilder &builder,
+ int resultIndex, int dim) {
+ if (dim == 0)
+ return tensor::getMixedSize(builder, getLoc(), getOperand(), 0);
+ return failure();
+}
+
+//===----------------------------------------------------------------------===//
// SideEffectOp
//===----------------------------------------------------------------------===//
@@ -1052,6 +1149,32 @@ LogicalResult OpWithRefineTypeInterfaceOp::refineReturnTypes(
}
//===----------------------------------------------------------------------===//
+// TilingNoDpsOp
+//===----------------------------------------------------------------------===//
+
+SmallVector<Range> TilingNoDpsOp::getIterationDomain(OpBuilder &builder) {
+ return {};
+}
+
+SmallVector<utils::IteratorType> TilingNoDpsOp::getLoopIteratorTypes() {
+ return {};
+}
+
+FailureOr<TilingResult>
+TilingNoDpsOp::getTiledImplementation(OpBuilder &builder,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes) {
+ return failure();
+}
+
+LogicalResult TilingNoDpsOp::getResultTilePosition(
+ OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
+ SmallVector<OpFoldResult> &resultSizes) {
+ return failure();
+}
+
+//===----------------------------------------------------------------------===//
// OpWithShapedTypeInferTypeAdaptorInterfaceOp
//===----------------------------------------------------------------------===//
@@ -1514,3 +1637,14 @@ test::TestCreateTensorOp::getBufferType(
return convertTensorToBuffer(getOperation(), options, type);
}
+
+// Define a custom builder for ManyRegionsOp declared in TestOps.td.
+// OpBuilder<(ins "::std::unique_ptr<::mlir::Region>":$firstRegion,
+// "::std::unique_ptr<::mlir::Region>":$secondRegion)>
+void test::ManyRegionsOp::build(
+ mlir::OpBuilder &builder, mlir::OperationState &state,
+ llvm::SmallVectorImpl<std::unique_ptr<mlir::Region>> &&regions) {
+ for (auto &&regionPtr : std::move(regions))
+ state.addRegion(std::move(regionPtr));
+ ManyRegionsOp::build(builder, state, {}, regions.size());
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.h b/mlir/test/lib/Dialect/Test/TestOps.h
index 4201ade..6792743 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.h
+++ b/mlir/test/lib/Dialect/Test/TestOps.h
@@ -42,6 +42,7 @@
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallVector.h"
namespace test {
class TestDialect;
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index a3430ba..5417ae9 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -30,6 +30,7 @@ include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/Interfaces/MemorySlotInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/TilingInterface.td"
include "mlir/Interfaces/ValueBoundsOpInterface.td"
include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td"
@@ -119,6 +120,13 @@ def SymbolOp : TEST_Op<"symbol", [NoMemoryEffect, Symbol]> {
OptionalAttr<StrAttr>:$sym_visibility);
}
+def SymbolWithResultOp : TEST_Op<"symbol_with_result", [Symbol]> {
+ let summary = "invalid symbol operation that produces an SSA result";
+ let arguments = (ins StrAttr:$sym_name,
+ OptionalAttr<StrAttr>:$sym_visibility);
+ let results = (outs AnyType:$result);
+}
+
def OverriddenSymbolVisibilityOp : TEST_Op<"overridden_symbol_visibility", [
DeclareOpInterfaceMethods<Symbol, ["getVisibility", "setVisibility"]>,
]> {
@@ -914,13 +922,97 @@ def OpWithResultShapeInterfaceOp : TEST_Op<"op_with_result_shape_interface",
let results = (outs AnyRankedTensor:$result1, AnyRankedTensor:$result2);
}
-def OpWithResultShapePerDimInterfaceOp :
- TEST_Op<"op_with_result_shape_per_dim_interface",
- [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
+def ReifyShapedTypeUsingReifyResultShapesOp :
+ TEST_Op<"reify_shaped_type_using_reify_result_shapes",
+ [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface,
+ ["reifyResultShapes"]>]> {
+ let description = [{
+ Test that when resolving a single dimension of a result for an operation
+ that doesnt implement `reifyShapeOfResult` nor implements `reifyDimOfResult`
+ calls into the implementation of `reifyResultShapes` to get the required value.
+ The op semantics is that the first result has the same shape as the second operand
+ and the second result has the same shape as the first operand.
+ }];
+ let arguments = (ins AnyRankedTensor:$operand1, AnyRankedTensor:$operand2);
+ let results = (outs AnyRankedTensor:$result1, AnyRankedTensor:$result2);
+}
+
+def ReifyShapedTypeUsingReifyShapeOfResultOp :
+ TEST_Op<"reify_shaped_type_using_reify_shape_of_result",
+ [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface,
+ ["reifyResultShapes", "reifyShapeOfResult"]>]> {
+ let description = [{
+ Test that when resolving a single dimension of a result for an operation
+ that doesnt implement `reifyDimOfResult` but implements `reifyShapeOfResult`, which
+ is used to get the required value. `reifyResultShapes` is implemented as a failure
+ (which is also the default implementation) to ensure it is not called.
+ The op semantics is that the first result has the same shape as the second operand
+ and the second result has the same shape as the first operand.
+ }];
+ let arguments = (ins AnyRankedTensor:$operand1, AnyRankedTensor:$operand2);
+ let results = (outs AnyRankedTensor:$result1, AnyRankedTensor:$result2);
+}
+
+def ReifyShapedTypeUsingReifyDimOfResultOp :
+ TEST_Op<"reify_shaped_type_using_reify_dim_of_result",
+ [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface,
+ ["reifyResultShapes", "reifyShapeOfResult", "reifyDimOfResult"]>]> {
+ let description = [{
+ Test that when resolving a single dimension of a result for an operation
+ that implements `reifyDimOfResult`, which is used to get the required value.
+ `reifyResultShapes` and `reifyShapeOfResult` are implemented as failures
+ to ensure they are not called. The op semantics is that the first result has
+ the same shape as the second operand and the second result has the same shape
+ as the first operand.
+ }];
let arguments = (ins AnyRankedTensor:$operand1, AnyRankedTensor:$operand2);
let results = (outs AnyRankedTensor:$result1, AnyRankedTensor:$result2);
}
+def UnreifiableResultShapesOp : TEST_Op<"unreifiable_result_shapes",
+ [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface,
+ ["reifyResultShapes"]>]> {
+ let description = [{
+ Test handling of case where some dimension of the result cannot be
+ reified. This tests the path when `reifyResultShapes` is implemented.
+
+ Expected that dim 0 of `result` is reifable as dim 0 of `operand`, but
+ dim 1 of `result` is not reifiable.
+ }];
+ let arguments = (ins 2DTensorOf<[AnyType]>:$operand);
+ let results = (outs 2DTensorOf<[AnyType]>:$result);
+}
+
+def UnreifiableResultShapeOp : TEST_Op<"unreifiable_result_shape",
+ [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface,
+ ["reifyResultShapes", "reifyShapeOfResult"]>]> {
+ let description = [{
+ Test handling of case where some dimension of the result cannot be
+ reified. This tests the path when `reifyShapeOfResult` is implemented,
+ but not `reifyDimOfResult` with `reifyResultShapes` implemented as a failure.
+
+ Expected that dim 0 of `result` is reifable as dim 0 of `operand`, but
+ dim 1 of `result` is not reifiable.
+ }];
+ let arguments = (ins 2DTensorOf<[AnyType]>:$operand);
+ let results = (outs 2DTensorOf<[AnyType]>:$result);
+}
+
+def UnreifiableDimOfResultShapeOp : TEST_Op<"unreifiable_dim_of_result_shape",
+ [DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface,
+ ["reifyResultShapes", "reifyShapeOfResult", "reifyDimOfResult"]>]> {
+ let description = [{
+ Test handling of case where some dimension of the result cannot be
+ reified. This tests the path when `reifyDimOfResult` is implemented,
+ and `reifyDimOfResult` with `reifyResultShapes` are implemented as a failure.
+
+ Expected that dim 0 of `result` is reifable as dim 0 of `operand`, but
+ dim 1 of `result` is not reifiable.
+ }];
+ let arguments = (ins 2DTensorOf<[AnyType]>:$operand);
+ let results = (outs 2DTensorOf<[AnyType]>:$result);
+}
+
def IsNotScalar : Constraint<CPred<"$0.getType().getRank() != 0">>;
def UpdateAttr : Pat<(I32ElementsAttrOp $attr),
@@ -1107,6 +1199,12 @@ def TestLocationDstNoResOp : TEST_Op<"loc_dst_no_res"> {
let results = (outs);
}
+def TestLocationAttrOp : TEST_Op<"op_with_loc_attr"> {
+ let arguments = (ins LocationAttr:$loc_attr);
+ let results = (outs );
+ let assemblyFormat = "$loc_attr attr-dict";
+}
+
//===----------------------------------------------------------------------===//
// Test Patterns
//===----------------------------------------------------------------------===//
@@ -2254,6 +2352,24 @@ def IsolatedGraphRegionOp : TEST_Op<"isolated_graph_region", [
let assemblyFormat = "attr-dict-with-keyword $region";
}
+def ManyRegionsOp : TEST_Op<"many_regions", []> {
+ let summary = "operation created with move-only objects";
+ let description = [{
+ Test op with multiple regions with a `create` function that
+ takes parameters containing move-only objects.
+ }];
+
+ let regions = (region VariadicRegion<AnyRegion>:$regions);
+ let builders =
+ [OpBuilder<(ins "::std::unique_ptr<::mlir::Region>":$singleRegion), [{
+ $_state.addRegion(std::move(singleRegion));
+ build($_builder, $_state, {}, /*regionsCount=*/1);
+ }]>,
+ // Define in TestOps.cpp.
+ OpBuilder<(ins "::llvm::SmallVectorImpl<::std::unique_ptr<::mlir::"
+ "Region>>&&":$regions)>];
+}
+
def AffineScopeOp : TEST_Op<"affine_scope", [AffineScope]> {
let summary = "affine scope operation";
let description = [{
@@ -2888,6 +3004,20 @@ def TestLinalgFillOp :
}
//===----------------------------------------------------------------------===//
+// Test TilingInterface.
+//===----------------------------------------------------------------------===//
+
+def Test_TilingNoDpsOp : TEST_Op<"tiling_no_dps_op",
+ [Pure, DeclareOpInterfaceMethods<TilingInterface,
+ ["getIterationDomain",
+ "getLoopIteratorTypes",
+ "getResultTilePosition",
+ "getTiledImplementation"]>]> {
+ let arguments = (ins AnyRankedTensor:$lhs, AnyRankedTensor:$rhs);
+ let results = (outs AnyRankedTensor:$result);
+}
+
+//===----------------------------------------------------------------------===//
// Test NVVM RequiresSM trait.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index efbdbfb..7eabaae 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -11,6 +11,7 @@
#include "TestTypes.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/CommonFolders.h"
+#include "mlir/Dialect/ControlFlow/Transforms/StructuralTypeConversions.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
@@ -976,7 +977,13 @@ struct TestValueReplace : public ConversionPattern {
// Replace the first operand with 2x the second operand.
Value from = op->getOperand(0);
Value repl = op->getOperand(1);
- rewriter.replaceAllUsesWith(from, {repl, repl});
+ if (op->hasAttr("conditional")) {
+ rewriter.replaceUsesWithIf(from, {repl, repl}, [=](OpOperand &use) {
+ return use.getOwner()->hasAttr("replace_uses");
+ });
+ } else {
+ rewriter.replaceAllUsesWith(from, {repl, repl});
+ }
rewriter.modifyOpInPlace(op, [&] {
// If the "trigger_rollback" attribute is set, keep the op illegal, so
// that a rollback is triggered.
@@ -1417,6 +1424,22 @@ public:
}
};
+class TestPostOrderLegalization : public ConversionPattern {
+public:
+ TestPostOrderLegalization(MLIRContext *ctx, const TypeConverter &converter)
+ : ConversionPattern(converter, "test.post_order_legalization", 1, ctx) {}
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ for (Region &r : op->getRegions())
+ if (failed(rewriter.legalize(&r)))
+ return failure();
+ rewriter.modifyOpInPlace(
+ op, [&]() { op->setAttr("is_legal", rewriter.getUnitAttr()); });
+ return success();
+ }
+};
+
/// Test unambiguous overload resolution of replaceOpWithMultiple. This
/// function is just to trigger compiler errors. It is never executed.
[[maybe_unused]] void testReplaceOpWithMultipleOverloads(
@@ -1531,7 +1554,8 @@ struct TestLegalizePatternDriver
patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
TestPassthroughInvalidOp, TestMultiple1ToNReplacement,
TestValueReplace, TestReplaceWithValidConsumer,
- TestTypeConsumerOpPattern>(&getContext(), converter);
+ TestTypeConsumerOpPattern, TestPostOrderLegalization>(
+ &getContext(), converter);
patterns.add<TestConvertBlockArgs>(converter, &getContext());
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
converter);
@@ -1552,14 +1576,16 @@ struct TestLegalizePatternDriver
[](Type type) { return type.isF32(); });
});
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
- return converter.isSignatureLegal(op.getFunctionType()) &&
- converter.isLegal(&op.getBody());
+ return converter.isSignatureLegal(op.getFunctionType());
});
target.addDynamicallyLegalOp<func::CallOp>(
[&](func::CallOp op) { return converter.isLegal(op); });
target.addDynamicallyLegalOp(
OperationName("test.value_replace", &getContext()),
[](Operation *op) { return op->hasAttr("is_legal"); });
+ target.addDynamicallyLegalOp(
+ OperationName("test.post_order_legalization", &getContext()),
+ [](Operation *op) { return op->hasAttr("is_legal"); });
// TestCreateUnregisteredOp creates `arith.constant` operation,
// which was not added to target intentionally to test
@@ -2042,6 +2068,10 @@ struct TestTypeConversionDriver
});
converter.addConversion([](IndexType type) { return type; });
converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &types) {
+ if (type.isInteger(1)) {
+ // i1 is legal.
+ types.push_back(type);
+ }
if (type.isInteger(38)) {
// i38 is legal.
types.push_back(type);
@@ -2151,8 +2181,7 @@ struct TestTypeConversionDriver
recursiveType.getName() == "outer_converted_type");
});
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
- return converter.isSignatureLegal(op.getFunctionType()) &&
- converter.isLegal(&op.getBody());
+ return converter.isSignatureLegal(op.getFunctionType());
});
target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) {
// Allow casts from F64 to F32.
@@ -2175,6 +2204,8 @@ struct TestTypeConversionDriver
converter);
mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
converter, patterns, target);
+ mlir::cf::populateCFStructuralTypeConversionsAndLegality(converter,
+ patterns, target);
ConversionConfig config;
config.allowPatternRollback = allowPatternRollback;
diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index ea20597..9859bd0 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -470,4 +470,11 @@ def TestMemrefType : Test_Type<"TestMemref",
}];
}
+// Test implementation of an interface with methods specifying a
+// method body
+def TestBaseBody : Test_Type<"TestBaseBody",
+ [DeclareTypeInterfaceMethods<TestBaseTypeInterfacePrintTypeA>]> {
+ let mnemonic = "test_base_body";
+}
+
#endif // TEST_TYPEDEFS
diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp
index 614121f..9cf64a8 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp
@@ -569,11 +569,17 @@ TestTensorType::getBufferType(
::mlir::LogicalResult TestTensorType::verifyCompatibleBufferType(
::mlir::bufferization::BufferLikeType bufferType,
::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) {
- auto testMemref = dyn_cast<TestMemrefType>(bufferType);
- if (!testMemref)
- return emitError() << "expected TestMemrefType";
+ if (auto testMemref = dyn_cast<TestMemrefType>(bufferType)) {
+ const bool valid = getShape() == testMemref.getShape() &&
+ getElementType() == testMemref.getElementType();
+ return mlir::success(valid);
+ }
+
+ if (auto builtinMemref = dyn_cast<MemRefType>(bufferType)) {
+ const bool valid = getShape() == builtinMemref.getShape() &&
+ getElementType() == builtinMemref.getElementType();
+ return mlir::success(valid);
+ }
- const bool valid = getShape() == testMemref.getShape() &&
- getElementType() == testMemref.getElementType();
- return mlir::success(valid);
+ return emitError() << "expected MemRefType or TestMemrefType";
}