diff options
Diffstat (limited to 'mlir/test/Dialect/Linalg/canonicalize.mlir')
-rw-r--r-- | mlir/test/Dialect/Linalg/canonicalize.mlir | 71 |
1 files changed, 59 insertions, 12 deletions
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir index 9cbb56e4..5c5f7e8 100644 --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -1176,6 +1176,52 @@ func.func @broadcast_same_shape(%input: tensor<2x3xf32>, %init: tensor<2x3xf32>) // ----- +// CHECK-LABEL: @broadcast_broadcast_fold +// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<2xf32> +// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<2x3xf32> +// CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<2x3x4xf32> +// CHECK: %[[BROADCAST:.+]] = linalg.broadcast ins(%[[INPUT]] : tensor<2xf32>) outs(%[[INIT2]] : tensor<2x3x4xf32>) dimensions = [1, 2] +// CHECK-NOT: linalg.broadcast +// CHECK: return %[[BROADCAST]] : tensor<2x3x4xf32> +func.func @broadcast_broadcast_fold(%input: tensor<2xf32>, + %init1: tensor<2x3xf32>, + %init2: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> { + %broadcast1 = linalg.broadcast + ins(%input: tensor<2xf32>) + outs(%init1: tensor<2x3xf32>) + dimensions = [1] + %broadcast2 = linalg.broadcast + ins(%broadcast1: tensor<2x3xf32>) + outs(%init2: tensor<2x3x4xf32>) + dimensions = [2] + func.return %broadcast2 : tensor<2x3x4xf32> +} + +// ----- + +// CHECK-LABEL: @broadcast_broadcast_fold +// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<2xf32> +// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<2x4xf32> +// CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<2x3x4xf32> +// CHECK: %[[BROADCAST:.+]] = linalg.broadcast ins(%[[INPUT]] : tensor<2xf32>) outs(%[[INIT2]] : tensor<2x3x4xf32>) dimensions = [1, 2] +// CHECK-NOT: linalg.broadcast +// CHECK: return %[[BROADCAST]] : tensor<2x3x4xf32> +func.func @broadcast_broadcast_fold(%input: tensor<2xf32>, + %init1: tensor<2x4xf32>, + %init2: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> { + %broadcast1 = linalg.broadcast + ins(%input: tensor<2xf32>) + outs(%init1: tensor<2x4xf32>) + dimensions = [1] + %broadcast2 = linalg.broadcast + ins(%broadcast1: tensor<2x4xf32>) + outs(%init2: tensor<2x3x4xf32>) + dimensions = [1] + func.return %broadcast2 : tensor<2x3x4xf32> +} + +// ----- + func.func @transpose_1d(%input: tensor<16xf32>, %init: tensor<16xf32>) -> tensor<16xf32> { %transpose = linalg.transpose @@ -1387,42 +1433,43 @@ func.func @recursive_effect(%arg : tensor<1xf32>) { // CHECK-LABEL: @recursive_effect // CHECK: linalg.map +// ----- + //===----------------------------------------------------------------------===// // linalg.pack //===----------------------------------------------------------------------===// // CHECK-LABEL: func @fold_pack_constant_splat // CHECK-NOT: linalg.pack -// CHECK: arith.constant dense<1.000000e-01> : tensor<8x16x8x32xf32> -func.func @fold_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> { +// CHECK: arith.constant dense<1.000000e-01> : tensor<4x8x8x32xf32> +func.func @fold_pack_constant_splat(%dest : tensor<4x8x8x32xf32>) -> tensor<4x8x8x32xf32> { %cst = arith.constant dense<1.000000e-01> : tensor<64x128xf32> %0 = linalg.pack %cst outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] - inner_tiles = [8, 32] into %dest : tensor<64x128xf32> -> tensor<8x16x8x32xf32> - return %0 : tensor<8x16x8x32xf32> + inner_tiles = [8, 32] into %dest : tensor<64x128xf32> -> tensor<4x8x8x32xf32> + return %0 : tensor<4x8x8x32xf32> } // ----- // CHECK-LABEL: func @fold_padding_value_pack_constant_splat // CHECK-NOT: linalg.pack -// CHECK: arith.constant dense<1.000000e-01> : tensor<8x16x8x32xf32> -func.func @fold_padding_value_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> { +// CHECK: arith.constant dense<1.000000e-01> : tensor<4x8x8x32xf32> +func.func @fold_padding_value_pack_constant_splat(%dest : tensor<4x8x8x32xf32>) -> tensor<4x8x8x32xf32> { %pad = arith.constant 1.000000e-01 : f32 %cst = arith.constant dense<1.000000e-01> : tensor<63x127xf32> %0 = linalg.pack %cst padding_value(%pad : f32) outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] - inner_tiles = [8, 32] into %dest : tensor<63x127xf32> -> tensor<8x16x8x32xf32> - return %0 : tensor<8x16x8x32xf32> + inner_tiles = [8, 32] into %dest : tensor<63x127xf32> -> tensor<4x8x8x32xf32> + return %0 : tensor<4x8x8x32xf32> } - // ----- // CHECK-LABEL: func @nofold_padding_value_pack_constant_splat // CHECK: arith.constant dense<1.000000e-01> : tensor<63x127xf32> // CHECK: linalg.pack -func.func @nofold_padding_value_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> { +func.func @nofold_padding_value_pack_constant_splat(%dest : tensor<4x8x8x32xf32>) -> tensor<4x8x8x32xf32> { %pad = arith.constant 0.0 : f32 %cst = arith.constant dense<1.000000e-01> : tensor<63x127xf32> %0 = linalg.pack %cst @@ -1430,8 +1477,8 @@ func.func @nofold_padding_value_pack_constant_splat(%dest : tensor<8x16x8x32xf32 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [8, 32] - into %dest : tensor<63x127xf32> -> tensor<8x16x8x32xf32> - return %0 : tensor<8x16x8x32xf32> + into %dest : tensor<63x127xf32> -> tensor<4x8x8x32xf32> + return %0 : tensor<4x8x8x32xf32> } // ----- |