//===- ExtractSliceFromReshapeUtils.cpp - Slice reshape rewrites ----------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements rewrites that replace slices of reshape results with // aggregated slices of the reshape source. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Transforms/TransformUtils.h" #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpDefinition.h" #include "llvm/ADT/STLExtras.h" using namespace mlir; using namespace mlir::affine; using namespace mlir::tensor; /// A tuple that represents (dimension number, dimension value). using DimAndIndex = std::tuple; /// Transform `dimAndIndex` from the output index space of a (non-rank-reducing) /// slice described by `sliceParams` into the input index space. static DimAndIndex invertSliceIndexing(OpBuilder &b, Location loc, ArrayRef sliceParams, const DimAndIndex &dimAndIndex) { AffineExpr d0, s0, s1; bindDims(b.getContext(), d0); bindSymbols(b.getContext(), s0, s1); auto [dim, indexValue] = dimAndIndex; assert(dim < sliceParams.size() && "slice should be non rank-reducing"); return std::make_pair( dim, affine::makeComposedAffineApply( b, loc, s0 + d0 * s1, {indexValue, sliceParams[dim].offset, sliceParams[dim].stride})); } /// Transform `dimAndIndex` from the result tensor index space of a /// CollapseShapeOp to the source tensor index space. static ValueRange invertCollapseShapeIndexing( OpBuilder &b, Location loc, ArrayRef reassociation, ArrayRef reshapeSourceShape, const DimAndIndex &dimAndIndex) { const auto &[dim, indexValue] = dimAndIndex; SmallVector basis; for (int64_t i : reassociation[dim]) basis.push_back(reshapeSourceShape[i]); auto delinearized = AffineDelinearizeIndexOp::create(b, loc, indexValue, basis); return delinearized->getResults(); } FailureOr tensor::ExtractSliceFromCollapseHelper::create( OpBuilder &b, tensor::CollapseShapeOp collapseOp, tensor::ExtractSliceOp extractOp) { if (extractOp.getSource().getDefiningOp() != collapseOp) return failure(); SmallVector ranges; ranges.reserve(extractOp.getSourceType().getRank()); for (const auto &[o, s, st] : llvm::zip(extractOp.getMixedOffsets(), extractOp.getMixedSizes(), extractOp.getMixedStrides())) { ranges.push_back({o, s, st}); } return ExtractSliceFromCollapseHelper::create(b, collapseOp, ranges); } FailureOr tensor::ExtractSliceFromCollapseHelper::create(OpBuilder &b, tensor::CollapseShapeOp op, ArrayRef sliceParams) { // Don't perform this pattern if the collapse op can be simplified by // a rank-reducing extract slice. if (succeeded(mlir::getSimplifyCollapseShapeWithRankReducingSliceInfo( op.getSrcType(), op.getReassociationIndices()))) return failure(); // Materialize the output shape of the collapse_shape operation. This will // create IR describing the output shape in terms of the input shape. ReifiedRankedShapedTypeDims reifiedShapes; if (failed(reifyResultShapes(b, op, reifiedShapes))) return failure(); SmallVector &collapseShapeOutputShape = reifiedShapes[0]; SmallVector reassociationIndices = op.getReassociationIndices(); // Determine which of the CollapseShapeOp's result dimensions are sliced // and/or linearized. llvm::SmallBitVector linearizedDimensions = getLinearizedDimensions(reassociationIndices); llvm::SmallBitVector slicedDimensions = getSlicedDimensions(collapseShapeOutputShape, sliceParams); auto collapseShapeInputShape = tensor::getMixedSizes(b, op.getLoc(), op.getSrc()); SmallVector tileSizes; for (unsigned i = 0; i < sliceParams.size(); i++) { if (slicedDimensions[i] && linearizedDimensions[i]) tileSizes.push_back( getValueOrCreateConstantIndexOp(b, op.getLoc(), sliceParams[i].size)); } return ExtractSliceFromCollapseHelper( op, collapseShapeInputShape, collapseShapeOutputShape, sliceParams, linearizedDimensions, slicedDimensions, tileSizes); } std::pair> tensor::ExtractSliceFromCollapseHelper::emitLoopNestBody( OpBuilder &builder, Location loc, ValueRange tileInductionVars) { // Create the helper class for forming the slice parameters. const SmallVector reassociationIndices = collapseShapeOp.getReassociationIndices(); SliceFromCollapseHelper helper(reassociationIndices, collapseShapeInputShape, collapseShapeOutputShape, sliceParams); // Get the indices of the tiled dims (linearized by the collapse_shape // and sliced by the extract_slice) invert the index spaces // transformations. SmallVector multiIndices; unsigned loopIdx = 0; for (unsigned i = 0, e = linearizedDimensions.size(); i < e; i++) { if (linearizedDimensions[i] && slicedDimensions[i]) { DimAndIndex tb = invertSliceIndexing(builder, loc, sliceParams, std::make_tuple(i, tileInductionVars[loopIdx++])); multiIndices.push_back(invertCollapseShapeIndexing( builder, loc, reassociationIndices, collapseShapeInputShape, tb)); } } SmallVector extractParams = helper.getExtractSliceParams(builder.getContext(), multiIndices); Value subTileResult = tensor::ExtractSliceOp::create( builder, loc, collapseShapeOp.getSrc(), extractParams); SmallVector insertParams = helper.getInsertSliceParams(builder.getContext(), tileInductionVars); // Collapse the dimensions of the source slice back down. Value collapsedResult = tensor::CollapseShapeOp::create( builder, loc, subTileResult, reassociationIndices); return std::make_pair(collapsedResult, insertParams); } FailureOr tensor::simplifyCollapseShapeWithRankReducingExtractSlice( tensor::CollapseShapeOp op, RewriterBase &rewriter) { SmallVector reassociationIndices = op.getReassociationIndices(); RankedTensorType sourceType = op.getSrcType(); FailureOr info = getSimplifyCollapseShapeWithRankReducingSliceInfo(sourceType, reassociationIndices); if (failed(info)) return failure(); // Create the rank-reducing extract slice op. auto zero = rewriter.getIndexAttr(0); auto one = rewriter.getIndexAttr(1); SmallVector offsets(sourceType.getRank(), zero); SmallVector sizes = tensor::getMixedSizes(rewriter, op.getLoc(), op.getSrc()); SmallVector strides(sourceType.getRank(), one); auto sliceOp = tensor::ExtractSliceOp::create( rewriter, op.getLoc(), info->sliceResultType, op.getSrc(), offsets, sizes, strides); if (!info->newReassociationIndices.has_value()) { rewriter.replaceOp(op, sliceOp.getResult()); return sliceOp.getOperation(); } return rewriter .replaceOpWithNewOp( op, sliceOp.getResult(), *info->newReassociationIndices) .getOperation(); }