aboutsummaryrefslogtreecommitdiff
path: root/mlir/test/lib/Dialect/Bufferization/TestOneShotModuleBufferize.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/test/lib/Dialect/Bufferization/TestOneShotModuleBufferize.cpp')
-rw-r--r--mlir/test/lib/Dialect/Bufferization/TestOneShotModuleBufferize.cpp26
1 files changed, 26 insertions, 0 deletions
diff --git a/mlir/test/lib/Dialect/Bufferization/TestOneShotModuleBufferize.cpp b/mlir/test/lib/Dialect/Bufferization/TestOneShotModuleBufferize.cpp
index 1e2d4a7..4069a74 100644
--- a/mlir/test/lib/Dialect/Bufferization/TestOneShotModuleBufferize.cpp
+++ b/mlir/test/lib/Dialect/Bufferization/TestOneShotModuleBufferize.cpp
@@ -11,11 +11,25 @@
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
+#include "TestAttributes.h" // TestTensorEncodingAttr, TestMemRefLayoutAttr
+#include "TestDialect.h"
+
using namespace mlir;
namespace {
+MemRefLayoutAttrInterface
+getMemRefLayoutForTensorEncoding(RankedTensorType tensorType) {
+ if (auto encoding = dyn_cast_if_present<test::TestTensorEncodingAttr>(
+ tensorType.getEncoding())) {
+ return cast<MemRefLayoutAttrInterface>(test::TestMemRefLayoutAttr::get(
+ tensorType.getContext(), encoding.getDummy()));
+ }
+ return {};
+}
+
struct TestOneShotModuleBufferizePass
: public PassWrapper<TestOneShotModuleBufferizePass, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestOneShotModuleBufferizePass)
@@ -25,6 +39,7 @@ struct TestOneShotModuleBufferizePass
: PassWrapper(pass) {}
void getDependentDialects(DialectRegistry &registry) const override {
+ registry.insert<test::TestDialect>();
registry.insert<bufferization::BufferizationDialect>();
}
StringRef getArgument() const final {
@@ -41,6 +56,17 @@ struct TestOneShotModuleBufferizePass
bufferization::OneShotBufferizationOptions opt;
opt.bufferizeFunctionBoundaries = true;
+ opt.functionArgTypeConverterFn =
+ [&](bufferization::TensorLikeType tensor, Attribute memSpace,
+ func::FuncOp, const bufferization::BufferizationOptions &) {
+ assert(isa<RankedTensorType>(tensor) && "tests only builtin tensors");
+ auto tensorType = cast<RankedTensorType>(tensor);
+ auto layout = getMemRefLayoutForTensorEncoding(tensorType);
+ return cast<bufferization::BufferLikeType>(
+ MemRefType::get(tensorType.getShape(),
+ tensorType.getElementType(), layout, memSpace));
+ };
+
bufferization::BufferizationState bufferizationState;
if (failed(bufferization::runOneShotModuleBufferize(getOperation(), opt,