aboutsummaryrefslogtreecommitdiff
path: root/mlir/test/Transforms
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/test/Transforms')
-rw-r--r--mlir/test/Transforms/canonicalize.mlir48
1 files changed, 48 insertions, 0 deletions
diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir
index 3204185..3603c47 100644
--- a/mlir/test/Transforms/canonicalize.mlir
+++ b/mlir/test/Transforms/canonicalize.mlir
@@ -1062,3 +1062,51 @@ func @static_dynamic_tensor_from_elements(%size1: index, %size4: index) -> tenso
return %0 : tensor<3x?x?x7x?xindex>
}
+// -----
+
+// CHECK-LABEL: @tensor_cast_chain_ok
+// CHECK-SAME: %[[IN:.*]]: tensor<*xi32>
+func @tensor_cast_chain_ok(%input: tensor<*xi32>) -> tensor<4x8xi32> {
+ // CHECK-NEXT: %[[RES:.*]] = tensor_cast %[[IN]] : tensor<*xi32> to tensor<4x8xi32>
+ %0 = tensor_cast %input : tensor<*xi32> to tensor<4x?xi32>
+ %1 = tensor_cast %0 : tensor<4x?xi32> to tensor<4x8xi32>
+ // CHECK-NEXT: return %[[RES]]
+ return %1 : tensor<4x8xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @tensor_cast_chain_regain
+// CHECK-SAME: %[[IN:.*]]: tensor<4xi32>
+func @tensor_cast_chain_regain(%input: tensor<4xi32>) -> tensor<4xi32> {
+ %0 = tensor_cast %input : tensor<4xi32> to tensor<?xi32>
+ %1 = tensor_cast %0 : tensor<?xi32> to tensor<4xi32>
+ // CHECK-NEXT: return %[[IN]]
+ return %1 : tensor<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @tensor_cast_chain_keep
+// CHECK-SAME: %[[IN:.*]]: tensor<?x?xi32>
+func @tensor_cast_chain_keep(%input: tensor<?x?xi32>) -> tensor<?x8xi32> {
+ // CHECK-NEXT: %[[C1:.*]] = tensor_cast %[[IN]]
+ %0 = tensor_cast %input : tensor<?x?xi32> to tensor<4x?xi32>
+ // CHECK-NEXT: %[[C2:.*]] = tensor_cast %[[C1]]
+ %1 = tensor_cast %0 : tensor<4x?xi32> to tensor<?x8xi32>
+ // CHECK-NEXT: return %[[C2]]
+ return %1 : tensor<?x8xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @tensor_cast_chain_invalid
+// CHECK-SAME: %[[IN:.*]]: tensor<4x8xi32>
+func @tensor_cast_chain_invalid(%input: tensor<4x8xi32>) -> tensor<8x4xi32> {
+ // CHECK-NEXT: %[[C1:.*]] = tensor_cast %[[IN]]
+ %0 = tensor_cast %input : tensor<4x8xi32> to tensor<?x?xi32>
+ // CHECK-NEXT: %[[C2:.*]] = tensor_cast %[[C1]]
+ %1 = tensor_cast %0 : tensor<?x?xi32> to tensor<8x4xi32>
+ // CHECK-NEXT: return %[[C2]]
+ return %1 : tensor<8x4xi32>
+}