diff options
author | ofri frishman <ofri4321@gmail.com> | 2025-04-02 23:06:43 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-04-02 21:06:43 +0100 |
commit | 6f1347d57bdaed75b73b2013a96a4a69c8969ebe (patch) | |
tree | f45781cbaed8f13fdad2e91b88ddae18dafa9eb7 | |
parent | c87dc2b7d4ac0131cb97f096be522a50a4b3068b (diff) | |
download | llvm-6f1347d57bdaed75b73b2013a96a4a69c8969ebe.zip llvm-6f1347d57bdaed75b73b2013a96a4a69c8969ebe.tar.gz llvm-6f1347d57bdaed75b73b2013a96a4a69c8969ebe.tar.bz2 |
[MLIR] Bubble up tensor.extract_slice through tensor.collapse_shape (#131982)
Add a pattern that bubbles up tensor.extract_slice through
tensor.collapse_shape.
The pattern is registered in a pattern population function that is used
by the transform op
transform.apply_patterns.tensor.bubble_up_extract_slice and by the
tranform op transform.structured.fuse as a cleanup pattern.
This pattern enables tiling and fusing op chains which contain
tensor.collapse_shape if added as a cleanup pattern of tile and fuse
utility.
Without this pattern that would not be possible, as
tensor.collapse_shape does not implement the tiling interface. This is
an additional pattern to the one added in PR #126898
-rw-r--r-- | mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp | 254 | ||||
-rw-r--r-- | mlir/test/Dialect/Linalg/transform-op-fuse.mlir | 49 | ||||
-rw-r--r-- | mlir/test/Dialect/Tensor/bubble-up-extract-slice-op.mlir | 174 |
3 files changed, 476 insertions, 1 deletions
diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp index acedf51d0..eed44e6 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Tensor/Transforms/Transforms.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/ValueBoundsOpInterface.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/Support/Debug.h" #include "llvm/Support/LogicalResult.h" @@ -428,6 +429,256 @@ struct BubbleUpExpandShapeThroughExtractSlice } }; +/// Converts `tensor.extract_slice(tensor.collapse_shape)` to +/// `tensor.collapse_shape(tensor.extract_slice)`. +/// +/// For this transformation to be possible - after bubbling up, the extraction +/// of the contiguous slice must be representable as a single slice obtained via +/// tensor.extract_slice within each reassociation group of the src. +/// +/// In case the size and offset extracted are static then this is possible if +/// the following conditions are met within each reassociation group: +/// Let T be a tensor of shape [A0, A1, ..., An] (these are the sizes of the +/// dimensions in the reassociation group), and let S = [S0, S1, ..., Sn] be the +/// shape of a desired slice. A slice of shape S can be extracted as a +/// contiguous span of elements if and only if there exists an index k in {0, 1, +/// ..., n} such that: +/// S_i = 1 for all i < k (that is, all leading dimensions are singleton), +/// 1 <= S_k <= A_k (that is, non trivial slicing occurs along exactly +/// one dimension), +/// S_i = A_i for all i > k (that is, all trailing dimensions are preserved +/// in full). +/// In other words, the slice shape S must be of the form: +/// [ 1, 1, ..., 1, Sk, Ak + 1, Ak + 2, ...,An ] +/// +/// In case the size and/or offset extracted are dynamic then this is possible +/// only if there is single dimension in the reassociation group that has a size +/// not equal to 1. +/// In other words, the tensor shape must be of the form: +/// [ 1, 1, ..., 1, A, 1, ...,1 ] +/// Note - it might be possible to enable this pattern for more cases when the +/// size/offset are dynamic via performing an analysis of the possible values +/// that could be given to the size/offset. +/// +/// Example: +/// The transformation is possible because each reassociation group can be +/// represented as a contiguous slice (i.e., [8x16->2x16], [1x7->1x?], +/// [20->10]). +/// ``` +/// BEFORE: +/// %collapse = tensor.collapse_shape %src [[0, 1], [2, 3], [4]] ... +/// tensor<8x16x1x7x20f32> to tensor<128x7x20xf32> +/// %slice = tensor.extract_slice %slice [0, 0, 0][32, %size, 10][1, 1, 1] +/// tensor<128x7x20xf32> to tensor<32x?x10xf32> +/// +/// AFTER: +/// %slice = tensor.extract_slice %src [0, 0, 0, 0, 0][2, 16, 1, %size, 10] +// [1, 1, 1, 1, 1] : tensor<8x16x1x7x20f32> to tensor<2x16x1x?x10xf32> +/// %collapse = tensor.collapse_shape %slice [[0, 1], [2, 3], [4]] ... +/// tensor<2x16x1x?x10xf32> to tensor<32x?x10xf32> +/// ``` +/// +/// Negative example: +/// The transformation is not possible because we cannot use a single slice to +/// represent the reassociation group [2x3x10->???]. If we would want the +/// collapse to be after the extraction, we would need to extract multiple +/// slices and concat them together. +/// ``` +/// %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<2x3x10xf32> into +/// tensor<60xf32> %extract = tensor.extract_slice %collapse[0][15][1] : +/// tensor<60xf32> to tensor<15xf32> +/// ``` +/// If we would want the collapse to be after the extraction, a possible +/// alternate transformation could be to extract multiple slices and concat them +/// together: +/// ``` +/// %extract_1 = tensor.extract_slice %src[0, 0, 0][1, 1, 10] : +/// tensor<2x3x10xf32> to tensor <1x1x10xf32> +/// %extract_2 = tensor.extract_slice %src[0, 1, 0][1, 1, 5] : +/// tensor<2x3x10xf32> to tensor <1x1x5xf32> +/// %concat = tosa.concat %extract_1, %extract_2 {axis = 0 : i32} : +/// (<1x1x10xf32>, <1x1x5xf32>) -> <1x1x15xf32> +/// %collapse = tensor.collapse_shape %concat [[0, 1, 2]] : tensor<1x1x15xf32> +/// to tensor<15xf32> +/// ``` +/// But this is not the intended purpose of the transformation. +struct BubbleUpCollapseShapeThroughExtractSlice + : public OpRewritePattern<tensor::ExtractSliceOp> { + using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp, + PatternRewriter &rewriter) const override { + auto collapseShapeOp = + sliceOp.getSource().getDefiningOp<tensor::CollapseShapeOp>(); + if (!collapseShapeOp) { + return rewriter.notifyMatchFailure( + sliceOp, + "tensor.extract_slice source not produced by tensor.collapse_shape"); + } + + if (!sliceOp.hasUnitStride()) { + return rewriter.notifyMatchFailure( + sliceOp, "unsupported: non-unit stride. Only contiguous slices can " + "be supported in this transformation."); + } + + // The tensor.extract_slice before applying the pattern works on the result + // of the tensor.collapse_shape, so variables (i.e. inputs for + // ExtractSliceOp) referring to the state before applying the pattern are + // named with the prefix "collapsed", and ones referring to the state after + // applying the pattern are named with the prefix "expanded". + SmallVector<OpFoldResult> collapsedOffsets = sliceOp.getMixedOffsets(); + SmallVector<OpFoldResult> collapsedSizes = sliceOp.getMixedSizes(); + + if (static_cast<size_t>(sliceOp.getResultType().getRank()) != + collapsedSizes.size()) { + return rewriter.notifyMatchFailure(sliceOp, + "unimplemented: rank reducing slice"); + } + + ArrayRef<int64_t> srcShape = collapseShapeOp.getSrcType().getShape(); + SmallVector<ReassociationIndices, 4> reassociationIndices = + collapseShapeOp.getReassociationIndices(); + + // Compute new offsets, sizes, and strides for tensor.extract_slice. + // The new tensor.extract_slice will work on a tensor that has has a rank + // equal to the rank of the src of the collapse_shape. In each iteration of + // the loop, the offsets and sizes will be computed per reassociation group. + SmallVector<OpFoldResult> expandedOffsets, expandedSizes; + SmallVector<OpFoldResult> expandedStrides(srcShape.size(), + rewriter.getIndexAttr(1)); + + for (auto [collapsedSize, collapsedOffset, reassocIndices] : + llvm::zip_equal(collapsedSizes, collapsedOffsets, + collapseShapeOp.getReassociationIndices())) { + // CASE #1 - size and/or offset are dynamic. + // In this case, the slice can be represented as a contiguous slice only + // if there is a single dimension in the reassociation group that has a + // size not equal to 1. + if (isa<Value>(collapsedSize) || isa<Value>(collapsedOffset)) { + int nonUnitSizeCount = 0; + for (int64_t expandedShapeIdx : reassocIndices) { + if (srcShape[expandedShapeIdx] != 1) { + nonUnitSizeCount++; + expandedSizes.push_back(collapsedSize); + expandedOffsets.push_back(collapsedOffset); + continue; + } + + expandedSizes.push_back(rewriter.getIndexAttr(1)); + expandedOffsets.push_back(rewriter.getIndexAttr(0)); + } + + if (nonUnitSizeCount != 1) { + return rewriter.notifyMatchFailure( + sliceOp, + "unsupported: slice cannot be verified to be contiguous"); + } + continue; + } + + // CASE #2 = size and offset are static. + // Verify that the slice can be represented as a contiguous slice of the + // src of the collapse_shape. + // Checking this is done on order of most internal dimensions first, + // so traversal is done in reverse order of the reassociation group. + // If the expected slice shape is [1, 1, ..., 1, Sk, Ak + 1, Ak + 2, + // ...,An] then we first find the size and offset for n...k+1 then for k + // and then for k-1...0. + + // currentCollapsedsize and currentCollapsedOffset are initialized with + // the original collapsed size and offset and divided by the expanded + // shape size in each dimension as we go along the reassociation group. + // In essence we are spreading the original collapsed size and offset over + // the various expanded slice dimensions. + // The variables are used both to check the validity of the slice and to + // compute the expanded sizes and offsets. + int64_t currentCollapsedsize = getConstantIntValue(collapsedSize).value(); + int64_t currentCollapsedOffset = + getConstantIntValue(collapsedOffset).value(); + + SmallVector<OpFoldResult> groupExpandedSizes, groupExpandedOffsets; + + ReassociationIndices reversedReassocIndices(reassocIndices.rbegin(), + reassocIndices.rend()); + int64_t idx = 0; + int64_t reassocGroupSize = reassocIndices.size(); + + // First handle the trailing dimensions where the slice size should be + // equal to the tensor shape and the offset should be 0 (n...k+1). + for (; idx < reassocGroupSize; ++idx) { + int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]]; + + if (currentCollapsedsize < expandedShapeSize) + break; + + // We need to make sure that the slice size can be set to the shape size + // and the offset to 0. + if ((currentCollapsedsize % expandedShapeSize) != 0 || + (currentCollapsedOffset % expandedShapeSize) != 0) { + return rewriter.notifyMatchFailure( + sliceOp, "unsupported: cannot be extracted as a contiguous slice " + "of the src of the collapse_shape"); + } + + groupExpandedSizes.push_back(rewriter.getIndexAttr(expandedShapeSize)); + groupExpandedOffsets.push_back(rewriter.getIndexAttr(0)); + + currentCollapsedsize /= expandedShapeSize; + currentCollapsedOffset /= expandedShapeSize; + } + + // Now handle the first dim where slicing occurs on (k). + if (idx < reassocGroupSize) { + int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]]; + int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize; + // We need to make sure that the slice size in this dim + offset will + // not exceed the shape size. + if ((currentCollapsedsize + offsetInDim) >= expandedShapeSize) { + return rewriter.notifyMatchFailure( + sliceOp, "unsupported: slice cannot be extracted as a contiguous " + "slice of the src of the collapse_shape"); + } + + groupExpandedSizes.push_back( + rewriter.getIndexAttr(currentCollapsedsize)); + groupExpandedOffsets.push_back(rewriter.getIndexAttr(offsetInDim)); + + currentCollapsedOffset /= expandedShapeSize; + } + + // Now handle the leading dimensions where the slice size is equal to 1 + // (k-1...0). + // The size for these dimensions must be 1 because of how we constructed + // the slice size of the expanded shape. We spread the original collapsed + // size over the expanded shape sizes until we reached dimension k where + // the remaining size was smaller than the expanded shape size, and spread + // the remaining size on it. So, now we are left with only 1s. + for (idx++; idx < reassocGroupSize; ++idx) { + int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]]; + int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize; + groupExpandedSizes.push_back(rewriter.getIndexAttr(1)); + groupExpandedOffsets.push_back(rewriter.getIndexAttr(offsetInDim)); + currentCollapsedOffset /= expandedShapeSize; + } + + expandedSizes.append(groupExpandedSizes.rbegin(), + groupExpandedSizes.rend()); + expandedOffsets.append(groupExpandedOffsets.rbegin(), + groupExpandedOffsets.rend()); + } + + Value newSliceOp = rewriter.create<tensor::ExtractSliceOp>( + collapseShapeOp->getLoc(), collapseShapeOp.getSrc(), expandedOffsets, + expandedSizes, expandedStrides); + rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>( + sliceOp, sliceOp.getResultType(), newSliceOp, + collapseShapeOp.getReassociationIndices()); + + return success(); + } +}; + } // namespace void mlir::tensor::populateReassociativeReshapeFoldingPatterns( @@ -448,5 +699,6 @@ void mlir::tensor::populateBubbleUpExpandShapePatterns( void mlir::tensor::populateBubbleUpExtractSliceOpPatterns( RewritePatternSet &patterns) { - patterns.add<BubbleUpExpandShapeThroughExtractSlice>(patterns.getContext()); + patterns.add<BubbleUpExpandShapeThroughExtractSlice, + BubbleUpCollapseShapeThroughExtractSlice>(patterns.getContext()); } diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir index 9bcc125..9628580 100644 --- a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir @@ -438,3 +438,52 @@ module attributes {transform.with_named_sequence} { transform.yield } } + +// ----- + +// CHECK-LABEL: func.func @bubble_up_extract_slice_through_collapse_shape( +// CHECK: scf.for %[[X:[A-Za-z0-9]+]] = {{.*}} -> (tensor<8x1800x32xf32>) { +// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice +// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[EXTRACT]] +// CHECK: %[[EXP1:.*]] = linalg.exp ins(%[[COLLAPSE]] +func.func @bubble_up_extract_slice_through_collapse_shape(%0: tensor<1x8x1800x32xf32>) -> tensor<8x1800x32xf32> { + %expand = tensor.collapse_shape %0 [[0, 1], [2], [3]] : tensor<1x8x1800x32xf32> into tensor<8x1800x32xf32> + %empty = tensor.empty() : tensor<8x1800x32xf32> + %exp = linalg.exp ins(%expand : tensor<8x1800x32xf32>) outs(%empty : tensor<8x1800x32xf32>) -> tensor<8x1800x32xf32> + return %exp : tensor<8x1800x32xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %transformed, %loops:1 = transform.structured.fuse %0 [1, 0, 0] interchange [0, 1, 2] apply_cleanup = true : + (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">) + transform.yield + } +} + +// ----- + +// CHECK-LABEL: func.func @bubble_up_extract_slice_through_collapse_shape_with_collapse_producer( +// CHECK: scf.for %[[X:[A-Za-z0-9]+]] = {{.*}} +// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice +// CHECK: %[[ABS:.*]] = linalg.abs ins(%[[EXTRACT]] +// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[ABS]] +// CHECK: %[[EXP:.*]] = linalg.exp ins(%[[COLLAPSE]] +func.func @bubble_up_extract_slice_through_collapse_shape_with_collapse_producer(%0: tensor<1x8x1800x32xf32>) -> tensor<8x1800x32xf32> { + %empty1 = tensor.empty() : tensor<1x8x1800x32xf32> + %abs = linalg.abs ins(%0 : tensor<1x8x1800x32xf32>) outs(%empty1 : tensor<1x8x1800x32xf32>) -> tensor<1x8x1800x32xf32> + %expand = tensor.collapse_shape %abs [[0, 1], [2], [3]] : tensor<1x8x1800x32xf32> into tensor<8x1800x32xf32> + %empty2 = tensor.empty() : tensor<8x1800x32xf32> + %exp = linalg.exp ins(%expand : tensor<8x1800x32xf32>) outs(%empty2 : tensor<8x1800x32xf32>) -> tensor<8x1800x32xf32> + return %exp : tensor<8x1800x32xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %transformed, %loops:1 = transform.structured.fuse %0 [1, 0, 0] interchange [0, 1, 2] apply_cleanup = true : + (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">) + transform.yield + } +} diff --git a/mlir/test/Dialect/Tensor/bubble-up-extract-slice-op.mlir b/mlir/test/Dialect/Tensor/bubble-up-extract-slice-op.mlir index 3900bc5..34128d6 100644 --- a/mlir/test/Dialect/Tensor/bubble-up-extract-slice-op.mlir +++ b/mlir/test/Dialect/Tensor/bubble-up-extract-slice-op.mlir @@ -1,5 +1,15 @@ // RUN: mlir-opt -split-input-file -transform-interpreter %s | FileCheck %s +///---------------------------------------------------------------------------------------- +/// [Pattern: BubbleUpExpandShapeThroughExtractSlice] +/// +/// IN: tensor.expand_shape(tensor.extract_slice) +/// OUT:tensor.extract_slice(tensor.expand_shape) +/// +/// Note: tensor.extract_slice is bubbled up to be before tensor.expand_shape. +/// Some tests are negative tests for cases where the pattern cannot be applied. +///---------------------------------------------------------------------------------------- + // CHECK-LABEL: func.func @bubble_up_extract_slice_through_expand_shape( // CHECK-SAME: %[[SRC:.*]]: tensor<60xf32>) -> tensor<1x1x5xf32> { // CHECK: %[[C1:.+]] = arith.constant 5 : index @@ -113,6 +123,170 @@ func.func @bubble_up_extract_slice_affine_apply_not_folded(%src: tensor<60xf32>, return %extract : tensor<?x5x2xf32> } +///---------------------------------------------------------------------------------------- +/// [Pattern: BubbleUpCollapseShapeThroughExtractSlice] +/// +/// IN: tensor.collapse_shape(tensor.extract_slice) +/// OUT:tensor.extract_slice(tensor.collapse_shape) +/// +/// Note: tensor.extract_slice is bubbled up to be before tensor.collapse_shape. +/// Some tests are negative tests for cases where the pattern cannot be applied. +///---------------------------------------------------------------------------------------- + +// CHECK-LABEL: func.func @bubble_up_extract_slice_through_collapse_shape_single_reassoc_group( +// CHECK-SAME: %[[SRC:.*]]: tensor<6x5x2xf32>) -> tensor<1xf32> { +// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]][0, 0, 0] [1, 1, 1] [1, 1, 1] +// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[EXTRACT]] {{\[\[}}0, 1, 2]] +// CHECK: return %[[COLLAPSE]] +func.func @bubble_up_extract_slice_through_collapse_shape_single_reassoc_group(%src: tensor<6x5x2xf32>) -> tensor<1xf32> { + %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<6x5x2xf32> into tensor<60xf32> + %extract = tensor.extract_slice %collapse[0][1][1] : tensor<60xf32> to tensor<1xf32> + return %extract : tensor<1xf32> +} + +// CHECK-LABEL: func.func @bubble_up_extract_slice_through_collapse_shape_multiple_reassoc_group( +// CHECK-SAME: %[[SRC:.*]]: tensor<6x5x3x10xf32>) -> tensor<15x10xf32> { +// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]][1, 0, 1, 0] [3, 5, 1, 10] [1, 1, 1, 1] +// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[EXTRACT]] {{\[\[}}0, 1], [2, 3]] +// CHECK: return %[[COLLAPSE]] +func.func @bubble_up_extract_slice_through_collapse_shape_multiple_reassoc_group(%src: tensor<6x5x3x10xf32>) -> tensor<15x10xf32> { + %collapse = tensor.collapse_shape %src [[0, 1], [2, 3]] : tensor<6x5x3x10xf32> into tensor<30x30xf32> + %extract = tensor.extract_slice %collapse[5, 10][15, 10][1, 1] : tensor<30x30xf32> to tensor<15x10xf32> + return %extract : tensor<15x10xf32> +} + +// CHECK-LABEL: func.func @bubble_up_extract_slice_through_collapse_shape_offset_on_leading_dim( +// CHECK-SAME: %[[SRC:.*]]: tensor<6x5x2xf32>) -> tensor<4xf32> { +// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]][2, 0, 0] [1, 2, 2] [1, 1, 1] +// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[EXTRACT]] {{\[\[}}0, 1, 2]] +// CHECK: return %[[COLLAPSE]] +func.func @bubble_up_extract_slice_through_collapse_shape_offset_on_leading_dim(%src: tensor<6x5x2xf32>) -> tensor<4xf32> { + %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<6x5x2xf32> into tensor<60xf32> + %extract = tensor.extract_slice %collapse[20][4][1] : tensor<60xf32> to tensor<4xf32> + return %extract : tensor<4xf32> +} + +// CHECK-LABEL: func.func @bubble_up_extract_slice_through_collapse_shape_dynamic_size( +// CHECK-SAME: %[[SRC:.*]]: tensor<1x5x1xf32>, +// CHECK-SAME: %[[SIZE:.*]]: index) -> tensor<?xf32> { +// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]][0, 0, 0] [1, %[[SIZE]], 1] [1, 1, 1] +// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[EXTRACT]] {{\[\[}}0, 1, 2]] +// CHECK: return %[[COLLAPSE]] +func.func @bubble_up_extract_slice_through_collapse_shape_dynamic_size(%src: tensor<1x5x1xf32>, %size : index) -> tensor<?xf32> { + %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<1x5x1xf32> into tensor<5xf32> + %extract = tensor.extract_slice %collapse[0][%size][1] : tensor<5xf32> to tensor<?xf32> + return %extract : tensor<?xf32> +} + +// CHECK-LABEL: func.func @bubble_up_extract_slice_through_collapse_shape_dynamic_size_and_src( +// CHECK-SAME: %[[SRC:.*]]: tensor<1x?x1xf32>, +// CHECK-SAME: %[[SIZE:.*]]: index) -> tensor<?xf32> { +// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]][0, 0, 0] [1, %[[SIZE]], 1] [1, 1, 1] +// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[EXTRACT]] {{\[\[}}0, 1, 2]] +// CHECK: return %[[COLLAPSE]] +func.func @bubble_up_extract_slice_through_collapse_shape_dynamic_size_and_src(%src: tensor<1x?x1xf32>, %size : index) -> tensor<?xf32> { + %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<1x?x1xf32> into tensor<?xf32> + %extract = tensor.extract_slice %collapse[0][%size][1] : tensor<?xf32> to tensor<?xf32> + return %extract : tensor<?xf32> +} + + +// CHECK-LABEL: func.func @bubble_up_extract_slice_through_collapse_shape_dynamic_offset( +// CHECK-SAME: %[[SRC:.*]]: tensor<1x5x1xf32>, +// CHECK-SAME: %[[OFFSET:.*]]: index) -> tensor<3xf32> { +// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]][0, %[[OFFSET]], 0] [1, 3, 1] [1, 1, 1] +// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[EXTRACT]] {{\[\[}}0, 1, 2]] +// CHECK: return %[[COLLAPSE]] +func.func @bubble_up_extract_slice_through_collapse_shape_dynamic_offset(%src: tensor<1x5x1xf32>, %offset : index) -> tensor<3xf32> { + %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<1x5x1xf32> into tensor<5xf32> + %extract = tensor.extract_slice %collapse[%offset][3][1] : tensor<5xf32> to tensor<3xf32> + return %extract : tensor<3xf32> +} + +// CHECK-LABEL: func.func @bubble_up_extract_slice_through_collapse_shape_dynamic_offset_and_size( +// CHECK-SAME: %[[SRC:.*]]: tensor<14x1xf32>, +// CHECK-SAME: %[[OFFSET:.*]]: index, +// CHECK-SAME: %[[SIZE:.*]]: index) -> tensor<?xf32> { +// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]]{{\[}}%[[OFFSET]], 0] {{\[}}%[[SIZE]], 1] [1, 1] +// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[EXTRACT]] {{\[\[}}0, 1]] +// CHECK: return %[[COLLAPSE]] +func.func @bubble_up_extract_slice_through_collapse_shape_dynamic_offset_and_size(%src: tensor<14x1xf32>, %offset : index, %size : index) -> tensor<?xf32> { + %collapse = tensor.collapse_shape %src [[0, 1]] : tensor<14x1xf32> into tensor<14xf32> + %extract = tensor.extract_slice %collapse[%offset][%size][1] : tensor<14xf32> to tensor<?xf32> + return %extract : tensor<?xf32> +} + +// CHECK-LABEL: func.func @bubble_up_extract_slice_through_collapse_shape_dynamic_and_static_groups( +// CHECK-SAME: %[[SRC:.*]]: tensor<5x10x1x1x40xf32>, +// CHECK-SAME: %[[OFFSET:.*]]: index, +// CHECK-SAME: %[[SIZE:.*]]: index) -> tensor<20x?xf32> { +// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]][1, 0, 0, 0, %[[OFFSET]]] [2, 10, 1, 1, %[[SIZE]]] [1, 1, 1, 1, 1] +// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[EXTRACT]] {{\[\[}}0, 1], [2, 3, 4]] +// CHECK: return %[[COLLAPSE]] +func.func @bubble_up_extract_slice_through_collapse_shape_dynamic_and_static_groups(%src: tensor<5x10x1x1x40xf32>, %offset : index, %size : index) -> tensor<20x?xf32> { + %collapse = tensor.collapse_shape %src [[0, 1], [2, 3, 4]] : tensor<5x10x1x1x40xf32> into tensor<50x40xf32> + %extract = tensor.extract_slice %collapse[10, %offset][20, %size][1, 1] : tensor<50x40xf32> to tensor<20x?xf32> + return %extract : tensor<20x?xf32> +} + +/// The 2 following tests are cases where the bubble up cannot occur because the contiguous size extracted +/// from the collapsed shape cannot be expressed via a single extract_slice op. +/// In the first test it is because the size extracted cannot be expressed as a slice +/// of the form [ 1, 1, ..., 1, Sk, Ak + 1, Ak + 2, ...,An ] (see the pattern documentation for more details). +/// In the second test, the size can be expressed as the required form, but the offset is such that the pattern +/// cannot be applied. + +// CHECK-LABEL: func.func @no_bubble_up_extract_slice_through_collapse_shape_on_non_contiguous_1( +// CHECK-SAME: %[[SRC:.*]]: tensor<2x3x10xf32>) -> tensor<15xf32> { +// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape +// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice +func.func @no_bubble_up_extract_slice_through_collapse_shape_on_non_contiguous_1(%src: tensor<2x3x10xf32>) -> tensor<15xf32> { + %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<2x3x10xf32> into tensor<60xf32> + %extract = tensor.extract_slice %collapse[0][15][1] : tensor<60xf32> to tensor<15xf32> + return %extract : tensor<15xf32> +} + +// CHECK-LABEL: func.func @no_bubble_up_extract_slice_through_collapse_shape_on_non_contiguous_2( +// CHECK-SAME: %[[SRC:.*]]: tensor<2x3x10xf32>) -> tensor<20xf32> { +// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape +// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice +func.func @no_bubble_up_extract_slice_through_collapse_shape_on_non_contiguous_2(%src: tensor<2x3x10xf32>) -> tensor<20xf32> { + %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<2x3x10xf32> into tensor<60xf32> + %extract = tensor.extract_slice %collapse[20][20][1] : tensor<60xf32> to tensor<20xf32> + return %extract : tensor<20xf32> +} + +// CHECK-LABEL: func.func @no_bubble_up_extract_slice_through_collapse_shape_on_stride( +// CHECK-SAME: %[[SRC:.*]]: tensor<2x3x10xf32>) -> tensor<5xf32> { +// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape +// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice +func.func @no_bubble_up_extract_slice_through_collapse_shape_on_stride(%src: tensor<2x3x10xf32>) -> tensor<5xf32> { + %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<2x3x10xf32> into tensor<60xf32> + %extract = tensor.extract_slice %collapse[0][5][2] : tensor<60xf32> to tensor<5xf32> + return %extract : tensor<5xf32> +} + +// CHECK-LABEL: func.func @no_bubble_up_extract_slice_through_collapse_shape_on_rank_reducing( +// CHECK-SAME: %[[SRC:.*]]: tensor<6x5x2x1xf32>) -> tensor<1xf32> { +// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape +// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice +func.func @no_bubble_up_extract_slice_through_collapse_shape_on_rank_reducing(%src: tensor<6x5x2x1xf32>) -> tensor<1xf32> { + %collapse = tensor.collapse_shape %src [[0, 1, 2], [3]] : tensor<6x5x2x1xf32> into tensor<60x1xf32> + %extract = tensor.extract_slice %collapse[0, 0][1, 1][1, 1] : tensor<60x1xf32> to tensor<1xf32> + return %extract : tensor<1xf32> +} + +// CHECK-LABEL: func.func @no_bubble_up_extract_slice_through_collapse_shape_on_unsupported_dynamic( +// CHECK-SAME: %[[SRC:.*]]: tensor<1x5x2xf32>, +// CHECK-SAME: %[[SIZE:.*]]: index) -> tensor<?xf32> { +// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape +// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice +func.func @no_bubble_up_extract_slice_through_collapse_shape_on_unsupported_dynamic(%src: tensor<1x5x2xf32>, %size : index) -> tensor<?xf32> { + %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<1x5x2xf32> into tensor<10xf32> + %extract = tensor.extract_slice %collapse[0][%size][1] : tensor<10xf32> to tensor<?xf32> + return %extract : tensor<?xf32> +} + module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%root: !transform.any_op {transform.readonly}) { %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func"> |