aboutsummaryrefslogtreecommitdiff
path: root/mlir/test/lib/Dialect/Test
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/test/lib/Dialect/Test')
-rw-r--r--mlir/test/lib/Dialect/Test/TestAttrDefs.td17
-rw-r--r--mlir/test/lib/Dialect/Test/TestAttributes.cpp18
-rw-r--r--mlir/test/lib/Dialect/Test/TestAttributes.h1
-rw-r--r--mlir/test/lib/Dialect/Test/TestDialect.h1
-rw-r--r--mlir/test/lib/Dialect/Test/TestDialect.td5
-rw-r--r--mlir/test/lib/Dialect/Test/TestOpDefs.cpp44
-rw-r--r--mlir/test/lib/Dialect/Test/TestOps.td15
7 files changed, 87 insertions, 14 deletions
diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index 5685004..9e7e4f8 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -22,6 +22,7 @@ include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/BuiltinAttributeInterfaces.td"
include "mlir/IR/EnumAttr.td"
include "mlir/IR/OpAsmInterface.td"
+include "mlir/IR/TensorEncoding.td"
// All of the attributes will extend this class.
class Test_Attr<string name, list<Trait> traits = []>
@@ -439,4 +440,20 @@ def TestCustomStorageCtorAttr : Test_Attr<"TestCustomStorageCtorAttr"> {
let hasStorageCustomConstructor = 1;
}
+def TestTensorEncodingAttr : Test_Attr<"TestTensorEncoding",
+ [DeclareAttrInterfaceMethods<VerifiableTensorEncoding>]> {
+ let mnemonic = "tensor_encoding";
+
+ let parameters = (ins "mlir::StringAttr":$dummy);
+ let assemblyFormat = "`<` $dummy `>`";
+}
+
+def TestMemRefLayoutAttr : Test_Attr<"TestMemRefLayout",
+ [DeclareAttrInterfaceMethods<MemRefLayoutAttrInterface>]> {
+ let mnemonic = "memref_layout";
+
+ let parameters = (ins "mlir::StringAttr":$dummy);
+ let assemblyFormat = "`<` $dummy `>`";
+}
+
#endif // TEST_ATTRDEFS
diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
index fe1e916..9db7b01 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
@@ -542,6 +542,24 @@ test::detail::TestCustomStorageCtorAttrAttrStorage::construct(
}
//===----------------------------------------------------------------------===//
+// TestTensorEncodingAttr
+//===----------------------------------------------------------------------===//
+
+::llvm::LogicalResult TestTensorEncodingAttr::verifyEncoding(
+ mlir::ArrayRef<int64_t> shape, mlir::Type elementType,
+ llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) const {
+ return mlir::success();
+}
+
+//===----------------------------------------------------------------------===//
+// TestMemRefLayoutAttr
+//===----------------------------------------------------------------------===//
+
+mlir::AffineMap TestMemRefLayoutAttr::getAffineMap() const {
+ return mlir::AffineMap::getMultiDimIdentityMap(1, getContext());
+}
+
+//===----------------------------------------------------------------------===//
// TestDialect
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.h b/mlir/test/lib/Dialect/Test/TestAttributes.h
index 778d84fa..0ad5ab6 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.h
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.h
@@ -24,6 +24,7 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/DialectResourceBlobManager.h"
+#include "mlir/IR/TensorEncoding.h"
// generated files require above includes to come first
#include "TestAttrInterfaces.h.inc"
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h
index f2adca6..bcf3b55d 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.h
+++ b/mlir/test/lib/Dialect/Test/TestDialect.h
@@ -18,6 +18,7 @@
#include "TestInterfaces.h"
#include "TestTypes.h"
#include "mlir/Bytecode/BytecodeImplementation.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/DLTI/Traits.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.td b/mlir/test/lib/Dialect/Test/TestDialect.td
index 2b5491f..37a263f 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.td
+++ b/mlir/test/lib/Dialect/Test/TestDialect.td
@@ -24,7 +24,10 @@ def Test_Dialect : Dialect {
let useDefaultTypePrinterParser = 0;
let useDefaultAttributePrinterParser = 1;
let isExtensible = 1;
- let dependentDialects = ["::mlir::DLTIDialect"];
+ let dependentDialects = [
+ "::mlir::DLTIDialect",
+ "::mlir::bufferization::BufferizationDialect"
+ ];
let discardableAttrs = (ins
"mlir::IntegerAttr":$discardable_attr_key,
"SimpleAAttr":$other_discardable_attr_key
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index 53055fe..b211e24 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -1425,6 +1425,39 @@ TestMultiSlotAlloca::handleDestructuringComplete(
return createNewMultiAllocaWithoutSlot(slot, builder, *this);
}
+namespace {
+/// Returns test dialect's memref layout for test dialect's tensor encoding when
+/// applicable.
+MemRefLayoutAttrInterface
+getMemRefLayoutForTensorEncoding(RankedTensorType tensorType) {
+ if (auto encoding =
+ dyn_cast<test::TestTensorEncodingAttr>(tensorType.getEncoding())) {
+ return cast<MemRefLayoutAttrInterface>(test::TestMemRefLayoutAttr::get(
+ tensorType.getContext(), encoding.getDummy()));
+ }
+ return {};
+}
+
+/// Auxiliary bufferization function for test and builtin tensors.
+bufferization::BufferLikeType
+convertTensorToBuffer(mlir::Operation *op,
+ const bufferization::BufferizationOptions &options,
+ bufferization::TensorLikeType tensorLike) {
+ auto buffer =
+ *tensorLike.getBufferType(options, [&]() { return op->emitError(); });
+ if (auto memref = dyn_cast<MemRefType>(buffer)) {
+ // Note: For the sake of testing, we want to ensure that encoding -> layout
+ // bufferization happens. This is currently achieved manually.
+ auto layout =
+ getMemRefLayoutForTensorEncoding(cast<RankedTensorType>(tensorLike));
+ return cast<bufferization::BufferLikeType>(
+ MemRefType::get(memref.getShape(), memref.getElementType(), layout,
+ memref.getMemorySpace()));
+ }
+ return buffer;
+}
+} // namespace
+
::mlir::LogicalResult test::TestDummyTensorOp::bufferize(
::mlir::RewriterBase &rewriter,
const ::mlir::bufferization::BufferizationOptions &options,
@@ -1435,8 +1468,8 @@ TestMultiSlotAlloca::handleDestructuringComplete(
return failure();
const auto outType = getOutput().getType();
- const auto bufferizedOutType = test::TestMemrefType::get(
- getContext(), outType.getShape(), outType.getElementType(), nullptr);
+ const auto bufferizedOutType =
+ convertTensorToBuffer(getOperation(), options, outType);
// replace op with memref analogy
auto dummyMemrefOp = test::TestDummyMemrefOp::create(
rewriter, getLoc(), bufferizedOutType, *buffer);
@@ -1470,13 +1503,12 @@ TestMultiSlotAlloca::handleDestructuringComplete(
mlir::FailureOr<mlir::bufferization::BufferLikeType>
test::TestCreateTensorOp::getBufferType(
- mlir::Value value, const mlir::bufferization::BufferizationOptions &,
+ mlir::Value value, const mlir::bufferization::BufferizationOptions &options,
const mlir::bufferization::BufferizationState &,
llvm::SmallVector<::mlir::Value> &) {
- const auto type = dyn_cast<test::TestTensorType>(value.getType());
+ const auto type = dyn_cast<bufferization::TensorLikeType>(value.getType());
if (type == nullptr)
return failure();
- return cast<mlir::bufferization::BufferLikeType>(test::TestMemrefType::get(
- getContext(), type.getShape(), type.getElementType(), nullptr));
+ return convertTensorToBuffer(getOperation(), options, type);
}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 6329d61..05a33cf 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -32,6 +32,7 @@ include "mlir/Interfaces/MemorySlotInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ValueBoundsOpInterface.td"
include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
+include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td"
// Include the attribute definitions.
include "TestAttrDefs.td"
@@ -2335,7 +2336,7 @@ def SideEffectWithRegionOp : TEST_Op<"side_effect_with_region_op",
}
//===----------------------------------------------------------------------===//
-// Copy Operation Test
+// Copy Operation Test
//===----------------------------------------------------------------------===//
def CopyOp : TEST_Op<"copy", []> {
@@ -3676,10 +3677,10 @@ def TestDummyTensorOp : TEST_Op<"dummy_tensor_op",
["bufferize", "bufferizesToMemoryRead",
"bufferizesToMemoryWrite", "getAliasingValues"]>]> {
let arguments = (ins
- Arg<TestTensorType>:$input
+ Arg<Bufferization_TensorLikeTypeInterface>:$input
);
let results = (outs
- Arg<TestTensorType>:$output
+ Arg<Bufferization_TensorLikeTypeInterface>:$output
);
let extraClassDefinition = [{
@@ -3701,10 +3702,10 @@ def TestDummyTensorOp : TEST_Op<"dummy_tensor_op",
def TestDummyMemrefOp : TEST_Op<"dummy_memref_op", []> {
let arguments = (ins
- Arg<TestMemrefType>:$input
+ Arg<Bufferization_BufferLikeTypeInterface>:$input
);
let results = (outs
- Arg<TestMemrefType>:$output
+ Arg<Bufferization_BufferLikeTypeInterface>:$output
);
}
@@ -3714,7 +3715,7 @@ def TestCreateTensorOp : TEST_Op<"create_tensor_op",
"bufferizesToMemoryWrite", "getAliasingValues",
"bufferizesToAllocation"]>]> {
let arguments = (ins);
- let results = (outs Arg<TestTensorType>:$output);
+ let results = (outs Arg<Bufferization_TensorLikeTypeInterface>:$output);
let extraClassDefinition = [{
bool test::TestCreateTensorOp::bufferizesToMemoryRead(::mlir::OpOperand&,
const ::mlir::bufferization::AnalysisState&) {
@@ -3738,7 +3739,7 @@ def TestCreateTensorOp : TEST_Op<"create_tensor_op",
def TestCreateMemrefOp : TEST_Op<"create_memref_op"> {
let arguments = (ins);
- let results = (outs Arg<TestMemrefType>:$output);
+ let results = (outs Arg<Bufferization_BufferLikeTypeInterface>:$output);
}
//===----------------------------------------------------------------------===//