diff options
author | Matthias Springer <springerm@google.com> | 2022-04-08 18:03:18 +0900 |
---|---|---|
committer | Matthias Springer <springerm@google.com> | 2022-04-08 18:11:10 +0900 |
commit | 8b09141909329d93b0de987ee18ee9cfaa7223ba (patch) | |
tree | a9d8be6e5b445cde5e7286ca2e36b9c0853d1552 | |
parent | 5626bd428930aa2cb2b5fdd69e93620e0d2b0532 (diff) | |
download | llvm-8b09141909329d93b0de987ee18ee9cfaa7223ba.zip llvm-8b09141909329d93b0de987ee18ee9cfaa7223ba.tar.gz llvm-8b09141909329d93b0de987ee18ee9cfaa7223ba.tar.bz2 |
[mlir][arith][bufferize] Fix tensors with different layouts after bufferization
Insert a cast if the two tensors with identical layout (that are passed to `arith.select`) have different layout maps after bufferization.
Differential Revision: https://reviews.llvm.org/D123321
-rw-r--r-- | mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp | 21 | ||||
-rw-r--r-- | mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir | 14 |
2 files changed, 35 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp index 12726a1..4f1add5 100644 --- a/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp @@ -129,6 +129,7 @@ struct SelectOpInterface LogicalResult bufferize(Operation *op, RewriterBase &rewriter, BufferizationState &state) const { auto selectOp = cast<arith::SelectOp>(op); + Location loc = selectOp.getLoc(); // `getBuffer` introduces copies if an OpOperand bufferizes out-of-place. // TODO: It would be more efficient to copy the result of the `select` op @@ -139,6 +140,26 @@ struct SelectOpInterface *state.getBuffer(rewriter, selectOp->getOpOperand(1) /*true_value*/); Value falseBuffer = *state.getBuffer(rewriter, selectOp->getOpOperand(2) /*false_value*/); + + // The "true" and the "false" operands must have the same type. If the + // buffers have different types, they differ only in their layout map. Cast + // both of them to the most dynamic MemRef type. + if (trueBuffer.getType() != falseBuffer.getType()) { + auto trueType = trueBuffer.getType().cast<MemRefType>(); + auto tensorType = selectOp.getTrueValue().getType().cast<TensorType>(); + int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset; + SmallVector<int64_t> dynamicStrides(tensorType.getRank(), + ShapedType::kDynamicStrideOrOffset); + AffineMap stridedLayout = makeStridedLinearLayoutMap( + dynamicStrides, dynamicOffset, op->getContext()); + BaseMemRefType castedType = bufferization::getMemRefType( + tensorType, state.getOptions(), AffineMapAttr::get(stridedLayout), + trueType.getMemorySpace()); + trueBuffer = rewriter.create<memref::CastOp>(loc, castedType, trueBuffer); + falseBuffer = + rewriter.create<memref::CastOp>(loc, castedType, falseBuffer); + } + replaceOpWithNewBufferizedOp<arith::SelectOp>( rewriter, op, selectOp.getCondition(), trueBuffer, falseBuffer); return success(); diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir index e711392..ac2249d 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir @@ -105,4 +105,18 @@ func @copy_deallocated() -> tensor<10xf32> { return %0 : tensor<10xf32> } +// ----- +// CHECK-LABEL: func @select_different_tensors( +// CHECK-SAME: %[[t:.*]]: tensor<?xf32> +func @select_different_tensors(%t: tensor<?xf32>, %sz: index, %c: i1) -> tensor<?xf32> { + // CHECK-DAG: %[[m:.*]] = bufferization.to_memref %[[t]] : memref<?xf32, #{{.*}}> + // CHECK-DAG: %[[alloc:.*]] = memref.alloc(%{{.*}}) {{.*}} : memref<?xf32> + %0 = linalg.init_tensor [%sz] : tensor<?xf32> + + // A cast must be inserted because %t and %0 have different memref types. + // CHECK: %[[casted:.*]] = memref.cast %[[alloc]] : memref<?xf32> to memref<?xf32, #{{.*}}> + // CHECK: arith.select %{{.*}}, %[[casted]], %[[m]] + %1 = arith.select %c, %0, %t : tensor<?xf32> + return %1 : tensor<?xf32> +} |