aboutsummaryrefslogtreecommitdiff
path: root/mlir/test
diff options
context:
space:
mode:
authorAdam Siemieniuk <adam.siemieniuk@intel.com>2024-06-18 10:28:24 +0200
committerGitHub <noreply@github.com>2024-06-18 10:28:24 +0200
commita945f55d3e6af6be6648fb92a20c80e88e3fc2b2 (patch)
tree84c4c3dc42de39b38fe024a157e636f5bcd41ec9 /mlir/test
parent0e21f125c69f0e3204ea76d931717c88493e5cb3 (diff)
downloadllvm-a945f55d3e6af6be6648fb92a20c80e88e3fc2b2.zip
llvm-a945f55d3e6af6be6648fb92a20c80e88e3fc2b2.tar.gz
llvm-a945f55d3e6af6be6648fb92a20c80e88e3fc2b2.tar.bz2
[mlir][linalg] Add pattern to bubble-up pack through expand shape op (#93529)
Extends bubble-up pack through reshape pattern to handle pack propagation through expand shape ops. --------- Co-authored-by: Prashant Kumar <pk5561@gmail.com>
Diffstat (limited to 'mlir/test')
-rw-r--r--mlir/test/Dialect/Linalg/data-layout-propagation.mlir260
1 files changed, 260 insertions, 0 deletions
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index 9140904..78505d0 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -988,6 +988,266 @@ func.func @no_bubble_up_pack_through_non_divisible_collapse(%1: tensor<3072x64x4
// -----
+func.func @bubble_up_pack_outer_expanded_through_expand(%arg0: tensor<32x64xf32>) -> tensor<4x2x64x4xf32> {
+ %empty = tensor.empty() : tensor<4x2x64x4xf32>
+ %expanded = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32>
+ %pack = tensor.pack %expanded inner_dims_pos = [1] inner_tiles = [4] into %empty : tensor<4x8x64xf32> -> tensor<4x2x64x4xf32>
+ return %pack : tensor<4x2x64x4xf32>
+}
+// CHECK-LABEL: func.func @bubble_up_pack_outer_expanded_through_expand(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x64x4xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [4] into %[[EMPTY]] : tensor<32x64xf32> -> tensor<8x64x4xf32>
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0, 1], [2], [3]]
+// CHECK-SAME: output_shape [4, 2, 64, 4] : tensor<8x64x4xf32> into tensor<4x2x64x4xf32>
+// CHECK: return %[[EXPANDED]] : tensor<4x2x64x4xf32>
+
+// -----
+
+func.func @bubble_up_pack_inner_expanded_through_expand(%arg0: tensor<32x64xf32>) -> tensor<32x4x4x4xf32> {
+ %empty = tensor.empty() : tensor<32x4x4x4xf32>
+ %expanded = tensor.expand_shape %arg0 [[0], [1, 2]] output_shape [32, 4, 16] : tensor<32x64xf32> into tensor<32x4x16xf32>
+ %pack = tensor.pack %expanded inner_dims_pos = [2] inner_tiles = [4] into %empty : tensor<32x4x16xf32> -> tensor<32x4x4x4xf32>
+ return %pack : tensor<32x4x4x4xf32>
+}
+// CHECK-LABEL: func.func @bubble_up_pack_inner_expanded_through_expand(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<32x16x4xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME: inner_dims_pos = [1] inner_tiles = [4] into %[[EMPTY]]
+// CHECK-SAME: : tensor<32x64xf32> -> tensor<32x16x4xf32>
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0], [1, 2], [3]]
+// CHECK-SAME: output_shape [32, 4, 4, 4] : tensor<32x16x4xf32> into tensor<32x4x4x4xf32>
+// CHECK: return %[[EXPANDED]] : tensor<32x4x4x4xf32>
+
+// -----
+
+func.func @bubble_up_pack_non_expanded_dims_through_expand(%arg0: tensor<32x64x16xf32>) -> tensor<8x2x32x16x4xf32> {
+ %empty = tensor.empty() : tensor<8x2x32x16x4xf32>
+ %expanded = tensor.expand_shape %arg0 [[0], [1, 2], [3]] output_shape [32, 2, 32, 16] : tensor<32x64x16xf32> into tensor<32x2x32x16xf32>
+ %pack = tensor.pack %expanded inner_dims_pos = [0] inner_tiles = [4] into %empty : tensor<32x2x32x16xf32> -> tensor<8x2x32x16x4xf32>
+ return %pack : tensor<8x2x32x16x4xf32>
+}
+// CHECK-LABEL: func.func @bubble_up_pack_non_expanded_dims_through_expand(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x64x16x4xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack
+// CHECK-SAME: %[[ARG0]] inner_dims_pos = [0] inner_tiles = [4] into %[[EMPTY]]
+// CHECK-SAME: : tensor<32x64x16xf32> -> tensor<8x64x16x4xf32>
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0], [1, 2], [3], [4]]
+// CHECK-SAME: output_shape [8, 2, 32, 16, 4] : tensor<8x64x16x4xf32> into tensor<8x2x32x16x4xf32>
+// CHECK: return %[[EXPANDED]] : tensor<8x2x32x16x4xf32>
+
+// -----
+
+func.func @bubble_up_pack_through_expand_dynamic(%arg0: tensor<?x64xf32>) -> tensor<?x4x2x8xf32> {
+ %c0 = arith.constant 0 : index
+ %dim = tensor.dim %arg0, %c0 : tensor<?x64xf32>
+ %empty = tensor.empty(%dim) : tensor<?x4x2x8xf32>
+ %expanded = tensor.expand_shape %arg0 [[0], [1, 2]] output_shape [%dim, 4, 16] : tensor<?x64xf32> into tensor<?x4x16xf32>
+ %pack = tensor.pack %expanded inner_dims_pos = [2] inner_tiles = [8] into %empty : tensor<?x4x16xf32> -> tensor<?x4x2x8xf32>
+ return %pack : tensor<?x4x2x8xf32>
+}
+// CHECK-LABEL: func.func @bubble_up_pack_through_expand_dynamic(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[DIM_INPUT:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x64xf32>
+// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM_INPUT]]) : tensor<?x8x8xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME: inner_dims_pos = [1] inner_tiles = [8] into %[[EMPTY]]
+// CHECK-SAME: : tensor<?x64xf32> -> tensor<?x8x8xf32>
+// CHECK: %[[DIM_PACK:.+]] = tensor.dim %[[PACK]], %[[C0]] : tensor<?x8x8xf32>
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0], [1, 2], [3]]
+// CHECK-SAME: output_shape [%[[DIM_PACK]], 4, 2, 8] : tensor<?x8x8xf32> into tensor<?x4x2x8xf32>
+// CHECK: return %[[EXPANDED]] : tensor<?x4x2x8xf32>
+
+// -----
+
+func.func @bubble_up_pack_non_expanded_padding_through_expand(%arg0: tensor<32x60xf32>) -> tensor<4x2x8x4x8xf32> {
+ %cst = arith.constant 3.000000e+00 : f32
+ %empty = tensor.empty() : tensor<4x2x8x4x8xf32>
+ %expanded = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [4, 8, 64] : tensor<32x60xf32> into tensor<4x8x60xf32>
+ %pack = tensor.pack %expanded padding_value(%cst : f32) inner_dims_pos = [1, 2] inner_tiles = [4, 8] into %empty : tensor<4x8x60xf32> -> tensor<4x2x8x4x8xf32>
+ return %pack : tensor<4x2x8x4x8xf32>
+}
+// CHECK-LABEL: func.func @bubble_up_pack_non_expanded_padding_through_expand(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-DAG: %[[CST:.+]] = arith.constant 3.000000e+00 : f32
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x8x4x8xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] padding_value(%[[CST]] : f32)
+// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [4, 8] into %[[EMPTY]]
+// CHECK-SAME: : tensor<32x60xf32> -> tensor<8x8x4x8xf32>
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0, 1], [2], [3], [4]]
+// CHECK-SAME: output_shape [4, 2, 8, 4, 8] : tensor<8x8x4x8xf32> into tensor<4x2x8x4x8xf32>
+// CHECK: return %[[EXPANDED]] : tensor<4x2x8x4x8xf32>
+
+// -----
+
+func.func @bubble_up_pack_outer_dims_perm_identity_through_expand(%arg0: tensor<32x64xf32>) -> tensor<4x2x32x4x2xf32> {
+ %empty = tensor.empty() : tensor<4x2x32x4x2xf32>
+ %expanded = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32>
+ %pack = tensor.pack %expanded outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [4, 2] into %empty : tensor<4x8x64xf32> -> tensor<4x2x32x4x2xf32>
+ return %pack : tensor<4x2x32x4x2xf32>
+}
+// CHECK-LABEL: func.func @bubble_up_pack_outer_dims_perm_identity_through_expand(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x32x4x2xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [4, 2] into %[[EMPTY]]
+// CHECK-SAME: : tensor<32x64xf32> -> tensor<8x32x4x2xf32>
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0, 1], [2], [3], [4]]
+// CHECK-SAME: output_shape [4, 2, 32, 4, 2] : tensor<8x32x4x2xf32> into tensor<4x2x32x4x2xf32>
+// CHECK: return %[[EXPANDED]] : tensor<4x2x32x4x2xf32>
+
+// -----
+
+func.func @bubble_up_pack_multiple_dims_through_expand(%arg0: tensor<32x64x16xf32>) -> tensor<8x2x4x8x4x8x2xf32> {
+ %empty = tensor.empty() : tensor<8x2x4x8x4x8x2xf32>
+ %expanded = tensor.expand_shape %arg0 [[0], [1, 2], [3]] output_shape [32, 2, 32, 16] : tensor<32x64x16xf32> into tensor<32x2x32x16xf32>
+ %pack = tensor.pack %expanded inner_dims_pos = [0, 2, 3] inner_tiles = [4, 8, 2] into %empty : tensor<32x2x32x16xf32> -> tensor<8x2x4x8x4x8x2xf32>
+ return %pack : tensor<8x2x4x8x4x8x2xf32>
+}
+// CHECK-LABEL: func.func @bubble_up_pack_multiple_dims_through_expand(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x8x8x4x8x2xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME: inner_dims_pos = [0, 1, 2] inner_tiles = [4, 8, 2] into %[[EMPTY]]
+// CHECK-SAME: : tensor<32x64x16xf32> -> tensor<8x8x8x4x8x2xf32>
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0], [1, 2], [3], [4], [5], [6]]
+// CHECK-SAME: output_shape [8, 2, 4, 8, 4, 8, 2] : tensor<8x8x8x4x8x2xf32> into tensor<8x2x4x8x4x8x2xf32>
+// CHECK: return %[[EXPANDED]] : tensor<8x2x4x8x4x8x2xf32>
+
+// -----
+
+func.func @bubble_up_pack_inner_dims_reorder_through_expand(%arg0: tensor<32x64xf32>) -> tensor<4x2x4x16x4xf32> {
+ %empty = tensor.empty() : tensor<4x2x4x16x4xf32>
+ %expanded = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32>
+ %pack = tensor.pack %expanded inner_dims_pos = [2, 1] inner_tiles = [16, 4] into %empty : tensor<4x8x64xf32> -> tensor<4x2x4x16x4xf32>
+ return %pack : tensor<4x2x4x16x4xf32>
+}
+// CHECK-LABEL: func.func @bubble_up_pack_inner_dims_reorder_through_expand(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x4x16x4xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME: inner_dims_pos = [1, 0] inner_tiles = [16, 4] into %[[EMPTY]]
+// CHECK-SAME: : tensor<32x64xf32> -> tensor<8x4x16x4xf32>
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0, 1], [2], [3], [4]]
+// CHECK-SAME: output_shape [4, 2, 4, 16, 4] : tensor<8x4x16x4xf32> into tensor<4x2x4x16x4xf32>
+// CHECK: return %[[EXPANDED]] : tensor<4x2x4x16x4xf32>
+
+// -----
+
+func.func @bubble_up_pack_multiple_different_expanded_dims_through_expand(%arg0: tensor<32x64x16xf32>) -> tensor<4x2x2x8x16x4x4xf32> {
+ %empty = tensor.empty() : tensor<4x2x2x8x16x4x4xf32>
+ %expanded = tensor.expand_shape %arg0 [[0, 1], [2, 3], [4]] output_shape [4, 8, 2, 32, 16] : tensor<32x64x16xf32> into tensor<4x8x2x32x16xf32>
+ %pack = tensor.pack %expanded inner_dims_pos = [1, 3] inner_tiles = [4, 4] into %empty : tensor<4x8x2x32x16xf32> -> tensor<4x2x2x8x16x4x4xf32>
+ return %pack : tensor<4x2x2x8x16x4x4xf32>
+}
+// CHECK-LABEL: func.func @bubble_up_pack_multiple_different_expanded_dims_through_expand(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x16x16x4x4xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [4, 4] into %[[EMPTY]]
+// CHECK-SAME: : tensor<32x64x16xf32> -> tensor<8x16x16x4x4xf32>
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0, 1], [2, 3], [4], [5], [6]]
+// CHECK-SAME: output_shape [4, 2, 2, 8, 16, 4, 4] : tensor<8x16x16x4x4xf32> into tensor<4x2x2x8x16x4x4xf32>
+// CHECK: return %[[EXPANDED]] : tensor<4x2x2x8x16x4x4xf32>
+
+// -----
+
+func.func @no_bubble_up_pack_outer_dims_permutation_through_expand(%arg0: tensor<32x64xf32>) -> tensor<32x4x2x4x2xf32> {
+ %empty = tensor.empty() : tensor<32x4x2x4x2xf32>
+ %expanded = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32>
+ %pack = tensor.pack %expanded outer_dims_perm = [2, 0, 1] inner_dims_pos = [1, 2] inner_tiles = [4, 2] into %empty : tensor<4x8x64xf32> -> tensor<32x4x2x4x2xf32>
+ return %pack : tensor<32x4x2x4x2xf32>
+}
+// CHECK-LABEL: func.func @no_bubble_up_pack_outer_dims_permutation_through_expand(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<32x4x2x4x2xf32>
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2]]
+// CHECK-SAME: output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack %[[EXPANDED]]
+// CHECK-SAME: outer_dims_perm = [2, 0, 1] inner_dims_pos = [1, 2] inner_tiles = [4, 2] into %[[EMPTY]]
+// CHECK-SAME: : tensor<4x8x64xf32> -> tensor<32x4x2x4x2xf32>
+// CHECK: return %[[PACK]] : tensor<32x4x2x4x2xf32>
+
+// -----
+
+func.func @no_bubble_up_pack_multiple_same_expanded_dim_through_expand(%arg0: tensor<32x64xf32>) -> tensor<2x2x64x2x4xf32> {
+ %empty = tensor.empty() : tensor<2x2x64x2x4xf32>
+ %expanded = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32>
+ %pack = tensor.pack %expanded inner_dims_pos = [0, 1] inner_tiles = [2, 4] into %empty : tensor<4x8x64xf32> -> tensor<2x2x64x2x4xf32>
+ return %pack : tensor<2x2x64x2x4xf32>
+}
+// CHECK-LABEL: func.func @no_bubble_up_pack_multiple_same_expanded_dim_through_expand(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<2x2x64x2x4xf32>
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2]]
+// CHECK-SAME: output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack %[[EXPANDED]]
+// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [2, 4] into %[[EMPTY]]
+// CHECK-SAME: : tensor<4x8x64xf32> -> tensor<2x2x64x2x4xf32>
+// CHECK: return %[[PACK]] : tensor<2x2x64x2x4xf32>
+
+// -----
+
+func.func @no_bubble_up_pack_non_innermost_expanded_dim_through_expand(%arg0: tensor<32x64xf32>) -> tensor<2x8x64x2xf32> {
+ %empty = tensor.empty() : tensor<2x8x64x2xf32>
+ %expanded = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32>
+ %pack = tensor.pack %expanded inner_dims_pos = [0] inner_tiles = [2] into %empty : tensor<4x8x64xf32> -> tensor<2x8x64x2xf32>
+ return %pack : tensor<2x8x64x2xf32>
+}
+// CHECK-LABEL: func.func @no_bubble_up_pack_non_innermost_expanded_dim_through_expand(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<2x8x64x2xf32>
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2]]
+// CHECK-SAME: output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack %[[EXPANDED]]
+// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [2] into %[[EMPTY]]
+// CHECK-SAME: : tensor<4x8x64xf32> -> tensor<2x8x64x2xf32>
+// CHECK: return %[[PACK]] : tensor<2x8x64x2xf32>
+
+// -----
+
+func.func @no_bubble_up_pack_expanded_padding_through_expand_cannot_reassociate(%arg0: tensor<30x60xf32>) -> tensor<3x2x60x8xf32> {
+ %cst = arith.constant 3.000000e+00 : f32
+ %empty = tensor.empty() : tensor<3x2x60x8xf32>
+ %expanded = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [3, 10, 60] : tensor<30x60xf32> into tensor<3x10x60xf32>
+ %pack = tensor.pack %expanded padding_value(%cst : f32) inner_dims_pos = [1] inner_tiles = [8] into %empty : tensor<3x10x60xf32> -> tensor<3x2x60x8xf32>
+ return %pack : tensor<3x2x60x8xf32>
+}
+// CHECK-LABEL: func.func @no_bubble_up_pack_expanded_padding_through_expand_cannot_reassociate(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-DAG: %[[CST:.+]] = arith.constant 3.000000e+00 : f32
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<3x2x60x8xf32>
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2]]
+// CHECK-SAME: output_shape [3, 10, 60] : tensor<30x60xf32> into tensor<3x10x60xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack %[[EXPANDED]] padding_value(%[[CST]] : f32)
+// CHECK-SAME: inner_dims_pos = [1] inner_tiles = [8] into %[[EMPTY]]
+// CHECK-SAME: : tensor<3x10x60xf32> -> tensor<3x2x60x8xf32>
+// CHECK: return %[[PACK]] : tensor<3x2x60x8xf32>
+
+// -----
+
+func.func @no_bubble_up_pack_extending_dimension_through_expand_cannot_reassociate(%arg0: tensor<32x64xf32>) -> tensor<8x4x16x8xf32> {
+ %empty = tensor.empty() : tensor<8x4x16x8xf32>
+ %expanded = tensor.expand_shape %arg0 [[0], [1, 2]] output_shape [32, 4, 16] : tensor<32x64xf32> into tensor<32x4x16xf32>
+ %pack = tensor.pack %expanded inner_dims_pos = [0] inner_tiles = [8] into %empty : tensor<32x4x16xf32> -> tensor<8x4x16x8xf32>
+ return %pack : tensor<8x4x16x8xf32>
+}
+// CHECK-LABEL: func.func @no_bubble_up_pack_extending_dimension_through_expand_cannot_reassociate(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x4x16x8xf32>
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2]]
+// CHECK-SAME: output_shape [32, 4, 16] : tensor<32x64xf32> into tensor<32x4x16xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack %[[EXPANDED]]
+// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [8] into %[[EMPTY]]
+// CHECK-SAME: : tensor<32x4x16xf32> -> tensor<8x4x16x8xf32>
+// CHECK: return %[[PACK]] : tensor<8x4x16x8xf32>
+
+// -----
+
func.func @push_down_unpack_through_expand(%5: tensor<?x32x8x8xf32>, %dim: index, %sz0: index) -> tensor<?x256x256xf32> {
%6 = tensor.empty(%dim) : tensor<?x256xf32>
%unpack = tensor.unpack %5 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %6 : tensor<?x32x8x8xf32> -> tensor<?x256xf32>