diff options
-rw-r--r-- | mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 10 | ||||
-rw-r--r-- | mlir/test/Dialect/Tensor/canonicalize.mlir | 15 |
2 files changed, 5 insertions, 20 deletions
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 49cfec6..edddfb8 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -1417,11 +1417,11 @@ struct InsertSliceOpSourceCastInserter final return failure(); SmallVector<int64_t> newSrcShape(srcType.getShape().begin(), srcType.getShape().end()); - // Offsets / sizes / strides can be a subprefix of the rank; take only the - // leading dimensions. - for (auto en : llvm::enumerate(insertSliceOp.getMixedSizes())) - if (Optional<int64_t> constInt = getConstantIntValue(en.value())) - newSrcShape[en.index()] = *constInt; + for (int64_t i = 0; i < srcType.getRank(); ++i) { + if (Optional<int64_t> constInt = + getConstantIntValue(insertSliceOp.getMixedSizes()[i])) + newSrcShape[i] = *constInt; + } RankedTensorType newSrcType = RankedTensorType::get(newSrcShape, srcType.getElementType()); diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir index 50fda25..fc9abe4 100644 --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -536,21 +536,6 @@ func @insert_tensor_cast_on_insert_slice_src( // ----- -// CHECK-LABEL: func @insert_tensor_cast_on_insert_slice_src_prefix( -// CHECK-SAME: %[[arg0:.*]]: tensor<?x5x?xf32>, %[[arg1:.*]]: tensor<?x?x?xf32> -// CHECK: %[[cast:.*]] = tensor.cast %[[arg0]] : tensor<?x5x?xf32> to tensor<64x5x?xf32> -// CHECK: %[[r:.*]] = tensor.insert_slice %[[cast]] into %[[arg1]][0, 1] [64, 5] [1, 1] : tensor<64x5x?xf32> into tensor<?x?x?xf32> -// CHECK: return %[[r]] -func @insert_tensor_cast_on_insert_slice_src_prefix( - %arg0 : tensor<?x5x?xf32>, %arg1 : tensor<?x?x?xf32>, %sz0: index, %sz2: index) -> tensor<?x?x?xf32> { - %c64 = arith.constant 64: index - %r = tensor.insert_slice %arg0 into %arg1[0, 1] [%c64, 5] [1, 1] - : tensor<?x5x?xf32> into tensor<?x?x?xf32> - return %r : tensor<?x?x?xf32> -} - -// ----- - // CHECK-LABEL: func @fold_extract_insert // CHECK-SAME: %{{.+}}: tensor<?x?x?xf32>, %[[SLICE:.+]]: tensor<4x?x8xf32> func @fold_extract_insert(%input : tensor<?x?x?xf32>, %slice: tensor<4x?x8xf32>, %i: index, %size: index) -> (tensor<4x?x8xf32>) { |