//===- IntRangeOptimizations.cpp - Optimizations based on integer ranges --===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" #include "mlir/Analysis/DataFlow/Utils.h" #include "mlir/Analysis/DataFlowFramework.h" #include "mlir/Dialect/Arith/Transforms/Passes.h" #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace mlir::arith { #define GEN_PASS_DEF_ARITHINTRANGEOPTS #include "mlir/Dialect/Arith/Transforms/Passes.h.inc" #define GEN_PASS_DEF_ARITHINTRANGENARROWING #include "mlir/Dialect/Arith/Transforms/Passes.h.inc" } // namespace mlir::arith using namespace mlir; using namespace mlir::arith; using namespace mlir::dataflow; static std::optional getMaybeConstantValue(DataFlowSolver &solver, Value value) { auto *maybeInferredRange = solver.lookupState(value); if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized()) return std::nullopt; const ConstantIntRanges &inferredRange = maybeInferredRange->getValue().getValue(); return inferredRange.getConstantValue(); } static void copyIntegerRange(DataFlowSolver &solver, Value oldVal, Value newVal) { auto *oldState = solver.lookupState(oldVal); if (!oldState) return; (void)solver.getOrCreateState(newVal)->join( *oldState); } namespace mlir::dataflow { /// Patterned after SCCP LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver, RewriterBase &rewriter, Value value) { if (value.use_empty()) return failure(); std::optional maybeConstValue = getMaybeConstantValue(solver, value); if (!maybeConstValue.has_value()) return failure(); Type type = value.getType(); Location loc = value.getLoc(); Operation *maybeDefiningOp = value.getDefiningOp(); Dialect *valueDialect = maybeDefiningOp ? maybeDefiningOp->getDialect() : value.getParentRegion()->getParentOp()->getDialect(); Attribute constAttr; if (auto shaped = dyn_cast(type)) { constAttr = mlir::DenseIntElementsAttr::get(shaped, *maybeConstValue); } else { constAttr = rewriter.getIntegerAttr(type, *maybeConstValue); } Operation *constOp = valueDialect->materializeConstant(rewriter, constAttr, type, loc); // Fall back to arith.constant if the dialect materializer doesn't know what // to do with an integer constant. if (!constOp) constOp = rewriter.getContext() ->getLoadedDialect() ->materializeConstant(rewriter, constAttr, type, loc); if (!constOp) return failure(); OpResult res = constOp->getResult(0); if (solver.lookupState(res)) solver.eraseState(res); copyIntegerRange(solver, value, res); rewriter.replaceAllUsesWith(value, res); return success(); } } // namespace mlir::dataflow namespace { class DataFlowListener : public RewriterBase::Listener { public: DataFlowListener(DataFlowSolver &s) : s(s) {} protected: void notifyOperationErased(Operation *op) override { s.eraseState(s.getProgramPointAfter(op)); for (Value res : op->getResults()) s.eraseState(res); } DataFlowSolver &s; }; /// Rewrite any results of `op` that were inferred to be constant integers to /// and replace their uses with that constant. Return success() if all results /// where thus replaced and the operation is erased. Also replace any block /// arguments with their constant values. struct MaterializeKnownConstantValues : public RewritePattern { MaterializeKnownConstantValues(MLIRContext *context, DataFlowSolver &s) : RewritePattern::RewritePattern(Pattern::MatchAnyOpTypeTag(), /*benefit=*/1, context), solver(s) {} LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { if (matchPattern(op, m_Constant())) return failure(); auto needsReplacing = [&](Value v) { return getMaybeConstantValue(solver, v).has_value() && !v.use_empty(); }; bool hasConstantResults = llvm::any_of(op->getResults(), needsReplacing); if (op->getNumRegions() == 0) if (!hasConstantResults) return failure(); bool hasConstantRegionArgs = false; for (Region ®ion : op->getRegions()) { for (Block &block : region.getBlocks()) { hasConstantRegionArgs |= llvm::any_of(block.getArguments(), needsReplacing); } } if (!hasConstantResults && !hasConstantRegionArgs) return failure(); bool replacedAll = (op->getNumResults() != 0); for (Value v : op->getResults()) replacedAll &= (succeeded(maybeReplaceWithConstant(solver, rewriter, v)) || v.use_empty()); if (replacedAll && isOpTriviallyDead(op)) { rewriter.eraseOp(op); return success(); } PatternRewriter::InsertionGuard guard(rewriter); for (Region ®ion : op->getRegions()) { for (Block &block : region.getBlocks()) { rewriter.setInsertionPointToStart(&block); for (BlockArgument &arg : block.getArguments()) { (void)maybeReplaceWithConstant(solver, rewriter, arg); } } } return success(); } private: DataFlowSolver &solver; }; template struct DeleteTrivialRem : public OpRewritePattern { DeleteTrivialRem(MLIRContext *context, DataFlowSolver &s) : OpRewritePattern(context), solver(s) {} LogicalResult matchAndRewrite(RemOp op, PatternRewriter &rewriter) const override { Value lhs = op.getOperand(0); Value rhs = op.getOperand(1); auto maybeModulus = getConstantIntValue(rhs); if (!maybeModulus.has_value()) return failure(); int64_t modulus = *maybeModulus; if (modulus <= 0) return failure(); auto *maybeLhsRange = solver.lookupState(lhs); if (!maybeLhsRange || maybeLhsRange->getValue().isUninitialized()) return failure(); const ConstantIntRanges &lhsRange = maybeLhsRange->getValue().getValue(); const APInt &min = isa(op) ? lhsRange.umin() : lhsRange.smin(); const APInt &max = isa(op) ? lhsRange.umax() : lhsRange.smax(); // The minima and maxima here are given as closed ranges, we must be // strictly less than the modulus. if (min.isNegative() || min.uge(modulus)) return failure(); if (max.isNegative() || max.uge(modulus)) return failure(); if (!min.ule(max)) return failure(); // With all those conditions out of the way, we know thas this invocation of // a remainder is a noop because the input is strictly within the range // [0, modulus), so get rid of it. rewriter.replaceOp(op, ValueRange{lhs}); return success(); } private: DataFlowSolver &solver; }; /// Gather ranges for all the values in `values`. Appends to the existing /// vector. static LogicalResult collectRanges(DataFlowSolver &solver, ValueRange values, SmallVectorImpl &ranges) { for (Value val : values) { auto *maybeInferredRange = solver.lookupState(val); if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized()) return failure(); const ConstantIntRanges &inferredRange = maybeInferredRange->getValue().getValue(); ranges.push_back(inferredRange); } return success(); } /// Return int type truncated to `targetBitwidth`. If `srcType` is shaped, /// return shaped type as well. static Type getTargetType(Type srcType, unsigned targetBitwidth) { auto dstType = IntegerType::get(srcType.getContext(), targetBitwidth); if (auto shaped = dyn_cast(srcType)) return shaped.clone(dstType); assert(srcType.isIntOrIndex() && "Invalid src type"); return dstType; } namespace { // Enum for tracking which type of truncation should be performed // to narrow an operation, if any. enum class CastKind : uint8_t { None, Signed, Unsigned, Both }; } // namespace /// If the values within `range` can be represented using only `width` bits, /// return the kind of truncation needed to preserve that property. /// /// This check relies on the fact that the signed and unsigned ranges are both /// always correct, but that one might be an approximation of the other, /// so we want to use the correct truncation operation. static CastKind checkTruncatability(const ConstantIntRanges &range, unsigned targetWidth) { unsigned srcWidth = range.smin().getBitWidth(); if (srcWidth <= targetWidth) return CastKind::None; unsigned removedWidth = srcWidth - targetWidth; // The sign bits need to extend into the sign bit of the target width. For // example, if we're truncating 64 bits to 32, we need 64 - 32 + 1 = 33 sign // bits. bool canTruncateSigned = range.smin().getNumSignBits() >= (removedWidth + 1) && range.smax().getNumSignBits() >= (removedWidth + 1); bool canTruncateUnsigned = range.umin().countLeadingZeros() >= removedWidth && range.umax().countLeadingZeros() >= removedWidth; if (canTruncateSigned && canTruncateUnsigned) return CastKind::Both; if (canTruncateSigned) return CastKind::Signed; if (canTruncateUnsigned) return CastKind::Unsigned; return CastKind::None; } static CastKind mergeCastKinds(CastKind lhs, CastKind rhs) { if (lhs == CastKind::None || rhs == CastKind::None) return CastKind::None; if (lhs == CastKind::Both) return rhs; if (rhs == CastKind::Both) return lhs; if (lhs == rhs) return lhs; return CastKind::None; } static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType, CastKind castKind) { Type srcType = src.getType(); assert(isa(srcType) == isa(dstType) && "Mixing vector and non-vector types"); assert(castKind != CastKind::None && "Can't cast when casting isn't allowed"); Type srcElemType = getElementTypeOrSelf(srcType); Type dstElemType = getElementTypeOrSelf(dstType); assert(srcElemType.isIntOrIndex() && "Invalid src type"); assert(dstElemType.isIntOrIndex() && "Invalid dst type"); if (srcType == dstType) return src; if (isa(srcElemType) || isa(dstElemType)) { if (castKind == CastKind::Signed) return arith::IndexCastOp::create(builder, loc, dstType, src); return arith::IndexCastUIOp::create(builder, loc, dstType, src); } auto srcInt = cast(srcElemType); auto dstInt = cast(dstElemType); if (dstInt.getWidth() < srcInt.getWidth()) return arith::TruncIOp::create(builder, loc, dstType, src); if (castKind == CastKind::Signed) return arith::ExtSIOp::create(builder, loc, dstType, src); return arith::ExtUIOp::create(builder, loc, dstType, src); } struct NarrowElementwise final : OpTraitRewritePattern { NarrowElementwise(MLIRContext *context, DataFlowSolver &s, ArrayRef target) : OpTraitRewritePattern(context), solver(s), targetBitwidths(target) {} using OpTraitRewritePattern::OpTraitRewritePattern; LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { if (op->getNumResults() == 0) return rewriter.notifyMatchFailure(op, "can't narrow resultless op"); SmallVector ranges; if (failed(collectRanges(solver, op->getOperands(), ranges))) return rewriter.notifyMatchFailure(op, "input without specified range"); if (failed(collectRanges(solver, op->getResults(), ranges))) return rewriter.notifyMatchFailure(op, "output without specified range"); Type srcType = op->getResult(0).getType(); if (!llvm::all_equal(op->getResultTypes())) return rewriter.notifyMatchFailure(op, "mismatched result types"); if (op->getNumOperands() == 0 || !llvm::all_of(op->getOperandTypes(), [=](Type t) { return t == srcType; })) return rewriter.notifyMatchFailure( op, "no operands or operand types don't match result type"); for (unsigned targetBitwidth : targetBitwidths) { CastKind castKind = CastKind::Both; for (const ConstantIntRanges &range : ranges) { castKind = mergeCastKinds(castKind, checkTruncatability(range, targetBitwidth)); if (castKind == CastKind::None) break; } if (castKind == CastKind::None) continue; Type targetType = getTargetType(srcType, targetBitwidth); if (targetType == srcType) continue; Location loc = op->getLoc(); IRMapping mapping; for (auto [arg, argRange] : llvm::zip_first(op->getOperands(), ranges)) { CastKind argCastKind = castKind; // When dealing with `index` values, preserve non-negativity in the // index_casts since we can't recover this in unsigned when equivalent. if (argCastKind == CastKind::Signed && argRange.smin().isNonNegative()) argCastKind = CastKind::Both; Value newArg = doCast(rewriter, loc, arg, targetType, argCastKind); mapping.map(arg, newArg); } Operation *newOp = rewriter.clone(*op, mapping); rewriter.modifyOpInPlace(newOp, [&]() { for (OpResult res : newOp->getResults()) { res.setType(targetType); } }); SmallVector newResults; for (auto [newRes, oldRes] : llvm::zip_equal(newOp->getResults(), op->getResults())) { Value castBack = doCast(rewriter, loc, newRes, srcType, castKind); copyIntegerRange(solver, oldRes, castBack); newResults.push_back(castBack); } rewriter.replaceOp(op, newResults); return success(); } return failure(); } private: DataFlowSolver &solver; SmallVector targetBitwidths; }; struct NarrowCmpI final : OpRewritePattern { NarrowCmpI(MLIRContext *context, DataFlowSolver &s, ArrayRef target) : OpRewritePattern(context), solver(s), targetBitwidths(target) {} LogicalResult matchAndRewrite(arith::CmpIOp op, PatternRewriter &rewriter) const override { Value lhs = op.getLhs(); Value rhs = op.getRhs(); SmallVector ranges; if (failed(collectRanges(solver, op.getOperands(), ranges))) return failure(); const ConstantIntRanges &lhsRange = ranges[0]; const ConstantIntRanges &rhsRange = ranges[1]; Type srcType = lhs.getType(); for (unsigned targetBitwidth : targetBitwidths) { CastKind lhsCastKind = checkTruncatability(lhsRange, targetBitwidth); CastKind rhsCastKind = checkTruncatability(rhsRange, targetBitwidth); CastKind castKind = mergeCastKinds(lhsCastKind, rhsCastKind); // Note: this includes target width > src width. if (castKind == CastKind::None) continue; Type targetType = getTargetType(srcType, targetBitwidth); if (targetType == srcType) continue; Location loc = op->getLoc(); IRMapping mapping; Value lhsCast = doCast(rewriter, loc, lhs, targetType, lhsCastKind); Value rhsCast = doCast(rewriter, loc, rhs, targetType, rhsCastKind); mapping.map(lhs, lhsCast); mapping.map(rhs, rhsCast); Operation *newOp = rewriter.clone(*op, mapping); copyIntegerRange(solver, op.getResult(), newOp->getResult(0)); rewriter.replaceOp(op, newOp->getResults()); return success(); } return failure(); } private: DataFlowSolver &solver; SmallVector targetBitwidths; }; /// Fold index_cast(index_cast(%arg: i8, index), i8) -> %arg /// This pattern assumes all passed `targetBitwidths` are not wider than index /// type. template struct FoldIndexCastChain final : OpRewritePattern { FoldIndexCastChain(MLIRContext *context, ArrayRef target) : OpRewritePattern(context), targetBitwidths(target) {} LogicalResult matchAndRewrite(CastOp op, PatternRewriter &rewriter) const override { auto srcOp = op.getIn().template getDefiningOp(); if (!srcOp) return rewriter.notifyMatchFailure(op, "doesn't come from an index cast"); Value src = srcOp.getIn(); if (src.getType() != op.getType()) return rewriter.notifyMatchFailure(op, "outer types don't match"); if (!srcOp.getType().isIndex()) return rewriter.notifyMatchFailure(op, "intermediate type isn't index"); auto intType = dyn_cast(op.getType()); if (!intType || !llvm::is_contained(targetBitwidths, intType.getWidth())) return failure(); rewriter.replaceOp(op, src); return success(); } private: SmallVector targetBitwidths; }; struct NarrowLoopBounds final : OpInterfaceRewritePattern { NarrowLoopBounds(MLIRContext *context, DataFlowSolver &s, ArrayRef target) : OpInterfaceRewritePattern(context), solver(s), targetBitwidths(target), boundsNarrowingFailedAttr( StringAttr::get(context, "arith.bounds_narrowing_failed")) {} LogicalResult matchAndRewrite(LoopLikeOpInterface loopLike, PatternRewriter &rewriter) const override { // Skip ops where bounds narrowing previously failed. if (loopLike->hasAttr(boundsNarrowingFailedAttr)) return rewriter.notifyMatchFailure(loopLike, "bounds narrowing previously failed"); std::optional> inductionVars = loopLike.getLoopInductionVars(); if (!inductionVars.has_value() || inductionVars->empty()) return rewriter.notifyMatchFailure(loopLike, "no induction variables"); std::optional> lowerBounds = loopLike.getLoopLowerBounds(); std::optional> upperBounds = loopLike.getLoopUpperBounds(); std::optional> steps = loopLike.getLoopSteps(); if (!lowerBounds.has_value() || !upperBounds.has_value() || !steps.has_value()) return rewriter.notifyMatchFailure(loopLike, "no loop bounds or steps"); if (lowerBounds->size() != inductionVars->size() || upperBounds->size() != inductionVars->size() || steps->size() != inductionVars->size()) return rewriter.notifyMatchFailure(loopLike, "mismatched bounds/steps count"); Location loc = loopLike->getLoc(); SmallVector newLowerBounds(*lowerBounds); SmallVector newUpperBounds(*upperBounds); SmallVector newSteps(*steps); SmallVector> narrowings; // Check each (indVar, lb, ub, step) tuple. for (auto [idx, indVar, lbOFR, ubOFR, stepOFR] : llvm::enumerate(*inductionVars, *lowerBounds, *upperBounds, *steps)) { // Only process value operands, skip attributes. auto maybeLb = dyn_cast(lbOFR); auto maybeUb = dyn_cast(ubOFR); auto maybeStep = dyn_cast(stepOFR); if (!maybeLb || !maybeUb || !maybeStep) continue; // Collect ranges for (lb, ub, step, indVar). SmallVector ranges; if (failed(collectRanges( solver, ValueRange{maybeLb, maybeUb, maybeStep, indVar}, ranges))) continue; const ConstantIntRanges &stepRange = ranges[2]; const ConstantIntRanges &indVarRange = ranges[3]; Type srcType = maybeLb.getType(); // Try each target bitwidth. for (unsigned targetBitwidth : targetBitwidths) { Type targetType = getTargetType(srcType, targetBitwidth); if (targetType == srcType) continue; // Check if the target type is valid for this loop's induction // variables. if (!loopLike.isValidInductionVarType(targetType)) continue; // Check if all values in this tuple can be truncated. CastKind castKind = CastKind::Both; for (const ConstantIntRanges &range : ranges) { castKind = mergeCastKinds(castKind, checkTruncatability(range, targetBitwidth)); if (castKind == CastKind::None) break; } if (castKind == CastKind::None) continue; // Check if indVar + step fits in the narrowed type. // This is critical for loop correctness: the loop computes // iv_next = iv_current + step in the narrowed type, then compares // iv_next < ub. If iv_current + step overflows, the comparison may // produce incorrect results and break loop termination. // Both signed and unsigned interpretations must fit because loop // semantics are unknown (integer types are signless). ConstantIntRanges indVarPlusStepRange( indVarRange.smin().sadd_sat(stepRange.smin()), indVarRange.smax().sadd_sat(stepRange.smax()), indVarRange.umin().uadd_sat(stepRange.umin()), indVarRange.umax().uadd_sat(stepRange.umax())); if (checkTruncatability(indVarPlusStepRange, targetBitwidth) != CastKind::Both) continue; // Narrow the bounds and step values. Value newLb = doCast(rewriter, loc, maybeLb, targetType, castKind); Value newUb = doCast(rewriter, loc, maybeUb, targetType, castKind); Value newStep = doCast(rewriter, loc, maybeStep, targetType, castKind); newLowerBounds[idx] = newLb; newUpperBounds[idx] = newUb; newSteps[idx] = newStep; narrowings.push_back({idx, targetType, castKind}); break; } } if (narrowings.empty()) return rewriter.notifyMatchFailure(loopLike, "no narrowings found"); // Save original types before modifying. SmallVector origTypes; for (auto [idx, targetType, castKind] : narrowings) { Value indVar = (*inductionVars)[idx]; origTypes.push_back(indVar.getType()); } // Attempt to update bounds and induction variable types. // If this fails, mark the op so we don't try again. bool updateFailed = false; rewriter.modifyOpInPlace(loopLike, [&]() { // Update the loop bounds and steps. if (failed(loopLike.setLoopLowerBounds(newLowerBounds)) || failed(loopLike.setLoopUpperBounds(newUpperBounds)) || failed(loopLike.setLoopSteps(newSteps))) { // Mark op to prevent future attempts. IR was modified (attribute // added), so we must return success() from the pattern. loopLike->setAttr(boundsNarrowingFailedAttr, rewriter.getUnitAttr()); updateFailed = true; return; } // Update induction variable types. for (auto [idx, targetType, castKind] : narrowings) { Value indVar = (*inductionVars)[idx]; auto blockArg = cast(indVar); // Change the block argument type. blockArg.setType(targetType); } }); if (updateFailed) return success(); // Insert casts back to original type for uses. for (auto [narrowingIdx, narrowingInfo] : llvm::enumerate(narrowings)) { auto [idx, targetType, castKind] = narrowingInfo; Value indVar = (*inductionVars)[idx]; auto blockArg = cast(indVar); Type origType = origTypes[narrowingIdx]; OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(blockArg.getOwner()); Value casted = doCast(rewriter, loc, blockArg, origType, castKind); copyIntegerRange(solver, blockArg, casted); // Replace all uses of the narrowed indVar with the casted value. rewriter.replaceAllUsesExcept(blockArg, casted, casted.getDefiningOp()); } return success(); } private: DataFlowSolver &solver; SmallVector targetBitwidths; StringAttr boundsNarrowingFailedAttr; }; struct IntRangeOptimizationsPass final : arith::impl::ArithIntRangeOptsBase { void runOnOperation() override { Operation *op = getOperation(); MLIRContext *ctx = op->getContext(); DataFlowSolver solver; loadBaselineAnalyses(solver); solver.load(); if (failed(solver.initializeAndRun(op))) return signalPassFailure(); DataFlowListener listener(solver); RewritePatternSet patterns(ctx); populateIntRangeOptimizationsPatterns(patterns, solver); if (failed(applyPatternsGreedily( op, std::move(patterns), GreedyRewriteConfig().setListener(&listener)))) signalPassFailure(); } }; struct IntRangeNarrowingPass final : arith::impl::ArithIntRangeNarrowingBase { using ArithIntRangeNarrowingBase::ArithIntRangeNarrowingBase; void runOnOperation() override { Operation *op = getOperation(); MLIRContext *ctx = op->getContext(); DataFlowSolver solver; loadBaselineAnalyses(solver); solver.load(); if (failed(solver.initializeAndRun(op))) return signalPassFailure(); DataFlowListener listener(solver); RewritePatternSet patterns(ctx); populateIntRangeNarrowingPatterns(patterns, solver, bitwidthsSupported); populateControlFlowValuesNarrowingPatterns(patterns, solver, bitwidthsSupported); // We specifically need bottom-up traversal as cmpi pattern needs range // data, attached to its original argument values. if (failed(applyPatternsGreedily( op, std::move(patterns), GreedyRewriteConfig().setUseTopDownTraversal(false).setListener( &listener)))) signalPassFailure(); } }; } // namespace void mlir::arith::populateIntRangeOptimizationsPatterns( RewritePatternSet &patterns, DataFlowSolver &solver) { patterns.add, DeleteTrivialRem>(patterns.getContext(), solver); } void mlir::arith::populateIntRangeNarrowingPatterns( RewritePatternSet &patterns, DataFlowSolver &solver, ArrayRef bitwidthsSupported) { patterns.add(patterns.getContext(), solver, bitwidthsSupported); patterns.add, FoldIndexCastChain>(patterns.getContext(), bitwidthsSupported); } void mlir::arith::populateControlFlowValuesNarrowingPatterns( RewritePatternSet &patterns, DataFlowSolver &solver, ArrayRef bitwidthsSupported) { patterns.add(patterns.getContext(), solver, bitwidthsSupported); } std::unique_ptr mlir::arith::createIntRangeOptimizationsPass() { return std::make_unique(); }