aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthias Springer <springerm@google.com>2022-04-08 18:03:18 +0900
committerMatthias Springer <springerm@google.com>2022-04-08 18:11:10 +0900
commit8b09141909329d93b0de987ee18ee9cfaa7223ba (patch)
treea9d8be6e5b445cde5e7286ca2e36b9c0853d1552
parent5626bd428930aa2cb2b5fdd69e93620e0d2b0532 (diff)
downloadllvm-8b09141909329d93b0de987ee18ee9cfaa7223ba.zip
llvm-8b09141909329d93b0de987ee18ee9cfaa7223ba.tar.gz
llvm-8b09141909329d93b0de987ee18ee9cfaa7223ba.tar.bz2
[mlir][arith][bufferize] Fix tensors with different layouts after bufferization
Insert a cast if the two tensors with identical layout (that are passed to `arith.select`) have different layout maps after bufferization. Differential Revision: https://reviews.llvm.org/D123321
-rw-r--r--mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp21
-rw-r--r--mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir14
2 files changed, 35 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp
index 12726a1..4f1add5 100644
--- a/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -129,6 +129,7 @@ struct SelectOpInterface
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
auto selectOp = cast<arith::SelectOp>(op);
+ Location loc = selectOp.getLoc();
// `getBuffer` introduces copies if an OpOperand bufferizes out-of-place.
// TODO: It would be more efficient to copy the result of the `select` op
@@ -139,6 +140,26 @@ struct SelectOpInterface
*state.getBuffer(rewriter, selectOp->getOpOperand(1) /*true_value*/);
Value falseBuffer =
*state.getBuffer(rewriter, selectOp->getOpOperand(2) /*false_value*/);
+
+ // The "true" and the "false" operands must have the same type. If the
+ // buffers have different types, they differ only in their layout map. Cast
+ // both of them to the most dynamic MemRef type.
+ if (trueBuffer.getType() != falseBuffer.getType()) {
+ auto trueType = trueBuffer.getType().cast<MemRefType>();
+ auto tensorType = selectOp.getTrueValue().getType().cast<TensorType>();
+ int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset;
+ SmallVector<int64_t> dynamicStrides(tensorType.getRank(),
+ ShapedType::kDynamicStrideOrOffset);
+ AffineMap stridedLayout = makeStridedLinearLayoutMap(
+ dynamicStrides, dynamicOffset, op->getContext());
+ BaseMemRefType castedType = bufferization::getMemRefType(
+ tensorType, state.getOptions(), AffineMapAttr::get(stridedLayout),
+ trueType.getMemorySpace());
+ trueBuffer = rewriter.create<memref::CastOp>(loc, castedType, trueBuffer);
+ falseBuffer =
+ rewriter.create<memref::CastOp>(loc, castedType, falseBuffer);
+ }
+
replaceOpWithNewBufferizedOp<arith::SelectOp>(
rewriter, op, selectOp.getCondition(), trueBuffer, falseBuffer);
return success();
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
index e711392..ac2249d 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
@@ -105,4 +105,18 @@ func @copy_deallocated() -> tensor<10xf32> {
return %0 : tensor<10xf32>
}
+// -----
+// CHECK-LABEL: func @select_different_tensors(
+// CHECK-SAME: %[[t:.*]]: tensor<?xf32>
+func @select_different_tensors(%t: tensor<?xf32>, %sz: index, %c: i1) -> tensor<?xf32> {
+ // CHECK-DAG: %[[m:.*]] = bufferization.to_memref %[[t]] : memref<?xf32, #{{.*}}>
+ // CHECK-DAG: %[[alloc:.*]] = memref.alloc(%{{.*}}) {{.*}} : memref<?xf32>
+ %0 = linalg.init_tensor [%sz] : tensor<?xf32>
+
+ // A cast must be inserted because %t and %0 have different memref types.
+ // CHECK: %[[casted:.*]] = memref.cast %[[alloc]] : memref<?xf32> to memref<?xf32, #{{.*}}>
+ // CHECK: arith.select %{{.*}}, %[[casted]], %[[m]]
+ %1 = arith.select %c, %0, %t : tensor<?xf32>
+ return %1 : tensor<?xf32>
+}