diff options
-rw-r--r-- | mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp | 130 | ||||
-rw-r--r-- | mlir/test/Dialect/Linalg/data-layout-propagation.mlir | 260 |
2 files changed, 390 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp index 2bea083..e51ae22 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp @@ -17,6 +17,8 @@ #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/IR/Dominance.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/SetOperations.h" +#include "llvm/ADT/SetVector.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" #include <optional> @@ -694,6 +696,131 @@ bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp, return success(); } +/// Project dimsPos to their collapsed positions in the reassocIndices. +/// +/// For example, given dimsPos [0, 1, 2, 4], and matching reassocIndices +/// [[0], [1, 2], [3], [4]], it returns [0, 1, 1, 3]. Because for pos 0, +/// the reassoc dim [0] is 0. For pos 1 and 2, the reassoc dim in pos +/// [1, 2] is 1. And for pos 4, the reassoc dim [4] is 3. +static SmallVector<int64_t> +projectDimsPosIntoReassocPos(ArrayRef<int64_t> dimsPos, + ArrayRef<ReassociationIndices> reassocIndices) { + SmallVector<int64_t> projectedPos; + + // Map each dimension to the position of corresponding reassociation index. + for (auto pos : dimsPos) { + for (auto [idx, indices] : llvm::enumerate(reassocIndices)) { + // If the dimension is present in the current indices group, the group + // position within the reassociation map is the desired projected + // dimension position. + if (llvm::any_of(indices, + [&](int64_t expandDim) { return expandDim == pos; })) { + projectedPos.push_back(idx); + break; + } + } + } + assert(projectedPos.size() == dimsPos.size() && "Invalid dim pos projection"); + + return projectedPos; +} + +/// Bubble up pack op through expand shape op. +/// +/// For example: +/// +/// %expand = tensor.expand_shape %in [[0], [1, 2]] +/// : tensor<?x64xf32> into tensor<?x4x16xf32> +/// %pack = tensor.pack %expand outer_dims_perm = [0, 1] +/// inner_dims_pos = [2] inner_tiles = [8] into %empty +/// : tensor<?x4x16xf32> -> tensor<?x4x2x8xf32> +/// +/// can be transformed into: +/// +/// %pack = tensor.pack %in outer_dims_perm = [1, 2] +/// inner_dims_pos = [1] inner_tiles = [8] into %empty +/// : tensor<?x64xf32> -> tensor<?x8x8xf32> +/// %expand = tensor.expand_shape %pack [[0], [1, 2], [3]] +/// : tensor<?x8x8xf32> into tensor<?x4x2x8xf32> +static LogicalResult +bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp, + tensor::PackOp packOp, + PatternRewriter &rewriter) { + // Outer dimensions permutation is not supported currently. + // TODO: Handle outer_dims_perm variants. + ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm(); + if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) { + return rewriter.notifyMatchFailure(packOp, + "non-identity outer dims perm NYI"); + } + + // Validate dimensions' relations between shape expansion and packing. + SmallVector<ReassociationIndices, 4> reassoc = + expandOp.getReassociationIndices(); + ArrayRef<int64_t> packInnerDims = packOp.getInnerDimsPos(); + llvm::SetVector<int64_t> packDimsPos(packInnerDims.begin(), + packInnerDims.end()); + + for (auto [idx, indices] : llvm::enumerate(reassoc)) { + // For each expand_shape reassociation, figure out which dimensions get + // packed if any. + llvm::SetVector<int64_t> expandDimPos(indices.begin(), indices.end()); + llvm::SetVector<int64_t> packedDims = + llvm::set_intersection(packDimsPos, expandDimPos); + + // The expanded dimension is not packed so, it does not affect moving pack + // before shape expansion - simply continue. + if (packedDims.empty()) + continue; + // Shape expansion cannot be propagated when multiple expanded dimension are + // packed - in this case operation reordering would affect final element + // positions and/or shapes can no longer be projected. + if (packedDims.size() != 1) + return rewriter.notifyMatchFailure( + packOp, "only one of the expanded dimensions can be packed"); + // Only the inner-most expanded dimension should be packed. Otherwise, + // elements order will be affected after operation reordering. + if (packedDims.front() != indices.back()) + return rewriter.notifyMatchFailure( + packOp, "can only pack the inner-most expanded dimension"); + } + + // Project pack.inner_dims_pos to positions before shape expansion. + SmallVector<int64_t> projectedInnerDimsPos = + projectDimsPosIntoReassocPos(packInnerDims, reassoc); + + // Project the shape expansion to new packed shape. + // The pack.outer_dims_perm is restricted to identity so, the permutation can + // be omitted for simplicity. + // TODO: Account for outer dimensions permutation. + // + // If reassociation is not possible, then reordering cannot happen. + // This can be caused by pack padding affecting previously expanded + // dimensions or packing extending dimensions. + RankedTensorType newPackType = tensor::PackOp::inferPackedType( + expandOp.getSrcType(), packOp.getStaticInnerTiles(), + projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{}); + auto reassocExpand = + getReassociationIndicesForReshape(newPackType, packOp.getDestType()); + if (!reassocExpand) + return rewriter.notifyMatchFailure( + packOp, "could not reassociate dims after bubbling up"); + + Value destTensor = tensor::PackOp::createDestinationTensor( + rewriter, packOp.getLoc(), expandOp.getSrc(), packOp.getMixedTiles(), + projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{}); + Value packedVal = rewriter.create<tensor::PackOp>( + packOp.getLoc(), expandOp.getSrc(), destTensor, projectedInnerDimsPos, + packOp.getMixedTiles(), packOp.getPaddingValue(), + /*outerDimsPerm=*/SmallVector<int64_t>{}); + + Value newExpandOp = rewriter.create<tensor::ExpandShapeOp>( + packOp.getLoc(), packOp.getDestType(), packedVal, *reassocExpand); + rewriter.replaceOp(packOp, newExpandOp); + + return success(); +} + class BubbleUpPackOpThroughReshapeOp final : public OpRewritePattern<tensor::PackOp> { public: @@ -723,6 +850,9 @@ public: .Case([&](tensor::CollapseShapeOp op) { return bubbleUpPackOpThroughCollapseShape(op, packOp, rewriter); }) + .Case([&](tensor::ExpandShapeOp op) { + return bubbleUpPackOpThroughExpandShape(op, packOp, rewriter); + }) .Default([](Operation *) { return failure(); }); } diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir index 9140904..78505d0 100644 --- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir +++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir @@ -988,6 +988,266 @@ func.func @no_bubble_up_pack_through_non_divisible_collapse(%1: tensor<3072x64x4 // ----- +func.func @bubble_up_pack_outer_expanded_through_expand(%arg0: tensor<32x64xf32>) -> tensor<4x2x64x4xf32> { + %empty = tensor.empty() : tensor<4x2x64x4xf32> + %expanded = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32> + %pack = tensor.pack %expanded inner_dims_pos = [1] inner_tiles = [4] into %empty : tensor<4x8x64xf32> -> tensor<4x2x64x4xf32> + return %pack : tensor<4x2x64x4xf32> +} +// CHECK-LABEL: func.func @bubble_up_pack_outer_expanded_through_expand( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x64x4xf32> +// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] +// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [4] into %[[EMPTY]] : tensor<32x64xf32> -> tensor<8x64x4xf32> +// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0, 1], [2], [3]] +// CHECK-SAME: output_shape [4, 2, 64, 4] : tensor<8x64x4xf32> into tensor<4x2x64x4xf32> +// CHECK: return %[[EXPANDED]] : tensor<4x2x64x4xf32> + +// ----- + +func.func @bubble_up_pack_inner_expanded_through_expand(%arg0: tensor<32x64xf32>) -> tensor<32x4x4x4xf32> { + %empty = tensor.empty() : tensor<32x4x4x4xf32> + %expanded = tensor.expand_shape %arg0 [[0], [1, 2]] output_shape [32, 4, 16] : tensor<32x64xf32> into tensor<32x4x16xf32> + %pack = tensor.pack %expanded inner_dims_pos = [2] inner_tiles = [4] into %empty : tensor<32x4x16xf32> -> tensor<32x4x4x4xf32> + return %pack : tensor<32x4x4x4xf32> +} +// CHECK-LABEL: func.func @bubble_up_pack_inner_expanded_through_expand( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<32x16x4xf32> +// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] +// CHECK-SAME: inner_dims_pos = [1] inner_tiles = [4] into %[[EMPTY]] +// CHECK-SAME: : tensor<32x64xf32> -> tensor<32x16x4xf32> +// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0], [1, 2], [3]] +// CHECK-SAME: output_shape [32, 4, 4, 4] : tensor<32x16x4xf32> into tensor<32x4x4x4xf32> +// CHECK: return %[[EXPANDED]] : tensor<32x4x4x4xf32> + +// ----- + +func.func @bubble_up_pack_non_expanded_dims_through_expand(%arg0: tensor<32x64x16xf32>) -> tensor<8x2x32x16x4xf32> { + %empty = tensor.empty() : tensor<8x2x32x16x4xf32> + %expanded = tensor.expand_shape %arg0 [[0], [1, 2], [3]] output_shape [32, 2, 32, 16] : tensor<32x64x16xf32> into tensor<32x2x32x16xf32> + %pack = tensor.pack %expanded inner_dims_pos = [0] inner_tiles = [4] into %empty : tensor<32x2x32x16xf32> -> tensor<8x2x32x16x4xf32> + return %pack : tensor<8x2x32x16x4xf32> +} +// CHECK-LABEL: func.func @bubble_up_pack_non_expanded_dims_through_expand( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x64x16x4xf32> +// CHECK: %[[PACK:.+]] = tensor.pack +// CHECK-SAME: %[[ARG0]] inner_dims_pos = [0] inner_tiles = [4] into %[[EMPTY]] +// CHECK-SAME: : tensor<32x64x16xf32> -> tensor<8x64x16x4xf32> +// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0], [1, 2], [3], [4]] +// CHECK-SAME: output_shape [8, 2, 32, 16, 4] : tensor<8x64x16x4xf32> into tensor<8x2x32x16x4xf32> +// CHECK: return %[[EXPANDED]] : tensor<8x2x32x16x4xf32> + +// ----- + +func.func @bubble_up_pack_through_expand_dynamic(%arg0: tensor<?x64xf32>) -> tensor<?x4x2x8xf32> { + %c0 = arith.constant 0 : index + %dim = tensor.dim %arg0, %c0 : tensor<?x64xf32> + %empty = tensor.empty(%dim) : tensor<?x4x2x8xf32> + %expanded = tensor.expand_shape %arg0 [[0], [1, 2]] output_shape [%dim, 4, 16] : tensor<?x64xf32> into tensor<?x4x16xf32> + %pack = tensor.pack %expanded inner_dims_pos = [2] inner_tiles = [8] into %empty : tensor<?x4x16xf32> -> tensor<?x4x2x8xf32> + return %pack : tensor<?x4x2x8xf32> +} +// CHECK-LABEL: func.func @bubble_up_pack_through_expand_dynamic( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK: %[[DIM_INPUT:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x64xf32> +// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM_INPUT]]) : tensor<?x8x8xf32> +// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] +// CHECK-SAME: inner_dims_pos = [1] inner_tiles = [8] into %[[EMPTY]] +// CHECK-SAME: : tensor<?x64xf32> -> tensor<?x8x8xf32> +// CHECK: %[[DIM_PACK:.+]] = tensor.dim %[[PACK]], %[[C0]] : tensor<?x8x8xf32> +// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0], [1, 2], [3]] +// CHECK-SAME: output_shape [%[[DIM_PACK]], 4, 2, 8] : tensor<?x8x8xf32> into tensor<?x4x2x8xf32> +// CHECK: return %[[EXPANDED]] : tensor<?x4x2x8xf32> + +// ----- + +func.func @bubble_up_pack_non_expanded_padding_through_expand(%arg0: tensor<32x60xf32>) -> tensor<4x2x8x4x8xf32> { + %cst = arith.constant 3.000000e+00 : f32 + %empty = tensor.empty() : tensor<4x2x8x4x8xf32> + %expanded = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [4, 8, 64] : tensor<32x60xf32> into tensor<4x8x60xf32> + %pack = tensor.pack %expanded padding_value(%cst : f32) inner_dims_pos = [1, 2] inner_tiles = [4, 8] into %empty : tensor<4x8x60xf32> -> tensor<4x2x8x4x8xf32> + return %pack : tensor<4x2x8x4x8xf32> +} +// CHECK-LABEL: func.func @bubble_up_pack_non_expanded_padding_through_expand( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK-DAG: %[[CST:.+]] = arith.constant 3.000000e+00 : f32 +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x8x4x8xf32> +// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] padding_value(%[[CST]] : f32) +// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [4, 8] into %[[EMPTY]] +// CHECK-SAME: : tensor<32x60xf32> -> tensor<8x8x4x8xf32> +// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0, 1], [2], [3], [4]] +// CHECK-SAME: output_shape [4, 2, 8, 4, 8] : tensor<8x8x4x8xf32> into tensor<4x2x8x4x8xf32> +// CHECK: return %[[EXPANDED]] : tensor<4x2x8x4x8xf32> + +// ----- + +func.func @bubble_up_pack_outer_dims_perm_identity_through_expand(%arg0: tensor<32x64xf32>) -> tensor<4x2x32x4x2xf32> { + %empty = tensor.empty() : tensor<4x2x32x4x2xf32> + %expanded = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32> + %pack = tensor.pack %expanded outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [4, 2] into %empty : tensor<4x8x64xf32> -> tensor<4x2x32x4x2xf32> + return %pack : tensor<4x2x32x4x2xf32> +} +// CHECK-LABEL: func.func @bubble_up_pack_outer_dims_perm_identity_through_expand( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x32x4x2xf32> +// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] +// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [4, 2] into %[[EMPTY]] +// CHECK-SAME: : tensor<32x64xf32> -> tensor<8x32x4x2xf32> +// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0, 1], [2], [3], [4]] +// CHECK-SAME: output_shape [4, 2, 32, 4, 2] : tensor<8x32x4x2xf32> into tensor<4x2x32x4x2xf32> +// CHECK: return %[[EXPANDED]] : tensor<4x2x32x4x2xf32> + +// ----- + +func.func @bubble_up_pack_multiple_dims_through_expand(%arg0: tensor<32x64x16xf32>) -> tensor<8x2x4x8x4x8x2xf32> { + %empty = tensor.empty() : tensor<8x2x4x8x4x8x2xf32> + %expanded = tensor.expand_shape %arg0 [[0], [1, 2], [3]] output_shape [32, 2, 32, 16] : tensor<32x64x16xf32> into tensor<32x2x32x16xf32> + %pack = tensor.pack %expanded inner_dims_pos = [0, 2, 3] inner_tiles = [4, 8, 2] into %empty : tensor<32x2x32x16xf32> -> tensor<8x2x4x8x4x8x2xf32> + return %pack : tensor<8x2x4x8x4x8x2xf32> +} +// CHECK-LABEL: func.func @bubble_up_pack_multiple_dims_through_expand( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x8x8x4x8x2xf32> +// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] +// CHECK-SAME: inner_dims_pos = [0, 1, 2] inner_tiles = [4, 8, 2] into %[[EMPTY]] +// CHECK-SAME: : tensor<32x64x16xf32> -> tensor<8x8x8x4x8x2xf32> +// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0], [1, 2], [3], [4], [5], [6]] +// CHECK-SAME: output_shape [8, 2, 4, 8, 4, 8, 2] : tensor<8x8x8x4x8x2xf32> into tensor<8x2x4x8x4x8x2xf32> +// CHECK: return %[[EXPANDED]] : tensor<8x2x4x8x4x8x2xf32> + +// ----- + +func.func @bubble_up_pack_inner_dims_reorder_through_expand(%arg0: tensor<32x64xf32>) -> tensor<4x2x4x16x4xf32> { + %empty = tensor.empty() : tensor<4x2x4x16x4xf32> + %expanded = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32> + %pack = tensor.pack %expanded inner_dims_pos = [2, 1] inner_tiles = [16, 4] into %empty : tensor<4x8x64xf32> -> tensor<4x2x4x16x4xf32> + return %pack : tensor<4x2x4x16x4xf32> +} +// CHECK-LABEL: func.func @bubble_up_pack_inner_dims_reorder_through_expand( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x4x16x4xf32> +// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] +// CHECK-SAME: inner_dims_pos = [1, 0] inner_tiles = [16, 4] into %[[EMPTY]] +// CHECK-SAME: : tensor<32x64xf32> -> tensor<8x4x16x4xf32> +// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0, 1], [2], [3], [4]] +// CHECK-SAME: output_shape [4, 2, 4, 16, 4] : tensor<8x4x16x4xf32> into tensor<4x2x4x16x4xf32> +// CHECK: return %[[EXPANDED]] : tensor<4x2x4x16x4xf32> + +// ----- + +func.func @bubble_up_pack_multiple_different_expanded_dims_through_expand(%arg0: tensor<32x64x16xf32>) -> tensor<4x2x2x8x16x4x4xf32> { + %empty = tensor.empty() : tensor<4x2x2x8x16x4x4xf32> + %expanded = tensor.expand_shape %arg0 [[0, 1], [2, 3], [4]] output_shape [4, 8, 2, 32, 16] : tensor<32x64x16xf32> into tensor<4x8x2x32x16xf32> + %pack = tensor.pack %expanded inner_dims_pos = [1, 3] inner_tiles = [4, 4] into %empty : tensor<4x8x2x32x16xf32> -> tensor<4x2x2x8x16x4x4xf32> + return %pack : tensor<4x2x2x8x16x4x4xf32> +} +// CHECK-LABEL: func.func @bubble_up_pack_multiple_different_expanded_dims_through_expand( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x16x16x4x4xf32> +// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] +// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [4, 4] into %[[EMPTY]] +// CHECK-SAME: : tensor<32x64x16xf32> -> tensor<8x16x16x4x4xf32> +// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0, 1], [2, 3], [4], [5], [6]] +// CHECK-SAME: output_shape [4, 2, 2, 8, 16, 4, 4] : tensor<8x16x16x4x4xf32> into tensor<4x2x2x8x16x4x4xf32> +// CHECK: return %[[EXPANDED]] : tensor<4x2x2x8x16x4x4xf32> + +// ----- + +func.func @no_bubble_up_pack_outer_dims_permutation_through_expand(%arg0: tensor<32x64xf32>) -> tensor<32x4x2x4x2xf32> { + %empty = tensor.empty() : tensor<32x4x2x4x2xf32> + %expanded = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32> + %pack = tensor.pack %expanded outer_dims_perm = [2, 0, 1] inner_dims_pos = [1, 2] inner_tiles = [4, 2] into %empty : tensor<4x8x64xf32> -> tensor<32x4x2x4x2xf32> + return %pack : tensor<32x4x2x4x2xf32> +} +// CHECK-LABEL: func.func @no_bubble_up_pack_outer_dims_permutation_through_expand( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<32x4x2x4x2xf32> +// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2]] +// CHECK-SAME: output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32> +// CHECK: %[[PACK:.+]] = tensor.pack %[[EXPANDED]] +// CHECK-SAME: outer_dims_perm = [2, 0, 1] inner_dims_pos = [1, 2] inner_tiles = [4, 2] into %[[EMPTY]] +// CHECK-SAME: : tensor<4x8x64xf32> -> tensor<32x4x2x4x2xf32> +// CHECK: return %[[PACK]] : tensor<32x4x2x4x2xf32> + +// ----- + +func.func @no_bubble_up_pack_multiple_same_expanded_dim_through_expand(%arg0: tensor<32x64xf32>) -> tensor<2x2x64x2x4xf32> { + %empty = tensor.empty() : tensor<2x2x64x2x4xf32> + %expanded = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32> + %pack = tensor.pack %expanded inner_dims_pos = [0, 1] inner_tiles = [2, 4] into %empty : tensor<4x8x64xf32> -> tensor<2x2x64x2x4xf32> + return %pack : tensor<2x2x64x2x4xf32> +} +// CHECK-LABEL: func.func @no_bubble_up_pack_multiple_same_expanded_dim_through_expand( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<2x2x64x2x4xf32> +// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2]] +// CHECK-SAME: output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32> +// CHECK: %[[PACK:.+]] = tensor.pack %[[EXPANDED]] +// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [2, 4] into %[[EMPTY]] +// CHECK-SAME: : tensor<4x8x64xf32> -> tensor<2x2x64x2x4xf32> +// CHECK: return %[[PACK]] : tensor<2x2x64x2x4xf32> + +// ----- + +func.func @no_bubble_up_pack_non_innermost_expanded_dim_through_expand(%arg0: tensor<32x64xf32>) -> tensor<2x8x64x2xf32> { + %empty = tensor.empty() : tensor<2x8x64x2xf32> + %expanded = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32> + %pack = tensor.pack %expanded inner_dims_pos = [0] inner_tiles = [2] into %empty : tensor<4x8x64xf32> -> tensor<2x8x64x2xf32> + return %pack : tensor<2x8x64x2xf32> +} +// CHECK-LABEL: func.func @no_bubble_up_pack_non_innermost_expanded_dim_through_expand( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<2x8x64x2xf32> +// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2]] +// CHECK-SAME: output_shape [4, 8, 64] : tensor<32x64xf32> into tensor<4x8x64xf32> +// CHECK: %[[PACK:.+]] = tensor.pack %[[EXPANDED]] +// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [2] into %[[EMPTY]] +// CHECK-SAME: : tensor<4x8x64xf32> -> tensor<2x8x64x2xf32> +// CHECK: return %[[PACK]] : tensor<2x8x64x2xf32> + +// ----- + +func.func @no_bubble_up_pack_expanded_padding_through_expand_cannot_reassociate(%arg0: tensor<30x60xf32>) -> tensor<3x2x60x8xf32> { + %cst = arith.constant 3.000000e+00 : f32 + %empty = tensor.empty() : tensor<3x2x60x8xf32> + %expanded = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [3, 10, 60] : tensor<30x60xf32> into tensor<3x10x60xf32> + %pack = tensor.pack %expanded padding_value(%cst : f32) inner_dims_pos = [1] inner_tiles = [8] into %empty : tensor<3x10x60xf32> -> tensor<3x2x60x8xf32> + return %pack : tensor<3x2x60x8xf32> +} +// CHECK-LABEL: func.func @no_bubble_up_pack_expanded_padding_through_expand_cannot_reassociate( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK-DAG: %[[CST:.+]] = arith.constant 3.000000e+00 : f32 +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<3x2x60x8xf32> +// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2]] +// CHECK-SAME: output_shape [3, 10, 60] : tensor<30x60xf32> into tensor<3x10x60xf32> +// CHECK: %[[PACK:.+]] = tensor.pack %[[EXPANDED]] padding_value(%[[CST]] : f32) +// CHECK-SAME: inner_dims_pos = [1] inner_tiles = [8] into %[[EMPTY]] +// CHECK-SAME: : tensor<3x10x60xf32> -> tensor<3x2x60x8xf32> +// CHECK: return %[[PACK]] : tensor<3x2x60x8xf32> + +// ----- + +func.func @no_bubble_up_pack_extending_dimension_through_expand_cannot_reassociate(%arg0: tensor<32x64xf32>) -> tensor<8x4x16x8xf32> { + %empty = tensor.empty() : tensor<8x4x16x8xf32> + %expanded = tensor.expand_shape %arg0 [[0], [1, 2]] output_shape [32, 4, 16] : tensor<32x64xf32> into tensor<32x4x16xf32> + %pack = tensor.pack %expanded inner_dims_pos = [0] inner_tiles = [8] into %empty : tensor<32x4x16xf32> -> tensor<8x4x16x8xf32> + return %pack : tensor<8x4x16x8xf32> +} +// CHECK-LABEL: func.func @no_bubble_up_pack_extending_dimension_through_expand_cannot_reassociate( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x4x16x8xf32> +// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2]] +// CHECK-SAME: output_shape [32, 4, 16] : tensor<32x64xf32> into tensor<32x4x16xf32> +// CHECK: %[[PACK:.+]] = tensor.pack %[[EXPANDED]] +// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [8] into %[[EMPTY]] +// CHECK-SAME: : tensor<32x4x16xf32> -> tensor<8x4x16x8xf32> +// CHECK: return %[[PACK]] : tensor<8x4x16x8xf32> + +// ----- + func.func @push_down_unpack_through_expand(%5: tensor<?x32x8x8xf32>, %dim: index, %sz0: index) -> tensor<?x256x256xf32> { %6 = tensor.empty(%dim) : tensor<?x256xf32> %unpack = tensor.unpack %5 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %6 : tensor<?x32x8x8xf32> -> tensor<?x256xf32> |