diff options
Diffstat (limited to 'mlir/test/Transforms')
-rw-r--r-- | mlir/test/Transforms/canonicalize.mlir | 48 |
1 files changed, 48 insertions, 0 deletions
diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir index 3204185..3603c47 100644 --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -1062,3 +1062,51 @@ func @static_dynamic_tensor_from_elements(%size1: index, %size4: index) -> tenso return %0 : tensor<3x?x?x7x?xindex> } +// ----- + +// CHECK-LABEL: @tensor_cast_chain_ok +// CHECK-SAME: %[[IN:.*]]: tensor<*xi32> +func @tensor_cast_chain_ok(%input: tensor<*xi32>) -> tensor<4x8xi32> { + // CHECK-NEXT: %[[RES:.*]] = tensor_cast %[[IN]] : tensor<*xi32> to tensor<4x8xi32> + %0 = tensor_cast %input : tensor<*xi32> to tensor<4x?xi32> + %1 = tensor_cast %0 : tensor<4x?xi32> to tensor<4x8xi32> + // CHECK-NEXT: return %[[RES]] + return %1 : tensor<4x8xi32> +} + +// ----- + +// CHECK-LABEL: @tensor_cast_chain_regain +// CHECK-SAME: %[[IN:.*]]: tensor<4xi32> +func @tensor_cast_chain_regain(%input: tensor<4xi32>) -> tensor<4xi32> { + %0 = tensor_cast %input : tensor<4xi32> to tensor<?xi32> + %1 = tensor_cast %0 : tensor<?xi32> to tensor<4xi32> + // CHECK-NEXT: return %[[IN]] + return %1 : tensor<4xi32> +} + +// ----- + +// CHECK-LABEL: @tensor_cast_chain_keep +// CHECK-SAME: %[[IN:.*]]: tensor<?x?xi32> +func @tensor_cast_chain_keep(%input: tensor<?x?xi32>) -> tensor<?x8xi32> { + // CHECK-NEXT: %[[C1:.*]] = tensor_cast %[[IN]] + %0 = tensor_cast %input : tensor<?x?xi32> to tensor<4x?xi32> + // CHECK-NEXT: %[[C2:.*]] = tensor_cast %[[C1]] + %1 = tensor_cast %0 : tensor<4x?xi32> to tensor<?x8xi32> + // CHECK-NEXT: return %[[C2]] + return %1 : tensor<?x8xi32> +} + +// ----- + +// CHECK-LABEL: @tensor_cast_chain_invalid +// CHECK-SAME: %[[IN:.*]]: tensor<4x8xi32> +func @tensor_cast_chain_invalid(%input: tensor<4x8xi32>) -> tensor<8x4xi32> { + // CHECK-NEXT: %[[C1:.*]] = tensor_cast %[[IN]] + %0 = tensor_cast %input : tensor<4x8xi32> to tensor<?x?xi32> + // CHECK-NEXT: %[[C2:.*]] = tensor_cast %[[C1]] + %1 = tensor_cast %0 : tensor<?x?xi32> to tensor<8x4xi32> + // CHECK-NEXT: return %[[C2]] + return %1 : tensor<8x4xi32> +} |