diff options
author | Stephan Herhut <herhut@google.com> | 2020-09-16 10:01:54 +0200 |
---|---|---|
committer | Stephan Herhut <herhut@google.com> | 2020-09-17 16:50:38 +0200 |
commit | 5e0ded268929b87ddf2c5e077c9185554342f602 (patch) | |
tree | 45fc08983399685f99bee158b6f10d7c8edbd6e7 | |
parent | e7de267910e935ab885dae22b5191bfb118ca5f9 (diff) | |
download | llvm-5e0ded268929b87ddf2c5e077c9185554342f602.zip llvm-5e0ded268929b87ddf2c5e077c9185554342f602.tar.gz llvm-5e0ded268929b87ddf2c5e077c9185554342f602.tar.bz2 |
[mlir][Standard] Canonicalize chains of tensor_cast operations
Adds a pattern that replaces a chain of two tensor_cast operations by a single tensor_cast operation if doing so will not remove constraints on the shapes.
-rw-r--r-- | mlir/include/mlir/Dialect/StandardOps/IR/Ops.td | 2 | ||||
-rw-r--r-- | mlir/lib/Dialect/StandardOps/IR/Ops.cpp | 81 | ||||
-rw-r--r-- | mlir/test/Transforms/canonicalize.mlir | 48 |
3 files changed, 131 insertions, 0 deletions
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td index b0aa9b9..2113dfe 100644 --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -2997,6 +2997,8 @@ def TensorCastOp : CastOp<"tensor_cast"> { /// The result of a tensor_cast is always a tensor. TensorType getType() { return getResult().getType().cast<TensorType>(); } }]; + + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp index 0c86c87..c0dc872 100644 --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -3163,6 +3163,87 @@ OpFoldResult TensorCastOp::fold(ArrayRef<Attribute> operands) { return impl::foldCastOp(*this); } +/// Compute a TensorType that has the joined shape knowledge of the two +/// given TensorTypes. The element types need to match. +static TensorType joinShapes(TensorType one, TensorType two) { + assert(one.getElementType() == two.getElementType()); + + if (!one.hasRank()) + return two; + if (!two.hasRank()) + return one; + + int64_t rank = one.getRank(); + if (rank != two.getRank()) + return {}; + + SmallVector<int64_t, 4> join; + join.reserve(rank); + for (int64_t i = 0; i < rank; ++i) { + if (one.isDynamicDim(i)) { + join.push_back(two.getDimSize(i)); + continue; + } + if (two.isDynamicDim(i)) { + join.push_back(one.getDimSize(i)); + continue; + } + if (one.getDimSize(i) != two.getDimSize(i)) + return {}; + join.push_back(one.getDimSize(i)); + } + return RankedTensorType::get(join, one.getElementType()); +} + +namespace { + +/// Replaces chains of two tensor_cast operations by a single tensor_cast +/// operation if doing so does not remove runtime constraints. +struct ChainedTensorCast : public OpRewritePattern<TensorCastOp> { + using OpRewritePattern<TensorCastOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(TensorCastOp tensorCast, + PatternRewriter &rewriter) const final { + auto tensorCastOperand = + tensorCast.getOperand().getDefiningOp<TensorCastOp>(); + + if (!tensorCastOperand) + return failure(); + + auto sourceType = + tensorCastOperand.getOperand().getType().cast<TensorType>(); + auto intermediateType = tensorCastOperand.getType().cast<TensorType>(); + auto resultType = tensorCast.getType().cast<TensorType>(); + + // We can remove the intermediate cast if joining all three produces the + // same result as just joining the source and result shapes. + auto firstJoin = + joinShapes(joinShapes(sourceType, intermediateType), resultType); + + // The join might not exist if the cast sequence would fail at runtime. + if (!firstJoin) + return failure(); + + // The newJoin always exists if the above join exists, it might just contain + // less information. If so, we cannot drop the intermediate cast, as doing + // so would remove runtime checks. + auto newJoin = joinShapes(sourceType, resultType); + if (firstJoin != newJoin) + return failure(); + + rewriter.replaceOpWithNewOp<TensorCastOp>(tensorCast, resultType, + tensorCastOperand.getOperand()); + return success(); + } +}; + +} // namespace + +void TensorCastOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert<ChainedTensorCast>(context); +} + //===----------------------------------------------------------------------===// // Helpers for Tensor[Load|Store]Op //===----------------------------------------------------------------------===// diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir index 3204185..3603c47 100644 --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -1062,3 +1062,51 @@ func @static_dynamic_tensor_from_elements(%size1: index, %size4: index) -> tenso return %0 : tensor<3x?x?x7x?xindex> } +// ----- + +// CHECK-LABEL: @tensor_cast_chain_ok +// CHECK-SAME: %[[IN:.*]]: tensor<*xi32> +func @tensor_cast_chain_ok(%input: tensor<*xi32>) -> tensor<4x8xi32> { + // CHECK-NEXT: %[[RES:.*]] = tensor_cast %[[IN]] : tensor<*xi32> to tensor<4x8xi32> + %0 = tensor_cast %input : tensor<*xi32> to tensor<4x?xi32> + %1 = tensor_cast %0 : tensor<4x?xi32> to tensor<4x8xi32> + // CHECK-NEXT: return %[[RES]] + return %1 : tensor<4x8xi32> +} + +// ----- + +// CHECK-LABEL: @tensor_cast_chain_regain +// CHECK-SAME: %[[IN:.*]]: tensor<4xi32> +func @tensor_cast_chain_regain(%input: tensor<4xi32>) -> tensor<4xi32> { + %0 = tensor_cast %input : tensor<4xi32> to tensor<?xi32> + %1 = tensor_cast %0 : tensor<?xi32> to tensor<4xi32> + // CHECK-NEXT: return %[[IN]] + return %1 : tensor<4xi32> +} + +// ----- + +// CHECK-LABEL: @tensor_cast_chain_keep +// CHECK-SAME: %[[IN:.*]]: tensor<?x?xi32> +func @tensor_cast_chain_keep(%input: tensor<?x?xi32>) -> tensor<?x8xi32> { + // CHECK-NEXT: %[[C1:.*]] = tensor_cast %[[IN]] + %0 = tensor_cast %input : tensor<?x?xi32> to tensor<4x?xi32> + // CHECK-NEXT: %[[C2:.*]] = tensor_cast %[[C1]] + %1 = tensor_cast %0 : tensor<4x?xi32> to tensor<?x8xi32> + // CHECK-NEXT: return %[[C2]] + return %1 : tensor<?x8xi32> +} + +// ----- + +// CHECK-LABEL: @tensor_cast_chain_invalid +// CHECK-SAME: %[[IN:.*]]: tensor<4x8xi32> +func @tensor_cast_chain_invalid(%input: tensor<4x8xi32>) -> tensor<8x4xi32> { + // CHECK-NEXT: %[[C1:.*]] = tensor_cast %[[IN]] + %0 = tensor_cast %input : tensor<4x8xi32> to tensor<?x?xi32> + // CHECK-NEXT: %[[C2:.*]] = tensor_cast %[[C1]] + %1 = tensor_cast %0 : tensor<?x?xi32> to tensor<8x4xi32> + // CHECK-NEXT: return %[[C2]] + return %1 : tensor<8x4xi32> +} |