aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorStephan Herhut <herhut@google.com>2020-09-16 10:01:54 +0200
committerStephan Herhut <herhut@google.com>2020-09-17 16:50:38 +0200
commit5e0ded268929b87ddf2c5e077c9185554342f602 (patch)
tree45fc08983399685f99bee158b6f10d7c8edbd6e7
parente7de267910e935ab885dae22b5191bfb118ca5f9 (diff)
downloadllvm-5e0ded268929b87ddf2c5e077c9185554342f602.zip
llvm-5e0ded268929b87ddf2c5e077c9185554342f602.tar.gz
llvm-5e0ded268929b87ddf2c5e077c9185554342f602.tar.bz2
[mlir][Standard] Canonicalize chains of tensor_cast operations
Adds a pattern that replaces a chain of two tensor_cast operations by a single tensor_cast operation if doing so will not remove constraints on the shapes.
-rw-r--r--mlir/include/mlir/Dialect/StandardOps/IR/Ops.td2
-rw-r--r--mlir/lib/Dialect/StandardOps/IR/Ops.cpp81
-rw-r--r--mlir/test/Transforms/canonicalize.mlir48
3 files changed, 131 insertions, 0 deletions
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index b0aa9b9..2113dfe 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -2997,6 +2997,8 @@ def TensorCastOp : CastOp<"tensor_cast"> {
/// The result of a tensor_cast is always a tensor.
TensorType getType() { return getResult().getType().cast<TensorType>(); }
}];
+
+ let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 0c86c87..c0dc872 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -3163,6 +3163,87 @@ OpFoldResult TensorCastOp::fold(ArrayRef<Attribute> operands) {
return impl::foldCastOp(*this);
}
+/// Compute a TensorType that has the joined shape knowledge of the two
+/// given TensorTypes. The element types need to match.
+static TensorType joinShapes(TensorType one, TensorType two) {
+ assert(one.getElementType() == two.getElementType());
+
+ if (!one.hasRank())
+ return two;
+ if (!two.hasRank())
+ return one;
+
+ int64_t rank = one.getRank();
+ if (rank != two.getRank())
+ return {};
+
+ SmallVector<int64_t, 4> join;
+ join.reserve(rank);
+ for (int64_t i = 0; i < rank; ++i) {
+ if (one.isDynamicDim(i)) {
+ join.push_back(two.getDimSize(i));
+ continue;
+ }
+ if (two.isDynamicDim(i)) {
+ join.push_back(one.getDimSize(i));
+ continue;
+ }
+ if (one.getDimSize(i) != two.getDimSize(i))
+ return {};
+ join.push_back(one.getDimSize(i));
+ }
+ return RankedTensorType::get(join, one.getElementType());
+}
+
+namespace {
+
+/// Replaces chains of two tensor_cast operations by a single tensor_cast
+/// operation if doing so does not remove runtime constraints.
+struct ChainedTensorCast : public OpRewritePattern<TensorCastOp> {
+ using OpRewritePattern<TensorCastOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(TensorCastOp tensorCast,
+ PatternRewriter &rewriter) const final {
+ auto tensorCastOperand =
+ tensorCast.getOperand().getDefiningOp<TensorCastOp>();
+
+ if (!tensorCastOperand)
+ return failure();
+
+ auto sourceType =
+ tensorCastOperand.getOperand().getType().cast<TensorType>();
+ auto intermediateType = tensorCastOperand.getType().cast<TensorType>();
+ auto resultType = tensorCast.getType().cast<TensorType>();
+
+ // We can remove the intermediate cast if joining all three produces the
+ // same result as just joining the source and result shapes.
+ auto firstJoin =
+ joinShapes(joinShapes(sourceType, intermediateType), resultType);
+
+ // The join might not exist if the cast sequence would fail at runtime.
+ if (!firstJoin)
+ return failure();
+
+ // The newJoin always exists if the above join exists, it might just contain
+ // less information. If so, we cannot drop the intermediate cast, as doing
+ // so would remove runtime checks.
+ auto newJoin = joinShapes(sourceType, resultType);
+ if (firstJoin != newJoin)
+ return failure();
+
+ rewriter.replaceOpWithNewOp<TensorCastOp>(tensorCast, resultType,
+ tensorCastOperand.getOperand());
+ return success();
+ }
+};
+
+} // namespace
+
+void TensorCastOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ results.insert<ChainedTensorCast>(context);
+}
+
//===----------------------------------------------------------------------===//
// Helpers for Tensor[Load|Store]Op
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir
index 3204185..3603c47 100644
--- a/mlir/test/Transforms/canonicalize.mlir
+++ b/mlir/test/Transforms/canonicalize.mlir
@@ -1062,3 +1062,51 @@ func @static_dynamic_tensor_from_elements(%size1: index, %size4: index) -> tenso
return %0 : tensor<3x?x?x7x?xindex>
}
+// -----
+
+// CHECK-LABEL: @tensor_cast_chain_ok
+// CHECK-SAME: %[[IN:.*]]: tensor<*xi32>
+func @tensor_cast_chain_ok(%input: tensor<*xi32>) -> tensor<4x8xi32> {
+ // CHECK-NEXT: %[[RES:.*]] = tensor_cast %[[IN]] : tensor<*xi32> to tensor<4x8xi32>
+ %0 = tensor_cast %input : tensor<*xi32> to tensor<4x?xi32>
+ %1 = tensor_cast %0 : tensor<4x?xi32> to tensor<4x8xi32>
+ // CHECK-NEXT: return %[[RES]]
+ return %1 : tensor<4x8xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @tensor_cast_chain_regain
+// CHECK-SAME: %[[IN:.*]]: tensor<4xi32>
+func @tensor_cast_chain_regain(%input: tensor<4xi32>) -> tensor<4xi32> {
+ %0 = tensor_cast %input : tensor<4xi32> to tensor<?xi32>
+ %1 = tensor_cast %0 : tensor<?xi32> to tensor<4xi32>
+ // CHECK-NEXT: return %[[IN]]
+ return %1 : tensor<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @tensor_cast_chain_keep
+// CHECK-SAME: %[[IN:.*]]: tensor<?x?xi32>
+func @tensor_cast_chain_keep(%input: tensor<?x?xi32>) -> tensor<?x8xi32> {
+ // CHECK-NEXT: %[[C1:.*]] = tensor_cast %[[IN]]
+ %0 = tensor_cast %input : tensor<?x?xi32> to tensor<4x?xi32>
+ // CHECK-NEXT: %[[C2:.*]] = tensor_cast %[[C1]]
+ %1 = tensor_cast %0 : tensor<4x?xi32> to tensor<?x8xi32>
+ // CHECK-NEXT: return %[[C2]]
+ return %1 : tensor<?x8xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @tensor_cast_chain_invalid
+// CHECK-SAME: %[[IN:.*]]: tensor<4x8xi32>
+func @tensor_cast_chain_invalid(%input: tensor<4x8xi32>) -> tensor<8x4xi32> {
+ // CHECK-NEXT: %[[C1:.*]] = tensor_cast %[[IN]]
+ %0 = tensor_cast %input : tensor<4x8xi32> to tensor<?x?xi32>
+ // CHECK-NEXT: %[[C2:.*]] = tensor_cast %[[C1]]
+ %1 = tensor_cast %0 : tensor<?x?xi32> to tensor<8x4xi32>
+ // CHECK-NEXT: return %[[C2]]
+ return %1 : tensor<8x4xi32>
+}