//===- RankReductionPatterns.cpp - Patterns related to rank reductions ----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Transforms/Transforms.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/ValueBoundsOpInterface.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/LogicalResult.h" using namespace mlir; using namespace mlir::tensor; namespace { /// Fold expand_shape(extract_slice) ops that cancel itself out. struct FoldExpandOfRankReducingExtract : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ExpandShapeOp expandShapeOp, PatternRewriter &rewriter) const override { RankedTensorType resultType = expandShapeOp.getResultType(); auto extractSliceOp = expandShapeOp.getSrc().getDefiningOp(); if (!extractSliceOp) return failure(); RankedTensorType srcType = extractSliceOp.getSourceType(); // Only cases where the ExpandShapeOp can be folded away entirely are // supported. Moreover, only simple cases where the resulting ExtractSliceOp // has no rank-reduction anymore are supported at the moment. RankedTensorType nonReducingExtractType = ExtractSliceOp::inferResultType( srcType, extractSliceOp.getStaticOffsets(), extractSliceOp.getStaticSizes(), extractSliceOp.getStaticStrides()); if (nonReducingExtractType != resultType) return failure(); SmallVector mixedOffsets = extractSliceOp.getMixedOffsets(); SmallVector mixedSizes = extractSliceOp.getMixedSizes(); SmallVector mixedStrides = extractSliceOp.getMixedStrides(); rewriter.replaceOpWithNewOp( expandShapeOp, extractSliceOp.getSource(), mixedOffsets, mixedSizes, mixedStrides); return success(); } }; /// Fold collapse_shape which only removes static dimensions of size `1` /// into extract_slice. struct FoldUnPaddingCollapseIntoExtract : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tensor::CollapseShapeOp collapseShapeOp, PatternRewriter &rewriter) const override { auto extractSliceOp = collapseShapeOp.getSrc().getDefiningOp(); // Collapse cannot be folded away with multiple users of the extract slice // and it is not necessarily beneficial to only convert the collapse into // another extract slice. if (!extractSliceOp || !extractSliceOp->hasOneUse()) return failure(); // Only fold away simple collapse where all removed dimensions have static // size `1`. SliceVerificationResult res = isRankReducedType( collapseShapeOp.getSrcType(), collapseShapeOp.getResultType()); if (res != SliceVerificationResult::Success) return rewriter.notifyMatchFailure(collapseShapeOp, "expected unpadding collapse"); Value unPaddedExtractSlice = tensor::ExtractSliceOp::create( rewriter, extractSliceOp.getLoc(), collapseShapeOp.getResultType(), extractSliceOp.getSource(), extractSliceOp.getMixedOffsets(), extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides()); rewriter.replaceOp(collapseShapeOp, unPaddedExtractSlice); return success(); } }; /// Fold insert_slice(collapse_shape) ops that cancel itself out. template struct FoldInsertOfRankReducingInsert : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(OpTy insertSliceOp, PatternRewriter &rewriter) const override { auto collapseShapeOp = insertSliceOp.getSource().template getDefiningOp(); if (!collapseShapeOp) return failure(); RankedTensorType srcType = collapseShapeOp.getSrcType(); // Only cases where the CollapseShapeOp can be folded away entirely are // supported. Moreover, only simple cases where the resulting InsertSliceOp // has no rank-reduction anymore are supported at the moment. RankedTensorType nonReducingInsertType = RankedTensorType::get(insertSliceOp.getStaticSizes(), insertSliceOp.getDestType().getElementType()); if (nonReducingInsertType != srcType) return failure(); SmallVector mixedOffsets = insertSliceOp.getMixedOffsets(); SmallVector mixedSizes = insertSliceOp.getMixedSizes(); SmallVector mixedStrides = insertSliceOp.getMixedStrides(); rewriter.replaceOpWithNewOp(insertSliceOp, collapseShapeOp.getSrc(), insertSliceOp.getDest(), mixedOffsets, mixedSizes, mixedStrides); return success(); } }; /// Fold expand_shape which only adds static dimensions of size `1` /// into insert_slice. template struct FoldPaddingExpandIntoInsert : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(OpTy insertSliceOp, PatternRewriter &rewriter) const override { auto expandShapeOp = insertSliceOp.getSource() .template getDefiningOp(); if (!expandShapeOp) return failure(); // Only fold away simple expansion where all added dimensions have static // size `1`. SliceVerificationResult res = isRankReducedType( expandShapeOp.getResultType(), expandShapeOp.getSrcType()); if (res != SliceVerificationResult::Success) return rewriter.notifyMatchFailure(insertSliceOp, "expected rank increasing expansion"); rewriter.modifyOpInPlace(insertSliceOp, [&]() { insertSliceOp.getSourceMutable().assign(expandShapeOp.getSrc()); }); return success(); } }; /// Pattern to bubble up a tensor.expand_shape op through a producer /// tensor.collapse_shape op that has non intersecting reassociations. struct BubbleUpExpandThroughParallelCollapse : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp, PatternRewriter &rewriter) const override { auto collapseOp = expandOp.getSrc().getDefiningOp(); if (!collapseOp) return failure(); auto expandReInds = expandOp.getReassociationIndices(); auto collapseReInds = collapseOp.getReassociationIndices(); // Special case where the collapsed tensor to expand is a 0-D tensor, // then the reassociation maps will be empty and not produce valid results. if (expandReInds.size() == 0) { return failure(); } // Reshapes are parallel to each other (by construction the number of // reassociations specified in the collapse and expand are the same), if at // any position // 1. either the reassociation indices are of the same size, or // 2. either the reassociation in the collapse or the expand is of size 1. ArrayRef staticSourceSize = collapseOp.getSrcType().getShape(); ArrayRef staticResultSize = expandOp.getStaticOutputShape(); for (auto [expandReassociation, collapseReassociation] : llvm::zip_equal(expandReInds, collapseReInds)) { if (collapseReassociation.size() == expandReassociation.size()) { // Even if the reassociations are the same, the collapse/expand should // result in the same dimensions. i.e 4x8x2 into 64 should be expanded // into 4x8x2 again. In presense of dynamic dimensions one can only // verify "equality" when there is only one dynamic dimension present, // and all other static dimensions are equal. ArrayRef collapsedStaticShapes = staticSourceSize.slice( collapseReassociation.front(), collapseReassociation.size()); int64_t numCollapsedDynamic = llvm::count_if(collapsedStaticShapes, ShapedType::isDynamic); ArrayRef expandedStaticShapes = staticResultSize.slice( expandReassociation.front(), expandReassociation.size()); int64_t numExpandedDynamic = llvm::count_if(expandedStaticShapes, ShapedType::isDynamic); if (numCollapsedDynamic > 1 || numExpandedDynamic > 1 || collapsedStaticShapes != expandedStaticShapes) { return failure(); } continue; } // If the reassociations are not same, one or the other needs to be of // size one. if (collapseReassociation.size() != 1 && expandReassociation.size() != 1) return failure(); } // Compute new reassociation indices and expanded/collaped shapes. SmallVector newExpandReInds, newCollapseReInds; Location loc = expandOp->getLoc(); SmallVector sourceSizes = tensor::getMixedSizes(rewriter, loc, collapseOp.getSrc()); SmallVector resultSizes = expandOp.getMixedOutputShape(); SmallVector newExpandSizes; int64_t newExpandIndex = 0, newCollapseIndex = 0, sourceSizeIndex = 0, resultSizeIndex = 0; for (size_t idx = 0, idxEnd = collapseReInds.size(); idx < idxEnd; idx++) { auto &collapseReassociation = collapseReInds[idx]; auto &expandReassociation = expandReInds[idx]; // Case 1. The reassociations are same in the collapse producer // and expand consumer. In the swapped expand, each of the final // dimensions are kept as is in the expand and the collapse. So, // for every element in the `ReassocationIndices` vector add a new // `ReassociationIndices` vector for the swapped expand and collapse // (of size 1). if (collapseReassociation.size() == expandReassociation.size()) { for (size_t i = 0; i < collapseReassociation.size(); ++i) { newCollapseReInds.push_back({newCollapseIndex++}); newExpandReInds.push_back({newExpandIndex++}); newExpandSizes.push_back(resultSizes[resultSizeIndex++]); sourceSizeIndex++; } continue; } // Case 2. The `ReassociationIndices` in the collapse is of size > 1 (and // in the expand is of size == 1). In this case, the original dimensions // are preserved on expansion and collapsed subsequently. if (collapseReassociation.size() != 1) { ReassociationIndices newCollapseReassociation; for (size_t i = 0; i < collapseReassociation.size(); ++i) { newCollapseReassociation.push_back(newCollapseIndex++); newExpandReInds.push_back({newExpandIndex++}); newExpandSizes.push_back(sourceSizes[sourceSizeIndex++]); } resultSizeIndex++; newCollapseReInds.push_back(newCollapseReassociation); continue; } // Case 3. The `ReassociationIndices` in the expand is of size > 1 (and // in the collapse is of size == 1). In this case, the expansion happens // first and the expanded dimensions are preserved on collapse. ReassociationIndices newExpandReassociation; for (size_t i = 0; i < expandReassociation.size(); ++i) { newExpandReassociation.push_back(newExpandIndex++); newCollapseReInds.push_back({newCollapseIndex++}); newExpandSizes.push_back(resultSizes[resultSizeIndex++]); } newExpandReInds.push_back(newExpandReassociation); sourceSizeIndex++; } // Swap reshape order. SmallVector dynamicSizes; SmallVector staticSizes; dispatchIndexOpFoldResults(newExpandSizes, dynamicSizes, staticSizes); auto expandResultType = expandOp.getResultType().clone(staticSizes); Value newCollapseSrc = collapseOp.getSrc(); // If the number of reassociation indices in the new `expand_shape` op // matches the number of dimensions of the result, then the expand_shape // is a no-op. if (newExpandReInds.size() != newExpandSizes.size()) { newCollapseSrc = tensor::ExpandShapeOp::create( rewriter, loc, expandResultType, newCollapseSrc, newExpandReInds, newExpandSizes); } // If the number of reassociation indices in the new `collapse_shape` op // matches the number of dimensions of the source, then the collapse_shape // is a no-op. Value replacement = newCollapseSrc; if (newCollapseReInds.size() != newExpandSizes.size()) { replacement = tensor::CollapseShapeOp::create( rewriter, loc, newCollapseSrc, newCollapseReInds); } rewriter.replaceOp(expandOp, replacement); return success(); } }; /// Converts `tensor.extract_slice(tensor.expand_shape)` to /// `tensor.expand_shape(tensor.extract_slice)`. /// /// For this transformation to be possible, the slice must be fully contiguous /// within each reassociation group of the expand_shape. A slice is defined as /// fully contiguous within a reassociation group if after flattening the /// reassociation group to a single 1D range, then the slice taken out of the /// group could be defined as a single contiguous subrange within that range. /// /// Rank reducing slices are not supported. /// /// Example: /// The transformation is possible because each reassociation group has a /// contiguous slice (i.e., [2x4->2x4], [2x8->1x5], [4x2x4->1x1x4]). /// ``` /// BEFORE: /// %reshape = tensor.expand_shape %in [[0, 1], [2, 3], [4, 5, 6]] /// tensor<8x16x32xf32> to tensor<2x4x2x8x4x2x4xf32> /// %slice = tensor.extract_slice %reshape ... /// tensor<2x4x2x8x4x2x4xf32> to tensor<2x4x1x5x1x1x4xf32> /// /// AFTER: /// %slice = tensor.extract_slice %in ... /// tensor<8x16x32xf32> to tensor<8x5x4xf32> /// %reshape = tensor.expand_shape %slice [[0, 1], [2, 3], [4, 5, 6]] /// tensor<8x5x4xf32> to tensor<2x4x1x5x1x1x4xf32> /// ``` /// /// Note - this pattern could be extended to be a swap pattern between /// `tensor.expand_shape` and `tensor.extract_slice`, but is currently /// implemented only as a bubble up pattern for `tensor.extract_slice`. struct BubbleUpExtractSliceThroughExpandShape : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const override { auto expandShapeOp = sliceOp.getSource().getDefiningOp(); if (!expandShapeOp) { return rewriter.notifyMatchFailure( sliceOp, "tensor.extract_slice source not produced by expand_shape"); } SmallVector reassociation = expandShapeOp.getReassociationIndices(); SmallVector offsets, sizes, strides; if (failed(getCollapsedExtractSliceInfo(rewriter, sliceOp, reassociation, offsets, sizes, strides))) return failure(); // The shape of the result can be obtained from the sizes passed in. SmallVector expandedSizes = sliceOp.getMixedSizes(); RankedTensorType resultType = sliceOp.getResultType(); // Create a new ExtractSliceOp and ExpandShapeOp. Location loc = sliceOp.getLoc(); Value newSliceOp = tensor::ExtractSliceOp::create( rewriter, loc, expandShapeOp.getSrc(), offsets, sizes, strides); rewriter.replaceOpWithNewOp( sliceOp, resultType, newSliceOp, expandShapeOp.getReassociationIndices(), expandedSizes); return success(); } }; /// Converts `tensor.extract_slice(tensor.collapse_shape)` to /// `tensor.collapse_shape(tensor.extract_slice)`. /// /// For this transformation to be possible - after bubbling up, the extraction /// of the contiguous slice must be representable as a single slice obtained via /// tensor.extract_slice within each reassociation group of the src. /// /// In case the size and offset extracted are static then this is possible if /// the following conditions are met within each reassociation group: /// Let T be a tensor of shape [A0, A1, ..., An] (these are the sizes of the /// dimensions in the reassociation group), and let S = [S0, S1, ..., Sn] be the /// shape of a desired slice. A slice of shape S can be extracted as a /// contiguous span of elements if and only if there exists an index k in {0, 1, /// ..., n} such that: /// S_i = 1 for all i < k (that is, all leading dimensions are singleton), /// 1 <= S_k <= A_k (that is, non trivial slicing occurs along exactly /// one dimension), /// S_i = A_i for all i > k (that is, all trailing dimensions are preserved /// in full). /// In other words, the slice shape S must be of the form: /// [ 1, 1, ..., 1, Sk, Ak + 1, Ak + 2, ...,An ] /// /// In case the size and/or offset extracted are dynamic then this is possible /// only if there is single dimension in the reassociation group that has a size /// not equal to 1. /// In other words, the tensor shape must be of the form: /// [ 1, 1, ..., 1, A, 1, ...,1 ] /// Note - it might be possible to enable this pattern for more cases when the /// size/offset are dynamic via performing an analysis of the possible values /// that could be given to the size/offset. /// /// Example: /// The transformation is possible because each reassociation group can be /// represented as a contiguous slice (i.e., [8x16->2x16], [1x7->1x?], /// [20->10]). /// ``` /// BEFORE: /// %collapse = tensor.collapse_shape %src [[0, 1], [2, 3], [4]] ... /// tensor<8x16x1x7x20f32> to tensor<128x7x20xf32> /// %slice = tensor.extract_slice %slice [0, 0, 0][32, %size, 10][1, 1, 1] /// tensor<128x7x20xf32> to tensor<32x?x10xf32> /// /// AFTER: /// %slice = tensor.extract_slice %src [0, 0, 0, 0, 0][2, 16, 1, %size, 10] // [1, 1, 1, 1, 1] : tensor<8x16x1x7x20f32> to tensor<2x16x1x?x10xf32> /// %collapse = tensor.collapse_shape %slice [[0, 1], [2, 3], [4]] ... /// tensor<2x16x1x?x10xf32> to tensor<32x?x10xf32> /// ``` /// /// Negative example: /// The transformation is not possible because we cannot use a single slice to /// represent the reassociation group [2x3x10->???]. If we would want the /// collapse to be after the extraction, we would need to extract multiple /// slices and concat them together. /// ``` /// %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<2x3x10xf32> into /// tensor<60xf32> %extract = tensor.extract_slice %collapse[0][15][1] : /// tensor<60xf32> to tensor<15xf32> /// ``` /// If we would want the collapse to be after the extraction, a possible /// alternate transformation could be to extract multiple slices and concat them /// together: /// ``` /// %extract_1 = tensor.extract_slice %src[0, 0, 0][1, 1, 10] : /// tensor<2x3x10xf32> to tensor <1x1x10xf32> /// %extract_2 = tensor.extract_slice %src[0, 1, 0][1, 1, 5] : /// tensor<2x3x10xf32> to tensor <1x1x5xf32> /// %concat = tosa.concat %extract_1, %extract_2 {axis = 0 : i32} : /// (<1x1x10xf32>, <1x1x5xf32>) -> <1x1x15xf32> /// %collapse = tensor.collapse_shape %concat [[0, 1, 2]] : tensor<1x1x15xf32> /// to tensor<15xf32> /// ``` /// But this is not the intended purpose of the transformation. struct BubbleUpExtractSliceThroughCollapseShape : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const override { auto collapseShapeOp = sliceOp.getSource().getDefiningOp(); if (!collapseShapeOp) { return rewriter.notifyMatchFailure( sliceOp, "tensor.extract_slice source not produced by tensor.collapse_shape"); } SmallVector offsets, sizes, strides; if (failed(getExpandedExtractSliceInfo( rewriter, sliceOp, collapseShapeOp.getReassociationIndices(), collapseShapeOp.getSrcType().getShape(), offsets, sizes, strides))) return failure(); Value newSliceOp = tensor::ExtractSliceOp::create( rewriter, collapseShapeOp->getLoc(), collapseShapeOp.getSrc(), offsets, sizes, strides); rewriter.replaceOpWithNewOp( sliceOp, sliceOp.getResultType(), newSliceOp, collapseShapeOp.getReassociationIndices()); return success(); } }; } // namespace LogicalResult mlir::tensor::getCollapsedExtractSliceInfo( OpBuilder &b, tensor::ExtractSliceOp sliceOp, ArrayRef reassociation, SmallVectorImpl &collapsedOffsets, SmallVectorImpl &collapsedSizes, SmallVectorImpl &collapsedStrides) { if (!sliceOp.hasUnitStride()) { return failure(); } SmallVector offsets = sliceOp.getMixedOffsets(); SmallVector sizes = sliceOp.getMixedSizes(); if (static_cast(sliceOp.getResultType().getRank()) != sizes.size()) { return failure(); } auto isZeroOffsetAndFullSize = [&](OpFoldResult offset, OpFoldResult sliceSize, int64_t inputDim) { if (!isZeroInteger(offset)) return false; ValueBoundsConstraintSet::Variable inputSize(sliceOp.getSource(), inputDim); FailureOr maybeEqual = ValueBoundsConstraintSet::areEqual(sliceSize, inputSize); return llvm::succeeded(maybeEqual) && maybeEqual.value(); }; // Check that the slice is contiguous within each reassociation group. // The slice is contiguous only if after the first dimension where a non // unit slice is taken, the slice size on all subsequent dimensions of the // group is equal to the entire size of the dimension. // Examples of contiguous slices: // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 1, 10] // full sizes: [5, 10] slice offsets: [3, 0] slice sizes: [2, 10] // Examples of non contiguous slices: // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 2, 5] // full sizes: [5, 10] slice offsets: [0, 4] slice sizes: [2, 5] for (const ReassociationIndices &indices : reassociation) { int64_t i = 0; int64_t e = indices.size(); // Find the first expanded dim after the first dim with non-unit extracted // size. for (; i < e; ++i) { if (!isOneInteger(sizes[indices[i]])) { // +1 to skip the first non-unit size dim. i++; break; } } // Verify that all subsequent dimensions extract the full size of the // source tensor. for (; i < e; ++i) { int64_t expandedDim = indices[i]; if (!isZeroOffsetAndFullSize(offsets[expandedDim], sizes[expandedDim], expandedDim)) { return failure(); } } } // The tensor.extract_slice before applying the pattern works on the result // of the tensor.expand_shape, so variables (i.e. inputs for ExtractSliceOp) // referring to the state before applying the pattern are named with the // prefix "expanded", and ones referring to the state after applying the // pattern are named with the prefix "collapsed". Location loc = sliceOp.getLoc(); SmallVector expandedOffsets = sliceOp.getMixedOffsets(); SmallVector expandedSizes = sliceOp.getMixedSizes(); SmallVector expandedShape = getMixedSizes(b, loc, sliceOp.getSource()); // Helper variables and function for accumulating the size values. AffineExpr d0, d1, d2; bindDims(b.getContext(), d0, d1, d2); // Multiply two integers. auto mul = [&](OpFoldResult v1, OpFoldResult v2) { auto mulMap = AffineMap::get(2, 0, {d0 * d1}); return affine::makeComposedFoldedAffineApply(b, loc, mulMap, {v1, v2}); }; // Compute new offsets, sizes, and strides for tensor.extract_slice. // The new tensor.extract_slice will work on a tensor that has has a rank of // ReassociationIndices.size(). In the loop a single offset, size, and // stride value is computed per reassociation group. for (const ReassociationIndices &indices : reassociation) { // collapsedSize will hold the size of the single dim that represents the // reassociation group in the non expanded tensor. OpFoldResult collapsedSize = b.getIndexAttr(1); // The reassocGroupSizes and reassocGroupOffsets are used to create an // affine.linearize_index op to linearize the single offset value required // for this reassociation group. SmallVector reassocGroupSizes, reassocGroupOffsets; for (long expandedDim : indices) { // reassocGroupSizes and reassocGroupOffsets can be obtained directly // from the expanded state, but the collapsed size requires calculation // as it did not previously exist. reassocGroupSizes.push_back(expandedShape[expandedDim]); reassocGroupOffsets.push_back(expandedOffsets[expandedDim]); collapsedSize = mul(collapsedSize, expandedSizes[expandedDim]); } SmallVector offsetVals = llvm::map_to_vector(reassocGroupOffsets, [&](OpFoldResult ofr) { return getValueOrCreateConstantIndexOp(b, loc, ofr); }); OpFoldResult collapsedOffset = affine::AffineLinearizeIndexOp::create( b, loc, offsetVals, reassocGroupSizes, /*disjoint=*/true) .getResult(); collapsedOffsets.push_back(collapsedOffset); collapsedSizes.push_back(collapsedSize); // Only unit stride is supported. collapsedStrides.push_back(b.getIndexAttr(1)); } return success(); } LogicalResult mlir::tensor::getExpandedExtractSliceInfo( OpBuilder &b, tensor::ExtractSliceOp sliceOp, ArrayRef reassociation, ArrayRef expandedShape, SmallVectorImpl &expandedOffsets, SmallVectorImpl &expandedSizes, SmallVectorImpl &expandedStrides) { if (!sliceOp.hasUnitStride()) { return failure(); } // The tensor.extract_slice before applying the pattern works on the result // of the tensor.collapse_shape, so variables (i.e. inputs for // ExtractSliceOp) referring to the state before applying the pattern are // named with the prefix "collapsed", and ones referring to the state after // applying the pattern are named with the prefix "expanded". SmallVector collapsedOffsets = sliceOp.getMixedOffsets(); SmallVector collapsedSizes = sliceOp.getMixedSizes(); if (static_cast(sliceOp.getResultType().getRank()) != collapsedSizes.size()) { return failure(); } // Compute new offsets, sizes, and strides for tensor.extract_slice. // The new tensor.extract_slice will work on a tensor that has has a rank // equal to the rank of the src of the collapse_shape. In each iteration of // the loop, the offsets and sizes will be computed per reassociation group. expandedStrides.resize(expandedShape.size(), b.getIndexAttr(1)); for (auto [collapsedSize, collapsedOffset, reassocIndices] : llvm::zip_equal(collapsedSizes, collapsedOffsets, reassociation)) { // CASE #1 - size and/or offset are dynamic. // In this case, the slice can be represented as a contiguous slice only // if there is a single dimension in the reassociation group that has a // size not equal to 1. if (isa(collapsedSize) || isa(collapsedOffset)) { int nonUnitSizeCount = 0; for (int64_t expandedShapeIdx : reassocIndices) { if (expandedShape[expandedShapeIdx] != 1) { nonUnitSizeCount++; expandedSizes.push_back(collapsedSize); expandedOffsets.push_back(collapsedOffset); continue; } expandedSizes.push_back(b.getIndexAttr(1)); expandedOffsets.push_back(b.getIndexAttr(0)); } if (nonUnitSizeCount != 1) { return failure(); } continue; } // CASE #2 = size and offset are static. // Verify that the slice can be represented as a contiguous slice of the // src of the collapse_shape. // Checking this is done on order of most internal dimensions first, // so traversal is done in reverse order of the reassociation group. // If the expected slice shape is [1, 1, ..., 1, Sk, Ak + 1, Ak + 2, // ...,An] then we first find the size and offset for n...k+1 then for k // and then for k-1...0. // currentCollapsedsize and currentCollapsedOffset are initialized with // the original collapsed size and offset and divided by the expanded // shape size in each dimension as we go along the reassociation group. // In essence we are spreading the original collapsed size and offset over // the various expanded slice dimensions. // The variables are used both to check the validity of the slice and to // compute the expanded sizes and offsets. int64_t currentCollapsedsize = getConstantIntValue(collapsedSize).value(); int64_t currentCollapsedOffset = getConstantIntValue(collapsedOffset).value(); SmallVector groupExpandedSizes, groupExpandedOffsets; ReassociationIndices reversedReassocIndices(reassocIndices.rbegin(), reassocIndices.rend()); int64_t idx = 0; int64_t reassocGroupSize = reassocIndices.size(); // First handle the trailing dimensions where the slice size should be // equal to the tensor shape and the offset should be 0 (n...k+1). for (; idx < reassocGroupSize; ++idx) { int64_t expandedShapeSize = expandedShape[reversedReassocIndices[idx]]; if (currentCollapsedsize < expandedShapeSize) break; // We need to make sure that the slice size can be set to the shape size // and the offset to 0. if ((currentCollapsedsize % expandedShapeSize) != 0 || (currentCollapsedOffset % expandedShapeSize) != 0) { return failure(); } groupExpandedSizes.push_back(b.getIndexAttr(expandedShapeSize)); groupExpandedOffsets.push_back(b.getIndexAttr(0)); currentCollapsedsize /= expandedShapeSize; currentCollapsedOffset /= expandedShapeSize; } // Now handle the first dim where slicing occurs on (k). if (idx < reassocGroupSize) { int64_t expandedShapeSize = expandedShape[reversedReassocIndices[idx]]; int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize; // We need to make sure that the slice size in this dim + offset will // not exceed the shape size. if ((currentCollapsedsize + offsetInDim) >= expandedShapeSize) { return failure(); } groupExpandedSizes.push_back(b.getIndexAttr(currentCollapsedsize)); groupExpandedOffsets.push_back(b.getIndexAttr(offsetInDim)); currentCollapsedOffset /= expandedShapeSize; } // Now handle the leading dimensions where the slice size is equal to 1 // (k-1...0). // The size for these dimensions must be 1 because of how we constructed // the slice size of the expanded shape. We spread the original collapsed // size over the expanded shape sizes until we reached dimension k where // the remaining size was smaller than the expanded shape size, and spread // the remaining size on it. So, now we are left with only 1s. for (idx++; idx < reassocGroupSize; ++idx) { int64_t expandedShapeSize = expandedShape[reversedReassocIndices[idx]]; int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize; groupExpandedSizes.push_back(b.getIndexAttr(1)); groupExpandedOffsets.push_back(b.getIndexAttr(offsetInDim)); currentCollapsedOffset /= expandedShapeSize; } expandedSizes.append(groupExpandedSizes.rbegin(), groupExpandedSizes.rend()); expandedOffsets.append(groupExpandedOffsets.rbegin(), groupExpandedOffsets.rend()); } return success(); } void mlir::tensor::populateReassociativeReshapeFoldingPatterns( RewritePatternSet &patterns) { patterns .add, FoldInsertOfRankReducingInsert, FoldPaddingExpandIntoInsert, FoldPaddingExpandIntoInsert>( patterns.getContext()); } void mlir::tensor::populateBubbleUpExpandShapePatterns( RewritePatternSet &patterns) { patterns.add(patterns.getContext()); } void mlir::tensor::populateBubbleUpExtractSliceOpPatterns( RewritePatternSet &patterns) { patterns.add(patterns.getContext()); }