//===- RewriteAsConstant.cpp - Patterns to rewrite tensor ops as constants ===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Transforms/Transforms.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "llvm/ADT/TypeSwitch.h" using namespace mlir; using namespace mlir::tensor; namespace { /// Rewrite tensor.generate with arith.constant if the yielded value is a /// constant and the tensor type is static. struct GenerateToConstant : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(GenerateOp generateOp, PatternRewriter &rewriter) const override { auto tensorType = llvm::cast(generateOp.getResult().getType()); if (!tensorType.hasStaticShape()) return failure(); auto terminatorOp = cast(generateOp.getBody().front().getTerminator()); Attribute attr; if (!matchPattern(terminatorOp.getValue(), m_Constant(&attr))) return failure(); Operation *constantOp = rewriter.getContext() ->getLoadedDialect() ->materializeConstant(rewriter, DenseElementsAttr::get(tensorType, attr), tensorType, generateOp->getLoc()); if (!constantOp) return failure(); rewriter.replaceOp(generateOp, constantOp->getResults()); return success(); } }; /// Transform a linear index from one indexing space to another given: /// /// - the shape of the source indexing space, /// - the strides of the target indexing space, /// - a linear index into the source indexing space. /// /// This function is logically a sequence of linearize/delinearize over /// different bases but avoids allocating intermediate SmallVectors. int64_t transformIndexSpace(ArrayRef inputShape, ArrayRef outputStrides, int64_t srcLinearIndex) { assert(inputShape.size() == outputStrides.size()); int64_t dstLinearIndex = 0; for (int64_t dim = inputShape.size() - 1; dim >= 0; --dim) { // Compute the index into the current dimension of the source tensor. // `quotient` is the remaining linear index after accounting for the // current dimension. // // `remainder` is the index into the source tensor for the current // dimension. auto [quotient, remainder] = std::div(srcLinearIndex, inputShape[dim]); srcLinearIndex = quotient; // Add the contribution of the current dimension to the output using the // permutation map. dstLinearIndex += outputStrides[dim] * remainder; } return dstLinearIndex; } template Value constantFoldPadOp(PatternRewriter &rewriter, Location loc, DenseElementsAttr input, AttrType padValue, ArrayRef padLow, ArrayRef padHigh) { auto inputValues = input.tryGetValues(); if (failed(inputValues)) return nullptr; auto oldShape = input.getType().getShape(); // Compute the output shape of the new value. auto newShape = llvm::map_to_vector(llvm::zip(oldShape, padLow, padHigh), [](std::tuple pack) { auto [old, low, high] = pack; return old + low + high; }); int64_t outputSize = computeProduct(newShape); // Fully initialize the vector with the padding value. // The non-padded area will then be copied. SmallVector values(outputSize, padValue.getValue()); // Strides for input and output are used to transform between the indexing // space of the input and output tensors. SmallVector outputStrides = computeStrides(newShape); // The contribution of the low padding to the offset in the output tensor. // This is the starting position of the source tensor within the padding // tensor. int64_t startingOffset = linearize(padLow, outputStrides); // Copy values from the input tensor to the corresponding sub-region // of the output tensor. for (auto [inputIndex, inputValue] : llvm::enumerate(*inputValues)) { auto outputIndex = transformIndexSpace(oldShape, outputStrides, inputIndex); values[outputIndex + startingOffset] = inputValue; } // Create an attribute for the folded value. auto newType = input.getType().clone(newShape); auto newAttr = DenseElementsAttr::get(newType, values); Operation *constantOp = rewriter.getContext() ->getLoadedDialect() ->materializeConstant(rewriter, newAttr, newType, loc); return constantOp ? constantOp->getResult(0) : nullptr; } struct PadOpToConstant final : public OpRewritePattern { PadOpToConstant(MLIRContext *context, const ControlFoldFn &controlFn, PatternBenefit benefit = 1) : OpRewritePattern(context, benefit), controlFn{controlFn} {} LogicalResult matchAndRewrite(PadOp padTensorOp, PatternRewriter &rewriter) const override { if (padTensorOp.getNofold()) return rewriter.notifyMatchFailure( padTensorOp, "refusing to fold nofold pad operation"); TypedValue input = padTensorOp.getSource(); RankedTensorType resultType = padTensorOp.getResult().getType(); DenseElementsAttr inputAttr = nullptr; if (!matchPattern(input, m_Constant(&inputAttr))) return failure(); Value paddingValue = padTensorOp.getConstantPaddingValue(); // Extract the constant value used for padding or bail out. Attribute paddingAttr = nullptr; if (!paddingValue || !matchPattern(paddingValue, m_Constant(&paddingAttr))) return rewriter.notifyMatchFailure(padTensorOp, "unable to get constant value"); // Try to extract the constant values of the low and high padding. auto lowPad = getConstantIntValues(padTensorOp.getMixedLowPad()); auto highPad = getConstantIntValues(padTensorOp.getMixedHighPad()); // If the padding cannot be extracted, bail out. if (!lowPad || !highPad) return rewriter.notifyMatchFailure(padTensorOp, "unable to extract constant padding"); // We have a potential candidate, consult the control function to // determine if the op should fold. if (!controlFn(&padTensorOp.getSourceMutable())) return rewriter.notifyMatchFailure(padTensorOp, "not folding due to cost function"); Location loc = padTensorOp.getLoc(); // Try constant folding the supported cases of integer and float values. Value newOp = llvm::TypeSwitch(paddingAttr) .Case([&](FloatAttr floatAttr) { return constantFoldPadOp( rewriter, loc, inputAttr, floatAttr, *lowPad, *highPad); }) .Case([&](IntegerAttr integerAttr) { return constantFoldPadOp( rewriter, loc, inputAttr, integerAttr, *lowPad, *highPad); }) .Default(Value()); if (!newOp) return rewriter.notifyMatchFailure(padTensorOp, "tensor type not supported"); if (newOp.getType() != resultType) newOp = tensor::CastOp::create(rewriter, loc, resultType, newOp); rewriter.replaceOp(padTensorOp, newOp); return success(); } private: ControlFoldFn controlFn; }; } // namespace void mlir::tensor::populateRewriteAsConstantPatterns( RewritePatternSet &patterns, const ControlFoldFn &controlFn) { patterns.add(patterns.getContext()); patterns.add(patterns.getContext(), controlFn); }