//===- PaddingTilingInterface.cpp - Padding of TilingInterface ops --------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/Value.h" #include "mlir/Interfaces/TilingInterface.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Casting.h" #define DEBUG_TYPE "pad-tiling-interface" using namespace mlir; using namespace mlir::linalg; using namespace mlir::tensor; #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") #define DBGSNL() (llvm::dbgs() << "\n") /// Form a "full-rank" padding specification so that the application is easy. static SmallVector getFullRankPaddingSizes(Builder &b, ArrayRef indexingSizes, const PadTilingInterfaceOptions &options) { SmallVector paddingSizes; // Complete the padding specification to specify all dimensions. for (size_t idx = 0, e = indexingSizes.size(); idx != e; ++idx) { // Complete to zero if needed. paddingSizes.push_back(options.paddingSizes.size() > idx ? options.paddingSizes[idx] : b.getIndexAttr(0)); // If a dimension is zero (either specified or completed), replace by: // - 1 if we are padding to the next multiple of. // - indexingSizes[idx] otherwise if (isZeroInteger(paddingSizes[idx])) { paddingSizes[idx] = options.padToMultipleOf ? b.getIndexAttr(1) : indexingSizes[idx]; } LLVM_DEBUG(DBGS() << "----idx: " << idx << " : " << paddingSizes[idx] << "\n"); } return paddingSizes; } /// Extracts the constant multiplier from an affine expression of the form /// `d * c` or `c * d`, where `d` is an AffineDimExpr and `c` is an /// AffineConstantExpr. Returns 1 if the expression is not a simple /// multiplication of a dimension and a constant. static int64_t extractConstantMultiplier(AffineExpr expr) { if (auto binOp = dyn_cast(expr)) { if (binOp.getKind() == AffineExprKind::Mul) { auto lhsD = dyn_cast(binOp.getLHS()); auto rhsC = dyn_cast(binOp.getRHS()); if (lhsD && rhsC) { return rhsC.getValue(); } auto lhsC = dyn_cast(binOp.getLHS()); auto rhsD = dyn_cast(binOp.getRHS()); if (lhsC && rhsD) { return lhsC.getValue(); } } } return 1; } /// Compute the padded shape of the given value `v` of `RankedTensorType` given /// - `indexingSizes` a list of OpFoldResult. /// - an `indexingMap` that encodes how the shape of varies with increases /// in `indexingSizes`. /// The `indexingMap` encodes how the shape of varies with `indexingSizes`. /// The `indexingMap` + `indexingSizes` encoding suits StructuredOps. /// The implementaiton below iteratively combines increases from contributing /// dimensions using affine.apply operations. /// The padded shape is computed by evaluating the maximum accessed index per /// dimension, which may involve multiplying by constant factors derived from /// the affine indexing expressions. Currently, only a limited set of projected /// permutation indexing maps are supported, such as /// - affine_map<(d0, d1, d2) -> (d0, d1)> /// - affine_map<(d0, d1, d2) -> (d0, d1 + d2)> /// - affine_map<(d0, d1) -> (d0 * 3 + d1)> /// In the future, more general interfaces can be devised to encode similar /// shape evolutions and map between an op and its operands. SmallVector linalg::computePaddedShape( RewriterBase &rewriter, TypedValue v, AffineMap indexingMap, ArrayRef indexingSizes, const PadTilingInterfaceOptions &options) { Location loc = v.getLoc(); SmallVector paddedShape; auto tensorType = cast(v.getType()); paddedShape.resize_for_overwrite(tensorType.getRank()); assert(tensorType.getRank() == indexingMap.getNumResults() && "expect the number of results of the affine map to match the tensor " "rank"); // "Full-rank" padding specification. SmallVector paddingSizes = getFullRankPaddingSizes(rewriter, indexingSizes, options); // For each dimension in the operand's shape, iterate over indexingSizes and // add the various term contributions. for (const auto &enResults : enumerate(indexingMap.getResults())) { int64_t resultIndex = enResults.index(); AffineMap partialIndexingMap = indexingMap.getSubMap( ArrayRef{static_cast(resultIndex)}); LLVM_DEBUG(DBGS() << "----resultIndex: " << resultIndex << " with partialIndexingMap: " << partialIndexingMap << "\n"); // Find all padding dimensions that contribute to this operand dimension // and compute the padded term contribution to the final padded shape. SmallVector terms; for (size_t paddingDim = 0, e = paddingSizes.size(); paddingDim != e; ++paddingDim) { OpFoldResult paddingSize = paddingSizes[paddingDim]; LLVM_DEBUG(DBGS() << "------try apply padding of dim: " << paddingDim << " to: " << paddingSize << "\n"); if (!enResults.value().isFunctionOfDim(paddingDim)) continue; LLVM_DEBUG(DBGS() << "------apply padding of dim: " << paddingDim << " to: " << paddingSize << "\n"); // Project non-'paddingDim' dimensions and compress the result. llvm::SmallBitVector projectedDims(partialIndexingMap.getNumDims(), true); projectedDims.flip(paddingDim); AffineMap projectedMap = mlir::projectDims(partialIndexingMap, projectedDims, /*compressDimsFlag=*/true); // If we are padding to the next multiple of, compose with ceil(sz) * sz. OpFoldResult paddingDimOfr; if (options.padToMultipleOf) { AffineExpr d0, s0; bindDims(rewriter.getContext(), d0); bindSymbols(rewriter.getContext(), s0); AffineMap ceilMap = AffineMap::get(1, 1, d0.ceilDiv(s0) * s0); AffineMap composedMap = projectedMap.compose(ceilMap); paddingDimOfr = affine::makeComposedFoldedAffineApply( rewriter, loc, composedMap, {indexingSizes[paddingDim], paddingSize}, /*composeAffineMin=*/true); } else { // Otherwise just set to paddingSize. paddingDimOfr = affine::makeComposedFoldedAffineApply( rewriter, loc, projectedMap, paddingSize); } // Adjust for the maximum accessed index, which is (paddingSize - 1) * // multiplier. AffineExpr d0; bindDims(rewriter.getContext(), d0); int64_t multiplier = extractConstantMultiplier(projectedMap.getResult(0)); AffineMap subtractMap = AffineMap::get(1, 0, d0 - multiplier); OpFoldResult maxAccessIdx = affine::makeComposedFoldedAffineApply( rewriter, loc, subtractMap, {paddingDimOfr}); terms.push_back(maxAccessIdx); LLVM_DEBUG(DBGS() << "------new term: " << terms.back() << "\n"); } // If there are no terms, just return the dim. if (terms.empty()) { paddedShape[resultIndex] = createFoldedDimOp(rewriter, loc, v, resultIndex); continue; } // Sum individual terms' contributions. SmallVector dims(terms.size()); bindDimsList(rewriter.getContext(), MutableArrayRef{dims}); AffineExpr sumExpr = dims.front(); for (unsigned i = 1; i < dims.size(); ++i) sumExpr = sumExpr + dims[i]; // Add 1 to the maximum accessed index and get the final padded size. OpFoldResult paddedDimOfr = affine::makeComposedFoldedAffineApply( rewriter, loc, sumExpr + 1, terms); paddedShape[resultIndex] = paddedDimOfr; } return paddedShape; } FailureOr> linalg::computeIndexingMapOpInterfacePaddedShape( RewriterBase &rewriter, OpOperand &operandToPad, ArrayRef iterationDomain, const PadTilingInterfaceOptions &options) { auto transferOp = llvm::dyn_cast(operandToPad.getOwner()); if (!transferOp) return failure(); // clang-format off assert(llvm::all_of(iterationDomain, [&rewriter](Range r) { return r.offset == OpFoldResult(rewriter.getIndexAttr(0)) && r.stride == OpFoldResult(rewriter.getIndexAttr(1)); }) && "expected 0-offset 1-stride loop ranges"); // clang-format on SmallVector loopUpperBounds; loopUpperBounds.reserve(iterationDomain.size()); for (const Range &range : iterationDomain) loopUpperBounds.push_back(range.size); AffineMap indexingMap = transferOp.getMatchingIndexingMap(&operandToPad); return computePaddedShape( rewriter, cast>(operandToPad.get()), indexingMap, loopUpperBounds, options); } /// Pad a single operand to `paddedShape` using `paddingValueAttr` as padding /// Value. static Value padOperand(RewriterBase &rewriter, TilingInterface opToPad, TypedValue v, ArrayRef paddedShape, Attribute paddingValueAttr) { Value paddingValue; if (auto complexTy = dyn_cast(getElementTypeOrSelf(v.getType()))) { if (auto complexAttr = dyn_cast(paddingValueAttr)) { paddingValue = complex::ConstantOp::create(rewriter, opToPad.getLoc(), complexTy, complexAttr); } } else if (isa(paddingValueAttr)) { paddingValue = ub::PoisonOp::create(rewriter, opToPad.getLoc(), getElementTypeOrSelf(v.getType())); } else if (auto typedAttr = dyn_cast(paddingValueAttr)) { paddingValue = arith::ConstantOp::create(rewriter, opToPad.getLoc(), typedAttr); } assert(paddingValue && "failed to create value from padding attribute"); // Pad the operand to the bounding box defined by `paddedShape`. SmallVector tensorShape; SmallVector dynDims; for (OpFoldResult ofr : paddedShape) { std::optional cst = getConstantIntValue(ofr); tensorShape.push_back(cst.has_value() ? *cst : ShapedType::kDynamic); if (!cst.has_value()) dynDims.push_back(ofr.dyn_cast()); } // TODO: use dispatchIndexOpFoldResults(paddedShape, dynDims, paddedShape); auto paddedTensorType = RankedTensorType::get(tensorShape, getElementTypeOrSelf(v)); LLVM_DEBUG(DBGS() << "--SUCCESS, makeComposedPadHighOp with type: " << paddedTensorType); return makeComposedPadHighOp(rewriter, opToPad.getLoc(), paddedTensorType, v, paddingValue, /*nofold=*/false, dynDims); } FailureOr linalg::rewriteAsPaddedOp( RewriterBase &rewriter, TilingInterface opToPad, const PadTilingInterfaceOptions &constOptions, SmallVector &padOps, const PadSizeComputationFunction &computePaddingSizeFun) { LLVM_DEBUG(DBGS() << "Start rewriteAsPaddedOp : " << opToPad << "\n"); Location loc = opToPad.getLoc(); PadTilingInterfaceOptions options(constOptions); // Allow inference of pad values if they are not explicitly specified. // TODO: be mindful about the value depending on the actual operation. if (options.paddingValues.empty()) { SmallVector types(opToPad->getOperandTypes()); llvm::append_range(types, opToPad->getResultTypes()); for (Type t : types) { options.paddingValues.push_back( rewriter.getZeroAttr(getElementTypeOrSelf(t))); } } if (llvm::any_of(opToPad->getOperands(), [](Value v) { return isa(v.getType()); })) { return rewriter.notifyMatchFailure(opToPad, "expected operation on tensors"); } OpBuilder::InsertionGuard g(rewriter); // Set IP after opToPad because we also take the dims of opToPad's output. rewriter.setInsertionPointAfter(opToPad); // 1. Get the loopUpperBounds from the TilingInterface. SmallVector iterationDomain = opToPad.getIterationDomain(rewriter); // 2. For each operand. SmallVector newOperands; newOperands.reserve(opToPad->getNumOperands()); for (OpOperand &opOperand : opToPad->getOpOperands()) { Value operand = opOperand.get(); LLVM_DEBUG(DBGS() << "--start padding oprd: " << operand << "\n"); // 2.a. Skip scalar-like operands. Type operandType = operand.getType(); if (!isa(operandType)) { assert((!isa(operandType) || isa(operandType)) && "Unexpected non-vector ShapedType"); newOperands.push_back(operand); continue; } // 2.a. Compute padded shape. FailureOr> maybePaddedShape = computePaddingSizeFun(rewriter, opOperand, iterationDomain, options); if (failed(maybePaddedShape)) { return rewriter.notifyMatchFailure(opToPad, "could not pad op"); } // 2.b. Expect proper `paddingValues`. // TODO: we may want to allow garbage padding in the future, in which case // we would just not assert. if (opOperand.getOperandNumber() >= options.paddingValues.size()) { return rewriter.notifyMatchFailure(opToPad, "--no padding value specified"); } Attribute paddingValueAttr = options.paddingValues[opOperand.getOperandNumber()]; // 2.c. Perform actual padding. Value paddedOperand = padOperand( rewriter, opToPad, cast>(operand), *maybePaddedShape, paddingValueAttr); LLVM_DEBUG(DBGS() << "--done padding operand: " << paddedOperand << "\n"); // 2.d. Perform actual padding. newOperands.push_back(paddedOperand); if (auto padOp = paddedOperand.getDefiningOp()) padOps.push_back(padOp); } // 3. Form the resulting tensor::ExtractSliceOp. ReifiedRankedShapedTypeDims reifiedResultShapes; if (failed(reifyResultShapes(rewriter, opToPad, reifiedResultShapes))) { LLVM_DEBUG(DBGS() << "--failed to reify result shapes -> FAIL\n"); return rewriter.notifyMatchFailure(opToPad, "failed to reify result shapes"); } assert(reifiedResultShapes.size() == opToPad->getNumResults() && "expected same number of results"); // Clone `opToPad` to operate on the statically padded shapes. auto resultTensorTypes = ValueRange(newOperands).take_back(opToPad->getNumResults()).getTypes(); // clone **should** properly notify the rewriter. TilingInterface paddedOp = clone(rewriter, opToPad, resultTensorTypes, newOperands); LLVM_DEBUG(DBGS() << "--cloned padded op: " << paddedOp << "\n"); // Recover the slice out of the new static results. This keeps the original // opToPad around because it uses the dims of the original results. SmallVector paddedSubtensorResults; paddedSubtensorResults.reserve(opToPad->getNumResults()); for (const auto &en : llvm::enumerate(paddedOp->getResults())) { Value paddedResult = en.value(); int64_t resultNumber = en.index(); int64_t rank = cast(paddedResult.getType()).getRank(); SmallVector offsets(rank, rewriter.getIndexAttr(0)); SmallVector strides(rank, rewriter.getIndexAttr(1)); paddedSubtensorResults.push_back(tensor::ExtractSliceOp::create( rewriter, loc, paddedResult, offsets, reifiedResultShapes[resultNumber], strides)); } rewriter.replaceOp(opToPad, paddedSubtensorResults); return paddedOp; }