aboutsummaryrefslogtreecommitdiff
path: root/mlir/test/Transforms
diff options
context:
space:
mode:
authorStephan Herhut <herhut@google.com>2020-09-16 10:01:54 +0200
committerStephan Herhut <herhut@google.com>2020-09-17 16:50:38 +0200
commit5e0ded268929b87ddf2c5e077c9185554342f602 (patch)
tree45fc08983399685f99bee158b6f10d7c8edbd6e7 /mlir/test/Transforms
parente7de267910e935ab885dae22b5191bfb118ca5f9 (diff)
downloadllvm-5e0ded268929b87ddf2c5e077c9185554342f602.zip
llvm-5e0ded268929b87ddf2c5e077c9185554342f602.tar.gz
llvm-5e0ded268929b87ddf2c5e077c9185554342f602.tar.bz2
[mlir][Standard] Canonicalize chains of tensor_cast operations
Adds a pattern that replaces a chain of two tensor_cast operations by a single tensor_cast operation if doing so will not remove constraints on the shapes.
Diffstat (limited to 'mlir/test/Transforms')
-rw-r--r--mlir/test/Transforms/canonicalize.mlir48
1 files changed, 48 insertions, 0 deletions
diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir
index 3204185..3603c47 100644
--- a/mlir/test/Transforms/canonicalize.mlir
+++ b/mlir/test/Transforms/canonicalize.mlir
@@ -1062,3 +1062,51 @@ func @static_dynamic_tensor_from_elements(%size1: index, %size4: index) -> tenso
return %0 : tensor<3x?x?x7x?xindex>
}
+// -----
+
+// CHECK-LABEL: @tensor_cast_chain_ok
+// CHECK-SAME: %[[IN:.*]]: tensor<*xi32>
+func @tensor_cast_chain_ok(%input: tensor<*xi32>) -> tensor<4x8xi32> {
+ // CHECK-NEXT: %[[RES:.*]] = tensor_cast %[[IN]] : tensor<*xi32> to tensor<4x8xi32>
+ %0 = tensor_cast %input : tensor<*xi32> to tensor<4x?xi32>
+ %1 = tensor_cast %0 : tensor<4x?xi32> to tensor<4x8xi32>
+ // CHECK-NEXT: return %[[RES]]
+ return %1 : tensor<4x8xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @tensor_cast_chain_regain
+// CHECK-SAME: %[[IN:.*]]: tensor<4xi32>
+func @tensor_cast_chain_regain(%input: tensor<4xi32>) -> tensor<4xi32> {
+ %0 = tensor_cast %input : tensor<4xi32> to tensor<?xi32>
+ %1 = tensor_cast %0 : tensor<?xi32> to tensor<4xi32>
+ // CHECK-NEXT: return %[[IN]]
+ return %1 : tensor<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @tensor_cast_chain_keep
+// CHECK-SAME: %[[IN:.*]]: tensor<?x?xi32>
+func @tensor_cast_chain_keep(%input: tensor<?x?xi32>) -> tensor<?x8xi32> {
+ // CHECK-NEXT: %[[C1:.*]] = tensor_cast %[[IN]]
+ %0 = tensor_cast %input : tensor<?x?xi32> to tensor<4x?xi32>
+ // CHECK-NEXT: %[[C2:.*]] = tensor_cast %[[C1]]
+ %1 = tensor_cast %0 : tensor<4x?xi32> to tensor<?x8xi32>
+ // CHECK-NEXT: return %[[C2]]
+ return %1 : tensor<?x8xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @tensor_cast_chain_invalid
+// CHECK-SAME: %[[IN:.*]]: tensor<4x8xi32>
+func @tensor_cast_chain_invalid(%input: tensor<4x8xi32>) -> tensor<8x4xi32> {
+ // CHECK-NEXT: %[[C1:.*]] = tensor_cast %[[IN]]
+ %0 = tensor_cast %input : tensor<4x8xi32> to tensor<?x?xi32>
+ // CHECK-NEXT: %[[C2:.*]] = tensor_cast %[[C1]]
+ %1 = tensor_cast %0 : tensor<?x?xi32> to tensor<8x4xi32>
+ // CHECK-NEXT: return %[[C2]]
+ return %1 : tensor<8x4xi32>
+}