diff options
Diffstat (limited to 'mlir/test/lib/Dialect/Test/TestOpDefs.cpp')
-rw-r--r-- | mlir/test/lib/Dialect/Test/TestOpDefs.cpp | 44 |
1 files changed, 38 insertions, 6 deletions
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp index 53055fe..b211e24 100644 --- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp +++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp @@ -1425,6 +1425,39 @@ TestMultiSlotAlloca::handleDestructuringComplete( return createNewMultiAllocaWithoutSlot(slot, builder, *this); } +namespace { +/// Returns test dialect's memref layout for test dialect's tensor encoding when +/// applicable. +MemRefLayoutAttrInterface +getMemRefLayoutForTensorEncoding(RankedTensorType tensorType) { + if (auto encoding = + dyn_cast<test::TestTensorEncodingAttr>(tensorType.getEncoding())) { + return cast<MemRefLayoutAttrInterface>(test::TestMemRefLayoutAttr::get( + tensorType.getContext(), encoding.getDummy())); + } + return {}; +} + +/// Auxiliary bufferization function for test and builtin tensors. +bufferization::BufferLikeType +convertTensorToBuffer(mlir::Operation *op, + const bufferization::BufferizationOptions &options, + bufferization::TensorLikeType tensorLike) { + auto buffer = + *tensorLike.getBufferType(options, [&]() { return op->emitError(); }); + if (auto memref = dyn_cast<MemRefType>(buffer)) { + // Note: For the sake of testing, we want to ensure that encoding -> layout + // bufferization happens. This is currently achieved manually. + auto layout = + getMemRefLayoutForTensorEncoding(cast<RankedTensorType>(tensorLike)); + return cast<bufferization::BufferLikeType>( + MemRefType::get(memref.getShape(), memref.getElementType(), layout, + memref.getMemorySpace())); + } + return buffer; +} +} // namespace + ::mlir::LogicalResult test::TestDummyTensorOp::bufferize( ::mlir::RewriterBase &rewriter, const ::mlir::bufferization::BufferizationOptions &options, @@ -1435,8 +1468,8 @@ TestMultiSlotAlloca::handleDestructuringComplete( return failure(); const auto outType = getOutput().getType(); - const auto bufferizedOutType = test::TestMemrefType::get( - getContext(), outType.getShape(), outType.getElementType(), nullptr); + const auto bufferizedOutType = + convertTensorToBuffer(getOperation(), options, outType); // replace op with memref analogy auto dummyMemrefOp = test::TestDummyMemrefOp::create( rewriter, getLoc(), bufferizedOutType, *buffer); @@ -1470,13 +1503,12 @@ TestMultiSlotAlloca::handleDestructuringComplete( mlir::FailureOr<mlir::bufferization::BufferLikeType> test::TestCreateTensorOp::getBufferType( - mlir::Value value, const mlir::bufferization::BufferizationOptions &, + mlir::Value value, const mlir::bufferization::BufferizationOptions &options, const mlir::bufferization::BufferizationState &, llvm::SmallVector<::mlir::Value> &) { - const auto type = dyn_cast<test::TestTensorType>(value.getType()); + const auto type = dyn_cast<bufferization::TensorLikeType>(value.getType()); if (type == nullptr) return failure(); - return cast<mlir::bufferization::BufferLikeType>(test::TestMemrefType::get( - getContext(), type.getShape(), type.getElementType(), nullptr)); + return convertTensorToBuffer(getOperation(), options, type); } |