From 5e0ded268929b87ddf2c5e077c9185554342f602 Mon Sep 17 00:00:00 2001 From: Stephan Herhut Date: Wed, 16 Sep 2020 10:01:54 +0200 Subject: [mlir][Standard] Canonicalize chains of tensor_cast operations Adds a pattern that replaces a chain of two tensor_cast operations by a single tensor_cast operation if doing so will not remove constraints on the shapes. --- mlir/test/Transforms/canonicalize.mlir | 48 ++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) (limited to 'mlir/test/Transforms') diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir index 3204185..3603c47 100644 --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -1062,3 +1062,51 @@ func @static_dynamic_tensor_from_elements(%size1: index, %size4: index) -> tenso return %0 : tensor<3x?x?x7x?xindex> } +// ----- + +// CHECK-LABEL: @tensor_cast_chain_ok +// CHECK-SAME: %[[IN:.*]]: tensor<*xi32> +func @tensor_cast_chain_ok(%input: tensor<*xi32>) -> tensor<4x8xi32> { + // CHECK-NEXT: %[[RES:.*]] = tensor_cast %[[IN]] : tensor<*xi32> to tensor<4x8xi32> + %0 = tensor_cast %input : tensor<*xi32> to tensor<4x?xi32> + %1 = tensor_cast %0 : tensor<4x?xi32> to tensor<4x8xi32> + // CHECK-NEXT: return %[[RES]] + return %1 : tensor<4x8xi32> +} + +// ----- + +// CHECK-LABEL: @tensor_cast_chain_regain +// CHECK-SAME: %[[IN:.*]]: tensor<4xi32> +func @tensor_cast_chain_regain(%input: tensor<4xi32>) -> tensor<4xi32> { + %0 = tensor_cast %input : tensor<4xi32> to tensor + %1 = tensor_cast %0 : tensor to tensor<4xi32> + // CHECK-NEXT: return %[[IN]] + return %1 : tensor<4xi32> +} + +// ----- + +// CHECK-LABEL: @tensor_cast_chain_keep +// CHECK-SAME: %[[IN:.*]]: tensor +func @tensor_cast_chain_keep(%input: tensor) -> tensor { + // CHECK-NEXT: %[[C1:.*]] = tensor_cast %[[IN]] + %0 = tensor_cast %input : tensor to tensor<4x?xi32> + // CHECK-NEXT: %[[C2:.*]] = tensor_cast %[[C1]] + %1 = tensor_cast %0 : tensor<4x?xi32> to tensor + // CHECK-NEXT: return %[[C2]] + return %1 : tensor +} + +// ----- + +// CHECK-LABEL: @tensor_cast_chain_invalid +// CHECK-SAME: %[[IN:.*]]: tensor<4x8xi32> +func @tensor_cast_chain_invalid(%input: tensor<4x8xi32>) -> tensor<8x4xi32> { + // CHECK-NEXT: %[[C1:.*]] = tensor_cast %[[IN]] + %0 = tensor_cast %input : tensor<4x8xi32> to tensor + // CHECK-NEXT: %[[C2:.*]] = tensor_cast %[[C1]] + %1 = tensor_cast %0 : tensor to tensor<8x4xi32> + // CHECK-NEXT: return %[[C2]] + return %1 : tensor<8x4xi32> +} -- cgit v1.1