//===- Tiling.cpp - Implementation of tiling using TilingInterface -------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements the tiling using TilingInterface. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Analysis/TopologicalSortUtils.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/SCF/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/IR/Dominance.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/DestinationStyleOpInterface.h" #include "mlir/Interfaces/TilingInterface.h" #include "mlir/Rewrite/FrozenRewritePatternSet.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" #include #define DEBUG_TYPE "tile-using-interface" using namespace mlir; scf::SCFTilingOptions & scf::SCFTilingOptions::setTileSizes(ArrayRef ts) { assert(!tileSizeComputationFunction && "tile sizes already set"); auto tileSizes = llvm::to_vector(ts); tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { return tileSizes; }; return *this; } scf::SCFTilingOptions & scf::SCFTilingOptions::setNumThreads(ArrayRef nt) { assert(!numThreadsComputationFunction && "num tiles already set"); auto numThreads = llvm::to_vector(nt); numThreadsComputationFunction = [numThreads](OpBuilder &b, Operation *op) { return numThreads; }; return *this; } /// Helper method to adjust the interchange vector to match the iteration /// domain. static SmallVector fillInterchangeVector(ArrayRef interchangeVector, size_t iterationDomainSize) { SmallVector filledVector = llvm::to_vector(interchangeVector); if (filledVector.size() < iterationDomainSize) { auto range = llvm::seq(filledVector.size(), iterationDomainSize); filledVector.append(range.begin(), range.end()); } if (filledVector.size() > iterationDomainSize) filledVector.resize(iterationDomainSize); return filledVector; } //===----------------------------------------------------------------------===// // tileUsingSCF implementation. //===----------------------------------------------------------------------===// /// Verify the tile size options are set in a consistent manner. static LogicalResult verifyOptions(RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options) { // Specifying number of threads is only supported on `scf.forall` op. if (options.numThreadsComputationFunction && options.loopType != scf::SCFTilingOptions::LoopType::ForallOp) { return rewriter.notifyMatchFailure( loc, "number of threads can only by specified when loop type is " "set to use `scf.forall`"); } // If specified, check that the interchange vector is a permutation. if (!options.interchangeVector.empty()) { if (!isPermutationVector(options.interchangeVector)) { return rewriter.notifyMatchFailure( loc, "invalid interchange vector, not a permutation of the entire " "iteration space"); } } return success(); } /// Method to instantiate the tile sizes and/or number of threads specified /// by the user. static std::tuple, SmallVector> getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op, ArrayRef iterationDomain, const scf::SCFTilingOptions &options) { OpFoldResult zero = rewriter.getIndexAttr(0); SmallVector tileSizes, numThreads; size_t numLoops = iterationDomain.size(); // Check whether the number of tiles to use is specified. if (options.numThreadsComputationFunction) { numThreads = options.numThreadsComputationFunction(rewriter, op); numThreads.resize(numLoops, zero); // If the number of tiles is also specified, use that. if (options.tileSizeComputationFunction) { tileSizes = options.tileSizeComputationFunction(rewriter, op); tileSizes.resize(numLoops, zero); return {tileSizes, numThreads}; } // Compute the tile sizes from the iteration domain and number // of tiles as follows // - niters = ceilDiv(ub - lb, step) // - tileSize = ceilDiv(niters, numThreads) AffineExpr s0, s1, s2; bindSymbols(rewriter.getContext(), s0, s1, s2); // TODO: The step here is assumed to be 1. AffineExpr numItersExpr = (s1 - s0); AffineExpr tileSizeExpr = numItersExpr.ceilDiv(s2); tileSizes.resize(numLoops, zero); for (auto [index, range, nt] : llvm::enumerate(iterationDomain, numThreads)) { if (isZeroInteger(nt)) continue; tileSizes[index] = affine::makeComposedFoldedAffineApply( rewriter, op.getLoc(), tileSizeExpr, {range.offset, range.size, nt}); } tileSizes.resize(numLoops, zero); return {tileSizes, numThreads}; } // Enforce the convention that "tiling by zero" // skips tiling a particular dimension. This convention is significantly // simpler to handle instead of adjusting affine maps to account for missing // dimensions. assert(options.tileSizeComputationFunction && "expected tile sizes to be specified"); tileSizes = options.tileSizeComputationFunction(rewriter, op); tileSizes.resize(numLoops, zero); return {tileSizes, numThreads}; } /// Checks if any of the tiled loops are not parallel. static LogicalResult checkTileSizes(TilingInterface op, scf::SCFTilingOptions::LoopType loopType, ReductionTilingStrategy reductionStrategy, ArrayRef givenTileSizes, ArrayRef numThreads) { auto iterators = op.getLoopIteratorTypes(); assert(iterators.size() == givenTileSizes.size() && "expected as many tile size values as number of loops"); assert((numThreads.empty() || (numThreads.size() == iterators.size())) && "when specified, expected number of threads to use for each loop"); bool isParallelTiling = false; for (auto [index, iterator, givenTileSize] : llvm::enumerate(iterators, givenTileSizes)) { if (!isConstantIntValue(givenTileSize, 0)) { isParallelTiling |= iterator == utils::IteratorType::parallel; } if (loopType == scf::SCFTilingOptions::LoopType::ForallOp && reductionStrategy == ReductionTilingStrategy::FullReduction) { // If num threads is specified, check that it is greater than one only for // parallel dimensions. if (!numThreads.empty()) { if (std::optional constNumThreads = getConstantIntValue(numThreads[index])) { if (constNumThreads.value() > 1 && iterator != utils::IteratorType::parallel) { op.emitWarning() << "tiling is not thread safe at axis #" << index; } } continue; } if (std::optional constTileSize = getConstantIntValue(givenTileSize)) { if (constTileSize.value() > 0 && iterator != utils::IteratorType::parallel) { op.emitWarning() << "tiling is not thread safe at axis #" << index; } } } } if (reductionStrategy != ReductionTilingStrategy::FullReduction) { if (isParallelTiling) { return op->emitOpError("tiling parallel dimensions is not supported with " "partial reduction tiling strategies"); } } return success(); } /// Get the reduction dims that are tiled. This accounts for reduction dims /// that are specified as tiled, but the tile size is 0. static SetVector getSanitizedReductionDims(ArrayRef givenTileSizes, const scf::SCFTilingOptions &options) { SetVector reductionDims; for (auto dim : options.reductionDims) { if (isConstantIntValue(givenTileSizes[dim], 0)) continue; reductionDims.insert(dim); } return reductionDims; } /// Check if `stride` evenly divides the trip count `size - offset`. static bool tileDividesIterationDomain(Range loopRange) { std::optional offsetAsInt = getConstantIntValue(loopRange.offset); if (!offsetAsInt) return false; std::optional sizeAsInt = getConstantIntValue(loopRange.size); if (!sizeAsInt) return false; std::optional strideAsInt = getConstantIntValue(loopRange.stride); if (!strideAsInt) return false; return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0); } /// Returns the bounded tile size given the current `offset`, `loopRange` and /// `tileSize`, i.e., `min(tileSize, range.end() - offset)`. static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc, Range loopRange, OpFoldResult offset, OpFoldResult givenTileSize) { std::optional ts = getConstantIntValue(givenTileSize); if (ts && ts.value() == 1) return givenTileSize; if (tileDividesIterationDomain( Range{loopRange.offset, loopRange.size, givenTileSize})) return givenTileSize; // The tile size to use (to avoid out of bounds access) is minimum of // `tileSize` and `ub - iv`, where `iv` is the induction variable of the tiled // loop. AffineExpr s0, s1, d0; bindDims(b.getContext(), d0); bindSymbols(b.getContext(), s0, s1); AffineMap minMap = AffineMap::get(1, 2, {s0 - d0, s1}, b.getContext()); Value size = getValueOrCreateConstantIndexOp(b, loc, loopRange.size); return affine::makeComposedFoldedAffineMin( b, loc, minMap, SmallVector{offset, size, givenTileSize}); } /// Returns true if the maximum tile offset `tileSize * numThreads-1` is less /// than `iterationSize`. static bool canOmitTileOffsetInBoundsCheck(OpFoldResult givenTileSize, OpFoldResult numThreads, OpFoldResult iterationSize) { std::optional tileSizeConst = getConstantIntValue(givenTileSize); std::optional numThreadsConst = getConstantIntValue(numThreads); std::optional iterSizeConst = getConstantIntValue(iterationSize); if (!tileSizeConst || !numThreadsConst || !iterSizeConst) return false; return *tileSizeConst * (*numThreadsConst - 1) < *iterSizeConst; } /// Compute the `OpFoldResult`s that represents the multi-dimensional /// `offset`s and `size`s of the tile of the iteration space that the /// innermost loop body of the generated tiled loops corresponds to. static std::tuple, SmallVector> getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs, ArrayRef iterationDomain, ArrayRef givenTileSizes) { SmallVector offsets, sizes; int materializedLoopNum = 0; for (auto [givenTileSize, loopRange] : llvm::zip_equal(givenTileSizes, iterationDomain)) { // Non-tiled cases, set the offset and size to the // `loopRange.offset/size`. if (isZeroInteger(givenTileSize)) { offsets.push_back(loopRange.offset); sizes.push_back(loopRange.size); continue; } Value iv = ivs[materializedLoopNum++]; OpFoldResult offset = getAsOpFoldResult(iv); offsets.push_back(offset); OpFoldResult size = getBoundedTileSize(rewriter, loc, loopRange, offset, givenTileSize); sizes.push_back(size); } return {offsets, sizes}; } /// Function to return the bounds of the loops to be generated. static std::tuple, SmallVector, SmallVector> getLoopBounds(RewriterBase &rewriter, Location loc, ArrayRef loopRanges, ArrayRef givenTileSizes) { SmallVector lbs, ubs, steps; for (auto [loopRange, givenTileSize] : llvm::zip_equal(loopRanges, givenTileSizes)) { // No loop if the tile size is 0. if (isZeroInteger(givenTileSize)) continue; lbs.push_back(loopRange.offset); ubs.push_back(loopRange.size); steps.push_back(givenTileSize); } return {lbs, ubs, steps}; } /// Typedef for function that allows returning additional yielded values during /// `yieldTiledValuesAndReplace`. /// - `ivs` induction variable for the loop. /// - `newBbArgs` basic block arguments corresponding to newly added iter_args. /// - `tiledValues` the tiled values to return. Must be of same size as /// `newbbArgs`, each element of this array is inserted into the corresponding /// element in `newbbArgs`. /// - `resultOffsets` is of the same size as `tiledValues` and represents /// the offsets to use when inserting corresponding element from `tiledValues` /// into the element from `newBbArgs`. /// - `resultSizes` is of the same size as `tiledValues` and represents /// the size of the corresponding element from `tiledValues` inserted into /// the element from `newBbArgs`. /// In case the method needs to return `failure()` the method is expected /// to clean up any inserted operations. using YieldTiledValuesFn = std::function &tiledValues, SmallVector> &resultOffsets, SmallVector> &resultSizes)>; /// Typedef for function that implements the body of a tiled loop. /// - `ivs` induction variable for the loop. /// - `tileOffsets` represents offsets for the tiled iteration space. /// - `tileSizes` represents the sizes for the tiled iteraiton space. /// - `outerDestinationTensors` tensor that holds the result. Is same size /// as the destination operands of the original operations. /// - `tiledResults` results of the tiled computation, corresponds to /// tiles of the original operation computed by the loop body. /// Should be same size as the `destinationTensors` /// - `resultOffsets` is of the same size as `tiledResults` and represents /// the offset to use when writing the corresponding element from /// `tiledResults` into `destinationTensors`. /// - `resultOffsets` is of the same size as `tiledResults` and represents /// the size to use when writing the corresponding element from /// `tiledResults` into `destinationTensors`. /// In case the method needs to return `failure()` the method is expected /// to clean up any inserted operations. using GenerateTiledBodyFn = std::function tileOffsets, ArrayRef tileSizes, ValueRange outerDestinationTensors, SmallVector &tiledResults, SmallVector> &resultOffsets, SmallVector> &resultSizes)>; /// Clones the operation and updates the destination if the operation /// implements the `DestinationStyleOpInterface`. static Operation *cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter, Operation *op, ValueRange newDestArgs) { Operation *clonedOp = rewriter.clone(*op); if (newDestArgs.empty()) return clonedOp; if (auto destinationStyleOp = dyn_cast(clonedOp)) destinationStyleOp.getDpsInitsMutable().assign(newDestArgs); return clonedOp; } /// Generate the tile-loop nest using `scf.for` operation. /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. /// - `givenTileSizes` is the tile sizes to use. Zero represent untiled loops. /// - `outerDestinationTensors` are the init values to use for the outer most /// loop. /// - `tiledBodyFn` is called to generated the loop body of the inner /// most /// loop. /// Returns the generated `scf.for` loops on success. static FailureOr> generateLoopNestUsingForOp( RewriterBase &rewriter, Location loc, ArrayRef loopRanges, ArrayRef givenTileSizes, ValueRange outerDestinationTensors, GenerateTiledBodyFn tiledBodyFn) { assert(!loopRanges.empty() && "unexpected empty loop ranges"); assert(loopRanges.size() == givenTileSizes.size() && "expected as many tile sizes as loop ranges"); OpBuilder::InsertionGuard guard(rewriter); SmallVector lbs, ubs, steps; std::tie(lbs, ubs, steps) = getLoopBounds(rewriter, loc, loopRanges, givenTileSizes); SmallVector lbVals = getValueOrCreateConstantIndexOp(rewriter, loc, lbs); SmallVector ubVals = getValueOrCreateConstantIndexOp(rewriter, loc, ubs); SmallVector stepVals = getValueOrCreateConstantIndexOp(rewriter, loc, steps); SmallVector ivs; SmallVector loops; ValueRange innerDestinationTensors(outerDestinationTensors); for (auto [lb, ub, step] : llvm::zip_equal(lbVals, ubVals, stepVals)) { auto loop = scf::ForOp::create(rewriter, loc, lb, ub, step, innerDestinationTensors, [](OpBuilder &bodyBuilder, Location bodyLoc, Value iv, ValueRange /*iterArgs*/) {}); loops.push_back(loop); ivs.push_back(loop.getInductionVar()); rewriter.setInsertionPointToEnd(loop.getBody()); innerDestinationTensors = loop.getRegionIterArgs(); } if (loops.empty()) return success(); // Compute the `offsets` and `sizes` to use for tiling. SmallVector offsets, sizes; std::tie(offsets, sizes) = getTileOffsetAndSizes(rewriter, loc, ivs, loopRanges, givenTileSizes); SmallVector tiledResults; SmallVector> resultOffsets, resultSizes; if (failed(tiledBodyFn(rewriter, loc, ivs, offsets, sizes, innerDestinationTensors, tiledResults, resultOffsets, resultSizes))) { return rewriter.notifyMatchFailure( loc, "failed to generate inner tile loop body"); } if (loops.empty()) return loops; assert(tiledResults.size() == innerDestinationTensors.size() && "Number of results of body should be equal to number of iter args"); // 6. Yield all the results of the tiled operation. SmallVector yieldedValues; for (auto [tiledValue, destinationTensor, resultOffset, resultSize] : llvm::zip_equal(tiledResults, innerDestinationTensors, resultOffsets, resultSizes)) { SmallVector resultStride(resultOffset.size(), rewriter.getIndexAttr(1)); auto insertSlice = tensor::InsertSliceOp::create( rewriter, loc, tiledValue, destinationTensor, resultOffset, resultSize, resultStride); yieldedValues.push_back(insertSlice); } scf::YieldOp::create(rewriter, loc, yieldedValues); // Add the scf.yield operations for all the outer loops. for (auto [outerLoop, innerLoop] : llvm::zip_equal(MutableArrayRef(loops).drop_back(), MutableArrayRef(loops).drop_front())) { rewriter.setInsertionPointToEnd( cast(outerLoop.getOperation()).getBody()); scf::YieldOp::create(rewriter, outerLoop.getLoc(), innerLoop->getResults()); } return loops; } /// Compute the `OpFoldResult`s that represents the multi-dimensional /// `offset`s and `size`s of the tile of the iteration space that the /// innermost loop body of the generated tiled loops corresponds to /// when tiling using `forall` op. This is handle separately due to /// the special case handling needed for when the tiling is done by /// specifying number of threads. static std::tuple, SmallVector> getTileOffsetAndSizesWithForAllOp(RewriterBase &rewriter, Location loc, ValueRange ivs, ArrayRef iterationDomain, ArrayRef givenTileSizes, ArrayRef numThreads) { if (numThreads.empty()) { return getTileOffsetAndSizes(rewriter, loc, ivs, iterationDomain, givenTileSizes); } SmallVector offsets, sizes; int materializedLoopNum = 0; AffineExpr d0, d1, s0, s1; AffineExpr offsetExpr, residualTileSizeExpr; bindDims(rewriter.getContext(), d0, d1); bindSymbols(rewriter.getContext(), s0, s1); offsetExpr = d0 + d1 * s0; residualTileSizeExpr = s1 - (d0 + d1 * s0); for (auto [index, nt, givenTileSize, loopRange] : llvm::enumerate(numThreads, givenTileSizes, iterationDomain)) { // Non-tiled cases, set the offset and size to the // `loopRange.offset/size`. if (isZeroInteger(nt)) { offsets.push_back(loopRange.offset); sizes.push_back(loopRange.size); continue; } Value iv = ivs[materializedLoopNum++]; OpFoldResult offset = affine::makeComposedFoldedAffineApply( rewriter, loc, offsetExpr, ArrayRef{loopRange.offset, iv, givenTileSize}); OpFoldResult residualTileSize = affine::makeComposedFoldedAffineApply( rewriter, loc, residualTileSizeExpr, {loopRange.offset, nt, givenTileSize, loopRange.size}); OpFoldResult size = givenTileSize; if (!isZeroInteger(residualTileSize)) { OpFoldResult sizeMinusOffsetPerThread = affine::makeComposedFoldedAffineApply(rewriter, loc, s0 - d0, {offset, loopRange.size}); size = affine::makeComposedFoldedAffineMin( rewriter, loc, AffineMap::getMultiDimIdentityMap(2, rewriter.getContext()), {sizeMinusOffsetPerThread, givenTileSize}); } // Consider the case where the original loop was `[0, 100)`. // If number of threads are `7`, the tile size would be computed as // `ceilDiv(100, 7) = 15`. For the last thread (thread_id = 6) // - `offset = 0 + 6 * 15 = 105` // - `tileSize = min(15, 100 - 105) = -5` // To avoid negative tile sizes, we need to do a further // `nonNegativeTileSize = affine.max(0, tileSize)`. // This `max` can be avoided if // `offset + tileSize * (numThreads - 1) < (ub - lb)` if (!canOmitTileOffsetInBoundsCheck(givenTileSize, nt, loopRange.size)) { AffineMap maxMap = AffineMap::getMultiDimIdentityMap(2, rewriter.getContext()); size = affine::makeComposedFoldedAffineMax( rewriter, loc, maxMap, {rewriter.getIndexAttr(0), size}); } offsets.push_back(offset); sizes.push_back(size); } return {offsets, sizes}; } /// Generate the tile-loop nest using `scf.forall` operation. /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. /// - `giventileSizes` is the tile sizes to use. Zero represent untiled loops. /// - `outerDestinationTensors` are the init values to use for the loop. /// - `mappingVector` is the mapping attributes to use for loop construction. /// Can be empty. /// - `tiledBodyFn` is called to generated the loop body of the inner /// most /// loop. /// Returns the generated `scf.forall` loop on success. static FailureOr> generateLoopNestUsingForallOp(RewriterBase &rewriter, Location loc, ArrayRef loopRanges, ArrayRef givenTileSizes, ArrayRef numThreads, ArrayRef mappingVector, ValueRange outerDestinationTensors, GenerateTiledBodyFn tiledBodyFn) { assert(!loopRanges.empty() && "unexpected empty loop ranges"); assert(loopRanges.size() == givenTileSizes.size() && "expected as many tile sizes as loop ranges"); OpBuilder::InsertionGuard guard(rewriter); std::optional mappingAttr; if (!mappingVector.empty()) mappingAttr = rewriter.getArrayAttr(mappingVector); scf::ForallOp forallOp; bool useNumThreads = !numThreads.empty(); SmallVector loops; if (useNumThreads) { // Prune the zero numthreads. SmallVector nonZeroNumThreads; for (auto nt : numThreads) { if (isZeroInteger(nt)) continue; nonZeroNumThreads.push_back(nt); } forallOp = scf::ForallOp::create(rewriter, loc, nonZeroNumThreads, outerDestinationTensors, mappingAttr); } else { SmallVector lbs, ubs, steps; std::tie(lbs, ubs, steps) = getLoopBounds(rewriter, loc, loopRanges, givenTileSizes); forallOp = scf::ForallOp::create(rewriter, loc, lbs, ubs, steps, outerDestinationTensors, mappingAttr); } loops.push_back(forallOp); rewriter.setInsertionPoint(forallOp.getTerminator()); ValueRange innerDestinationTensors = forallOp.getRegionOutArgs(); SmallVector ivs = forallOp.getInductionVars(); // Compute the `offsets` and `sizes` to use for tiling. SmallVector offsets, sizes; std::tie(offsets, sizes) = getTileOffsetAndSizesWithForAllOp( rewriter, loc, ivs, loopRanges, givenTileSizes, numThreads); SmallVector tiledResults; SmallVector> resultOffsets, resultSizes; if (failed(tiledBodyFn(rewriter, loc, ivs, offsets, sizes, innerDestinationTensors, tiledResults, resultOffsets, resultSizes))) return rewriter.notifyMatchFailure(loc, "failed to generate loop body"); rewriter.setInsertionPointToEnd(forallOp.getTerminator().getBody()); for (auto [tiledValue, destinationTensor, resultOffset, resultSize] : llvm::zip_equal(tiledResults, innerDestinationTensors, resultOffsets, resultSizes)) { SmallVector resultStride(resultOffset.size(), rewriter.getIndexAttr(1)); tensor::ParallelInsertSliceOp::create(rewriter, loc, tiledValue, destinationTensor, resultOffset, resultSize, resultStride); } return loops; } /// Generate the tile-loop nest using custom loop operation. /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops. /// - `destinationTensors` are the init values to use for the outer most loop. /// - `mappingVector` is the mapping attributes to use for loop construction. /// Can be empty. /// - `tiledBodyFn` is called to generated the loop body of the inner /// most /// loop. /// Returns the generated `scf.forall` loop on success. static FailureOr> generateLoopNestUsingCustomOp( RewriterBase &rewriter, Location loc, ArrayRef loopRanges, ArrayRef givenTileSizes, ValueRange outerDestinationTensors, const scf::SCFTilingOptions::GenerateLoopHeaderFn &generateLoopHeaderFn, const scf::SCFTilingOptions::GenerateLoopTerminatorFn &generateLoopTerminatorFn, GenerateTiledBodyFn tiledBodyFn) { assert(!loopRanges.empty() && "unexpected empty loop ranges"); assert(loopRanges.size() == givenTileSizes.size() && "expected as many tile sizes as loop ranges"); assert(generateLoopHeaderFn && generateLoopTerminatorFn && "expected loop header/terminator generation function"); OpBuilder::InsertionGuard guard(rewriter); FailureOr loopHeaderInfo = generateLoopHeaderFn(rewriter, loc, loopRanges, givenTileSizes, outerDestinationTensors); if (failed(loopHeaderInfo)) { return failure(); } SmallVector ivs; SmallVector tiledResults; SmallVector> resultOffsets, resultSizes; if (failed(tiledBodyFn(rewriter, loc, ivs, loopHeaderInfo->tileOffset, loopHeaderInfo->tileSizes, loopHeaderInfo->destinationTensors, tiledResults, resultOffsets, resultSizes))) { return failure(); } if (failed(generateLoopTerminatorFn(rewriter, loc, tiledResults, resultOffsets, resultSizes, loopHeaderInfo->destinationTensors))) { return failure(); } return loopHeaderInfo->loops; } /// Generate the tile-loop nest using the loop construct specifed in `options`. /// - `options`: Tiling options specified. /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops. /// - `outerDestinationTensors` are the init values to use for the outer most /// loop. /// - `yieldTiledValuesFn` is called to generated the loop body of the inner /// most /// loop. /// Returns the generated loops on success. static FailureOr> generateLoopNest( RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options, ArrayRef loopRanges, ArrayRef givenTileSizes, ArrayRef numThreads, ValueRange destinationTensors, GenerateTiledBodyFn tiledBodyFn) { // If the tile sizes are all zero, no loops are generated. Just call the // callback function to handle untiled case. if (llvm::all_of(givenTileSizes, isZeroInteger)) { SmallVector tiledResults; SmallVector> resultOffsets, resultSizes; auto tileOffsets = llvm::map_to_vector(loopRanges, [](Range r) { return r.offset; }); auto tileSizes = llvm::map_to_vector(loopRanges, [](Range r) { return r.size; }); if (failed(tiledBodyFn(rewriter, loc, ValueRange{}, tileOffsets, tileSizes, destinationTensors, tiledResults, resultOffsets, resultSizes))) { return failure(); } return SmallVector{}; } if (options.loopType == scf::SCFTilingOptions::LoopType::ForOp) { return generateLoopNestUsingForOp(rewriter, loc, loopRanges, givenTileSizes, destinationTensors, tiledBodyFn); } if (options.loopType == scf::SCFTilingOptions::LoopType::ForallOp) { return generateLoopNestUsingForallOp( rewriter, loc, loopRanges, givenTileSizes, numThreads, options.mappingVector, destinationTensors, tiledBodyFn); } if (options.loopType == scf::SCFTilingOptions::LoopType::CustomOp) { return generateLoopNestUsingCustomOp( rewriter, loc, loopRanges, givenTileSizes, destinationTensors, options.generateLoopHeaderFn, options.generateLoopTerminatorFn, tiledBodyFn); } return rewriter.notifyMatchFailure(loc, "unhandled loop type"); } static FailureOr> createInitialTensorsForTiling( RewriterBase &rewriter, TilingInterface op, ReductionTilingStrategy reductionStrategy, ArrayRef iterationDomain, ArrayRef numThreads, ArrayRef givenTileSizes, const SetVector &reductionDims) { SmallVector initTensors; Location loc = op->getLoc(); if (reductionStrategy == ReductionTilingStrategy::FullReduction) { if (failed(tensor::getOrCreateDestinations(rewriter, loc, op, initTensors))) return failure(); return initTensors; } auto redOp = dyn_cast(op.getOperation()); if (!redOp) { return op->emitOpError( "PartialReductionOuterReduction tiling strategy is only supported for " "operations implementing PartialReductionOpInterface"); } SmallVector sizes(iterationDomain.size()); AffineExpr s0, s1, s2; bindSymbols(rewriter.getContext(), s0, s1, s2); AffineExpr sizeExpr = ((s0 - s1).ceilDiv(s2)); AffineExpr divExpr = s0.ceilDiv(s1); for (auto [index, domain, tileSize] : llvm::enumerate(iterationDomain, givenTileSizes)) { if (!numThreads.empty()) { // Untiled case. if (isConstantIntValue(numThreads[index], 0)) { sizes[index] = affine::makeComposedFoldedAffineApply( rewriter, op.getLoc(), sizeExpr, {domain.size, domain.offset, domain.stride}); continue; } sizes[index] = numThreads[index]; continue; } // Non reduction dimensions/non-tiled dimensions. if (!reductionDims.contains(index) || isConstantIntValue(tileSize, 0)) { sizes[index] = affine::makeComposedFoldedAffineApply( rewriter, op.getLoc(), sizeExpr, {domain.size, domain.offset, domain.stride}); continue; } if (reductionStrategy == ReductionTilingStrategy::PartialReductionOuterReduction) { sizes[index] = tileSize; continue; } assert(reductionStrategy == ReductionTilingStrategy::PartialReductionOuterParallel); OpFoldResult normalizedRange = affine::makeComposedFoldedAffineApply( rewriter, op.getLoc(), sizeExpr, {domain.size, domain.offset, domain.stride}); sizes[index] = affine::makeComposedFoldedAffineApply( rewriter, op.getLoc(), divExpr, {normalizedRange, tileSize}); } return redOp.generateInitialTensorForPartialReduction(rewriter, loc, sizes, reductionDims); } /// For the case of `ReductionTilingStrategy::PartialReductionOuterParallel` /// the `PartialReductionOpInterface` methods need the index of the parallel /// split reduction being executed. static SmallVector getSplitReductionIvs(RewriterBase &rewriter, Location loc, ReductionTilingStrategy reductionStrategy, ValueRange ivs, ArrayRef numThreads, ArrayRef givenTileSizes, const SetVector &reductionDims) { SmallVector splitReductionIvs; splitReductionIvs.resize(reductionDims.size(), rewriter.getIndexAttr(0)); AffineExpr s0, s1; bindSymbols(rewriter.getContext(), s0, s1); AffineExpr divExpr = s0.floorDiv(s1); int ivIndex = 0; if (reductionStrategy == ReductionTilingStrategy::PartialReductionOuterParallel) { for (auto [index, reductionDim] : llvm::enumerate(reductionDims)) { if (!numThreads.empty()) { splitReductionIvs[index] = ivs[ivIndex++]; continue; } splitReductionIvs[index] = affine::makeComposedFoldedAffineApply( rewriter, loc, divExpr, ArrayRef{ivs[ivIndex++], givenTileSizes[reductionDim]}); } } return splitReductionIvs; } static FailureOr getTiledImplementation(RewriterBase &rewriter, TilingInterface op, ReductionTilingStrategy reductionStrategy, ValueRange regionIterArg, ArrayRef offsets, ArrayRef sizes, ValueRange ivs, ArrayRef numThreads, ArrayRef givenTileSizes, const SetVector &reductionDims) { if (reductionStrategy == ReductionTilingStrategy::FullReduction) { return op.getTiledImplementation(rewriter, offsets, sizes); } auto redOp = dyn_cast(op.getOperation()); if (!redOp) { return rewriter.notifyMatchFailure( op, "PartialReductionOuterReduction tiling strategy is only " "supported for operations " "implementing PartialReductionOpInterface"); } SmallVector splitReductionIvs = getSplitReductionIvs(rewriter, op.getLoc(), reductionStrategy, ivs, numThreads, givenTileSizes, reductionDims); return redOp.tileToPartialReduction(rewriter, op.getLoc(), reductionStrategy, regionIterArg, offsets, sizes, reductionDims, splitReductionIvs); } static LogicalResult getResultTilePosition( RewriterBase &rewriter, ReductionTilingStrategy reductionStrategy, int64_t index, Value tiledResult, TilingInterface op, ArrayRef offsets, ArrayRef sizes, ValueRange ivs, ArrayRef numThreads, ArrayRef givenTileSizes, const SetVector &reductionDims, SmallVector &resultOffset, SmallVector &resultSize) { if (reductionStrategy == ReductionTilingStrategy::FullReduction) { return op.getResultTilePosition(rewriter, index, offsets, sizes, resultOffset, resultSize); } auto redOp = dyn_cast(op.getOperation()); if (!redOp) { return rewriter.notifyMatchFailure( op, "PartialReductionOuterReduction tiling strategy is only supported" "for operations implementing PartialReductionOpInterface"); } SmallVector splitReductionIvs = getSplitReductionIvs(rewriter, op.getLoc(), reductionStrategy, ivs, numThreads, givenTileSizes, reductionDims); return redOp.getPartialResultTilePosition( rewriter, index, reductionStrategy, offsets, sizes, reductionDims, splitReductionIvs, resultOffset, resultSize); } static FailureOr mergeTilingResults(RewriterBase &rewriter, TilingInterface op, ReductionTilingStrategy reductionStrategy, const SetVector &reductionDims, ValueRange partialResults) { assert(reductionStrategy != ReductionTilingStrategy::FullReduction && "expected merge to be called for only partial reduction cases"); auto redOp = dyn_cast(op.getOperation()); if (!redOp) { return rewriter.notifyMatchFailure( op, "PartialReductionOuterReduction tiling strategy is only " "supported for operations " "implementing PartialReductionOpInterface"); } return redOp.mergeReductions(rewriter, op.getLoc(), partialResults, reductionDims); } /// Append the specified additional `newInitOperands` operands to the /// loops existing `init` operands (or similar), and replace `loopOp` with /// the new loop that has the additional init operands. The loop body of /// this loop is moved over to the new loop. `yieldTiledValuesFn` /// is called to get the new tiled values returned, and the offset /// and sizes at which the tiled value is inserted into the /// new region iter_args that correspond to the newly added init operands. template FailureOr yieldTiledValuesAndReplaceLoop(LoopType loopOp, RewriterBase &rewriter, ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn) { return rewriter.notifyMatchFailure(loopOp, "unhandled loop type"); } /// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.for`. template <> FailureOr yieldTiledValuesAndReplaceLoop( scf::ForOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn) { OpBuilder::InsertionGuard g(rewriter); Location loc = loopOp.getLoc(); rewriter.setInsertionPoint(loopOp); auto inits = llvm::to_vector(loopOp.getInitArgs()); inits.append(newInitOperands.begin(), newInitOperands.end()); auto newLoop = scf::ForOp::create( rewriter, loc, loopOp.getLowerBound(), loopOp.getUpperBound(), loopOp.getStep(), inits, [](OpBuilder &, Location, Value, ValueRange) {}, loopOp.getUnsignedCmp()); // Move the loop body to the new op. Block *loopBody = loopOp.getBody(); Block *newLoopBody = newLoop.getBody(); rewriter.mergeBlocks( loopBody, newLoopBody, newLoopBody->getArguments().take_front(loopBody->getNumArguments())); auto yieldOp = cast(newLoopBody->getTerminator()); rewriter.setInsertionPoint(yieldOp); SmallVector tiledValues; SmallVector> resultOffsets, resultSizes; ValueRange newRegionIterArgs = newLoop.getRegionIterArgs().take_back(newInitOperands.size()); if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVar(), newRegionIterArgs, tiledValues, resultOffsets, resultSizes))) { rewriter.eraseOp(newLoop); return rewriter.notifyMatchFailure(loopOp, "failed to get tiled values"); } SmallVector newYieldValues = llvm::to_vector(yieldOp.getOperands()); for (auto [tiledValue, regionIterArg, resultOffset, resultSize] : llvm::zip_equal(tiledValues, newRegionIterArgs, resultOffsets, resultSizes)) { SmallVector resultStride(resultOffset.size(), rewriter.getIndexAttr(1)); Value insert = tensor::InsertSliceOp::create( rewriter, yieldOp->getLoc(), tiledValue, regionIterArg, resultOffset, resultSize, resultStride); newYieldValues.push_back(insert); } rewriter.replaceOpWithNewOp(yieldOp, newYieldValues); rewriter.replaceOp(loopOp, newLoop->getResults().take_front(loopOp.getNumResults())); return cast(newLoop.getOperation()); } /// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.forall` template <> FailureOr yieldTiledValuesAndReplaceLoop( scf::ForallOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn) { OpBuilder::InsertionGuard g(rewriter); Location loc = loopOp.getLoc(); rewriter.setInsertionPoint(loopOp); auto inits = llvm::to_vector(loopOp.getOutputs()); inits.append(newInitOperands.begin(), newInitOperands.end()); auto newLoop = scf::ForallOp::create( rewriter, loc, loopOp.getMixedLowerBound(), loopOp.getMixedUpperBound(), loopOp.getMixedStep(), inits, loopOp.getMapping(), [](OpBuilder &, Location, ValueRange) {}); // Move the region of the current block to the newly created op. Block *loopBody = loopOp.getBody(); Block *newLoopBody = newLoop.getBody(); rewriter.mergeBlocks( loopBody, newLoopBody, newLoopBody->getArguments().take_front(loopBody->getNumArguments())); auto terminator = cast(newLoopBody->getTerminator()); rewriter.setInsertionPoint(terminator); SmallVector tiledValues; SmallVector> resultOffsets, resultSizes; ValueRange regionIterArgs = newLoop.getRegionIterArgs().take_back(newInitOperands.size()); if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVars(), regionIterArgs, tiledValues, resultOffsets, resultSizes))) { rewriter.eraseOp(newLoop); return rewriter.notifyMatchFailure(loopOp, "failed to get yielded tiled values"); } // Update the terminator. rewriter.setInsertionPointToEnd(terminator.getBody()); for (auto [tiledValue, iterArg, resultOffset, resultSize] : llvm::zip_equal( tiledValues, regionIterArgs, resultOffsets, resultSizes)) { SmallVector resultStride(resultOffset.size(), rewriter.getIndexAttr(1)); tensor::ParallelInsertSliceOp::create(rewriter, terminator.getLoc(), tiledValue, iterArg, resultOffset, resultSize, resultStride); } rewriter.replaceOp(loopOp, newLoop->getResults().take_front(loopOp.getNumResults())); return cast(newLoop.getOperation()); } /// Implementation of `yieldTiledValuesAndReplaceLoop` for /// `LoopLikeOpInterface`, that just dispatches to the implementation for each /// supported loop type. FailureOr yieldTiledValuesAndReplaceLoop( LoopLikeOpInterface loopLikeOp, RewriterBase &rewriter, ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn) { return TypeSwitch>( loopLikeOp.getOperation()) .Case( [&](auto loopOp) -> FailureOr { return yieldTiledValuesAndReplaceLoop( loopOp, rewriter, newInitOperands, yieldTiledValuesFn); }) .Default([&](auto loopOp) -> FailureOr { return rewriter.notifyMatchFailure(loopOp, "unhandled loop type"); }); } /// Method to add new init values to a loop nest. Updates `loops` in-place /// with new loops that use the `newInitValues`. The outer-loops are updated /// to yield the new result values of the inner loop. For the innermost loop, /// the call back `getNewYields` is invoked to get the additional values to /// yield form the innermost loop. static LogicalResult addInitOperandsToLoopNest( RewriterBase &rewriter, MutableArrayRef loops, ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn) { if (loops.empty()) return success(); OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(loops.front()); SmallVector ivs; for (auto &loop : loops.drop_back()) { rewriter.setInsertionPoint(loop); // if loops.size() > 1 we assume that scf.for is used for the loops. auto forLoop = cast(loop.getOperation()); // Create a new loop with the new init values for this loop. SmallVector newInits = llvm::to_vector(forLoop.getInitArgs()); newInits.append(newInitValues.begin(), newInitValues.end()); auto newLoop = scf::ForOp::create( rewriter, forLoop.getLoc(), forLoop.getLowerBound(), forLoop.getUpperBound(), forLoop.getStep(), newInits, [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {}, forLoop.getUnsignedCmp()); // Merge the body of the new loop with the body of the old loops. SmallVector sourceBlockArgs; sourceBlockArgs.push_back(newLoop.getInductionVar()); auto newRegionIterArgs = newLoop.getRegionIterArgs(); sourceBlockArgs.append( newRegionIterArgs.begin(), std::next(newRegionIterArgs.begin(), forLoop.getNumResults())); rewriter.mergeBlocks(forLoop.getBody(), newLoop.getBody(), sourceBlockArgs); rewriter.replaceOp( forLoop, newLoop.getResults().take_front(forLoop.getNumResults())); loop = newLoop; ivs.push_back(newLoop.getInductionVar()); newInitValues = newLoop.getRegionIterArgs().take_back(newInitValues.size()); } // Update the loop body of the innermost loop to get new yield values. LoopLikeOpInterface innerMostLoop = loops.back(); FailureOr newInnerMostLoop = yieldTiledValuesAndReplaceLoop(innerMostLoop, rewriter, newInitValues, getNewTiledYieldsFn); if (failed(newInnerMostLoop)) return innerMostLoop.emitOpError("failed to return additional yields"); loops.back() = newInnerMostLoop.value(); // Make all other loops except the innermost loops yield the values returned // by the inner loop. for (auto [outerLoop, innerLoop] : llvm::zip_equal(loops.drop_back(), loops.drop_front())) { // Again assume that all the outer loops are scf.for operations. auto outerForLoop = cast(outerLoop); auto outerLoopYield = cast(outerForLoop.getBody()->getTerminator()); SmallVector newYields = llvm::to_vector(outerLoopYield.getOperands()); ValueRange additionalYields = innerLoop->getResults().take_back(newInitValues.size()); newYields.append(additionalYields.begin(), additionalYields.end()); rewriter.setInsertionPoint(outerLoopYield); rewriter.replaceOpWithNewOp(outerLoopYield, newYields); } return success(); } /// Implementation of tiling transformation of `op` that implements the /// `TilingInterface` using `scf.for` to iterate over the tiles. FailureOr mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op, const scf::SCFTilingOptions &options) { if (failed(verifyOptions(rewriter, op.getLoc(), options))) { return failure(); } OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointAfter(op); // 1. Get the range of the loops that are represented by the operation. SmallVector iterationDomain = op.getIterationDomain(rewriter); // 2. Materialize the tile sizes and/or number of threads; SmallVector givenTileSizes, numThreads; std::tie(givenTileSizes, numThreads) = getUserTileSizesAndNumThreads(rewriter, op, iterationDomain, options); // Check if it is safe to tile. This is hold over from previous iterations // of tile to for-all. Consider dropping it. if (failed(checkTileSizes(op, options.loopType, options.reductionStrategy, givenTileSizes, numThreads))) { return failure(); } // Get the reduction dims SetVector reductionDims = getSanitizedReductionDims(givenTileSizes, options); // 3. If there is an interchange specified, permute the iteration domain and // the tile sizes. SmallVector interchangeVector; if (!options.interchangeVector.empty()) { interchangeVector = fillInterchangeVector(options.interchangeVector, iterationDomain.size()); assert(isPermutationVector(interchangeVector) && "expected interchange vector to be a permutation"); applyPermutationToVector(iterationDomain, interchangeVector); applyPermutationToVector(givenTileSizes, interchangeVector); if (!numThreads.empty()) applyPermutationToVector(numThreads, interchangeVector); } FailureOr tilingResult; // 4. Define the lambda function used later to generate the body of the // innermost tiled loop. GenerateTiledBodyFn innerYieldTiledValuesFn = [&](RewriterBase &rewriter, Location loc, ValueRange ivs, ArrayRef tileOffsets, ArrayRef tileSizes, ValueRange regionIterArgs, SmallVector &tiledResults, SmallVector> &resultOffsets, SmallVector> &resultSizes) -> LogicalResult { // 4b. If interchange was provided, apply inverse of the interchange // to get back the offsets/sizes in the order to be specified. SmallVector tileOffsetsVec = llvm::to_vector(tileOffsets); SmallVector tileSizesVec = llvm::to_vector(tileSizes); if (!interchangeVector.empty()) { auto inversePermutation = invertPermutationVector(interchangeVector); applyPermutationToVector(tileOffsetsVec, inversePermutation); applyPermutationToVector(tileSizesVec, inversePermutation); } // 5. Generate the tiled implementation within the inner most loop. // 5a. Clone the operation within the loop body. auto clonedOp = cast( cloneOpAndUpdateDestinationArgs(rewriter, op, regionIterArgs)); // 5b. Early return cloned op if tiling is not happening. We can not // return the original op because it could lead to `rewriter.replaceOp(op, // op->getResults())` and users would get crash. if (llvm::all_of(givenTileSizes, isZeroInteger)) { tiledResults.append(clonedOp->result_begin(), clonedOp->result_end()); tilingResult = TilingResult{/*tiledOps=*/{clonedOp}, clonedOp->getResults(), /*generatedSlices=*/{}}; return success(); } // 5c. Tile the cloned operation. tilingResult = getTiledImplementation(rewriter, clonedOp, options.reductionStrategy, regionIterArgs, tileOffsetsVec, tileSizesVec, ivs, numThreads, givenTileSizes, reductionDims); if (failed(tilingResult)) { rewriter.eraseOp(clonedOp); return op.emitOpError("faild to tile operation"); } // 5d. Delete the cloned operation. rewriter.eraseOp(clonedOp); // 5e. Compute the offsets at which the result values are to be inserted // back into its destinations. for (auto [index, tiledValue] : llvm::enumerate(tilingResult->tiledValues)) { tiledResults.push_back(tiledValue); SmallVector resultOffset, resultSize; if (failed(getResultTilePosition( rewriter, options.reductionStrategy, index, tiledValue, op, tileOffsetsVec, tileSizesVec, ivs, numThreads, givenTileSizes, reductionDims, resultOffset, resultSize))) { for (auto op : tilingResult->tiledOps) { rewriter.eraseOp(op); } return rewriter.notifyMatchFailure( op, "failed to get slice of result produced"); } resultOffsets.emplace_back(std::move(resultOffset)); resultSizes.emplace_back(std::move(resultSize)); } return success(); }; // 6. Find the destination tensors to use for the operation. FailureOr> maybeInits = createInitialTensorsForTiling( rewriter, op, options.reductionStrategy, iterationDomain, numThreads, givenTileSizes, reductionDims); if (failed(maybeInits)) { return rewriter.notifyMatchFailure( op, "unable to create initial tensors for tiling"); } SmallVector &initTensors = maybeInits.value(); // 7. Generate the tiled loops nest using the callback defined above. SmallVector loops; { FailureOr> loopsOr = generateLoopNest( rewriter, op.getLoc(), options, iterationDomain, givenTileSizes, numThreads, initTensors, innerYieldTiledValuesFn); if (failed(loopsOr)) return op.emitOpError("failed to generate tiling loops"); assert(succeeded(tilingResult) && "expected tiling result to be computed after loop generation"); std::swap(loops, loopsOr.value()); } if (loops.empty()) { // If loops are empty, the tiled op is used as the replacement for the // untiled op. return scf::SCFTilingResult{tilingResult->tiledOps, initTensors, loops, tilingResult->tiledValues, tilingResult->generatedSlices, {}}; } auto loopResults = llvm::map_to_vector(loops.front()->getResults(), [](OpResult r) -> Value { return r; }); // For the full reduction case, there is nothing more to do. if (options.reductionStrategy == ReductionTilingStrategy::FullReduction) { return scf::SCFTilingResult{ tilingResult->tiledOps, initTensors, loops, loopResults, tilingResult->generatedSlices, {}}; } // The results of the loop needs to be merged. FailureOr mergeResult = mergeTilingResults( rewriter, op, options.reductionStrategy, reductionDims, loopResults); if (failed(mergeResult)) { return rewriter.notifyMatchFailure( op, "Failed to merge partial results from tiling"); } return scf::SCFTilingResult{tilingResult->tiledOps, initTensors, loops, mergeResult->replacements, tilingResult->generatedSlices, mergeResult->mergeOps}; } FailureOr mlir::scf::tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op, ArrayRef tileSize) { scf::SCFTilingOptions options; options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp); options.setReductionTilingStrategy( ReductionTilingStrategy::PartialReductionOuterReduction); options.setTileSizes(tileSize); SmallVector reductionDims; for (auto [index, iteratorType] : llvm::enumerate(op.getLoopIteratorTypes())) if (iteratorType == utils::IteratorType::reduction) reductionDims.push_back(index); options.setReductionDims(reductionDims); return tileUsingSCF(b, op, options); } //===----------------------------------------------------------------------===// // tileConsumerAndFuseProducersUsingSCF implementation. //===----------------------------------------------------------------------===// /// Return the untiled producer whose slice is used in a tiled consumer. The /// method traverses the tile loop nest (`loops`) if needed, and returns the /// `iter_args` of the outer most that is encountered. Traversing the /// iter_args indicates that this is a destination operand of the consumer. If /// there was no loop traversal needed, the second value of the returned tuple /// is empty. static std::tuple> getUntiledProducerFromSliceSource(OpOperand *source, ArrayRef loops) { std::optional destinationIterArg; assert(!loops.empty() && "expected non empty loops container"); auto loopIt = loops.rbegin(); while (loopIt != loops.rend() && isa(source->get())) { auto iterArg = cast(source->get()); auto loop = *loopIt; if (iterArg.getOwner()->getParentOp() != loop) break; source = loop.getTiedLoopInit(iterArg); loopIt++; } if (loopIt == loops.rend()) destinationIterArg = source; return {dyn_cast(source->get()), destinationIterArg}; } /// Implementation of fusing producer of a single slice by computing the /// slice of the producer in-place. std::optional mlir::scf::tileAndFuseProducerOfSlice( RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp, MutableArrayRef loops) { // 1. Get the producer of the source (potentially walking through // `iter_args` of nested `scf.for`) auto [fusableProducer, destinationInitArg] = getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(), loops); if (!fusableProducer) return std::nullopt; unsigned resultNumber = fusableProducer.getResultNumber(); OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(candidateSliceOp); // 2. Clone the fused producer // 2a. Compute the destination operands to use for the cloned operation. SmallVector origDestinationTensors, clonedOpDestinationTensors; Operation *fusableProducerOp = fusableProducer.getOwner(); if (isa(fusableProducerOp) && failed(tensor::getOrCreateDestinations( rewriter, fusableProducerOp->getLoc(), fusableProducerOp, origDestinationTensors))) return std::nullopt; clonedOpDestinationTensors = origDestinationTensors; if (destinationInitArg && isa(fusableProducerOp)) { // 2b. If the producer is also destination style, then to maintain the // destination passing style, update the destination of the producer to be // the source of the slice. clonedOpDestinationTensors[resultNumber] = candidateSliceOp.getSource(); } // 2c. Clone the fused producer. Operation *clonedProducerOp = cloneOpAndUpdateDestinationArgs( rewriter, fusableProducerOp, clonedOpDestinationTensors); // 2d. Update the source of the candidateSlice to be the cloned producer. // Easier to just clone the slice with different source since // replacements and DCE of cloned ops becomes easier SmallVector candidateSliceOpOperands = llvm::to_vector(candidateSliceOp->getOperands()); candidateSliceOpOperands[0] = clonedProducerOp->getResult(resultNumber); tensor::ExtractSliceOp clonedCandidateSliceOp = mlir::clone(rewriter, candidateSliceOp, candidateSliceOp->getResultTypes(), candidateSliceOpOperands); // 3. Generate the tiled implementation of the producer of the source FailureOr tileAndFuseResult = tensor::replaceExtractSliceWithTiledProducer( rewriter, clonedCandidateSliceOp, clonedProducerOp->getResult(resultNumber)); if (failed(tileAndFuseResult)) return std::nullopt; // Note: Do not delete the candidateSliceOp, since its passed in from the // caller. rewriter.replaceAllUsesWith(candidateSliceOp, tileAndFuseResult->tiledValues[0]); rewriter.eraseOp(clonedCandidateSliceOp); rewriter.eraseOp(clonedProducerOp); // 3. If the slice is for a destination operand, for example, // // ```mlir // %0 = linalg.init // %1 = linalg.fill .. outs(%0 : ) // %2 = scf.for .. iter_args(%arg0 = %1) { // %3 = scf.for .. iter_args(%arg1 = %arg0) { // %4 = tensor.extract_slice %arg1 [..] // .. = linalg.matmul .. outs(%4 : ) // } // } // ``` // // the IR is currently // // ``` // %0 = linalg.init // %1 = linalg.fill // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) { // %3 = scf.for .. iter_args(%arg1 = %arg0) { // %4 = tensor.extract_slice %arg1[..] // %5 = linalg.fill .. outs(%4 : ) // .. = linalg.matmul .. outs(%5 : ) // } // } // ``` // // The untiled `linalg.fill` is still used as the `init_value` since it // was originally a destination operand of the untiled `linalg.matmul`. // When fusing an operand that is a destination operand, the iter_arg of // the outer most loop should be changed to use the destination of the // fused operation. With this the IR will be. // // ``` // %0 = linalg.init // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) { // %2 = scf.for .. iter_args(%arg1 = %arg0) { // %3 = tensor.extract_slice %arg1[..] // %4 = linalg.fill .. outs(%3 : ) // .. = linalg.matmul .. outs(%4 : ) // } // } // ``` if (destinationInitArg && isa(fusableProducerOp) && !loops.empty()) { loops.front() ->getOpOperands()[destinationInitArg.value()->getOperandNumber()] .set(origDestinationTensors[resultNumber]); } return scf::SCFFuseProducerOfSliceResult{ fusableProducer, tileAndFuseResult->tiledValues[0], tileAndFuseResult->tiledOps, tileAndFuseResult->generatedSlices}; } /// Reconstruct the fused producer from within the tiled-and-fused code. FailureOr> mlir::scf::yieldReplacementForFusedProducer( RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp, scf::SCFFuseProducerOfSliceResult fusedProducerInfo, MutableArrayRef loops, ArrayRef yieldResultNumber) { if (loops.empty()) return success(); Operation *originalOwner = fusedProducerInfo.origProducer.getOwner(), *tiledOwner = fusedProducerInfo.tiledOps[0]; Location loc = originalOwner->getLoc(); // a. collect all init Value to be appended SmallVector initNumberList = yieldResultNumber.empty() ? llvm::to_vector(llvm::seq( 0, originalOwner->getNumResults())) : llvm::to_vector(yieldResultNumber); SmallVector initValueList; for (const auto &resultNumber : initNumberList) { FailureOr initValue = tensor::getOrCreateDestination( rewriter, loc, originalOwner->getResult(resultNumber)); if (succeeded(initValue)) { initValueList.push_back(initValue.value()); } else { return failure(); } } SmallVector generatedSlices; YieldTiledValuesFn newYieldValuesFn = [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/, ValueRange newRegionIterArgs, SmallVector &tiledResult, SmallVector> &tiledOffset, SmallVector> &tiledSizes) -> LogicalResult { OpBuilder::InsertionGuard g(innerRewriter); // get sliceOp tile information SmallVector sliceOffset = sliceOp.getMixedOffsets(), sliceSizes = sliceOp.getMixedSizes(); // expect all strides of sliceOp being 1 if (!llvm::all_of(sliceOp.getMixedStrides(), isOneInteger)) return failure(); unsigned sliceResultNumber = fusedProducerInfo.origProducer.getResultNumber(); auto tilableOp = cast(originalOwner); // b. get iterDomain Offset and Sizes based on sliceOp tile SmallVector iterDomainOffset, iterDomainSizes; // skip tensor.pack/unpack/pad, which expects single opResult if (tilableOp->getNumResults() > 1 && failed(tilableOp.getIterationDomainTileFromResultTile( rewriter, sliceResultNumber, sliceOffset, sliceSizes, iterDomainOffset, iterDomainSizes))) { // In theory, it is unnecessary to raise an error here. Actually // although it fails to reconstruct the result tensor, it should not // broke current fusion anyway. The reason why we must return failure // currently is that the callback function `newYieldValuesFn` will be // called after new init operand(s) has already been appended. It will // take more refactoring to make sure the init operands are added // consistently in the future. For more details, please refer to: // https://github.com/llvm/llvm-project/pull/93144#discussion_r1643760814 return failure(); } // c. calculate offsets and sizes info of all OpResults respectively based // on iteration Domain Tile SmallVector> offsetList, sizesList; for (const auto &resultNumber : initNumberList) { if (resultNumber == sliceResultNumber) { offsetList.push_back(sliceOffset); sizesList.push_back(sliceSizes); } else { assert(!iterDomainOffset.empty() && !iterDomainSizes.empty()); // infer result tile according to the iteration domain tile SmallVector offset, sizes; if (failed(tilableOp.getResultTilePosition( rewriter, resultNumber, iterDomainOffset, iterDomainSizes, offset, sizes))) { return failure(); } offsetList.push_back(offset); sizesList.push_back(sizes); } } // d. create `extract_slice` for `iter_args` for DPS operation if // necessary if (auto tiledDestStyleOp = dyn_cast(tiledOwner)) { rewriter.setInsertionPoint(tiledDestStyleOp); for (const auto &&[index, newRegionArg] : llvm::enumerate(newRegionIterArgs)) { auto destSlice = tensor::ExtractSliceOp::create( rewriter, loc, newRegionArg, offsetList[index], sizesList[index], SmallVector(offsetList[index].size(), rewriter.getIndexAttr(1))); generatedSlices.push_back(destSlice); unsigned resultNumber = initNumberList[index]; rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() { tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice); }); } } // e. prepare tiled offset and sizes for later `insert_slice` creation by // caller Block *block = rewriter.getInsertionPoint()->getBlock(); rewriter.setInsertionPoint(block->getTerminator()); for (const auto &&[index, resultNumber] : llvm::enumerate(initNumberList)) { tiledResult.push_back(tiledOwner->getResult(resultNumber)); tiledOffset.emplace_back(offsetList[index]); tiledSizes.emplace_back(sizesList[index]); } return success(); }; if (failed(addInitOperandsToLoopNest(rewriter, loops, initValueList, newYieldValuesFn))) { return failure(); } return generatedSlices; } namespace { //===----------------------------------------------------------------------===// // SliceTrackingListener //===----------------------------------------------------------------------===// /// This class is a listener for tracking the insertion and removal of /// `tensor.extract_slice` ops in a worklist. This can be used in a greedy /// fusion algorithm to apply cleanup patterns in between fusion steps. class SliceTrackingListener : public RewriterBase::Listener { public: explicit SliceTrackingListener( std::optional patterns); SliceTrackingListener() = default; /// Adds the given list of operations to the worklist, and if present, /// applies the list of `patterns` to the newly added operations. This only /// processes the given operations and any newly inserted ones by the /// pattern set. LogicalResult insertAndApplyPatterns(ArrayRef newOps); /// Add to the new operation worklist if it is an extract_slice. void notifyOperationInserted(Operation *op, OpBuilder::InsertPoint previous) override; /// Shared helper for operation removal from the worklist. void removeOp(Operation *op); /// Remove the operation from the worklist. void notifyOperationErased(Operation *op) override; /// Remove the operation from the worklist. void notifyOperationReplaced(Operation *op, ValueRange replacement) override; /// The worklist for this transformation keeps track of the slices to visit /// next for fusion. std::deque worklist; private: /// Optional pattern set to apply when adding new operations to the /// worklist. std::optional patterns = std::nullopt; }; SliceTrackingListener::SliceTrackingListener( std::optional p) { patterns = std::move(p); } LogicalResult SliceTrackingListener::insertAndApplyPatterns(ArrayRef ops) { for (Operation *op : ops) { if (auto slice = dyn_cast(op)) worklist.push_back(slice); } if (!patterns) return success(); return applyOpPatternsGreedily( ops, patterns.value(), GreedyRewriteConfig().setListener(this).setStrictness( GreedyRewriteStrictness::ExistingAndNewOps)); } void SliceTrackingListener::notifyOperationInserted( Operation *op, OpBuilder::InsertPoint previous) { auto slice = dyn_cast(op); if (!slice) return; worklist.push_back(slice); } // Scan the worklist for the given op and remove it if present. The // expectation is for the worklist to be small and for removal to be // relatively rare. void SliceTrackingListener::removeOp(Operation *op) { if (!isa(op)) return; auto iter = worklist.begin(); while (iter != worklist.end()) { if (*iter == op) break; iter++; } if (iter == worklist.end()) return; worklist.erase(iter); } void SliceTrackingListener::notifyOperationErased(Operation *op) { removeOp(op); } void SliceTrackingListener::notifyOperationReplaced(Operation *op, ValueRange replacement) { removeOp(op); } //===----------------------------------------------------------------------===// // ReplacementListener //===----------------------------------------------------------------------===// /// Listener that tracks updates replacements for values which can be mutated. /// This listener runs on top of the existing listener for the rewriter, /// to make sure external users can still run listeners. class ReplacementListener : public RewriterBase::ForwardingListener { public: ReplacementListener(DenseMap &replacements, OpBuilder::Listener *listener) : ForwardingListener(listener), replacements(replacements) {} void updateReplacementValues(ValueRange origValues, ValueRange replaceValues) { // This can probably be written better, but just iterates over the map // and the new replacements for now. for (auto &[key, val] : replacements) { for (auto [orig, replace] : llvm::zip_equal(origValues, replaceValues)) { if (val == orig) { val = replace; } } } } void notifyOperationReplaced(Operation *op, Operation *newOp) override { ForwardingListener::notifyOperationReplaced(op, newOp); updateReplacementValues(op->getResults(), newOp->getResults()); } void notifyOperationReplaced(Operation *op, ValueRange values) override { ForwardingListener::notifyOperationReplaced(op, values); updateReplacementValues(op->getResults(), values); } private: DenseMap &replacements; }; } // namespace /// Implementation of tile consumer and fuse producer greedily. FailureOr mlir::scf::tileConsumerAndFuseProducersUsingSCF( RewriterBase &rewriter, TilingInterface consumer, const scf::SCFTileAndFuseOptions &options) { // This transformation is only valid for ops that return values (i.e. not // valid to use with operations that have memref operands). if (!consumer->getNumResults()) { return rewriter.notifyMatchFailure( consumer, "invalid pattern for op with no results"); } // 1. First tile the consumer. SetVector fusedProducers, tiledAndFusedOps; FailureOr tilingResult = tileUsingSCF(rewriter, consumer, options.tilingOptions); if (failed(tilingResult)) return rewriter.notifyMatchFailure(consumer, "failed to tile consumer"); tiledAndFusedOps.insert_range(tilingResult->tiledOps); DenseMap replacements; for (auto [origVal, replacement] : llvm::zip_equal(consumer->getResults(), tilingResult->replacements)) { replacements[origVal] = replacement; } // If there are no loops generated, fusion is immaterial. auto &loops = tilingResult->loops; if (loops.empty()) { return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops, replacements}; } // Since the loop gets potentially replaced during fusion, we need to track // the mutation of replacement values. To do this, we attach a listener to // update the replacements as they happen. OpBuilder::Listener *previousListener = rewriter.getListener(); auto resetListener = llvm::make_scope_exit([&]() { rewriter.setListener(previousListener); }); ReplacementListener replaceListener(replacements, previousListener); rewriter.setListener(&replaceListener); // 2. Typically, the operands of the tiled operation are slices of the // operands of the untiled operation. These are expressed in IR using // `tensor.extract_slice` operations with source being the operands of // the untiled operation. Create a worklist of these // `tensor.extract_slice` operations. If the producers of the source of // the `tensor.extract_slice` can be tiled such that the tiled value is // generated in-place, that effectively tiles + fuses the operations. struct WorklistItem { tensor::ExtractSliceOp candidateSlice; SCFTileAndFuseOptions::ControlFnResult controlFnResult; }; SliceTrackingListener sliceTracker = SliceTrackingListener(options.cleanupPatterns); if (failed( sliceTracker.insertAndApplyPatterns(tilingResult->generatedSlices))) { return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed"); } OpBuilder::InsertionGuard g(rewriter); while (!sliceTracker.worklist.empty()) { auto candidateSlice = sliceTracker.worklist.front(); sliceTracker.worklist.pop_front(); auto [fusableProducer, destinationInitArg] = getUntiledProducerFromSliceSource(&candidateSlice.getSourceMutable(), loops); if (!fusableProducer) continue; std::optional controlFnResult = options.fusionControlFn(candidateSlice, fusableProducer, destinationInitArg.has_value()); if (!controlFnResult) continue; WorklistItem worklistItem = {candidateSlice, controlFnResult.value()}; // The operands of the fused producer might themselved be slices of // values produced by operations that implement the `TilingInterface`. // Add these operations to the worklist. std::optional fusedResult = tileAndFuseProducerOfSlice(rewriter, worklistItem.candidateSlice, loops); if (!fusedResult) continue; SmallVector worklistCandidates = fusedResult->generatedSlices; if (worklistItem.controlFnResult.yieldProducerReplacement) { // Reconstruct and yield all opResult of fusableProducerOp by default. // The caller can specific which one to yield by designating optional // argument named `yieldResultNumber` of // `yieldReplacementForFusedProducer`. Operation *fusableProducerOp = fusedResult->origProducer.getOwner(); FailureOr> newSlices = yieldReplacementForFusedProducer(rewriter, worklistItem.candidateSlice, fusedResult.value(), loops); if (failed(newSlices)) { return rewriter.notifyMatchFailure( fusableProducerOp, "failed to replacement value for this " "operation from within the tiled loop"); } worklistCandidates.append(newSlices.value()); for (auto [index, result] : llvm::enumerate(fusableProducerOp->getResults())) { replacements[result] = loops.front()->getResult( loops.front()->getNumResults() - fusableProducerOp->getNumResults() + index); } } if (Operation *tiledAndFusedOp = fusedResult->tiledAndFusedProducer.getDefiningOp()) { fusedProducers.insert(fusedResult->origProducer.getDefiningOp()); tiledAndFusedOps.insert(tiledAndFusedOp); } if (failed(sliceTracker.insertAndApplyPatterns(worklistCandidates))) { return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed"); } } return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops, replacements}; } //===----------------------------------------------------------------------===// // tileAndFuseConsumerUsingSCF implementation. //===----------------------------------------------------------------------===// /// A utility function that checks whether the only use of the result of a /// tensor.insert_slice op is in a scf.yield op. static LogicalResult checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) { Value result = candidateSliceOp.getResult(); Value::use_range uses = result.getUses(); if (!llvm::hasSingleElement(uses)) { LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n"); return failure(); } OpOperand &operandUse = (*uses.begin()); Operation *userOp = operandUse.getOwner(); if (!isa(userOp)) { LLVM_DEBUG(llvm::dbgs() << "Expected scf.yield to be the only user, but got -> " << (*userOp)); return failure(); } if (result.getDefiningOp()->getBlock() != userOp->getBlock()) { LLVM_DEBUG(llvm::dbgs() << "Expected tensor.insert_slice and scf.yield to " "be in the same block\n"); return failure(); } return success(); } /// An utility to get the first user of the given loopOp. If any of user stay /// in different block of loopOp, return failure. static FailureOr getFirstUserOfLoop(Operation *loopOp) { if (!isa(loopOp)) return failure(); Operation *firstUserOfLoop = nullptr; for (Operation *userOp : loopOp->getUsers()) { // `ParallelInsertSlice` located inside `InParallelOp` has no same parent // block with any other types of operation. Thus, just redirecting to its // parent `InParallelOp`. E.g. // // ``` // %1 = scf.for { // ... // } // %2 = consumerOp ins(%1, ...) // scf.forall.in_parallel { // tensor.parallel_insert_slice %1 // } // ``` // where `InParallelOp` but not `ParallelInsertSlice` stays in the same // same block with `consumerOp`. if (isa(userOp)) userOp = userOp->getParentOfType(); if (loopOp->getBlock() != userOp->getBlock()) return failure(); if (!firstUserOfLoop || userOp->isBeforeInBlock(firstUserOfLoop)) firstUserOfLoop = userOp; } return firstUserOfLoop; } /// This utility currently checks whether the first userOp of loop is NOT /// before the last defineOp of consumer operand. Because that we need to move /// the whole loop structure right before the `firstUserOfLoop`. This utility /// thus helps ensuring that no invalid IR is formed, i.e. no backward slice /// of consumerOp is dominated by the `firstUserOfLoop`. Saying that: /// /// ``` /// %0 = scf.for() { /// ... /// } /// ... /// %1 = firstUserOfLoop(%0) /// ... /// %2 = lastDefOfConsumerOperand /// ... /// %3 = consumerOp(%2) /// ``` /// /// If the `firstUserOfLoop` is before `lastDefOfConsumerOperand`, then it /// would be invalid to move the `loopOp` right before the `firstUserOfLoop`, /// a.k.a. use-def chain violation: /// /// ``` /// %0:2 = scf.for() { /// // use before define error /// %3 = tiledConsumerOp(%2) /// } /// %1 = firstUserOfLoop(%0) /// ... /// %2 = lastDefOfConsumerOperand /// ``` /// /// @param loopOp: loop operation /// @param consumerOp: consumer operation /// @param reorderOperations: the flag controls whether to reorder the /// backward slice w.r.t. the defineOp of `consumerOp` operands. /// @return: computed backward slice of consumerOp, but excluding those /// already dominates `firstUserOfLoop`. static FailureOr> checkAssumptionForLoop(Operation *loopOp, Operation *consumerOp, bool reorderOperations) { FailureOr firstUserOfLoop = getFirstUserOfLoop(loopOp); if (failed(firstUserOfLoop)) return failure(); BackwardSliceOptions options; DominanceInfo dominanceInfo; options.inclusive = true; options.omitBlockArguments = true; bool includeLoopOp = false; options.filter = [&](Operation *op) { if (op == loopOp) { includeLoopOp = true; return false; } // Cut off the slice to not include any operation that already dominates // firstUserOfLoop. return !dominanceInfo.properlyDominates(op, *firstUserOfLoop); }; llvm::SetVector slice; for (auto operand : consumerOp->getOperands()) { LogicalResult result = getBackwardSlice(operand, &slice, options); assert(result.succeeded() && "expected a backward slice"); (void)result; } if (!slice.empty()) { // If consumerOp has one producer, which is also the user of loopOp. // E.g. // ``` // %0 = %loopOp // %1 = consumerOp1 ins(%0) // %2 = consumerOp2 ins(%0, %1) // ``` // We can not fuse consumerOp2 into loopOp due to UD chain, unless // consumerOp1 has already been fused into loopOp before. if (includeLoopOp || !reorderOperations) return failure(); } return slice; } /// Fetches the OpOperand of the first valid user (and use) of the value `val` /// which implements `TilingInterface` and `DestinationStyleOpInterface`. /// Returns failure otherwise. static FailureOr getConsumerFromLoopUses(RewriterBase &rewriter, Operation *loopOp, unsigned resultNumber) { if (!isa(loopOp)) return failure(); Value val = loopOp->getResult(resultNumber); Block *loopBlock = loopOp->getBlock(); for (OpOperand &opOperand : val.getUses()) { Operation *consumerOp = opOperand.getOwner(); // Step 1. Check if the user is tilable. if (!isa(consumerOp) || !isa(consumerOp)) { // TODO: We have to init result of consumer before scf.for, use // DestinationStyleOpInterface to get result shape from init for now. // Add support for other op such as op has InferTypeOpInterface. continue; } // Step 2. Check if user stay in the same block. if (loopBlock != consumerOp->getBlock()) continue; // Step 3. Check if user has succeeding user. Otherwise, it usually // represents already tiled. if (consumerOp->use_empty()) continue; // Step 4. Check assumption for loop with `reorderOperations` enabled. FailureOr> slice = checkAssumptionForLoop(loopOp, consumerOp, true); if (failed(slice)) continue; // Step 5. If backward sice is not empty, move them before // firstUserOfLoop. if (!slice->empty()) { mlir::topologicalSort(*slice); FailureOr firstUserOfLoop = getFirstUserOfLoop(loopOp); assert(succeeded(firstUserOfLoop) && "First user of loop is not found"); for (auto op : *slice) { rewriter.moveOpBefore(op, *firstUserOfLoop); } } return &opOperand; } return failure(); } /// Fetch the untiled consumer of the outermost scf.for's result which is /// yielded by a tensor.insert_slice from the innermost scf.for. This function /// makes the following assumptions : /// 1. tensor.insert_slice has scf.yield as its only user. /// 2. scf.for's corresponding result has only one use. /// 3. The `loops` passed in are perfectly nested `scf.for` operations. static FailureOr getUntiledConsumerFromSlice(RewriterBase &rewriter, tensor::InsertSliceOp candidateSliceOp, MutableArrayRef loops) { assert(!loops.empty() && "unexpected loops to be empty"); // 1. Expect slice to be part of the body of the inner most loop. Operation *containingOp = candidateSliceOp->getParentOp(); if (containingOp != loops.back()) { return rewriter.notifyMatchFailure( candidateSliceOp, "expected slice to be within body of inner-most loop"); } // 2. Check that the loop is perfectly nested. if (!isPerfectlyNestedForLoops(loops)) { return rewriter.notifyMatchFailure( candidateSliceOp, "expected passed loops to be perfectly nested."); } if (failed(checkAssumptionForFusingConsumer(candidateSliceOp))) return failure(); Value sliceResult = candidateSliceOp.getResult(); // 3. Fetch the corresponding output. OpOperand &yieldOpOperand = (*sliceResult.getUses().begin()); unsigned resultNumber = yieldOpOperand.getOperandNumber(); scf::ForOp topLevelForOp = cast(loops.front().getOperation()); return getConsumerFromLoopUses(rewriter, topLevelForOp, resultNumber); } /// Fetch the first untiled consumer of a scf.forall's result which is yielded /// by a tensor.parallel_insert_slice. static FailureOr getUntiledConsumerFromSlice(RewriterBase &rewriter, tensor::ParallelInsertSliceOp candidateSliceOp, MutableArrayRef loops) { assert(!loops.empty() && "unexpected loops to be empty"); // 1. Check that the surrounding loop is a single scf.forall loop. if (loops.size() != 1) { return rewriter.notifyMatchFailure( candidateSliceOp, "expected single surrounding scf.forall"); } auto forallOp = dyn_cast(loops.front().getOperation()); if (!forallOp) { return rewriter.notifyMatchFailure( candidateSliceOp, "expected single surrounding scf.forall"); } // 2. Fetch the corresponding output Value sliceDest = candidateSliceOp.getDest(); auto iterArg = dyn_cast(sliceDest); if (!iterArg) return failure(); if (iterArg.getOwner()->getParentOp() != forallOp) return failure(); unsigned resultNumber = forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg)) .getResultNumber(); return getConsumerFromLoopUses(rewriter, forallOp, resultNumber); } /// A utility to fetch an untiled consumer of /// tensor.insert_slice/tensor.parallel_insert_slice. static FailureOr> getUntiledConsumerOperandsFromSlices( RewriterBase &rewriter, ArrayRef sliceOps, MutableArrayRef loops) { assert(!loops.empty() && "unexpected empty loops"); assert(!sliceOps.empty() && "unexpected empty list of candidate slices"); SmallVector fusedOperands; for (auto sliceOp : sliceOps) { FailureOr fusedOperand = TypeSwitch>(sliceOp) .Case( [&](auto op) { return getUntiledConsumerFromSlice(rewriter, op, loops); }) .Default([&](Operation *op) { return rewriter.notifyMatchFailure(op, "unhandled slice type"); }); if (failed(fusedOperand)) { return failure(); } if (!fusedOperands.empty() && fusedOperand.value()->getOwner() != fusedOperands.front()->getOwner()) { return rewriter.notifyMatchFailure( fusedOperand.value()->getOwner(), "all candidate slices must be to the same consumer"); } fusedOperands.push_back(fusedOperand.value()); } return fusedOperands; } template static tensor::InsertSliceOp cloneAsInsertSlice(RewriterBase &rewriter, InsertSliceOpTy sliceOp); template <> tensor::InsertSliceOp cloneAsInsertSlice(RewriterBase &rewriter, tensor::InsertSliceOp insertSliceOp) { return cast( rewriter.clone(*insertSliceOp.getOperation())); } template <> tensor::InsertSliceOp cloneAsInsertSlice( RewriterBase &rewriter, tensor::ParallelInsertSliceOp insertSliceOp) { return tensor::InsertSliceOp::create( rewriter, insertSliceOp->getLoc(), insertSliceOp.getSource(), insertSliceOp.getDest(), insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides()); } static SmallVector cloneAsInsertSlices(RewriterBase &rewriter, ArrayRef candidateSlices) { assert(!candidateSlices.empty() && "unexpected empty list of slices to clone"); SmallVector clonedSlices; for (auto sliceOp : candidateSlices) { TypeSwitch(sliceOp) .Case( [&](auto op) { auto clonedOp = cloneAsInsertSlice(rewriter, op); clonedSlices.push_back(clonedOp); }) .Default([&](Operation *op) { // Assert here assuming this has already been checked. assert(0 && "unexpected slice type while cloning as insert slice"); }); } return clonedSlices; } /// Implementation of fusing consumer of a single slice by computing the /// slice of the consumer in-place for scf loop. FailureOr mlir::scf::tileAndFuseConsumerOfSlices( RewriterBase &rewriter, ArrayRef candidateSlices, MutableArrayRef loops) { if (candidateSlices.empty()) { return rewriter.notifyMatchFailure( rewriter.getUnknownLoc(), "no candidate slices provided for consumer fusion"); } // Return if `loops` is empty, return an error for now. Caller is expected // to handle this case. if (loops.empty()) { return rewriter.notifyMatchFailure( candidateSlices.front(), "cannot call tile and fuse consumer with an empty loop nest"); } if (!(llvm::all_of(candidateSlices, llvm::IsaPred) || llvm::all_of(candidateSlices, llvm::IsaPred))) { return rewriter.notifyMatchFailure( candidateSlices.front(), "candidates slices need to be all `tensor.extract_slice`s or " "`tensor.parallel_insert_slice`s"); } // 1. Get the consumer of scf.for for the result yielded by // tensor.insert_slice/parallel_insert_slice. SmallVector consumerOpOperands; Operation *consumerOp; { FailureOr> maybeConsumerOpOperand = getUntiledConsumerOperandsFromSlices(rewriter, candidateSlices, loops); if (failed(maybeConsumerOpOperand)) { return rewriter.notifyMatchFailure(candidateSlices.front(), "could not fetch consumer to fuse"); } std::swap(consumerOpOperands, maybeConsumerOpOperand.value()); consumerOp = consumerOpOperands.front()->getOwner(); } LoopLikeOpInterface outerMostLoop = loops.front(); LoopLikeOpInterface innerMostLoop = loops.back(); // Check assumption for loop with `reorderOperations` disabled. if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp, false))) { return rewriter.notifyMatchFailure( outerMostLoop, "the first user of loop should not dominate any define " "of consumer operand(s)"); } OpBuilder::InsertionGuard g(rewriter); // 2. Check consumer is not using scf loop's output as init. auto dstOp = dyn_cast(consumerOp); if (!dstOp) return rewriter.notifyMatchFailure(consumerOp, "consumer op is not DPS operation"); if (llvm::any_of(consumerOpOperands, [&](OpOperand *opOperand) { return dstOp.isDpsInit(opOperand); })) { return rewriter.notifyMatchFailure( consumerOp, "consumer op taking the result of scf.for as init is not supported"); } SmallVector newInits = llvm::to_vector(dstOp.getDpsInits()); // 3. Move the whole loop structure right before firstUserOfLoop, the // dominance should be already ensured by `checkAssumptionForLoop`. FailureOr firstUserOfLoop = getFirstUserOfLoop(outerMostLoop); if (failed(firstUserOfLoop)) { return rewriter.notifyMatchFailure( outerMostLoop, "could not find the first user of outer most loop"); } rewriter.moveOpBefore(outerMostLoop, *firstUserOfLoop); // 4. Set insertion point before terminator op of the loop and create a new // tensor.insert_slice. In the scf.for case this is a clone of the // candidateSliceOp whereas in the scf.forall case this is created from the // operands of tensor.parallel_insert_slice. if (auto sliceOp = dyn_cast(candidateSlices.front())) { auto newForallOp = cast(innerMostLoop.getOperation()); rewriter.setInsertionPoint(newForallOp.getTerminator()); } else { rewriter.setInsertionPoint(candidateSlices.front()); } // 5.a. Clone all the candidate slices as equivalent insert slice ops. SmallVector clonedInsertSlices = cloneAsInsertSlices(rewriter, candidateSlices); // 5.b. Clone consumer op. auto clonedConsumerOp = cast(rewriter.clone(*consumerOp)); SmallVector operandNumbers = llvm::map_to_vector(consumerOpOperands, [](OpOperand *opOperand) { return opOperand->getOperandNumber(); }); SmallVector clonedOpFusedOperandsList = llvm::map_to_vector(operandNumbers, [&](unsigned operandNum) { return &clonedConsumerOp->getOpOperand(operandNum); }); // 5.c. Replace all uses of the loop result with the result of the cloned // tensor.insert_slice. rewriter.modifyOpInPlace(clonedConsumerOp, [&]() { for (auto [operandToReplace, clonedSliceOp] : llvm::zip_equal(clonedOpFusedOperandsList, clonedInsertSlices)) { operandToReplace->set(clonedSliceOp.getResult()); } }); // 6. Perform tiling of the cloned consumer and replace the operand at // `operandNumber` with the source of the cloned tensor.insert_slice op. FailureOr tileAndFuseResult = tensor::replaceInsertSlicesWithTiledConsumer(rewriter, clonedInsertSlices, clonedOpFusedOperandsList); if (failed(tileAndFuseResult)) { return failure(); } auto tiledConsumerOp = cast(tileAndFuseResult->tiledOps[0]); for (auto [operandNum, clonedSliceOp] : llvm::zip_equal(operandNumbers, clonedInsertSlices)) { rewriter.replaceAllUsesWith(tiledConsumerOp->getOperand(operandNum), clonedSliceOp.getSource()); } // 7. Reconstruct [nested] loop with new inits. YieldTiledValuesFn newYieldValuesFn = [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/, ValueRange newRegionIterArgs, SmallVector &tiledResult, SmallVector> &tiledOffset, SmallVector> &tiledSizes) -> LogicalResult { OpBuilder::InsertionGuard g(innerRewriter); // 8. Set inner insertPoint right before tiled consumer op. innerRewriter.setInsertionPoint(tiledConsumerOp); SmallVector> allOffsets, allSizes; for (auto candidateSliceOp : clonedInsertSlices) { SmallVector offsets = candidateSliceOp.getMixedOffsets(); SmallVector sizes = candidateSliceOp.getMixedSizes(); SmallVector strides = candidateSliceOp.getMixedStrides(); // 9. Check all insert stride is 1. if (!llvm::all_of(strides, isOneInteger)) { return rewriter.notifyMatchFailure( candidateSliceOp, "containingOp's result yield with stride"); } allOffsets.emplace_back(std::move(offsets)); allSizes.emplace_back(std::move(sizes)); } // 10. Try to get iter domain position from input position. Use // clonedConsumerOp instead of tiledConsumerOp, because the iteration // domain may require index computation based on the result size. The // sizes and offsets should be the same either way, but using // tiledConsumerOp could lead to some chained unnecessary extra index // computation. SmallVector iterDomainOffsets, iterDomainSizes; if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTiles( rewriter, operandNumbers, allOffsets, allSizes, iterDomainOffsets, iterDomainSizes))) { return rewriter.notifyMatchFailure( clonedConsumerOp, "can't get iter domain position from input position"); } // 11. Try to fetch the offset and size for all results of the cloned // consumer. This would then be used to form the corresponding // tensor.insert_slice/parallel_insert_slice later. unsigned totalNumResultsOfConsumer = tiledConsumerOp->getNumResults(); SmallVector> resultOffsets( totalNumResultsOfConsumer); SmallVector> resultSizes( totalNumResultsOfConsumer); for (auto [idx, v] : llvm::enumerate(tiledConsumerOp->getResults())) { if (failed(tiledConsumerOp.getResultTilePosition( rewriter, idx, iterDomainOffsets, iterDomainSizes, resultOffsets[idx], resultSizes[idx]))) { return rewriter.notifyMatchFailure( tiledConsumerOp, "can't get result domain position from iter domain position"); } } // 12. Create `extract_slice` for `iter_args` for DPS operation if // necessary. if (auto tiledDestStyleOp = dyn_cast( tiledConsumerOp.getOperation())) { rewriter.setInsertionPoint(tiledDestStyleOp); for (const auto &&[index, newRegionArg] : llvm::enumerate(newRegionIterArgs)) { auto destSlice = tensor::ExtractSliceOp::create( rewriter, loc, newRegionArg, resultOffsets[index], resultSizes[index], SmallVector(resultOffsets[index].size(), rewriter.getIndexAttr(1))); // Make a copy of index to avoid a capturing structured binding, which // is a C++20 extension. auto dstNumber = index; rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() { tiledDestStyleOp.getDpsInitsMutable()[dstNumber].set(destSlice); }); } } // 13. Prepare tiled offset and sizes for later `insert_slice` creation by // caller. Block *block = rewriter.getInsertionPoint()->getBlock(); rewriter.setInsertionPoint(block->getTerminator()); for (const auto &&[index, result] : llvm::enumerate(tiledConsumerOp->getResults())) { tiledResult.push_back(result); tiledOffset.emplace_back(resultOffsets[index]); tiledSizes.emplace_back(resultSizes[index]); } return success(); }; // 14. Add new inits to [nested] loops. if (failed(addInitOperandsToLoopNest(rewriter, loops, newInits, newYieldValuesFn))) { return rewriter.notifyMatchFailure(tiledConsumerOp, "unable to add new inits to nest loop"); } // 15. Replace the result of scf loop and consumer op with new loop's // results. for (auto &&[oldResult, newResult] : llvm::zip(consumerOp->getResults(), loops.front()->getResults().take_back(newInits.size()))) { rewriter.replaceAllUsesWith(oldResult, newResult); } // 16. Need to erase the old scf loop and the cloned consumer op. rewriter.eraseOp(clonedConsumerOp); SmallVector tiledAndFusedOpOperands = llvm::map_to_vector(operandNumbers, [&](unsigned operandNum) { return &tileAndFuseResult->tiledOps[0]->getOpOperand(operandNum); }); return scf::SCFFuseConsumerOfSliceResult{ std::move(consumerOpOperands), std::move(tiledAndFusedOpOperands), std::move(tileAndFuseResult->tiledOps)}; } //===----------------------------------------------------------------------===// // lowerToLoopsUsingSCFForOp implementation. //===----------------------------------------------------------------------===// FailureOr> mlir::scf::lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op) { // TODO: Handle cases where the op has results if needed. if (op->getNumResults() > 0) { return rewriter.notifyMatchFailure( op, "unable to lower to loops operations with return values"); } SmallVector domain = op.getIterationDomain(rewriter); SmallVector ivs; SmallVector loops; Location loc = op.getLoc(); for (auto loopRange : domain) { Value offsetVal = getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset); Value sizeVal = getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size); Value strideVal = getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride); auto loop = scf::ForOp::create(rewriter, op.getLoc(), offsetVal, sizeVal, strideVal, ValueRange{}); loops.push_back(loop); ivs.push_back(loop.getInductionVar()); rewriter.setInsertionPoint(loop.getBody()->getTerminator()); } if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) { return failure(); } return loops; }