aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthias Springer <me@m-sp.org>2023-05-25 19:10:05 +0200
committerMatthias Springer <me@m-sp.org>2023-05-25 19:15:13 +0200
commit7d36a468aa6c5b35058e8d4855c5bf9bba583c99 (patch)
treedc3dcbd3ca8ee815b34b4252498ad1245052d1a7
parent810c7410b5f63607c75edaa2200e03f7400396fd (diff)
downloadllvm-7d36a468aa6c5b35058e8d4855c5bf9bba583c99.zip
llvm-7d36a468aa6c5b35058e8d4855c5bf9bba583c99.tar.gz
llvm-7d36a468aa6c5b35058e8d4855c5bf9bba583c99.tar.bz2
[mlir][tensor] TrackingListener: Support cast-like InsertSliceOps with dynamic shape
When looking for payload op replacements, rank-expanding InsertSliceOps of dynamically-typed tensors are now supported. Differential Revision: https://reviews.llvm.org/D151444
-rw-r--r--mlir/lib/Dialect/Tensor/TransformOps/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp32
-rw-r--r--mlir/test/Dialect/Tensor/tracking-listener.mlir23
-rw-r--r--utils/bazel/llvm-project-overlay/mlir/BUILD.bazel1
4 files changed, 46 insertions, 11 deletions
diff --git a/mlir/lib/Dialect/Tensor/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/Tensor/TransformOps/CMakeLists.txt
index 27ea475..ff603c9 100644
--- a/mlir/lib/Dialect/Tensor/TransformOps/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/TransformOps/CMakeLists.txt
@@ -14,4 +14,5 @@ add_mlir_dialect_library(MLIRTensorTransformOps
MLIRTensorDialect
MLIRTensorTransforms
MLIRTransformDialect
+ MLIRValueBoundsOpInterface
)
diff --git a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
index 6fd32f6..92f7dbd 100644
--- a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
+++ b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
@@ -14,6 +14,7 @@
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
@@ -23,18 +24,27 @@ using namespace tensor;
// TrackingListener
//===----------------------------------------------------------------------===//
-/// A tensor.insert_slice is a cast-like operation if it the source tensor and
-/// the destination tensor have the same number of elements. I.e., the result
-/// tensor data equals the source tensor data, maybe rank-extended to a
-/// different shape.
+/// A tensor.insert_slice is a cast-like operation if it merely rank-extends the
+/// source tensor or inserts the source tensor into a destination tensor with
+/// the same shape.
static bool isCastLikeInsertSliceOp(InsertSliceOp op) {
- // TODO: Support dynamically shaped tensors. Utilize ValueBoundsOpInterface
- // to check if source and destination have the same shape.
- if (!op.getSourceType().hasStaticShape() ||
- !op.getDestType().hasStaticShape())
- return false;
- return op.getSourceType().getNumElements() ==
- op.getDestType().getNumElements();
+ llvm::SmallBitVector droppedDims = op.getDroppedDims();
+ int64_t srcDim = 0;
+ // Source dims and destination dims (apart from dropped dims) must have the
+ // same size.
+ for (int64_t resultDim = 0; resultDim < op.getDestType().getRank();
+ ++resultDim) {
+ if (droppedDims.test(resultDim)) {
+ continue;
+ }
+ FailureOr<bool> equalDimSize = ValueBoundsConstraintSet::areEqual(
+ op.getSource(), op.getResult(), srcDim, resultDim);
+ if (failed(equalDimSize) || !*equalDimSize)
+ return false;
+ ++srcDim;
+ }
+
+ return true;
}
Operation *
diff --git a/mlir/test/Dialect/Tensor/tracking-listener.mlir b/mlir/test/Dialect/Tensor/tracking-listener.mlir
index c046f16..369dcec 100644
--- a/mlir/test/Dialect/Tensor/tracking-listener.mlir
+++ b/mlir/test/Dialect/Tensor/tracking-listener.mlir
@@ -82,3 +82,26 @@ func.func @non_cast_like_insert_slice(%t: tensor<7xf32>) {
: tensor<5xf32> into tensor<7xf32>
return
}
+
+// -----
+
+func.func @cast_like_insert_slice_dynamic(
+ %t: tensor<1x?x1xf32>, %f: f32, %pos: index) {
+ %c0 = arith.constant 0 : index
+ %0 = tensor.insert %f into %t[%c0, %pos, %c0] {replaced} : tensor<1x?x1xf32>
+
+ // Rank reduction
+ %c1 = arith.constant 1 : index
+ %dim1 = tensor.dim %t, %c1 : tensor<1x?x1xf32>
+ %1 = tensor.extract_slice %t[0, 0, 0][1, %dim1, 1][1, 1, 1]
+ : tensor<1x?x1xf32> to tensor<?xf32>
+ // expected-remark @below {{replacement found}}
+ %2 = tensor.insert %f into %1[%c0] : tensor<?xf32>
+ // Rank expansion
+ // Throw in a wrench: Do not use %dim1 directly, but another SSA value that
+ // has the same runtime value.
+ %dim1b = tensor.dim %1, %c0 : tensor<?xf32>
+ %3 = tensor.insert_slice %2 into %t[0, 0, 0][1, %dim1b, 1][1, 1, 1]
+ {replacement_0 = 0} : tensor<?xf32> into tensor<1x?x1xf32>
+ return
+}
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index c3d105f4..e6748cd 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -6003,6 +6003,7 @@ cc_library(
":TensorTransformOpsIncGen",
":TensorTransforms",
":TransformDialect",
+ ":ValueBoundsOpInterface",
"//llvm:Support",
],
)