diff options
author | Matthias Springer <me@m-sp.org> | 2023-05-25 19:10:05 +0200 |
---|---|---|
committer | Matthias Springer <me@m-sp.org> | 2023-05-25 19:15:13 +0200 |
commit | 7d36a468aa6c5b35058e8d4855c5bf9bba583c99 (patch) | |
tree | dc3dcbd3ca8ee815b34b4252498ad1245052d1a7 | |
parent | 810c7410b5f63607c75edaa2200e03f7400396fd (diff) | |
download | llvm-7d36a468aa6c5b35058e8d4855c5bf9bba583c99.zip llvm-7d36a468aa6c5b35058e8d4855c5bf9bba583c99.tar.gz llvm-7d36a468aa6c5b35058e8d4855c5bf9bba583c99.tar.bz2 |
[mlir][tensor] TrackingListener: Support cast-like InsertSliceOps with dynamic shape
When looking for payload op replacements, rank-expanding InsertSliceOps of dynamically-typed tensors are now supported.
Differential Revision: https://reviews.llvm.org/D151444
4 files changed, 46 insertions, 11 deletions
diff --git a/mlir/lib/Dialect/Tensor/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/Tensor/TransformOps/CMakeLists.txt index 27ea475..ff603c9 100644 --- a/mlir/lib/Dialect/Tensor/TransformOps/CMakeLists.txt +++ b/mlir/lib/Dialect/Tensor/TransformOps/CMakeLists.txt @@ -14,4 +14,5 @@ add_mlir_dialect_library(MLIRTensorTransformOps MLIRTensorDialect MLIRTensorTransforms MLIRTransformDialect + MLIRValueBoundsOpInterface ) diff --git a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp index 6fd32f6..92f7dbd 100644 --- a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp +++ b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/Tensor/Transforms/Transforms.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/Interfaces/ValueBoundsOpInterface.h" #include "llvm/ADT/TypeSwitch.h" using namespace mlir; @@ -23,18 +24,27 @@ using namespace tensor; // TrackingListener //===----------------------------------------------------------------------===// -/// A tensor.insert_slice is a cast-like operation if it the source tensor and -/// the destination tensor have the same number of elements. I.e., the result -/// tensor data equals the source tensor data, maybe rank-extended to a -/// different shape. +/// A tensor.insert_slice is a cast-like operation if it merely rank-extends the +/// source tensor or inserts the source tensor into a destination tensor with +/// the same shape. static bool isCastLikeInsertSliceOp(InsertSliceOp op) { - // TODO: Support dynamically shaped tensors. Utilize ValueBoundsOpInterface - // to check if source and destination have the same shape. - if (!op.getSourceType().hasStaticShape() || - !op.getDestType().hasStaticShape()) - return false; - return op.getSourceType().getNumElements() == - op.getDestType().getNumElements(); + llvm::SmallBitVector droppedDims = op.getDroppedDims(); + int64_t srcDim = 0; + // Source dims and destination dims (apart from dropped dims) must have the + // same size. + for (int64_t resultDim = 0; resultDim < op.getDestType().getRank(); + ++resultDim) { + if (droppedDims.test(resultDim)) { + continue; + } + FailureOr<bool> equalDimSize = ValueBoundsConstraintSet::areEqual( + op.getSource(), op.getResult(), srcDim, resultDim); + if (failed(equalDimSize) || !*equalDimSize) + return false; + ++srcDim; + } + + return true; } Operation * diff --git a/mlir/test/Dialect/Tensor/tracking-listener.mlir b/mlir/test/Dialect/Tensor/tracking-listener.mlir index c046f16..369dcec 100644 --- a/mlir/test/Dialect/Tensor/tracking-listener.mlir +++ b/mlir/test/Dialect/Tensor/tracking-listener.mlir @@ -82,3 +82,26 @@ func.func @non_cast_like_insert_slice(%t: tensor<7xf32>) { : tensor<5xf32> into tensor<7xf32> return } + +// ----- + +func.func @cast_like_insert_slice_dynamic( + %t: tensor<1x?x1xf32>, %f: f32, %pos: index) { + %c0 = arith.constant 0 : index + %0 = tensor.insert %f into %t[%c0, %pos, %c0] {replaced} : tensor<1x?x1xf32> + + // Rank reduction + %c1 = arith.constant 1 : index + %dim1 = tensor.dim %t, %c1 : tensor<1x?x1xf32> + %1 = tensor.extract_slice %t[0, 0, 0][1, %dim1, 1][1, 1, 1] + : tensor<1x?x1xf32> to tensor<?xf32> + // expected-remark @below {{replacement found}} + %2 = tensor.insert %f into %1[%c0] : tensor<?xf32> + // Rank expansion + // Throw in a wrench: Do not use %dim1 directly, but another SSA value that + // has the same runtime value. + %dim1b = tensor.dim %1, %c0 : tensor<?xf32> + %3 = tensor.insert_slice %2 into %t[0, 0, 0][1, %dim1b, 1][1, 1, 1] + {replacement_0 = 0} : tensor<?xf32> into tensor<1x?x1xf32> + return +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index c3d105f4..e6748cd 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -6003,6 +6003,7 @@ cc_library( ":TensorTransformOpsIncGen", ":TensorTransforms", ":TransformDialect", + ":ValueBoundsOpInterface", "//llvm:Support", ], ) |