diff options
Diffstat (limited to 'llvm/lib/Transforms/Scalar')
28 files changed, 1586 insertions, 619 deletions
diff --git a/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp b/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp index e84ca81..2f1f59c 100644 --- a/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp +++ b/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp @@ -625,8 +625,12 @@ private: NewPath.setDeterminator(PhiBB); NewPath.setExitValue(C); // Don't add SwitchBlock at the start, this is handled later. - if (IncomingBB != SwitchBlock) + if (IncomingBB != SwitchBlock) { + // Don't add a cycle to the path. + if (VB.contains(IncomingBB)) + continue; NewPath.push_back(IncomingBB); + } NewPath.push_back(PhiBB); Res.push_back(NewPath); continue; @@ -815,7 +819,12 @@ private: std::vector<ThreadingPath> TempList; for (const ThreadingPath &Path : PathsToPhiDef) { + SmallPtrSet<BasicBlock *, 32> PathSet(Path.getPath().begin(), + Path.getPath().end()); for (const PathType &PathToSw : PathsToSwitchBB) { + if (any_of(llvm::drop_begin(PathToSw), + [&](const BasicBlock *BB) { return PathSet.contains(BB); })) + continue; ThreadingPath PathCopy(Path); PathCopy.appendExcludingFirst(PathToSw); TempList.push_back(PathCopy); diff --git a/llvm/lib/Transforms/Scalar/DropUnnecessaryAssumes.cpp b/llvm/lib/Transforms/Scalar/DropUnnecessaryAssumes.cpp index 89980d5..4a7144f 100644 --- a/llvm/lib/Transforms/Scalar/DropUnnecessaryAssumes.cpp +++ b/llvm/lib/Transforms/Scalar/DropUnnecessaryAssumes.cpp @@ -78,11 +78,16 @@ DropUnnecessaryAssumesPass::run(Function &F, FunctionAnalysisManager &FAM) { SmallVector<OperandBundleDef> KeptBundles; unsigned NumBundles = Assume->getNumOperandBundles(); for (unsigned I = 0; I != NumBundles; ++I) { - auto IsDead = [](OperandBundleUse Bundle) { + auto IsDead = [&](OperandBundleUse Bundle) { // "ignore" operand bundles are always dead. if (Bundle.getTagName() == "ignore") return true; + // "dereferenceable" operand bundles are only dropped if requested + // (e.g., after loop vectorization has run). + if (Bundle.getTagName() == "dereferenceable") + return DropDereferenceable; + // Bundles without arguments do not affect any specific values. // Always keep them for now. if (Bundle.Inputs.empty()) @@ -122,7 +127,8 @@ DropUnnecessaryAssumesPass::run(Function &F, FunctionAnalysisManager &FAM) { Value *Cond = Assume->getArgOperand(0); // Don't drop type tests, which have special semantics. - if (match(Cond, m_Intrinsic<Intrinsic::type_test>())) + if (match(Cond, m_Intrinsic<Intrinsic::type_test>()) || + match(Cond, m_Intrinsic<Intrinsic::public_type_test>())) continue; SmallVector<Value *> Affected; diff --git a/llvm/lib/Transforms/Scalar/Float2Int.cpp b/llvm/lib/Transforms/Scalar/Float2Int.cpp index 14686ce..37822cf 100644 --- a/llvm/lib/Transforms/Scalar/Float2Int.cpp +++ b/llvm/lib/Transforms/Scalar/Float2Int.cpp @@ -237,10 +237,14 @@ std::optional<ConstantRange> Float2IntPass::calcRange(Instruction *I) { // OK, it's representable. Now get it. APSInt Int(MaxIntegerBW+1, false); bool Exact; - CF->getValueAPF().convertToInteger(Int, - APFloat::rmNearestTiesToEven, - &Exact); - OpRanges.push_back(ConstantRange(Int)); + APFloat::opStatus Status = CF->getValueAPF().convertToInteger( + Int, APFloat::rmNearestTiesToEven, &Exact); + // Although the round above is loseless, we still need to check if the + // floating-point value can be represented in the integer type. + if (Status == APFloat::opOK || Status == APFloat::opInexact) + OpRanges.push_back(ConstantRange(Int)); + else + return badRange(); } else { llvm_unreachable("Should have already marked this as badRange!"); } diff --git a/llvm/lib/Transforms/Scalar/GVNSink.cpp b/llvm/lib/Transforms/Scalar/GVNSink.cpp index a06f832..4dddb01 100644 --- a/llvm/lib/Transforms/Scalar/GVNSink.cpp +++ b/llvm/lib/Transforms/Scalar/GVNSink.cpp @@ -142,7 +142,7 @@ public: for (unsigned I = 0, E = PN->getNumIncomingValues(); I != E; ++I) Ops.push_back({PN->getIncomingBlock(I), PN->getIncomingValue(I)}); - auto ComesBefore = [BlockOrder](OpsType O1, OpsType O2) { + auto ComesBefore = [&](OpsType O1, OpsType O2) { return BlockOrder.lookup(O1.first) < BlockOrder.lookup(O2.first); }; // Sort in a deterministic order. @@ -167,8 +167,8 @@ public: verifyModelledPHI(const DenseMap<const BasicBlock *, unsigned> &BlockOrder) { assert(Values.size() > 1 && Blocks.size() > 1 && "Modelling PHI with less than 2 values"); - auto ComesBefore = [BlockOrder](const BasicBlock *BB1, - const BasicBlock *BB2) { + [[maybe_unused]] auto ComesBefore = [&](const BasicBlock *BB1, + const BasicBlock *BB2) { return BlockOrder.lookup(BB1) < BlockOrder.lookup(BB2); }; assert(llvm::is_sorted(Blocks, ComesBefore)); @@ -514,7 +514,7 @@ public: class GVNSink { public: - GVNSink() {} + GVNSink() = default; bool run(Function &F) { LLVM_DEBUG(dbgs() << "GVNSink: running on function @" << F.getName() diff --git a/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp b/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp index 4ba4ba3..400cb1e 100644 --- a/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp +++ b/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp @@ -162,6 +162,8 @@ class IndVarSimplify { const SCEV *ExitCount, PHINode *IndVar, SCEVExpander &Rewriter); + bool sinkUnusedInvariants(Loop *L); + public: IndVarSimplify(LoopInfo *LI, ScalarEvolution *SE, DominatorTree *DT, const DataLayout &DL, TargetLibraryInfo *TLI, @@ -196,181 +198,267 @@ static bool ConvertToSInt(const APFloat &APF, int64_t &IntVal) { return true; } -/// If the loop has floating induction variable then insert corresponding -/// integer induction variable if possible. -/// For example, -/// for(double i = 0; i < 10000; ++i) -/// bar(i) -/// is converted into -/// for(int i = 0; i < 10000; ++i) -/// bar((double)i); -bool IndVarSimplify::handleFloatingPointIV(Loop *L, PHINode *PN) { +/// Ensure we stay within the bounds of fp values that can be represented as +/// integers without gaps, which are 2^24 and 2^53 for IEEE-754 single and +/// double precision respectively (both on negative and positive side). +static bool isRepresentableAsExactInteger(const APFloat &FPVal, + int64_t IntVal) { + const auto &FltSema = FPVal.getSemantics(); + if (!APFloat::isIEEELikeFP(FltSema)) + return false; + return isUIntN(APFloat::semanticsPrecision(FltSema), AbsoluteValue(IntVal)); +} + +/// Represents a floating-point induction variable pattern that may be +/// convertible to integer form. +struct FloatingPointIV { + APFloat InitValue; + APFloat IncrValue; + APFloat ExitValue; + FCmpInst *Compare; + BinaryOperator *Add; + + FloatingPointIV(APFloat Init, APFloat Incr, APFloat Exit, FCmpInst *Compare, + BinaryOperator *Add) + : InitValue(std::move(Init)), IncrValue(std::move(Incr)), + ExitValue(std::move(Exit)), Compare(Compare), Add(Add) {} +}; + +/// Represents the integer values for a converted IV. +struct IntegerIV { + int64_t InitValue; + int64_t IncrValue; + int64_t ExitValue; + CmpInst::Predicate NewPred; +}; + +static CmpInst::Predicate getIntegerPredicate(CmpInst::Predicate FPPred) { + switch (FPPred) { + case CmpInst::FCMP_OEQ: + case CmpInst::FCMP_UEQ: + return CmpInst::ICMP_EQ; + case CmpInst::FCMP_ONE: + case CmpInst::FCMP_UNE: + return CmpInst::ICMP_NE; + case CmpInst::FCMP_OGT: + case CmpInst::FCMP_UGT: + return CmpInst::ICMP_SGT; + case CmpInst::FCMP_OGE: + case CmpInst::FCMP_UGE: + return CmpInst::ICMP_SGE; + case CmpInst::FCMP_OLT: + case CmpInst::FCMP_ULT: + return CmpInst::ICMP_SLT; + case CmpInst::FCMP_OLE: + case CmpInst::FCMP_ULE: + return CmpInst::ICMP_SLE; + default: + return CmpInst::BAD_ICMP_PREDICATE; + } +} + +/// Analyze a PN to determine whether it represents a simple floating-point +/// induction variable, with constant fp init, increment, and exit values. +/// +/// Returns a FloatingPointIV struct if matched, std::nullopt otherwise. +static std::optional<FloatingPointIV> +maybeFloatingPointRecurrence(Loop *L, PHINode *PN) { + // Identify incoming and backedge for the PN. unsigned IncomingEdge = L->contains(PN->getIncomingBlock(0)); - unsigned BackEdge = IncomingEdge^1; + unsigned BackEdge = IncomingEdge ^ 1; // Check incoming value. auto *InitValueVal = dyn_cast<ConstantFP>(PN->getIncomingValue(IncomingEdge)); - - int64_t InitValue; - if (!InitValueVal || !ConvertToSInt(InitValueVal->getValueAPF(), InitValue)) - return false; + if (!InitValueVal) + return std::nullopt; // Check IV increment. Reject this PN if increment operation is not // an add or increment value can not be represented by an integer. auto *Incr = dyn_cast<BinaryOperator>(PN->getIncomingValue(BackEdge)); - if (Incr == nullptr || Incr->getOpcode() != Instruction::FAdd) return false; + if (!Incr || Incr->getOpcode() != Instruction::FAdd) + return std::nullopt; // If this is not an add of the PHI with a constantfp, or if the constant fp // is not an integer, bail out. - ConstantFP *IncValueVal = dyn_cast<ConstantFP>(Incr->getOperand(1)); - int64_t IncValue; - if (IncValueVal == nullptr || Incr->getOperand(0) != PN || - !ConvertToSInt(IncValueVal->getValueAPF(), IncValue)) - return false; + auto *IncValueVal = dyn_cast<ConstantFP>(Incr->getOperand(1)); + if (!IncValueVal || Incr->getOperand(0) != PN) + return std::nullopt; // Check Incr uses. One user is PN and the other user is an exit condition // used by the conditional terminator. - Value::user_iterator IncrUse = Incr->user_begin(); - Instruction *U1 = cast<Instruction>(*IncrUse++); - if (IncrUse == Incr->user_end()) return false; - Instruction *U2 = cast<Instruction>(*IncrUse++); - if (IncrUse != Incr->user_end()) return false; + // TODO: Should relax this, so as to allow any `fpext` that may occur. + if (!Incr->hasNUses(2)) + return std::nullopt; // Find exit condition, which is an fcmp. If it doesn't exist, or if it isn't // only used by a branch, we can't transform it. - FCmpInst *Compare = dyn_cast<FCmpInst>(U1); - if (!Compare) - Compare = dyn_cast<FCmpInst>(U2); - if (!Compare || !Compare->hasOneUse() || - !isa<BranchInst>(Compare->user_back())) - return false; + auto It = llvm::find_if(Incr->users(), + [](const User *U) { return isa<FCmpInst>(U); }); + if (It == Incr->users().end()) + return std::nullopt; - BranchInst *TheBr = cast<BranchInst>(Compare->user_back()); + FCmpInst *Compare = cast<FCmpInst>(*It); + if (!Compare->hasOneUse()) + return std::nullopt; // We need to verify that the branch actually controls the iteration count // of the loop. If not, the new IV can overflow and no one will notice. // The branch block must be in the loop and one of the successors must be out // of the loop. - assert(TheBr->isConditional() && "Can't use fcmp if not conditional"); - if (!L->contains(TheBr->getParent()) || - (L->contains(TheBr->getSuccessor(0)) && - L->contains(TheBr->getSuccessor(1)))) - return false; + auto *BI = dyn_cast<BranchInst>(Compare->user_back()); + if (!BI) + return std::nullopt; + + assert(BI->isConditional() && "Can't use fcmp if not conditional"); + if (!L->contains(BI->getParent()) || + (L->contains(BI->getSuccessor(0)) && L->contains(BI->getSuccessor(1)))) + return std::nullopt; // If it isn't a comparison with an integer-as-fp (the exit value), we can't // transform it. - ConstantFP *ExitValueVal = dyn_cast<ConstantFP>(Compare->getOperand(1)); - int64_t ExitValue; - if (ExitValueVal == nullptr || - !ConvertToSInt(ExitValueVal->getValueAPF(), ExitValue)) - return false; + auto *ExitValueVal = dyn_cast<ConstantFP>(Compare->getOperand(1)); + if (!ExitValueVal) + return std::nullopt; - // Find new predicate for integer comparison. - CmpInst::Predicate NewPred = CmpInst::BAD_ICMP_PREDICATE; - switch (Compare->getPredicate()) { - default: return false; // Unknown comparison. - case CmpInst::FCMP_OEQ: - case CmpInst::FCMP_UEQ: NewPred = CmpInst::ICMP_EQ; break; - case CmpInst::FCMP_ONE: - case CmpInst::FCMP_UNE: NewPred = CmpInst::ICMP_NE; break; - case CmpInst::FCMP_OGT: - case CmpInst::FCMP_UGT: NewPred = CmpInst::ICMP_SGT; break; - case CmpInst::FCMP_OGE: - case CmpInst::FCMP_UGE: NewPred = CmpInst::ICMP_SGE; break; - case CmpInst::FCMP_OLT: - case CmpInst::FCMP_ULT: NewPred = CmpInst::ICMP_SLT; break; - case CmpInst::FCMP_OLE: - case CmpInst::FCMP_ULE: NewPred = CmpInst::ICMP_SLE; break; - } + return FloatingPointIV(InitValueVal->getValueAPF(), + IncValueVal->getValueAPF(), + ExitValueVal->getValueAPF(), Compare, Incr); +} + +/// Ensure that the floating-point IV can be converted to a semantics-preserving +/// signed 32-bit integer IV. +/// +/// Returns a IntegerIV struct if possible, std::nullopt otherwise. +static std::optional<IntegerIV> +tryConvertToIntegerIV(const FloatingPointIV &FPIV) { + // Convert floating-point predicate to integer. + auto NewPred = getIntegerPredicate(FPIV.Compare->getPredicate()); + if (NewPred == CmpInst::BAD_ICMP_PREDICATE) + return std::nullopt; + + // Convert APFloat values to signed integers. + int64_t InitValue, IncrValue, ExitValue; + if (!ConvertToSInt(FPIV.InitValue, InitValue) || + !ConvertToSInt(FPIV.IncrValue, IncrValue) || + !ConvertToSInt(FPIV.ExitValue, ExitValue)) + return std::nullopt; + + // Bail out if integers cannot be represented exactly. + if (!isRepresentableAsExactInteger(FPIV.InitValue, InitValue) || + !isRepresentableAsExactInteger(FPIV.ExitValue, ExitValue)) + return std::nullopt; // We convert the floating point induction variable to a signed i32 value if - // we can. This is only safe if the comparison will not overflow in a way - // that won't be trapped by the integer equivalent operations. Check for this - // now. + // we can. This is only safe if the comparison will not overflow in a way that + // won't be trapped by the integer equivalent operations. Check for this now. // TODO: We could use i64 if it is native and the range requires it. // The start/stride/exit values must all fit in signed i32. - if (!isInt<32>(InitValue) || !isInt<32>(IncValue) || !isInt<32>(ExitValue)) - return false; + if (!isInt<32>(InitValue) || !isInt<32>(IncrValue) || !isInt<32>(ExitValue)) + return std::nullopt; // If not actually striding (add x, 0.0), avoid touching the code. - if (IncValue == 0) - return false; + if (IncrValue == 0) + return std::nullopt; // Positive and negative strides have different safety conditions. - if (IncValue > 0) { + if (IncrValue > 0) { // If we have a positive stride, we require the init to be less than the // exit value. if (InitValue >= ExitValue) - return false; + return std::nullopt; - uint32_t Range = uint32_t(ExitValue-InitValue); + uint32_t Range = uint32_t(ExitValue - InitValue); // Check for infinite loop, either: // while (i <= Exit) or until (i > Exit) if (NewPred == CmpInst::ICMP_SLE || NewPred == CmpInst::ICMP_SGT) { - if (++Range == 0) return false; // Range overflows. + if (++Range == 0) + return std::nullopt; // Range overflows. } - unsigned Leftover = Range % uint32_t(IncValue); + unsigned Leftover = Range % uint32_t(IncrValue); // If this is an equality comparison, we require that the strided value // exactly land on the exit value, otherwise the IV condition will wrap // around and do things the fp IV wouldn't. if ((NewPred == CmpInst::ICMP_EQ || NewPred == CmpInst::ICMP_NE) && Leftover != 0) - return false; + return std::nullopt; // If the stride would wrap around the i32 before exiting, we can't // transform the IV. - if (Leftover != 0 && int32_t(ExitValue+IncValue) < ExitValue) - return false; + if (Leftover != 0 && int32_t(ExitValue + IncrValue) < ExitValue) + return std::nullopt; } else { // If we have a negative stride, we require the init to be greater than the // exit value. if (InitValue <= ExitValue) - return false; + return std::nullopt; - uint32_t Range = uint32_t(InitValue-ExitValue); + uint32_t Range = uint32_t(InitValue - ExitValue); // Check for infinite loop, either: // while (i >= Exit) or until (i < Exit) if (NewPred == CmpInst::ICMP_SGE || NewPred == CmpInst::ICMP_SLT) { - if (++Range == 0) return false; // Range overflows. + if (++Range == 0) + return std::nullopt; // Range overflows. } - unsigned Leftover = Range % uint32_t(-IncValue); + unsigned Leftover = Range % uint32_t(-IncrValue); // If this is an equality comparison, we require that the strided value // exactly land on the exit value, otherwise the IV condition will wrap // around and do things the fp IV wouldn't. if ((NewPred == CmpInst::ICMP_EQ || NewPred == CmpInst::ICMP_NE) && Leftover != 0) - return false; + return std::nullopt; // If the stride would wrap around the i32 before exiting, we can't // transform the IV. - if (Leftover != 0 && int32_t(ExitValue+IncValue) > ExitValue) - return false; + if (Leftover != 0 && int32_t(ExitValue + IncrValue) > ExitValue) + return std::nullopt; } + return IntegerIV{InitValue, IncrValue, ExitValue, NewPred}; +} + +/// Rewrite the floating-point IV as an integer IV. +static void canonicalizeToIntegerIV(Loop *L, PHINode *PN, + const FloatingPointIV &FPIV, + const IntegerIV &IIV, + const TargetLibraryInfo *TLI, + std::unique_ptr<MemorySSAUpdater> &MSSAU) { + unsigned IncomingEdge = L->contains(PN->getIncomingBlock(0)); + unsigned BackEdge = IncomingEdge ^ 1; + IntegerType *Int32Ty = Type::getInt32Ty(PN->getContext()); + auto *Incr = cast<BinaryOperator>(PN->getIncomingValue(BackEdge)); + auto *BI = cast<BranchInst>(FPIV.Compare->user_back()); + + LLVM_DEBUG(dbgs() << "INDVARS: Rewriting floating-point IV to integer IV:\n" + << " Init: " << IIV.InitValue << "\n" + << " Incr: " << IIV.IncrValue << "\n" + << " Exit: " << IIV.ExitValue << "\n" + << " Pred: " << CmpInst::getPredicateName(IIV.NewPred) + << "\n" + << " Original PN: " << *PN << "\n"); // Insert new integer induction variable. PHINode *NewPHI = PHINode::Create(Int32Ty, 2, PN->getName() + ".int", PN->getIterator()); - NewPHI->addIncoming(ConstantInt::getSigned(Int32Ty, InitValue), + NewPHI->addIncoming(ConstantInt::getSigned(Int32Ty, IIV.InitValue), PN->getIncomingBlock(IncomingEdge)); NewPHI->setDebugLoc(PN->getDebugLoc()); Instruction *NewAdd = BinaryOperator::CreateAdd( - NewPHI, ConstantInt::getSigned(Int32Ty, IncValue), + NewPHI, ConstantInt::getSigned(Int32Ty, IIV.IncrValue), Incr->getName() + ".int", Incr->getIterator()); NewAdd->setDebugLoc(Incr->getDebugLoc()); NewPHI->addIncoming(NewAdd, PN->getIncomingBlock(BackEdge)); ICmpInst *NewCompare = new ICmpInst( - TheBr->getIterator(), NewPred, NewAdd, - ConstantInt::getSigned(Int32Ty, ExitValue), Compare->getName()); - NewCompare->setDebugLoc(Compare->getDebugLoc()); + BI->getIterator(), IIV.NewPred, NewAdd, + ConstantInt::getSigned(Int32Ty, IIV.ExitValue), FPIV.Compare->getName()); + NewCompare->setDebugLoc(FPIV.Compare->getDebugLoc()); // In the following deletions, PN may become dead and may be deleted. // Use a WeakTrackingVH to observe whether this happens. @@ -378,9 +466,9 @@ bool IndVarSimplify::handleFloatingPointIV(Loop *L, PHINode *PN) { // Delete the old floating point exit comparison. The branch starts using the // new comparison. - NewCompare->takeName(Compare); - Compare->replaceAllUsesWith(NewCompare); - RecursivelyDeleteTriviallyDeadInstructions(Compare, TLI, MSSAU.get()); + NewCompare->takeName(FPIV.Compare); + FPIV.Compare->replaceAllUsesWith(NewCompare); + RecursivelyDeleteTriviallyDeadInstructions(FPIV.Compare, TLI, MSSAU.get()); // Delete the old floating point increment. Incr->replaceAllUsesWith(PoisonValue::get(Incr->getType())); @@ -400,6 +488,28 @@ bool IndVarSimplify::handleFloatingPointIV(Loop *L, PHINode *PN) { PN->replaceAllUsesWith(Conv); RecursivelyDeleteTriviallyDeadInstructions(PN, TLI, MSSAU.get()); } +} + +/// If the loop has a floating induction variable, then insert corresponding +/// integer induction variable if possible. For example, the following: +/// for(double i = 0; i < 10000; ++i) +/// bar(i) +/// is converted into +/// for(int i = 0; i < 10000; ++i) +/// bar((double)i); +bool IndVarSimplify::handleFloatingPointIV(Loop *L, PHINode *PN) { + // See if the PN matches a floating-point IV pattern. + auto FPIV = maybeFloatingPointRecurrence(L, PN); + if (!FPIV) + return false; + + // Can we safely convert the floating-point values to integer ones? + auto IIV = tryConvertToIntegerIV(*FPIV); + if (!IIV) + return false; + + // Perform the rewriting. + canonicalizeToIntegerIV(L, PN, *FPIV, *IIV, TLI, MSSAU); return true; } @@ -1077,6 +1187,85 @@ linearFunctionTestReplace(Loop *L, BasicBlock *ExitingBB, return true; } +//===----------------------------------------------------------------------===// +// sinkUnusedInvariants. A late subpass to cleanup loop preheaders. +//===----------------------------------------------------------------------===// + +/// If there's a single exit block, sink any loop-invariant values that +/// were defined in the preheader but not used inside the loop into the +/// exit block to reduce register pressure in the loop. +bool IndVarSimplify::sinkUnusedInvariants(Loop *L) { + BasicBlock *ExitBlock = L->getExitBlock(); + if (!ExitBlock) return false; + + BasicBlock *Preheader = L->getLoopPreheader(); + if (!Preheader) return false; + + bool MadeAnyChanges = false; + for (Instruction &I : llvm::make_early_inc_range(llvm::reverse(*Preheader))) { + + // Skip BB Terminator. + if (Preheader->getTerminator() == &I) + continue; + + // New instructions were inserted at the end of the preheader. + if (isa<PHINode>(I)) + break; + + // Don't move instructions which might have side effects, since the side + // effects need to complete before instructions inside the loop. Also don't + // move instructions which might read memory, since the loop may modify + // memory. Note that it's okay if the instruction might have undefined + // behavior: LoopSimplify guarantees that the preheader dominates the exit + // block. + if (I.mayHaveSideEffects() || I.mayReadFromMemory()) + continue; + + // Skip debug or pseudo instructions. + if (I.isDebugOrPseudoInst()) + continue; + + // Skip eh pad instructions. + if (I.isEHPad()) + continue; + + // Don't sink alloca: we never want to sink static alloca's out of the + // entry block, and correctly sinking dynamic alloca's requires + // checks for stacksave/stackrestore intrinsics. + // FIXME: Refactor this check somehow? + if (isa<AllocaInst>(&I)) + continue; + + // Determine if there is a use in or before the loop (direct or + // otherwise). + bool UsedInLoop = false; + for (Use &U : I.uses()) { + Instruction *User = cast<Instruction>(U.getUser()); + BasicBlock *UseBB = User->getParent(); + if (PHINode *P = dyn_cast<PHINode>(User)) { + unsigned i = + PHINode::getIncomingValueNumForOperand(U.getOperandNo()); + UseBB = P->getIncomingBlock(i); + } + if (UseBB == Preheader || L->contains(UseBB)) { + UsedInLoop = true; + break; + } + } + + // If there is, the def must remain in the preheader. + if (UsedInLoop) + continue; + + // Otherwise, sink it to the exit block. + I.moveBefore(ExitBlock->getFirstInsertionPt()); + SE->forgetValue(&I); + MadeAnyChanges = true; + } + + return MadeAnyChanges; +} + static void replaceExitCond(BranchInst *BI, Value *NewCond, SmallVectorImpl<WeakTrackingVH> &DeadInsts) { auto *OldCond = BI->getCondition(); @@ -1760,7 +1949,7 @@ bool IndVarSimplify::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) { // is that enough for *all* side effects? bool HasThreadLocalSideEffects = false; for (BasicBlock *BB : L->blocks()) - for (auto &I : *BB) + for (auto &I : *BB) { // TODO:isGuaranteedToTransfer if (I.mayHaveSideEffects()) { if (!LoopPredicationTraps) @@ -1778,6 +1967,18 @@ bool IndVarSimplify::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) { } } + // Skip if the loop has tokens referenced outside the loop to avoid + // changing convergence behavior. + if (I.getType()->isTokenTy()) { + for (User *U : I.users()) { + Instruction *UserInst = dyn_cast<Instruction>(U); + if (UserInst && !L->contains(UserInst)) { + return false; + } + } + } + } + bool Changed = false; // Finally, do the actual predication for all predicatable blocks. A couple // of notes here: @@ -1984,6 +2185,10 @@ bool IndVarSimplify::run(Loop *L) { // The Rewriter may not be used from this point on. + // Loop-invariant instructions in the preheader that aren't used in the + // loop may be sunk below the loop to reduce register pressure. + Changed |= sinkUnusedInvariants(L); + // rewriteFirstIterationLoopExitValues does not rely on the computation of // trip count and therefore can further simplify exit values in addition to // rewriteLoopExitValues. diff --git a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp index 3ad8754..352a1b3 100644 --- a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp +++ b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp @@ -617,6 +617,41 @@ InferAddressSpacesImpl::collectFlatAddressExpressions(Function &F) const { return Postorder; } +// Inserts an addrspacecast for a phi node operand, handling the proper +// insertion position based on the operand type. +static Value *phiNodeOperandWithNewAddressSpace(AddrSpaceCastInst *NewI, + Value *Operand) { + auto InsertBefore = [NewI](auto It) { + NewI->insertBefore(It); + NewI->setDebugLoc(It->getDebugLoc()); + return NewI; + }; + + if (auto *Arg = dyn_cast<Argument>(Operand)) { + // For arguments, insert the cast at the beginning of entry block. + // Consider inserting at the dominating block for better placement. + Function *F = Arg->getParent(); + auto InsertI = F->getEntryBlock().getFirstNonPHIIt(); + return InsertBefore(InsertI); + } + + // No check for Constant here, as constants are already handled. + assert(isa<Instruction>(Operand)); + + Instruction *OpInst = cast<Instruction>(Operand); + if (LLVM_UNLIKELY(OpInst->getOpcode() == Instruction::PHI)) { + // If the operand is defined by another PHI node, insert after the first + // non-PHI instruction at the corresponding basic block. + auto InsertI = OpInst->getParent()->getFirstNonPHIIt(); + return InsertBefore(InsertI); + } + + // Otherwise, insert immediately after the operand definition. + NewI->insertAfter(OpInst->getIterator()); + NewI->setDebugLoc(OpInst->getDebugLoc()); + return NewI; +} + // A helper function for cloneInstructionWithNewAddressSpace. Returns the clone // of OperandUse.get() in the new address space. If the clone is not ready yet, // returns poison in the new address space as a placeholder. @@ -642,6 +677,10 @@ static Value *operandWithNewAddressSpaceOrCreatePoison( unsigned NewAS = I->second; Type *NewPtrTy = getPtrOrVecOfPtrsWithNewAS(Operand->getType(), NewAS); auto *NewI = new AddrSpaceCastInst(Operand, NewPtrTy); + + if (LLVM_UNLIKELY(Inst->getOpcode() == Instruction::PHI)) + return phiNodeOperandWithNewAddressSpace(NewI, Operand); + NewI->insertBefore(Inst->getIterator()); NewI->setDebugLoc(Inst->getDebugLoc()); return NewI; diff --git a/llvm/lib/Transforms/Scalar/InferAlignment.cpp b/llvm/lib/Transforms/Scalar/InferAlignment.cpp index 39751c0..93de58f 100644 --- a/llvm/lib/Transforms/Scalar/InferAlignment.cpp +++ b/llvm/lib/Transforms/Scalar/InferAlignment.cpp @@ -12,15 +12,20 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Scalar/InferAlignment.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/STLFunctionalExtras.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/Support/KnownBits.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/Local.h" using namespace llvm; +using namespace llvm::PatternMatch; static bool tryToImproveAlign( const DataLayout &DL, Instruction *I, @@ -37,6 +42,21 @@ static bool tryToImproveAlign( } } + Value *PtrOp; + const APInt *Const; + if (match(I, m_And(m_PtrToIntOrAddr(m_Value(PtrOp)), m_APInt(Const)))) { + Align ActualAlign = Fn(PtrOp, Align(1), Align(1)); + if (Const->ult(ActualAlign.value())) { + I->replaceAllUsesWith(Constant::getNullValue(I->getType())); + return true; + } + if (Const->uge( + APInt::getBitsSetFrom(Const->getBitWidth(), Log2(ActualAlign)))) { + I->replaceAllUsesWith(I->getOperand(0)); + return true; + } + } + IntrinsicInst *II = dyn_cast<IntrinsicInst>(I); if (!II) return false; diff --git a/llvm/lib/Transforms/Scalar/JumpThreading.cpp b/llvm/lib/Transforms/Scalar/JumpThreading.cpp index c7d71eb..aa8c80a 100644 --- a/llvm/lib/Transforms/Scalar/JumpThreading.cpp +++ b/llvm/lib/Transforms/Scalar/JumpThreading.cpp @@ -1920,6 +1920,13 @@ bool JumpThreadingPass::maybeMergeBasicBlockIntoOnlyPred(BasicBlock *BB) { if (Unreachable.count(SinglePred)) return false; + // Don't merge if both the basic block and the predecessor contain loop or + // entry convergent intrinsics, since there may only be one convergence token + // per block. + if (HasLoopOrEntryConvergenceToken(BB) && + HasLoopOrEntryConvergenceToken(SinglePred)) + return false; + // If SinglePred was a loop header, BB becomes one. if (LoopHeaders.erase(SinglePred)) LoopHeaders.insert(BB); diff --git a/llvm/lib/Transforms/Scalar/LICM.cpp b/llvm/lib/Transforms/Scalar/LICM.cpp index d13b990..b2c526b 100644 --- a/llvm/lib/Transforms/Scalar/LICM.cpp +++ b/llvm/lib/Transforms/Scalar/LICM.cpp @@ -211,15 +211,9 @@ static Instruction *cloneInstructionInExitBlock( static void eraseInstruction(Instruction &I, ICFLoopSafetyInfo &SafetyInfo, MemorySSAUpdater &MSSAU); -static void moveInstructionBefore( - Instruction &I, BasicBlock::iterator Dest, ICFLoopSafetyInfo &SafetyInfo, - MemorySSAUpdater &MSSAU, ScalarEvolution *SE, - MemorySSA::InsertionPlace Point = MemorySSA::BeforeTerminator); - -static bool sinkUnusedInvariantsFromPreheaderToExit( - Loop *L, AAResults *AA, ICFLoopSafetyInfo *SafetyInfo, - MemorySSAUpdater &MSSAU, ScalarEvolution *SE, DominatorTree *DT, - SinkAndHoistLICMFlags &SinkFlags, OptimizationRemarkEmitter *ORE); +static void moveInstructionBefore(Instruction &I, BasicBlock::iterator Dest, + ICFLoopSafetyInfo &SafetyInfo, + MemorySSAUpdater &MSSAU, ScalarEvolution *SE); static void foreachMemoryAccess(MemorySSA *MSSA, Loop *L, function_ref<void(Instruction *)> Fn); @@ -477,12 +471,6 @@ bool LoopInvariantCodeMotion::runOnLoop(Loop *L, AAResults *AA, LoopInfo *LI, TLI, TTI, L, MSSAU, &SafetyInfo, Flags, ORE) : sinkRegion(DT->getNode(L->getHeader()), AA, LI, DT, TLI, TTI, L, MSSAU, &SafetyInfo, Flags, ORE); - - // sink pre-header defs that are unused in-loop into the unique exit to reduce - // pressure. - Changed |= sinkUnusedInvariantsFromPreheaderToExit(L, AA, &SafetyInfo, MSSAU, - SE, DT, Flags, ORE); - Flags.setIsSink(false); if (Preheader) Changed |= hoistRegion(DT->getNode(L->getHeader()), AA, LI, DT, AC, TLI, L, @@ -1468,80 +1456,19 @@ static void eraseInstruction(Instruction &I, ICFLoopSafetyInfo &SafetyInfo, static void moveInstructionBefore(Instruction &I, BasicBlock::iterator Dest, ICFLoopSafetyInfo &SafetyInfo, - MemorySSAUpdater &MSSAU, ScalarEvolution *SE, - MemorySSA::InsertionPlace Point) { + MemorySSAUpdater &MSSAU, + ScalarEvolution *SE) { SafetyInfo.removeInstruction(&I); SafetyInfo.insertInstructionTo(&I, Dest->getParent()); I.moveBefore(*Dest->getParent(), Dest); if (MemoryUseOrDef *OldMemAcc = cast_or_null<MemoryUseOrDef>( MSSAU.getMemorySSA()->getMemoryAccess(&I))) - MSSAU.moveToPlace(OldMemAcc, Dest->getParent(), Point); + MSSAU.moveToPlace(OldMemAcc, Dest->getParent(), + MemorySSA::BeforeTerminator); if (SE) SE->forgetBlockAndLoopDispositions(&I); } -// If there's a single exit block, sink any loop-invariant values that were -// defined in the preheader but not used inside the loop into the exit block -// to reduce register pressure in the loop. -static bool sinkUnusedInvariantsFromPreheaderToExit( - Loop *L, AAResults *AA, ICFLoopSafetyInfo *SafetyInfo, - MemorySSAUpdater &MSSAU, ScalarEvolution *SE, DominatorTree *DT, - SinkAndHoistLICMFlags &SinkFlags, OptimizationRemarkEmitter *ORE) { - BasicBlock *ExitBlock = L->getExitBlock(); - if (!ExitBlock) - return false; - - BasicBlock *Preheader = L->getLoopPreheader(); - if (!Preheader) - return false; - - bool MadeAnyChanges = false; - - for (Instruction &I : llvm::make_early_inc_range(llvm::reverse(*Preheader))) { - - // Skip terminator. - if (Preheader->getTerminator() == &I) - continue; - - // New instructions were inserted at the end of the preheader. - if (isa<PHINode>(I)) - break; - - // Don't move instructions which might have side effects, since the side - // effects need to complete before instructions inside the loop. Note that - // it's okay if the instruction might have undefined behavior: LoopSimplify - // guarantees that the preheader dominates the exit block. - if (I.mayHaveSideEffects()) - continue; - - if (!canSinkOrHoistInst(I, AA, DT, L, MSSAU, true, SinkFlags, nullptr)) - continue; - - // Determine if there is a use in or before the loop (direct or - // otherwise). - bool UsedInLoopOrPreheader = false; - for (Use &U : I.uses()) { - auto *UserI = cast<Instruction>(U.getUser()); - BasicBlock *UseBB = UserI->getParent(); - if (auto *PN = dyn_cast<PHINode>(UserI)) { - UseBB = PN->getIncomingBlock(U); - } - if (UseBB == Preheader || L->contains(UseBB)) { - UsedInLoopOrPreheader = true; - break; - } - } - if (UsedInLoopOrPreheader) - continue; - - moveInstructionBefore(I, ExitBlock->getFirstInsertionPt(), *SafetyInfo, - MSSAU, SE, MemorySSA::Beginning); - MadeAnyChanges = true; - } - - return MadeAnyChanges; -} - static Instruction *sinkThroughTriviallyReplaceablePHI( PHINode *TPN, Instruction *I, LoopInfo *LI, SmallDenseMap<BasicBlock *, Instruction *, 32> &SunkCopies, diff --git a/llvm/lib/Transforms/Scalar/LoopDistribute.cpp b/llvm/lib/Transforms/Scalar/LoopDistribute.cpp index 1099aa3..ab292e8 100644 --- a/llvm/lib/Transforms/Scalar/LoopDistribute.cpp +++ b/llvm/lib/Transforms/Scalar/LoopDistribute.cpp @@ -65,7 +65,6 @@ #include <cassert> #include <list> #include <tuple> -#include <utility> using namespace llvm; @@ -521,7 +520,7 @@ public: // -1 means belonging to multiple partitions. else if (Partition == -1) break; - else if (Partition != (int)ThisPartition) + else if (Partition != ThisPartition) Partition = -1; } assert(Partition != -2 && "Pointer not belonging to any partition"); diff --git a/llvm/lib/Transforms/Scalar/LoopFuse.cpp b/llvm/lib/Transforms/Scalar/LoopFuse.cpp index 19eccb9..9ffa602 100644 --- a/llvm/lib/Transforms/Scalar/LoopFuse.cpp +++ b/llvm/lib/Transforms/Scalar/LoopFuse.cpp @@ -1796,14 +1796,16 @@ private: // mergeLatch may remove the only block in FC1. SE.forgetLoop(FC1.L); SE.forgetLoop(FC0.L); - // Forget block dispositions as well, so that there are no dangling - // pointers to erased/free'ed blocks. - SE.forgetBlockAndLoopDispositions(); // Move instructions from FC0.Latch to FC1.Latch. // Note: mergeLatch requires an updated DT. mergeLatch(FC0, FC1); + // Forget block dispositions as well, so that there are no dangling + // pointers to erased/free'ed blocks. It should be done after mergeLatch() + // since merging the latches may affect the dispositions. + SE.forgetBlockAndLoopDispositions(); + // Merge the loops. SmallVector<BasicBlock *, 8> Blocks(FC1.L->blocks()); for (BasicBlock *BB : Blocks) { @@ -2092,14 +2094,16 @@ private: // mergeLatch may remove the only block in FC1. SE.forgetLoop(FC1.L); SE.forgetLoop(FC0.L); - // Forget block dispositions as well, so that there are no dangling - // pointers to erased/free'ed blocks. - SE.forgetBlockAndLoopDispositions(); // Move instructions from FC0.Latch to FC1.Latch. // Note: mergeLatch requires an updated DT. mergeLatch(FC0, FC1); + // Forget block dispositions as well, so that there are no dangling + // pointers to erased/free'ed blocks. It should be done after mergeLatch() + // since merging the latches may affect the dispositions. + SE.forgetBlockAndLoopDispositions(); + // Merge the loops. SmallVector<BasicBlock *, 8> Blocks(FC1.L->blocks()); for (BasicBlock *BB : Blocks) { diff --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp index 019536ca..1730ec0 100644 --- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp +++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp @@ -72,6 +72,7 @@ #include "llvm/IR/Module.h" #include "llvm/IR/PassManager.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/IR/Type.h" #include "llvm/IR/User.h" #include "llvm/IR/Value.h" @@ -105,6 +106,7 @@ STATISTIC( STATISTIC(NumShiftUntilZero, "Number of uncountable loops recognized as 'shift until zero' idiom"); +namespace llvm { bool DisableLIRP::All; static cl::opt<bool, true> DisableLIRPAll("disable-" DEBUG_TYPE "-all", @@ -163,6 +165,10 @@ static cl::opt<bool> ForceMemsetPatternIntrinsic( cl::desc("Use memset.pattern intrinsic whenever possible"), cl::init(false), cl::Hidden); +extern cl::opt<bool> ProfcheckDisableMetadataFixes; + +} // namespace llvm + namespace { class LoopIdiomRecognize { @@ -297,8 +303,6 @@ PreservedAnalyses LoopIdiomRecognizePass::run(Loop &L, LoopAnalysisManager &AM, // but ORE cannot be preserved (see comment before the pass definition). OptimizationRemarkEmitter ORE(L.getHeader()->getParent()); - std::optional<PolynomialInfo> HR; - LoopIdiomRecognize LIR(&AR.AA, &AR.DT, &AR.LI, &AR.SE, &AR.TLI, &AR.TTI, AR.MSSA, DL, ORE); if (!LIR.runOnLoop(&L)) @@ -3199,7 +3203,21 @@ bool LoopIdiomRecognize::recognizeShiftUntilBitTest() { // The loop trip count check. auto *IVCheck = Builder.CreateICmpEQ(IVNext, LoopTripCount, CurLoop->getName() + ".ivcheck"); - Builder.CreateCondBr(IVCheck, SuccessorBB, LoopHeaderBB); + SmallVector<uint32_t> BranchWeights; + const bool HasBranchWeights = + !ProfcheckDisableMetadataFixes && + extractBranchWeights(*LoopHeaderBB->getTerminator(), BranchWeights); + + auto *BI = Builder.CreateCondBr(IVCheck, SuccessorBB, LoopHeaderBB); + if (HasBranchWeights) { + if (SuccessorBB == LoopHeaderBB->getTerminator()->getSuccessor(1)) + std::swap(BranchWeights[0], BranchWeights[1]); + // We're not changing the loop profile, so we can reuse the original loop's + // profile. + setBranchWeights(*BI, BranchWeights, + /*IsExpected=*/false); + } + LoopHeaderBB->getTerminator()->eraseFromParent(); // Populate the IV PHI. @@ -3368,10 +3386,10 @@ static bool detectShiftUntilZeroIdiom(Loop *CurLoop, ScalarEvolution *SE, /// %start = <...> /// %extraoffset = <...> /// <...> -/// br label %for.cond +/// br label %loop /// /// loop: -/// %iv = phi i8 [ %start, %entry ], [ %iv.next, %for.cond ] +/// %iv = phi i8 [ %start, %entry ], [ %iv.next, %loop ] /// %nbits = add nsw i8 %iv, %extraoffset /// %val.shifted = {{l,a}shr,shl} i8 %val, %nbits /// %val.shifted.iszero = icmp eq i8 %val.shifted, 0 @@ -3533,7 +3551,19 @@ bool LoopIdiomRecognize::recognizeShiftUntilZero() { // The loop terminator. Builder.SetInsertPoint(LoopHeaderBB->getTerminator()); - Builder.CreateCondBr(CIVCheck, SuccessorBB, LoopHeaderBB); + SmallVector<uint32_t> BranchWeights; + const bool HasBranchWeights = + !ProfcheckDisableMetadataFixes && + extractBranchWeights(*LoopHeaderBB->getTerminator(), BranchWeights); + + auto *BI = Builder.CreateCondBr(CIVCheck, SuccessorBB, LoopHeaderBB); + if (HasBranchWeights) { + if (InvertedCond) + std::swap(BranchWeights[0], BranchWeights[1]); + // We're not changing the loop profile, so we can reuse the original loop's + // profile. + setBranchWeights(*BI, BranchWeights, /*IsExpected=*/false); + } LoopHeaderBB->getTerminator()->eraseFromParent(); // Populate the IV PHI. diff --git a/llvm/lib/Transforms/Scalar/LoopInterchange.cpp b/llvm/lib/Transforms/Scalar/LoopInterchange.cpp index 9aaf6a5..330b4ab 100644 --- a/llvm/lib/Transforms/Scalar/LoopInterchange.cpp +++ b/llvm/lib/Transforms/Scalar/LoopInterchange.cpp @@ -1462,6 +1462,24 @@ std::optional<bool> LoopInterchangeProfitability::isProfitableForVectorization( bool LoopInterchangeProfitability::isProfitable( const Loop *InnerLoop, const Loop *OuterLoop, unsigned InnerLoopId, unsigned OuterLoopId, CharMatrix &DepMatrix, CacheCostManager &CCM) { + // Do not consider loops with a backedge that isn't taken, e.g. an + // unconditional branch true/false, as candidates for interchange. + // TODO: when interchange is forced, we should probably also allow + // interchange for these loops, and thus this logic should be moved just + // below the cost-model ignore check below. But this check is done first + // to avoid the issue in #163954. + const SCEV *InnerBTC = SE->getBackedgeTakenCount(InnerLoop); + const SCEV *OuterBTC = SE->getBackedgeTakenCount(OuterLoop); + if (InnerBTC && InnerBTC->isZero()) { + LLVM_DEBUG(dbgs() << "Inner loop back-edge isn't taken, rejecting " + "single iteration loop\n"); + return false; + } + if (OuterBTC && OuterBTC->isZero()) { + LLVM_DEBUG(dbgs() << "Outer loop back-edge isn't taken, rejecting " + "single iteration loop\n"); + return false; + } // Return true if interchange is forced and the cost-model ignored. if (Profitabilities.size() == 1 && Profitabilities[0] == RuleTy::Ignore) diff --git a/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp b/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp index a883998..1b770be 100644 --- a/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp +++ b/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp @@ -89,8 +89,8 @@ struct StoreToLoadForwardingCandidate { /// Return true if the dependence from the store to the load has an /// absolute distance of one. /// E.g. A[i+1] = A[i] (or A[i-1] = A[i] for descending loop) - bool isDependenceDistanceOfOne(PredicatedScalarEvolution &PSE, - Loop *L) const { + bool isDependenceDistanceOfOne(PredicatedScalarEvolution &PSE, Loop *L, + const DominatorTree &DT) const { Value *LoadPtr = Load->getPointerOperand(); Value *StorePtr = Store->getPointerOperand(); Type *LoadType = getLoadStoreType(Load); @@ -102,8 +102,10 @@ struct StoreToLoadForwardingCandidate { DL.getTypeSizeInBits(getLoadStoreType(Store)) && "Should be a known dependence"); - int64_t StrideLoad = getPtrStride(PSE, LoadType, LoadPtr, L).value_or(0); - int64_t StrideStore = getPtrStride(PSE, LoadType, StorePtr, L).value_or(0); + int64_t StrideLoad = + getPtrStride(PSE, LoadType, LoadPtr, L, DT).value_or(0); + int64_t StrideStore = + getPtrStride(PSE, LoadType, StorePtr, L, DT).value_or(0); if (!StrideLoad || !StrideStore || StrideLoad != StrideStore) return false; @@ -287,8 +289,8 @@ public: // so deciding which one forwards is easy. The later one forwards as // long as they both have a dependence distance of one to the load. if (Cand.Store->getParent() == OtherCand->Store->getParent() && - Cand.isDependenceDistanceOfOne(PSE, L) && - OtherCand->isDependenceDistanceOfOne(PSE, L)) { + Cand.isDependenceDistanceOfOne(PSE, L, *DT) && + OtherCand->isDependenceDistanceOfOne(PSE, L, *DT)) { // They are in the same block, the later one will forward to the load. if (getInstrIndex(OtherCand->Store) < getInstrIndex(Cand.Store)) OtherCand = &Cand; @@ -538,7 +540,7 @@ public: // Check whether the SCEV difference is the same as the induction step, // thus we load the value in the next iteration. - if (!Cand.isDependenceDistanceOfOne(PSE, L)) + if (!Cand.isDependenceDistanceOfOne(PSE, L, *DT)) continue; assert(isa<SCEVAddRecExpr>(PSE.getSCEV(Cand.Load->getPointerOperand())) && diff --git a/llvm/lib/Transforms/Scalar/LoopSimplifyCFG.cpp b/llvm/lib/Transforms/Scalar/LoopSimplifyCFG.cpp index b9546c5..e902b71 100644 --- a/llvm/lib/Transforms/Scalar/LoopSimplifyCFG.cpp +++ b/llvm/lib/Transforms/Scalar/LoopSimplifyCFG.cpp @@ -24,6 +24,7 @@ #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/Support/CommandLine.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Scalar/LoopPassManager.h" @@ -393,6 +394,17 @@ private: DTUpdates.push_back({DominatorTree::Insert, Preheader, BB}); ++NumLoopExitsDeleted; } + // We don't really need to add branch weights to DummySwitch, because all + // but one branches are just a temporary artifact - see the comment on top + // of this function. But, it's easy to estimate the weights, and it helps + // maintain a property of the overall compiler - that the branch weights + // don't "just get dropped" accidentally (i.e. profcheck) + if (DummySwitch->getParent()->getParent()->hasProfileData()) { + SmallVector<uint32_t> DummyBranchWeights(1 + DummySwitch->getNumCases()); + // default. 100% probability, the rest are dead. + DummyBranchWeights[0] = 1; + setBranchWeights(*DummySwitch, DummyBranchWeights, /*IsExpected=*/false); + } assert(L.getLoopPreheader() == NewPreheader && "Malformed CFG?"); if (Loop *OuterLoop = LI.getLoopFor(Preheader)) { diff --git a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp index 001215a..68cffd4 100644 --- a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp +++ b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp @@ -2195,8 +2195,8 @@ class LSRInstance { SmallSetVector<Instruction *, 4> InsertedNonLCSSAInsts; void OptimizeShadowIV(); - bool FindIVUserForCond(ICmpInst *Cond, IVStrideUse *&CondUse); - ICmpInst *OptimizeMax(ICmpInst *Cond, IVStrideUse* &CondUse); + bool FindIVUserForCond(Instruction *Cond, IVStrideUse *&CondUse); + Instruction *OptimizeMax(ICmpInst *Cond, IVStrideUse *&CondUse); void OptimizeLoopTermCond(); void ChainInstruction(Instruction *UserInst, Instruction *IVOper, @@ -2431,7 +2431,7 @@ void LSRInstance::OptimizeShadowIV() { /// If Cond has an operand that is an expression of an IV, set the IV user and /// stride information and return true, otherwise return false. -bool LSRInstance::FindIVUserForCond(ICmpInst *Cond, IVStrideUse *&CondUse) { +bool LSRInstance::FindIVUserForCond(Instruction *Cond, IVStrideUse *&CondUse) { for (IVStrideUse &U : IU) if (U.getUser() == Cond) { // NOTE: we could handle setcc instructions with multiple uses here, but @@ -2491,7 +2491,7 @@ bool LSRInstance::FindIVUserForCond(ICmpInst *Cond, IVStrideUse *&CondUse) { /// This function solves this problem by detecting this type of loop and /// rewriting their conditions from ICMP_NE back to ICMP_SLT, and deleting /// the instructions for the maximum computation. -ICmpInst *LSRInstance::OptimizeMax(ICmpInst *Cond, IVStrideUse* &CondUse) { +Instruction *LSRInstance::OptimizeMax(ICmpInst *Cond, IVStrideUse *&CondUse) { // Check that the loop matches the pattern we're looking for. if (Cond->getPredicate() != CmpInst::ICMP_EQ && Cond->getPredicate() != CmpInst::ICMP_NE) @@ -2635,15 +2635,22 @@ LSRInstance::OptimizeLoopTermCond() { // one register value. BranchInst *TermBr = dyn_cast<BranchInst>(ExitingBlock->getTerminator()); - if (!TermBr) + if (!TermBr || TermBr->isUnconditional()) continue; - // FIXME: Overly conservative, termination condition could be an 'or' etc.. - if (TermBr->isUnconditional() || !isa<ICmpInst>(TermBr->getCondition())) + + Instruction *Cond = dyn_cast<Instruction>(TermBr->getCondition()); + // If the argument to TermBr is an extractelement, then the source of that + // instruction is what's generated the condition. + auto *Extract = dyn_cast_or_null<ExtractElementInst>(Cond); + if (Extract) + Cond = dyn_cast<Instruction>(Extract->getVectorOperand()); + // FIXME: We could do more here, like handling logical operations where one + // side is a cmp that uses an induction variable. + if (!Cond) continue; // Search IVUsesByStride to find Cond's IVUse if there is one. IVStrideUse *CondUse = nullptr; - ICmpInst *Cond = cast<ICmpInst>(TermBr->getCondition()); if (!FindIVUserForCond(Cond, CondUse)) continue; @@ -2653,7 +2660,8 @@ LSRInstance::OptimizeLoopTermCond() { // One consequence of doing this now is that it disrupts the count-down // optimization. That's not always a bad thing though, because in such // cases it may still be worthwhile to avoid a max. - Cond = OptimizeMax(Cond, CondUse); + if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) + Cond = OptimizeMax(Cmp, CondUse); // If this exiting block dominates the latch block, it may also use // the post-inc value if it won't be shared with other uses. @@ -2718,13 +2726,14 @@ LSRInstance::OptimizeLoopTermCond() { // It's possible for the setcc instruction to be anywhere in the loop, and // possible for it to have multiple users. If it is not immediately before // the exiting block branch, move it. - if (Cond->getNextNode() != TermBr) { + if (isa_and_nonnull<CmpInst>(Cond) && Cond->getNextNode() != TermBr && + !Extract) { if (Cond->hasOneUse()) { Cond->moveBefore(TermBr->getIterator()); } else { // Clone the terminating condition and insert into the loopend. - ICmpInst *OldCond = Cond; - Cond = cast<ICmpInst>(Cond->clone()); + Instruction *OldCond = Cond; + Cond = Cond->clone(); Cond->setName(L->getHeader()->getName() + ".termcond"); Cond->insertInto(ExitingBlock, TermBr->getIterator()); @@ -5796,7 +5805,7 @@ Value *LSRInstance::Expand(const LSRUse &LU, const LSRFixup &LF, // negated immediate. if (!ICmpScaledV) ICmpScaledV = - ConstantInt::get(IntTy, -(uint64_t)Offset.getFixedValue()); + ConstantInt::getSigned(IntTy, -(uint64_t)Offset.getFixedValue()); else { Ops.push_back(SE.getUnknown(ICmpScaledV)); ICmpScaledV = ConstantInt::get(IntTy, Offset.getFixedValue()); @@ -6024,33 +6033,34 @@ void LSRInstance::Rewrite(const LSRUse &LU, const LSRFixup &LF, DeadInsts.emplace_back(OperandIsInstr); } -// Trying to hoist the IVInc to loop header if all IVInc users are in -// the loop header. It will help backend to generate post index load/store -// when the latch block is different from loop header block. -static bool canHoistIVInc(const TargetTransformInfo &TTI, const LSRFixup &Fixup, - const LSRUse &LU, Instruction *IVIncInsertPos, - Loop *L) { +// Determine where to insert the transformed IV increment instruction for this +// fixup. By default this is the default insert position, but if this is a +// postincrement opportunity then we try to insert it in the same block as the +// fixup user instruction, as this is needed for a postincrement instruction to +// be generated. +static Instruction *getFixupInsertPos(const TargetTransformInfo &TTI, + const LSRFixup &Fixup, const LSRUse &LU, + Instruction *IVIncInsertPos, + DominatorTree &DT) { + // Only address uses can be postincremented if (LU.Kind != LSRUse::Address) - return false; - - // For now this code do the conservative optimization, only work for - // the header block. Later we can hoist the IVInc to the block post - // dominate all users. - BasicBlock *LHeader = L->getHeader(); - if (IVIncInsertPos->getParent() == LHeader) - return false; - - if (!Fixup.OperandValToReplace || - any_of(Fixup.OperandValToReplace->users(), [&LHeader](User *U) { - Instruction *UI = cast<Instruction>(U); - return UI->getParent() != LHeader; - })) - return false; + return IVIncInsertPos; + // Don't try to postincrement if it's not legal Instruction *I = Fixup.UserInst; Type *Ty = I->getType(); - return (isa<LoadInst>(I) && TTI.isIndexedLoadLegal(TTI.MIM_PostInc, Ty)) || - (isa<StoreInst>(I) && TTI.isIndexedStoreLegal(TTI.MIM_PostInc, Ty)); + if (!(isa<LoadInst>(I) && TTI.isIndexedLoadLegal(TTI.MIM_PostInc, Ty)) && + !(isa<StoreInst>(I) && TTI.isIndexedStoreLegal(TTI.MIM_PostInc, Ty))) + return IVIncInsertPos; + + // It's only legal to hoist to the user block if it dominates the default + // insert position. + BasicBlock *HoistBlock = I->getParent(); + BasicBlock *IVIncBlock = IVIncInsertPos->getParent(); + if (!DT.dominates(I, IVIncBlock)) + return IVIncInsertPos; + + return HoistBlock->getTerminator(); } /// Rewrite all the fixup locations with new values, following the chosen @@ -6071,9 +6081,7 @@ void LSRInstance::ImplementSolution( for (size_t LUIdx = 0, NumUses = Uses.size(); LUIdx != NumUses; ++LUIdx) for (const LSRFixup &Fixup : Uses[LUIdx].Fixups) { Instruction *InsertPos = - canHoistIVInc(TTI, Fixup, Uses[LUIdx], IVIncInsertPos, L) - ? L->getHeader()->getTerminator() - : IVIncInsertPos; + getFixupInsertPos(TTI, Fixup, Uses[LUIdx], IVIncInsertPos, DT); Rewriter.setIVIncInsertPos(L, InsertPos); Rewrite(Uses[LUIdx], Fixup, *Solution[LUIdx], DeadInsts); Changed = true; diff --git a/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp b/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp index 2bda9d8..802ae4e 100644 --- a/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp +++ b/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp @@ -1327,7 +1327,8 @@ tryToUnrollLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, ScalarEvolution &SE, } // Do not attempt partial/runtime unrolling in FullLoopUnrolling - if (OnlyFullUnroll && (UP.Count < TripCount || UP.Count < MaxTripCount)) { + if (OnlyFullUnroll && ((!TripCount && !MaxTripCount) || + UP.Count < TripCount || UP.Count < MaxTripCount)) { LLVM_DEBUG( dbgs() << "Not attempting partial/runtime unroll in FullLoopUnroll.\n"); return LoopUnrollResult::Unmodified; diff --git a/llvm/lib/Transforms/Scalar/LoopVersioningLICM.cpp b/llvm/lib/Transforms/Scalar/LoopVersioningLICM.cpp index f3e6cbf..3aed643 100644 --- a/llvm/lib/Transforms/Scalar/LoopVersioningLICM.cpp +++ b/llvm/lib/Transforms/Scalar/LoopVersioningLICM.cpp @@ -413,7 +413,7 @@ bool LoopVersioningLICM::legalLoopInstructions() { LLVM_DEBUG(dbgs() << " Found a read-only loop!\n"); return false; } - // Profitablity check: + // Profitability check: // Check invariant threshold, should be in limit. if (InvariantCounter * 100 < InvariantThreshold * LoadAndStoreCounter) { LLVM_DEBUG( diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index 3487e81..7e70ba2 100644 --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -245,11 +245,14 @@ raw_ostream &operator<<(raw_ostream &OS, ShapeInfo SI) { } // namespace -static bool isUniformShape(Value *V) { +static bool isShapePreserving(Value *V) { Instruction *I = dyn_cast<Instruction>(V); if (!I) return true; + if (isa<SelectInst>(I)) + return true; + if (I->isBinaryOp()) return true; @@ -300,6 +303,16 @@ static bool isUniformShape(Value *V) { } } +/// Return an iterator over the operands of \p I that should share shape +/// information with \p I. +static iterator_range<Use *> getShapedOperandsForInst(Instruction *I) { + assert(isShapePreserving(I) && + "Can't retrieve shaped operands for an instruction that does not " + "preserve shape information"); + auto Ops = I->operands(); + return isa<SelectInst>(I) ? drop_begin(Ops) : Ops; +} + /// Return the ShapeInfo for the result of \p I, it it can be determined. static std::optional<ShapeInfo> computeShapeInfoForInst(Instruction *I, @@ -329,9 +342,8 @@ computeShapeInfoForInst(Instruction *I, return OpShape->second; } - if (isUniformShape(I) || isa<SelectInst>(I)) { - auto Ops = I->operands(); - auto ShapedOps = isa<SelectInst>(I) ? drop_begin(Ops) : Ops; + if (isShapePreserving(I)) { + auto ShapedOps = getShapedOperandsForInst(I); // Find the first operand that has a known shape and use that. for (auto &Op : ShapedOps) { auto OpShape = ShapeMap.find(Op.get()); @@ -710,10 +722,9 @@ public: case Intrinsic::matrix_column_major_store: return true; default: - return isUniformShape(II); + break; } - return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V) || - isa<SelectInst>(V); + return isShapePreserving(V) || isa<StoreInst>(V) || isa<LoadInst>(V); } /// Propagate the shape information of instructions to their users. @@ -800,9 +811,8 @@ public: } else if (isa<StoreInst>(V)) { // Nothing to do. We forward-propagated to this so we would just // backward propagate to an instruction with an already known shape. - } else if (isUniformShape(V) || isa<SelectInst>(V)) { - auto Ops = cast<Instruction>(V)->operands(); - auto ShapedOps = isa<SelectInst>(V) ? drop_begin(Ops) : Ops; + } else if (isShapePreserving(V)) { + auto ShapedOps = getShapedOperandsForInst(cast<Instruction>(V)); // Propagate to all operands. ShapeInfo Shape = ShapeMap[V]; for (Use &U : ShapedOps) { diff --git a/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp b/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp index 08be5df..db2afe26 100644 --- a/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp +++ b/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp @@ -47,6 +47,7 @@ #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/IR/PassManager.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/IR/Type.h" #include "llvm/IR/User.h" #include "llvm/IR/Value.h" @@ -1366,6 +1367,10 @@ bool MemCpyOptPass::processMemSetMemCpyDependence(MemCpyInst *MemCpy, Value *SizeDiff = Builder.CreateSub(DestSize, SrcSize); Value *MemsetLen = Builder.CreateSelect( Ule, ConstantInt::getNullValue(DestSize->getType()), SizeDiff); + // FIXME (#167968): we could explore estimating the branch_weights based on + // value profiling data about the 2 sizes. + if (auto *SI = dyn_cast<SelectInst>(MemsetLen)) + setExplicitlyUnknownBranchWeightsIfProfiled(*SI, DEBUG_TYPE); Instruction *NewMemSet = Builder.CreateMemSet(Builder.CreatePtrAdd(Dest, SrcSize), MemSet->getOperand(1), MemsetLen, Alignment); diff --git a/llvm/lib/Transforms/Scalar/MergeICmps.cpp b/llvm/lib/Transforms/Scalar/MergeICmps.cpp index f273e9d..d4358c1 100644 --- a/llvm/lib/Transforms/Scalar/MergeICmps.cpp +++ b/llvm/lib/Transforms/Scalar/MergeICmps.cpp @@ -50,8 +50,9 @@ #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" -#include "llvm/IR/Instruction.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" #include "llvm/Transforms/Scalar.h" @@ -66,6 +67,9 @@ using namespace llvm; #define DEBUG_TYPE "mergeicmps" +namespace llvm { +extern cl::opt<bool> ProfcheckDisableMetadataFixes; +} // namespace llvm namespace { // A BCE atom "Binary Compare Expression Atom" represents an integer load @@ -607,6 +611,37 @@ private: }; } // namespace +/// Determine the branch weights for the resulting conditional branch, resulting +/// after merging \p Comparisons. +static std::optional<SmallVector<uint32_t, 2>> +computeMergedBranchWeights(ArrayRef<BCECmpBlock> Comparisons) { + assert(!Comparisons.empty()); + if (ProfcheckDisableMetadataFixes) + return std::nullopt; + if (Comparisons.size() == 1) { + SmallVector<uint32_t, 2> Weights; + if (!extractBranchWeights(*Comparisons[0].BB->getTerminator(), Weights)) + return std::nullopt; + return Weights; + } + // The probability to go to the phi block is the disjunction of the + // probability to go to the phi block from the individual Comparisons. We'll + // swap the weights because `getDisjunctionWeights` computes the disjunction + // for the "true" branch, then swap back. + SmallVector<uint64_t, 2> Weights{0, 1}; + // At this point, Weights encodes "0-probability" for the "true" side. + for (const auto &C : Comparisons) { + SmallVector<uint32_t, 2> W; + if (!extractBranchWeights(*C.BB->getTerminator(), W)) + return std::nullopt; + + std::swap(W[0], W[1]); + Weights = getDisjunctionWeights(Weights, W); + } + std::swap(Weights[0], Weights[1]); + return fitWeights(Weights); +} + // Merges the given contiguous comparison blocks into one memcmp block. static BasicBlock *mergeComparisons(ArrayRef<BCECmpBlock> Comparisons, BasicBlock *const InsertBefore, @@ -640,7 +675,7 @@ static BasicBlock *mergeComparisons(ArrayRef<BCECmpBlock> Comparisons, // If there is one block that requires splitting, we do it now, i.e. // just before we know we will collapse the chain. The instructions // can be executed before any of the instructions in the chain. - const auto ToSplit = llvm::find_if( + const auto *ToSplit = llvm::find_if( Comparisons, [](const BCECmpBlock &B) { return B.RequireSplit; }); if (ToSplit != Comparisons.end()) { LLVM_DEBUG(dbgs() << "Splitting non_BCE work to header\n"); @@ -655,6 +690,7 @@ static BasicBlock *mergeComparisons(ArrayRef<BCECmpBlock> Comparisons, LhsLoad->replaceUsesOfWith(LhsLoad->getOperand(0), Lhs); RhsLoad->replaceUsesOfWith(RhsLoad->getOperand(0), Rhs); // There are no blocks to merge, just do the comparison. + // If we condition on this IsEqual, we already have its probabilities. IsEqual = Builder.CreateICmpEQ(LhsLoad, RhsLoad); } else { const unsigned TotalSizeBits = std::accumulate( @@ -684,7 +720,9 @@ static BasicBlock *mergeComparisons(ArrayRef<BCECmpBlock> Comparisons, DTU.applyUpdates({{DominatorTree::Insert, BB, PhiBB}}); } else { // Continue to next block if equal, exit to phi else. - Builder.CreateCondBr(IsEqual, NextCmpBlock, PhiBB); + auto *BI = Builder.CreateCondBr(IsEqual, NextCmpBlock, PhiBB); + if (auto BranchWeights = computeMergedBranchWeights(Comparisons)) + setBranchWeights(*BI, BranchWeights.value(), /*IsExpected=*/false); Phi.addIncoming(ConstantInt::getFalse(Context), BB); DTU.applyUpdates({{DominatorTree::Insert, BB, NextCmpBlock}, {DominatorTree::Insert, BB, PhiBB}}); diff --git a/llvm/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp b/llvm/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp index 2b50ccd..bcedc1a 100644 --- a/llvm/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp +++ b/llvm/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp @@ -19,6 +19,7 @@ #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/MDBuilder.h" #include "llvm/InitializePasses.h" #include "llvm/Support/DebugCounter.h" #include "llvm/Transforms/Scalar.h" @@ -27,6 +28,10 @@ using namespace llvm; +namespace llvm { +extern cl::opt<bool> ProfcheckDisableMetadataFixes; +} // namespace llvm + #define DEBUG_TYPE "partially-inline-libcalls" DEBUG_COUNTER(PILCounter, "partially-inline-libcalls-transform", @@ -94,7 +99,14 @@ static bool optimizeSQRT(CallInst *Call, Function *CalledFunc, : Builder.CreateFCmpOGE(Call->getOperand(0), ConstantFP::get(Ty, 0.0)); CurrBBTerm->setCondition(FCmp); - + if (!ProfcheckDisableMetadataFixes && + CurrBBTerm->getFunction()->getEntryCount()) { + // Presume the quick path - where we don't call the library call - is the + // frequent one + MDBuilder MDB(CurrBBTerm->getContext()); + CurrBBTerm->setMetadata(LLVMContext::MD_prof, + MDB.createLikelyBranchWeights()); + } // Add phi operands. Phi->addIncoming(Call, &CurrBB); Phi->addIncoming(LibCall, LibCallBB); diff --git a/llvm/lib/Transforms/Scalar/SROA.cpp b/llvm/lib/Transforms/Scalar/SROA.cpp index 5c60fad..3a70830 100644 --- a/llvm/lib/Transforms/Scalar/SROA.cpp +++ b/llvm/lib/Transforms/Scalar/SROA.cpp @@ -2178,35 +2178,6 @@ static bool isVectorPromotionViableForSlice(Partition &P, const Slice &S, return true; } -/// Test whether a vector type is viable for promotion. -/// -/// This implements the necessary checking for \c checkVectorTypesForPromotion -/// (and thus isVectorPromotionViable) over all slices of the alloca for the -/// given VectorType. -static bool checkVectorTypeForPromotion(Partition &P, VectorType *VTy, - const DataLayout &DL, unsigned VScale) { - uint64_t ElementSize = - DL.getTypeSizeInBits(VTy->getElementType()).getFixedValue(); - - // While the definition of LLVM vectors is bitpacked, we don't support sizes - // that aren't byte sized. - if (ElementSize % 8) - return false; - assert((DL.getTypeSizeInBits(VTy).getFixedValue() % 8) == 0 && - "vector size not a multiple of element size?"); - ElementSize /= 8; - - for (const Slice &S : P) - if (!isVectorPromotionViableForSlice(P, S, VTy, ElementSize, DL, VScale)) - return false; - - for (const Slice *S : P.splitSliceTails()) - if (!isVectorPromotionViableForSlice(P, *S, VTy, ElementSize, DL, VScale)) - return false; - - return true; -} - /// Test whether any vector type in \p CandidateTys is viable for promotion. /// /// This implements the necessary checking for \c isVectorPromotionViable over @@ -2291,11 +2262,30 @@ checkVectorTypesForPromotion(Partition &P, const DataLayout &DL, std::numeric_limits<unsigned short>::max(); }); - for (VectorType *VTy : CandidateTys) - if (checkVectorTypeForPromotion(P, VTy, DL, VScale)) - return VTy; + // Find a vector type viable for promotion by iterating over all slices. + auto *VTy = llvm::find_if(CandidateTys, [&](VectorType *VTy) -> bool { + uint64_t ElementSize = + DL.getTypeSizeInBits(VTy->getElementType()).getFixedValue(); - return nullptr; + // While the definition of LLVM vectors is bitpacked, we don't support sizes + // that aren't byte sized. + if (ElementSize % 8) + return false; + assert((DL.getTypeSizeInBits(VTy).getFixedValue() % 8) == 0 && + "vector size not a multiple of element size?"); + ElementSize /= 8; + + for (const Slice &S : P) + if (!isVectorPromotionViableForSlice(P, S, VTy, ElementSize, DL, VScale)) + return false; + + for (const Slice *S : P.splitSliceTails()) + if (!isVectorPromotionViableForSlice(P, *S, VTy, ElementSize, DL, VScale)) + return false; + + return true; + }); + return VTy != CandidateTys.end() ? *VTy : nullptr; } static VectorType *createAndCheckVectorTypesForPromotion( @@ -3150,7 +3140,6 @@ private: assert(IsSplit || BeginOffset == NewBeginOffset); uint64_t Offset = NewBeginOffset - NewAllocaBeginOffset; -#ifndef NDEBUG StringRef OldName = OldPtr->getName(); // Skip through the last '.sroa.' component of the name. size_t LastSROAPrefix = OldName.rfind(".sroa."); @@ -3169,17 +3158,10 @@ private: } // Strip any SROA suffixes as well. OldName = OldName.substr(0, OldName.find(".sroa_")); -#endif return getAdjustedPtr(IRB, DL, &NewAI, APInt(DL.getIndexTypeSizeInBits(PointerTy), Offset), - PointerTy, -#ifndef NDEBUG - Twine(OldName) + "." -#else - Twine() -#endif - ); + PointerTy, Twine(OldName) + "."); } /// Compute suitable alignment to access this slice of the *new* @@ -5213,7 +5195,6 @@ AllocaInst *SROA::rewritePartition(AllocaInst &AI, AllocaSlices &AS, // won't always succeed, in which case we fall back to a legal integer type // or an i8 array of an appropriate size. Type *SliceTy = nullptr; - VectorType *SliceVecTy = nullptr; const DataLayout &DL = AI.getDataLayout(); unsigned VScale = AI.getFunction()->getVScaleValue(); @@ -5222,10 +5203,8 @@ AllocaInst *SROA::rewritePartition(AllocaInst &AI, AllocaSlices &AS, // Do all uses operate on the same type? if (CommonUseTy.first) { TypeSize CommonUseSize = DL.getTypeAllocSize(CommonUseTy.first); - if (CommonUseSize.isFixed() && CommonUseSize.getFixedValue() >= P.size()) { + if (CommonUseSize.isFixed() && CommonUseSize.getFixedValue() >= P.size()) SliceTy = CommonUseTy.first; - SliceVecTy = dyn_cast<VectorType>(SliceTy); - } } // If not, can we find an appropriate subtype in the original allocated type? if (!SliceTy) @@ -5235,27 +5214,14 @@ AllocaInst *SROA::rewritePartition(AllocaInst &AI, AllocaSlices &AS, // If still not, can we use the largest bitwidth integer type used? if (!SliceTy && CommonUseTy.second) - if (DL.getTypeAllocSize(CommonUseTy.second).getFixedValue() >= P.size()) { + if (DL.getTypeAllocSize(CommonUseTy.second).getFixedValue() >= P.size()) SliceTy = CommonUseTy.second; - SliceVecTy = dyn_cast<VectorType>(SliceTy); - } if ((!SliceTy || (SliceTy->isArrayTy() && SliceTy->getArrayElementType()->isIntegerTy())) && DL.isLegalInteger(P.size() * 8)) { SliceTy = Type::getIntNTy(*C, P.size() * 8); } - // If the common use types are not viable for promotion then attempt to find - // another type that is viable. - if (SliceVecTy && !checkVectorTypeForPromotion(P, SliceVecTy, DL, VScale)) - if (Type *TypePartitionTy = getTypePartition(DL, AI.getAllocatedType(), - P.beginOffset(), P.size())) { - VectorType *TypePartitionVecTy = dyn_cast<VectorType>(TypePartitionTy); - if (TypePartitionVecTy && - checkVectorTypeForPromotion(P, TypePartitionVecTy, DL, VScale)) - SliceTy = TypePartitionTy; - } - if (!SliceTy) SliceTy = ArrayType::get(Type::getInt8Ty(*C), P.size()); assert(DL.getTypeAllocSize(SliceTy).getFixedValue() >= P.size()); diff --git a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp index 146e7d1..b7b08ae 100644 --- a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp +++ b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp @@ -1123,7 +1123,10 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT, if (TTI.isLegalMaskedLoad( CI->getType(), CI->getParamAlign(0).valueOrOne(), cast<PointerType>(CI->getArgOperand(0)->getType()) - ->getAddressSpace())) + ->getAddressSpace(), + isConstantIntVector(CI->getArgOperand(1)) + ? TTI::MaskKind::ConstantMask + : TTI::MaskKind::VariableOrConstantMask)) return false; scalarizeMaskedLoad(DL, HasBranchDivergence, CI, DTU, ModifiedDT); return true; @@ -1132,7 +1135,10 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT, CI->getArgOperand(0)->getType(), CI->getParamAlign(1).valueOrOne(), cast<PointerType>(CI->getArgOperand(1)->getType()) - ->getAddressSpace())) + ->getAddressSpace(), + isConstantIntVector(CI->getArgOperand(2)) + ? TTI::MaskKind::ConstantMask + : TTI::MaskKind::VariableOrConstantMask)) return false; scalarizeMaskedStore(DL, HasBranchDivergence, CI, DTU, ModifiedDT); return true; diff --git a/llvm/lib/Transforms/Scalar/Scalarizer.cpp b/llvm/lib/Transforms/Scalar/Scalarizer.cpp index 25a531c..46f92c3 100644 --- a/llvm/lib/Transforms/Scalar/Scalarizer.cpp +++ b/llvm/lib/Transforms/Scalar/Scalarizer.cpp @@ -252,7 +252,14 @@ static Value *concatenate(IRBuilder<> &Builder, ArrayRef<Value *> Fragments, Res = Builder.CreateInsertElement(Res, Fragment, I * VS.NumPacked, Name + ".upto" + Twine(I)); } else { - Fragment = Builder.CreateShuffleVector(Fragment, Fragment, ExtendMask); + if (NumPacked < VS.NumPacked) { + // If last pack of remained bits not match current ExtendMask size. + ExtendMask.truncate(NumPacked); + ExtendMask.resize(NumElements, -1); + } + + Fragment = Builder.CreateShuffleVector( + Fragment, PoisonValue::get(Fragment->getType()), ExtendMask); if (I == 0) { Res = Fragment; } else { diff --git a/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp b/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp index bb6c879..7e8cc03 100644 --- a/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp +++ b/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp @@ -40,6 +40,7 @@ #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/MDBuilder.h" #include "llvm/IR/Module.h" #include "llvm/IR/PatternMatch.h" #include "llvm/IR/ProfDataUtils.h" @@ -329,15 +330,14 @@ static void buildPartialUnswitchConditionalBranch( HasBranchWeights ? ComputeProfFrom.getMetadata(LLVMContext::MD_prof) : nullptr); if (!HasBranchWeights) - setExplicitlyUnknownBranchWeightsIfProfiled( - *BR, *BR->getParent()->getParent(), DEBUG_TYPE); + setExplicitlyUnknownBranchWeightsIfProfiled(*BR, DEBUG_TYPE); } /// Copy a set of loop invariant values, and conditionally branch on them. static void buildPartialInvariantUnswitchConditionalBranch( BasicBlock &BB, ArrayRef<Value *> ToDuplicate, bool Direction, BasicBlock &UnswitchedSucc, BasicBlock &NormalSucc, Loop &L, - MemorySSAUpdater *MSSAU) { + MemorySSAUpdater *MSSAU, const BranchInst &OriginalBranch) { ValueToValueMapTy VMap; for (auto *Val : reverse(ToDuplicate)) { Instruction *Inst = cast<Instruction>(Val); @@ -377,8 +377,18 @@ static void buildPartialInvariantUnswitchConditionalBranch( IRBuilder<> IRB(&BB); IRB.SetCurrentDebugLocation(DebugLoc::getCompilerGenerated()); Value *Cond = VMap[ToDuplicate[0]]; - IRB.CreateCondBr(Cond, Direction ? &UnswitchedSucc : &NormalSucc, - Direction ? &NormalSucc : &UnswitchedSucc); + // The expectation is that ToDuplicate[0] is the condition used by the + // OriginalBranch, case in which we can clone the profile metadata from there. + auto *ProfData = + !ProfcheckDisableMetadataFixes && + ToDuplicate[0] == skipTrivialSelect(OriginalBranch.getCondition()) + ? OriginalBranch.getMetadata(LLVMContext::MD_prof) + : nullptr; + auto *BR = + IRB.CreateCondBr(Cond, Direction ? &UnswitchedSucc : &NormalSucc, + Direction ? &NormalSucc : &UnswitchedSucc, ProfData); + if (!ProfData) + setExplicitlyUnknownBranchWeightsIfProfiled(*BR, DEBUG_TYPE); } /// Rewrite the PHI nodes in an unswitched loop exit basic block. @@ -2515,7 +2525,7 @@ static void unswitchNontrivialInvariants( // the branch in the split block. if (PartiallyInvariant) buildPartialInvariantUnswitchConditionalBranch( - *SplitBB, Invariants, Direction, *ClonedPH, *LoopPH, L, MSSAU); + *SplitBB, Invariants, Direction, *ClonedPH, *LoopPH, L, MSSAU, *BI); else { buildPartialUnswitchConditionalBranch( *SplitBB, Invariants, Direction, *ClonedPH, *LoopPH, @@ -2820,9 +2830,14 @@ static BranchInst *turnGuardIntoBranch(IntrinsicInst *GI, Loop &L, MSSAU->getMemorySSA()->verifyMemorySSA(); DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager); - Instruction *DeoptBlockTerm = - SplitBlockAndInsertIfThen(GI->getArgOperand(0), GI, true, - GI->getMetadata(LLVMContext::MD_prof), &DTU, &LI); + // llvm.experimental.guard doesn't have branch weights. We can assume, + // however, that the deopt path is unlikely. + Instruction *DeoptBlockTerm = SplitBlockAndInsertIfThen( + GI->getArgOperand(0), GI, true, + !ProfcheckDisableMetadataFixes && EstimateProfile + ? MDBuilder(GI->getContext()).createUnlikelyBranchWeights() + : nullptr, + &DTU, &LI); BranchInst *CheckBI = cast<BranchInst>(CheckBB->getTerminator()); // SplitBlockAndInsertIfThen inserts control flow that branches to // DeoptBlockTerm if the condition is true. We want the opposite. @@ -2899,8 +2914,8 @@ static int CalculateUnswitchCostMultiplier( ParentLoopSizeMultiplier = std::max<int>(ParentL->getNumBlocks() / UnswitchParentBlocksDiv, 1); - int SiblingsCount = (ParentL ? ParentL->getSubLoopsVector().size() - : std::distance(LI.begin(), LI.end())); + int SiblingsCount = + (ParentL ? ParentL->getSubLoopsVector().size() : llvm::size(LI)); // Count amount of clones that all the candidates might cause during // unswitching. Branch/guard/select counts as 1, switch counts as log2 of its // cases. @@ -3186,10 +3201,14 @@ injectPendingInvariantConditions(NonTrivialUnswitchCandidate Candidate, Loop &L, Builder.SetInsertPoint(TI); auto *InvariantBr = Builder.CreateCondBr(InjectedCond, InLoopSucc, CheckBlock); + // We don't know anything about the relation between the limits. + setExplicitlyUnknownBranchWeightsIfProfiled(*InvariantBr, DEBUG_TYPE); Builder.SetInsertPoint(CheckBlock); - Builder.CreateCondBr(TI->getCondition(), TI->getSuccessor(0), - TI->getSuccessor(1)); + Builder.CreateCondBr( + TI->getCondition(), TI->getSuccessor(0), TI->getSuccessor(1), + !ProfcheckDisableMetadataFixes ? TI->getMetadata(LLVMContext::MD_prof) + : nullptr); TI->eraseFromParent(); // Fixup phis. diff --git a/llvm/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp index e94ad19..2ad6f7e 100644 --- a/llvm/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp +++ b/llvm/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp @@ -12,17 +12,16 @@ // effective in simplifying arithmetic statements derived from an unrolled loop. // It can also simplify the logic of SeparateConstOffsetFromGEP. // -// There are many optimizations we can perform in the domain of SLSR. This file -// for now contains only an initial step. Specifically, we look for strength -// reduction candidates in the following forms: +// There are many optimizations we can perform in the domain of SLSR. +// We look for strength reduction candidates in the following forms: // -// Form 1: B + i * S -// Form 2: (B + i) * S -// Form 3: &B[i * S] +// Form Add: B + i * S +// Form Mul: (B + i) * S +// Form GEP: &B[i * S] // // where S is an integer variable, and i is a constant integer. If we found two // candidates S1 and S2 in the same form and S1 dominates S2, we may rewrite S2 -// in a simpler way with respect to S1. For example, +// in a simpler way with respect to S1 (index delta). For example, // // S1: X = B + i * S // S2: Y = B + i' * S => X + (i' - i) * S @@ -35,8 +34,29 @@ // // Note: (i' - i) * S is folded to the extent possible. // +// For Add and GEP forms, we can also rewrite a candidate in a simpler way +// with respect to other dominating candidates if their B or S are different +// but other parts are the same. For example, +// +// Base Delta: +// S1: X = B + i * S +// S2: Y = B' + i * S => X + (B' - B) +// +// S1: X = &B [i * S] +// S2: Y = &B'[i * S] => X + (B' - B) +// +// Stride Delta: +// S1: X = B + i * S +// S2: Y = B + i * S' => X + i * (S' - S) +// +// S1: X = &B[i * S] +// S2: Y = &B[i * S'] => X + i * (S' - S) +// +// PS: Stride delta rewrite on Mul form is usually non-profitable, and Base +// delta rewrite sometimes is profitable, so we do not support them on Mul. +// // This rewriting is in general a good idea. The code patterns we focus on -// usually come from loop unrolling, so (i' - i) * S is likely the same +// usually come from loop unrolling, so the delta is likely the same // across iterations and can be reused. When that happens, the optimized form // takes only one add starting from the second iteration. // @@ -47,19 +67,14 @@ // TODO: // // - Floating point arithmetics when fast math is enabled. -// -// - SLSR may decrease ILP at the architecture level. Targets that are very -// sensitive to ILP may want to disable it. Having SLSR to consider ILP is -// left as future work. -// -// - When (i' - i) is constant but i and i' are not, we could still perform -// SLSR. #include "llvm/Transforms/Scalar/StraightLineStrengthReduce.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Constants.h" @@ -86,16 +101,24 @@ #include <cstdint> #include <limits> #include <list> +#include <queue> #include <vector> using namespace llvm; using namespace PatternMatch; +#define DEBUG_TYPE "slsr" + static const unsigned UnknownAddressSpace = std::numeric_limits<unsigned>::max(); DEBUG_COUNTER(StraightLineStrengthReduceCounter, "slsr-counter", - "Controls whether rewriteCandidateWithBasis is executed."); + "Controls whether rewriteCandidate is executed."); + +// Only for testing. +static cl::opt<bool> + EnablePoisonReuseGuard("enable-poison-reuse-guard", cl::init(true), + cl::desc("Enable poison-reuse guard")); namespace { @@ -142,15 +165,23 @@ public: GEP, // &B[..][i * S][..] }; + enum DKind { + InvalidDelta, // reserved for the default constructor + IndexDelta, // Delta is a constant from Index + BaseDelta, // Delta is a constant or variable from Base + StrideDelta, // Delta is a constant or variable from Stride + }; + Candidate() = default; Candidate(Kind CT, const SCEV *B, ConstantInt *Idx, Value *S, - Instruction *I) - : CandidateKind(CT), Base(B), Index(Idx), Stride(S), Ins(I) {} + Instruction *I, const SCEV *StrideSCEV) + : CandidateKind(CT), Base(B), Index(Idx), Stride(S), Ins(I), + StrideSCEV(StrideSCEV) {} Kind CandidateKind = Invalid; const SCEV *Base = nullptr; - + // TODO: Swap Index and Stride's name. // Note that Index and Stride of a GEP candidate do not necessarily have the // same integer type. In that case, during rewriting, Stride will be // sign-extended or truncated to Index's type. @@ -177,22 +208,164 @@ public: // Points to the immediate basis of this candidate, or nullptr if we cannot // find any basis for this candidate. Candidate *Basis = nullptr; + + DKind DeltaKind = InvalidDelta; + + // Store SCEV of Stride to compute delta from different strides + const SCEV *StrideSCEV = nullptr; + + // Points to (Y - X) that will be used to rewrite this candidate. + Value *Delta = nullptr; + + /// Cost model: Evaluate the computational efficiency of the candidate. + /// + /// Efficiency levels (higher is better): + /// ZeroInst (5) - [Variable] or [Const] + /// OneInstOneVar (4) - [Variable + Const] or [Variable * Const] + /// OneInstTwoVar (3) - [Variable + Variable] or [Variable * Variable] + /// TwoInstOneVar (2) - [Const + Const * Variable] + /// TwoInstTwoVar (1) - [Variable + Const * Variable] + enum EfficiencyLevel : unsigned { + Unknown = 0, + TwoInstTwoVar = 1, + TwoInstOneVar = 2, + OneInstTwoVar = 3, + OneInstOneVar = 4, + ZeroInst = 5 + }; + + static EfficiencyLevel + getComputationEfficiency(Kind CandidateKind, const ConstantInt *Index, + const Value *Stride, const SCEV *Base = nullptr) { + bool IsConstantBase = false; + bool IsZeroBase = false; + // When evaluating the efficiency of a rewrite, if the Base's SCEV is + // not available, conservatively assume the base is not constant. + if (auto *ConstBase = dyn_cast_or_null<SCEVConstant>(Base)) { + IsConstantBase = true; + IsZeroBase = ConstBase->getValue()->isZero(); + } + + bool IsConstantStride = isa<ConstantInt>(Stride); + bool IsZeroStride = + IsConstantStride && cast<ConstantInt>(Stride)->isZero(); + // All constants + if (IsConstantBase && IsConstantStride) + return ZeroInst; + + // (Base + Index) * Stride + if (CandidateKind == Mul) { + if (IsZeroStride) + return ZeroInst; + if (Index->isZero()) + return (IsConstantStride || IsConstantBase) ? OneInstOneVar + : OneInstTwoVar; + + if (IsConstantBase) + return IsZeroBase && (Index->isOne() || Index->isMinusOne()) + ? ZeroInst + : OneInstOneVar; + + if (IsConstantStride) { + auto *CI = cast<ConstantInt>(Stride); + return (CI->isOne() || CI->isMinusOne()) ? OneInstOneVar + : TwoInstOneVar; + } + return TwoInstTwoVar; + } + + // Base + Index * Stride + assert(CandidateKind == Add || CandidateKind == GEP); + if (Index->isZero() || IsZeroStride) + return ZeroInst; + + bool IsSimpleIndex = Index->isOne() || Index->isMinusOne(); + + if (IsConstantBase) + return IsZeroBase ? (IsSimpleIndex ? ZeroInst : OneInstOneVar) + : (IsSimpleIndex ? OneInstOneVar : TwoInstOneVar); + + if (IsConstantStride) + return IsZeroStride ? ZeroInst : OneInstOneVar; + + if (IsSimpleIndex) + return OneInstTwoVar; + + return TwoInstTwoVar; + } + + // Evaluate if the given delta is profitable to rewrite this candidate. + bool isProfitableRewrite(const Value &Delta, const DKind DeltaKind) const { + // This function cannot accurately evaluate the profit of whole expression + // with context. A candidate (B + I * S) cannot express whether this + // instruction needs to compute on its own (I * S), which may be shared + // with other candidates or may need instructions to compute. + // If the rewritten form has the same strength, still rewrite to + // (X + Delta) since it may expose more CSE opportunities on Delta, as + // unrolled loops usually have identical Delta for each unrolled body. + // + // Note, this function should only be used on Index Delta rewrite. + // Base and Stride delta need context info to evaluate the register + // pressure impact from variable delta. + return getComputationEfficiency(CandidateKind, Index, Stride, Base) <= + getRewriteEfficiency(Delta, DeltaKind); + } + + // Evaluate the rewrite efficiency of this candidate with its Basis + EfficiencyLevel getRewriteEfficiency() const { + return Basis ? getRewriteEfficiency(*Delta, DeltaKind) : Unknown; + } + + // Evaluate the rewrite efficiency of this candidate with a given delta + EfficiencyLevel getRewriteEfficiency(const Value &Delta, + const DKind DeltaKind) const { + switch (DeltaKind) { + case BaseDelta: // [X + Delta] + return getComputationEfficiency( + CandidateKind, + ConstantInt::get(cast<IntegerType>(Delta.getType()), 1), &Delta); + case StrideDelta: // [X + Index * Delta] + return getComputationEfficiency(CandidateKind, Index, &Delta); + case IndexDelta: // [X + Delta * Stride] + return getComputationEfficiency(CandidateKind, + cast<ConstantInt>(&Delta), Stride); + default: + return Unknown; + } + } + + bool isHighEfficiency() const { + return getComputationEfficiency(CandidateKind, Index, Stride, Base) >= + OneInstOneVar; + } + + // Verify that this candidate has valid delta components relative to the + // basis + bool hasValidDelta(const Candidate &Basis) const { + switch (DeltaKind) { + case IndexDelta: + // Index differs, Base and Stride must match + return Base == Basis.Base && StrideSCEV == Basis.StrideSCEV; + case StrideDelta: + // Stride differs, Base and Index must match + return Base == Basis.Base && Index == Basis.Index; + case BaseDelta: + // Base differs, Stride and Index must match + return StrideSCEV == Basis.StrideSCEV && Index == Basis.Index; + default: + return false; + } + } }; bool runOnFunction(Function &F); private: - // Returns true if Basis is a basis for C, i.e., Basis dominates C and they - // share the same base and stride. - bool isBasisFor(const Candidate &Basis, const Candidate &C); - + // Fetch straight-line basis for rewriting C, update C.Basis to point to it, + // and store the delta between C and its Basis in C.Delta. + void setBasisAndDeltaFor(Candidate &C); // Returns whether the candidate can be folded into an addressing mode. - bool isFoldable(const Candidate &C, TargetTransformInfo *TTI, - const DataLayout *DL); - - // Returns true if C is already in a simplest form and not worth being - // rewritten. - bool isSimplestForm(const Candidate &C); + bool isFoldable(const Candidate &C, TargetTransformInfo *TTI); // Checks whether I is in a candidate form. If so, adds all the matching forms // to Candidates, and tries to find the immediate basis for each of them. @@ -216,12 +389,6 @@ private: // Allocate candidates and find bases for GetElementPtr instructions. void allocateCandidatesAndFindBasisForGEP(GetElementPtrInst *GEP); - // A helper function that scales Idx with ElementSize before invoking - // allocateCandidatesAndFindBasis. - void allocateCandidatesAndFindBasisForGEP(const SCEV *B, ConstantInt *Idx, - Value *S, uint64_t ElementSize, - Instruction *I); - // Adds the given form <CT, B, Idx, S> to Candidates, and finds its immediate // basis. void allocateCandidatesAndFindBasis(Candidate::Kind CT, const SCEV *B, @@ -229,13 +396,7 @@ private: Instruction *I); // Rewrites candidate C with respect to Basis. - void rewriteCandidateWithBasis(const Candidate &C, const Candidate &Basis); - - // A helper function that factors ArrayIdx to a product of a stride and a - // constant index, and invokes allocateCandidatesAndFindBasis with the - // factorings. - void factorArrayIndex(Value *ArrayIdx, const SCEV *Base, uint64_t ElementSize, - GetElementPtrInst *GEP); + void rewriteCandidate(const Candidate &C); // Emit code that computes the "bump" from Basis to C. static Value *emitBump(const Candidate &Basis, const Candidate &C, @@ -247,12 +408,209 @@ private: TargetTransformInfo *TTI = nullptr; std::list<Candidate> Candidates; - // Temporarily holds all instructions that are unlinked (but not deleted) by - // rewriteCandidateWithBasis. These instructions will be actually removed - // after all rewriting finishes. - std::vector<Instruction *> UnlinkedInstructions; + // Map from SCEV to instructions that represent the value, + // instructions are sorted in depth-first order. + DenseMap<const SCEV *, SmallSetVector<Instruction *, 2>> SCEVToInsts; + + // Record the dependency between instructions. If C.Basis == B, we would have + // {B.Ins -> {C.Ins, ...}}. + MapVector<Instruction *, std::vector<Instruction *>> DependencyGraph; + + // Map between each instruction and its possible candidates. + DenseMap<Instruction *, SmallVector<Candidate *, 3>> RewriteCandidates; + + // All instructions that have candidates sort in topological order based on + // dependency graph, from roots to leaves. + std::vector<Instruction *> SortedCandidateInsts; + + // Record all instructions that are already rewritten and will be removed + // later. + std::vector<Instruction *> DeadInstructions; + + // Classify candidates against Delta kind + class CandidateDictTy { + public: + using CandsTy = SmallVector<Candidate *, 8>; + using BBToCandsTy = DenseMap<const BasicBlock *, CandsTy>; + + private: + // Index delta Basis must have the same (Base, StrideSCEV, Inst.Type) + using IndexDeltaKeyTy = std::tuple<const SCEV *, const SCEV *, Type *>; + DenseMap<IndexDeltaKeyTy, BBToCandsTy> IndexDeltaCandidates; + + // Base delta Basis must have the same (StrideSCEV, Index, Inst.Type) + using BaseDeltaKeyTy = std::tuple<const SCEV *, ConstantInt *, Type *>; + DenseMap<BaseDeltaKeyTy, BBToCandsTy> BaseDeltaCandidates; + + // Stride delta Basis must have the same (Base, Index, Inst.Type) + using StrideDeltaKeyTy = std::tuple<const SCEV *, ConstantInt *, Type *>; + DenseMap<StrideDeltaKeyTy, BBToCandsTy> StrideDeltaCandidates; + + public: + // TODO: Disable index delta on GEP after we completely move + // from typed GEP to PtrAdd. + const BBToCandsTy *getCandidatesWithDeltaKind(const Candidate &C, + Candidate::DKind K) const { + assert(K != Candidate::InvalidDelta); + if (K == Candidate::IndexDelta) { + IndexDeltaKeyTy IndexDeltaKey(C.Base, C.StrideSCEV, C.Ins->getType()); + auto It = IndexDeltaCandidates.find(IndexDeltaKey); + if (It != IndexDeltaCandidates.end()) + return &It->second; + } else if (K == Candidate::BaseDelta) { + BaseDeltaKeyTy BaseDeltaKey(C.StrideSCEV, C.Index, C.Ins->getType()); + auto It = BaseDeltaCandidates.find(BaseDeltaKey); + if (It != BaseDeltaCandidates.end()) + return &It->second; + } else { + assert(K == Candidate::StrideDelta); + StrideDeltaKeyTy StrideDeltaKey(C.Base, C.Index, C.Ins->getType()); + auto It = StrideDeltaCandidates.find(StrideDeltaKey); + if (It != StrideDeltaCandidates.end()) + return &It->second; + } + return nullptr; + } + + // Pointers to C must remain valid until CandidateDict is cleared. + void add(Candidate &C) { + Type *ValueType = C.Ins->getType(); + BasicBlock *BB = C.Ins->getParent(); + IndexDeltaKeyTy IndexDeltaKey(C.Base, C.StrideSCEV, ValueType); + BaseDeltaKeyTy BaseDeltaKey(C.StrideSCEV, C.Index, ValueType); + StrideDeltaKeyTy StrideDeltaKey(C.Base, C.Index, ValueType); + IndexDeltaCandidates[IndexDeltaKey][BB].push_back(&C); + BaseDeltaCandidates[BaseDeltaKey][BB].push_back(&C); + StrideDeltaCandidates[StrideDeltaKey][BB].push_back(&C); + } + // Remove all mappings from set + void clear() { + IndexDeltaCandidates.clear(); + BaseDeltaCandidates.clear(); + StrideDeltaCandidates.clear(); + } + } CandidateDict; + + const SCEV *getAndRecordSCEV(Value *V) { + auto *S = SE->getSCEV(V); + if (isa<Instruction>(V) && !(isa<SCEVCouldNotCompute>(S) || + isa<SCEVUnknown>(S) || isa<SCEVConstant>(S))) + SCEVToInsts[S].insert(cast<Instruction>(V)); + + return S; + } + + bool candidatePredicate(Candidate *Basis, Candidate &C, Candidate::DKind K); + + bool searchFrom(const CandidateDictTy::BBToCandsTy &BBToCands, Candidate &C, + Candidate::DKind K); + + // Get the nearest instruction before CI that represents the value of S, + // return nullptr if no instruction is associated with S or S is not a + // reusable expression. + Value *getNearestValueOfSCEV(const SCEV *S, const Instruction *CI) const { + if (isa<SCEVCouldNotCompute>(S)) + return nullptr; + + if (auto *SU = dyn_cast<SCEVUnknown>(S)) + return SU->getValue(); + if (auto *SC = dyn_cast<SCEVConstant>(S)) + return SC->getValue(); + + auto It = SCEVToInsts.find(S); + if (It == SCEVToInsts.end()) + return nullptr; + + // Instructions are sorted in depth-first order, so search for the nearest + // instruction by walking the list in reverse order. + for (Instruction *I : reverse(It->second)) + if (DT->dominates(I, CI)) + return I; + + return nullptr; + } + + struct DeltaInfo { + Candidate *Cand; + Candidate::DKind DeltaKind; + Value *Delta; + + DeltaInfo() + : Cand(nullptr), DeltaKind(Candidate::InvalidDelta), Delta(nullptr) {} + DeltaInfo(Candidate *Cand, Candidate::DKind DeltaKind, Value *Delta) + : Cand(Cand), DeltaKind(DeltaKind), Delta(Delta) {} + operator bool() const { return Cand != nullptr; } + }; + + friend raw_ostream &operator<<(raw_ostream &OS, const DeltaInfo &DI); + + DeltaInfo compressPath(Candidate &C, Candidate *Basis) const; + + Candidate *pickRewriteCandidate(Instruction *I) const; + void sortCandidateInstructions(); + Value *getDelta(const Candidate &C, const Candidate &Basis, + Candidate::DKind K) const; + static bool isSimilar(Candidate &C, Candidate &Basis, Candidate::DKind K); + + // Add Basis -> C in DependencyGraph and propagate + // C.Stride and C.Delta's dependency to C + void addDependency(Candidate &C, Candidate *Basis) { + if (Basis) + DependencyGraph[Basis->Ins].emplace_back(C.Ins); + + // If any candidate of Inst has a basis, then Inst will be rewritten, + // C must be rewritten after rewriting Inst, so we need to propagate + // the dependency to C + auto PropagateDependency = [&](Instruction *Inst) { + if (auto CandsIt = RewriteCandidates.find(Inst); + CandsIt != RewriteCandidates.end() && + llvm::any_of(CandsIt->second, + [](Candidate *Cand) { return Cand->Basis; })) + DependencyGraph[Inst].emplace_back(C.Ins); + }; + + // If C has a variable delta and the delta is a candidate, + // propagate its dependency to C + if (auto *DeltaInst = dyn_cast_or_null<Instruction>(C.Delta)) + PropagateDependency(DeltaInst); + + // If the stride is a candidate, propagate its dependency to C + if (auto *StrideInst = dyn_cast<Instruction>(C.Stride)) + PropagateDependency(StrideInst); + }; }; +inline raw_ostream &operator<<(raw_ostream &OS, + const StraightLineStrengthReduce::Candidate &C) { + OS << "Ins: " << *C.Ins << "\n Base: " << *C.Base + << "\n Index: " << *C.Index << "\n Stride: " << *C.Stride + << "\n StrideSCEV: " << *C.StrideSCEV; + if (C.Basis) + OS << "\n Delta: " << *C.Delta << "\n Basis: \n [ " << *C.Basis << " ]"; + return OS; +} + +[[maybe_unused]] LLVM_DUMP_METHOD inline raw_ostream & +operator<<(raw_ostream &OS, const StraightLineStrengthReduce::DeltaInfo &DI) { + OS << "Cand: " << *DI.Cand << "\n"; + OS << "Delta Kind: "; + switch (DI.DeltaKind) { + case StraightLineStrengthReduce::Candidate::IndexDelta: + OS << "Index"; + break; + case StraightLineStrengthReduce::Candidate::BaseDelta: + OS << "Base"; + break; + case StraightLineStrengthReduce::Candidate::StrideDelta: + OS << "Stride"; + break; + default: + break; + } + OS << "\nDelta: " << *DI.Delta; + return OS; +} + } // end anonymous namespace char StraightLineStrengthReduceLegacyPass::ID = 0; @@ -269,17 +627,301 @@ FunctionPass *llvm::createStraightLineStrengthReducePass() { return new StraightLineStrengthReduceLegacyPass(); } -bool StraightLineStrengthReduce::isBasisFor(const Candidate &Basis, - const Candidate &C) { - return (Basis.Ins != C.Ins && // skip the same instruction - // They must have the same type too. Basis.Base == C.Base - // doesn't guarantee their types are the same (PR23975). - Basis.Ins->getType() == C.Ins->getType() && - // Basis must dominate C in order to rewrite C with respect to Basis. - DT->dominates(Basis.Ins->getParent(), C.Ins->getParent()) && - // They share the same base, stride, and candidate kind. - Basis.Base == C.Base && Basis.Stride == C.Stride && - Basis.CandidateKind == C.CandidateKind); +// A helper function that unifies the bitwidth of A and B. +static void unifyBitWidth(APInt &A, APInt &B) { + if (A.getBitWidth() < B.getBitWidth()) + A = A.sext(B.getBitWidth()); + else if (A.getBitWidth() > B.getBitWidth()) + B = B.sext(A.getBitWidth()); +} + +Value *StraightLineStrengthReduce::getDelta(const Candidate &C, + const Candidate &Basis, + Candidate::DKind K) const { + if (K == Candidate::IndexDelta) { + APInt Idx = C.Index->getValue(); + APInt BasisIdx = Basis.Index->getValue(); + unifyBitWidth(Idx, BasisIdx); + APInt IndexDelta = Idx - BasisIdx; + IntegerType *DeltaType = + IntegerType::get(C.Ins->getContext(), IndexDelta.getBitWidth()); + return ConstantInt::get(DeltaType, IndexDelta); + } else if (K == Candidate::BaseDelta || K == Candidate::StrideDelta) { + const SCEV *BasisPart = + (K == Candidate::BaseDelta) ? Basis.Base : Basis.StrideSCEV; + const SCEV *CandPart = (K == Candidate::BaseDelta) ? C.Base : C.StrideSCEV; + const SCEV *Diff = SE->getMinusSCEV(CandPart, BasisPart); + return getNearestValueOfSCEV(Diff, C.Ins); + } + return nullptr; +} + +bool StraightLineStrengthReduce::isSimilar(Candidate &C, Candidate &Basis, + Candidate::DKind K) { + bool SameType = false; + switch (K) { + case Candidate::StrideDelta: + SameType = C.StrideSCEV->getType() == Basis.StrideSCEV->getType(); + break; + case Candidate::BaseDelta: + SameType = C.Base->getType() == Basis.Base->getType(); + break; + case Candidate::IndexDelta: + SameType = true; + break; + default:; + } + return SameType && Basis.Ins != C.Ins && + Basis.CandidateKind == C.CandidateKind; +} + +// Try to find a Delta that C can reuse Basis to rewrite. +// Set C.Delta, C.Basis, and C.DeltaKind if found. +// Return true if found a constant delta. +// Return false if not found or the delta is not a constant. +bool StraightLineStrengthReduce::candidatePredicate(Candidate *Basis, + Candidate &C, + Candidate::DKind K) { + SmallVector<Instruction *> DropPoisonGeneratingInsts; + // Ensure the IR of Basis->Ins is not more poisonous than its SCEV. + if (!isSimilar(C, *Basis, K) || + (EnablePoisonReuseGuard && + !SE->canReuseInstruction(SE->getSCEV(Basis->Ins), Basis->Ins, + DropPoisonGeneratingInsts))) + return false; + + assert(DT->dominates(Basis->Ins, C.Ins)); + Value *Delta = getDelta(C, *Basis, K); + if (!Delta) + return false; + + // IndexDelta rewrite is not always profitable, e.g., + // X = B + 8 * S + // Y = B + S, + // rewriting Y to X - 7 * S is probably a bad idea. + // So, we need to check if the rewrite form's computation efficiency + // is better than the original form. + if (K == Candidate::IndexDelta && + !C.isProfitableRewrite(*Delta, Candidate::IndexDelta)) + return false; + + // If there is a Delta that we can reuse Basis to rewrite C, + // clean up DropPoisonGeneratingInsts returned by successful + // SE->canReuseInstruction() + for (Instruction *I : DropPoisonGeneratingInsts) + I->dropPoisonGeneratingAnnotations(); + + // Record delta if none has been found yet, or the new delta is + // a constant that is better than the existing delta. + if (!C.Delta || isa<ConstantInt>(Delta)) { + C.Delta = Delta; + C.Basis = Basis; + C.DeltaKind = K; + } + return isa<ConstantInt>(C.Delta); +} + +// return true if find a Basis with constant delta and stop searching, +// return false if did not find a Basis or the delta is not a constant +// and continue searching for a Basis with constant delta +bool StraightLineStrengthReduce::searchFrom( + const CandidateDictTy::BBToCandsTy &BBToCands, Candidate &C, + Candidate::DKind K) { + + // Stride delta rewrite on Mul form is usually non-profitable, and Base + // delta rewrite sometimes is profitable, so we do not support them on Mul. + if (C.CandidateKind == Candidate::Mul && K != Candidate::IndexDelta) + return false; + + // Search dominating candidates by walking the immediate-dominator chain + // from the candidate's defining block upward. Visiting blocks in this + // order ensures we prefer the closest dominating basis. + const BasicBlock *BB = C.Ins->getParent(); + while (BB) { + auto It = BBToCands.find(BB); + if (It != BBToCands.end()) + for (Candidate *Basis : reverse(It->second)) + if (candidatePredicate(Basis, C, K)) + return true; + + const DomTreeNode *Node = DT->getNode(BB); + if (!Node) + break; + Node = Node->getIDom(); + BB = Node ? Node->getBlock() : nullptr; + } + return false; +} + +void StraightLineStrengthReduce::setBasisAndDeltaFor(Candidate &C) { + if (const auto *BaseDeltaCandidates = + CandidateDict.getCandidatesWithDeltaKind(C, Candidate::BaseDelta)) + if (searchFrom(*BaseDeltaCandidates, C, Candidate::BaseDelta)) { + LLVM_DEBUG(dbgs() << "Found delta from Base: " << *C.Delta << "\n"); + return; + } + + if (const auto *StrideDeltaCandidates = + CandidateDict.getCandidatesWithDeltaKind(C, Candidate::StrideDelta)) + if (searchFrom(*StrideDeltaCandidates, C, Candidate::StrideDelta)) { + LLVM_DEBUG(dbgs() << "Found delta from Stride: " << *C.Delta << "\n"); + return; + } + + if (const auto *IndexDeltaCandidates = + CandidateDict.getCandidatesWithDeltaKind(C, Candidate::IndexDelta)) + if (searchFrom(*IndexDeltaCandidates, C, Candidate::IndexDelta)) { + LLVM_DEBUG(dbgs() << "Found delta from Index: " << *C.Delta << "\n"); + return; + } + + // If we did not find a constant delta, we might have found a variable delta + if (C.Delta) { + LLVM_DEBUG({ + dbgs() << "Found delta from "; + if (C.DeltaKind == Candidate::BaseDelta) + dbgs() << "Base: "; + else + dbgs() << "Stride: "; + dbgs() << *C.Delta << "\n"; + }); + assert(C.DeltaKind != Candidate::InvalidDelta && C.Basis); + } +} + +// Compress the path from `Basis` to the deepest Basis in the Basis chain +// to avoid non-profitable data dependency and improve ILP. +// X = A + 1 +// Y = X + 1 +// Z = Y + 1 +// -> +// X = A + 1 +// Y = A + 2 +// Z = A + 3 +// Return the delta info for C aginst the new Basis +auto StraightLineStrengthReduce::compressPath(Candidate &C, + Candidate *Basis) const + -> DeltaInfo { + if (!Basis || !Basis->Basis || C.CandidateKind == Candidate::Mul) + return {}; + Candidate *Root = Basis; + Value *NewDelta = nullptr; + auto NewKind = Candidate::InvalidDelta; + + while (Root->Basis) { + Candidate *NextRoot = Root->Basis; + if (C.Base == NextRoot->Base && C.StrideSCEV == NextRoot->StrideSCEV && + isSimilar(C, *NextRoot, Candidate::IndexDelta)) { + ConstantInt *CI = + cast<ConstantInt>(getDelta(C, *NextRoot, Candidate::IndexDelta)); + if (CI->isZero() || CI->isOne() || isa<SCEVConstant>(C.StrideSCEV)) { + Root = NextRoot; + NewKind = Candidate::IndexDelta; + NewDelta = CI; + continue; + } + } + + const SCEV *CandPart = nullptr; + const SCEV *BasisPart = nullptr; + auto CurrKind = Candidate::InvalidDelta; + if (C.Base == NextRoot->Base && C.Index == NextRoot->Index) { + CandPart = C.StrideSCEV; + BasisPart = NextRoot->StrideSCEV; + CurrKind = Candidate::StrideDelta; + } else if (C.StrideSCEV == NextRoot->StrideSCEV && + C.Index == NextRoot->Index) { + CandPart = C.Base; + BasisPart = NextRoot->Base; + CurrKind = Candidate::BaseDelta; + } else + break; + + assert(CandPart && BasisPart); + if (!isSimilar(C, *NextRoot, CurrKind)) + break; + + if (auto DeltaVal = + dyn_cast<SCEVConstant>(SE->getMinusSCEV(CandPart, BasisPart))) { + Root = NextRoot; + NewDelta = DeltaVal->getValue(); + NewKind = CurrKind; + } else + break; + } + + if (Root != Basis) { + assert(NewKind != Candidate::InvalidDelta && NewDelta); + LLVM_DEBUG(dbgs() << "Found new Basis with " << *NewDelta + << " from path compression.\n"); + return {Root, NewKind, NewDelta}; + } + + return {}; +} + +// Topologically sort candidate instructions based on their relationship in +// dependency graph. +void StraightLineStrengthReduce::sortCandidateInstructions() { + SortedCandidateInsts.clear(); + // An instruction may have multiple candidates that get different Basis + // instructions, and each candidate can get dependencies from Basis and + // Stride when Stride will also be rewritten by SLSR. Hence, an instruction + // may have multiple dependencies. Use InDegree to ensure all dependencies + // processed before processing itself. + DenseMap<Instruction *, int> InDegree; + for (auto &KV : DependencyGraph) { + InDegree.try_emplace(KV.first, 0); + + for (auto *Child : KV.second) { + InDegree[Child]++; + } + } + std::queue<Instruction *> WorkList; + DenseSet<Instruction *> Visited; + + for (auto &KV : DependencyGraph) + if (InDegree[KV.first] == 0) + WorkList.push(KV.first); + + while (!WorkList.empty()) { + Instruction *I = WorkList.front(); + WorkList.pop(); + if (!Visited.insert(I).second) + continue; + + SortedCandidateInsts.push_back(I); + + for (auto *Next : DependencyGraph[I]) { + auto &Degree = InDegree[Next]; + if (--Degree == 0) + WorkList.push(Next); + } + } + + assert(SortedCandidateInsts.size() == DependencyGraph.size() && + "Dependency graph should not have cycles"); +} + +auto StraightLineStrengthReduce::pickRewriteCandidate(Instruction *I) const + -> Candidate * { + // Return the candidate of instruction I that has the highest profit. + auto It = RewriteCandidates.find(I); + if (It == RewriteCandidates.end()) + return nullptr; + + Candidate *BestC = nullptr; + auto BestEfficiency = Candidate::Unknown; + for (Candidate *C : reverse(It->second)) + if (C->Basis) { + auto Efficiency = C->getRewriteEfficiency(); + if (Efficiency > BestEfficiency) { + BestEfficiency = Efficiency; + BestC = C; + } + } + + return BestC; } static bool isGEPFoldable(GetElementPtrInst *GEP, @@ -299,8 +941,7 @@ static bool isAddFoldable(const SCEV *Base, ConstantInt *Index, Value *Stride, } bool StraightLineStrengthReduce::isFoldable(const Candidate &C, - TargetTransformInfo *TTI, - const DataLayout *DL) { + TargetTransformInfo *TTI) { if (C.CandidateKind == Candidate::Add) return isAddFoldable(C.Base, C.Index, C.Stride, TTI); if (C.CandidateKind == Candidate::GEP) @@ -308,75 +949,39 @@ bool StraightLineStrengthReduce::isFoldable(const Candidate &C, return false; } -// Returns true if GEP has zero or one non-zero index. -static bool hasOnlyOneNonZeroIndex(GetElementPtrInst *GEP) { - unsigned NumNonZeroIndices = 0; - for (Use &Idx : GEP->indices()) { - ConstantInt *ConstIdx = dyn_cast<ConstantInt>(Idx); - if (ConstIdx == nullptr || !ConstIdx->isZero()) - ++NumNonZeroIndices; - } - return NumNonZeroIndices <= 1; -} - -bool StraightLineStrengthReduce::isSimplestForm(const Candidate &C) { - if (C.CandidateKind == Candidate::Add) { - // B + 1 * S or B + (-1) * S - return C.Index->isOne() || C.Index->isMinusOne(); - } - if (C.CandidateKind == Candidate::Mul) { - // (B + 0) * S - return C.Index->isZero(); - } - if (C.CandidateKind == Candidate::GEP) { - // (char*)B + S or (char*)B - S - return ((C.Index->isOne() || C.Index->isMinusOne()) && - hasOnlyOneNonZeroIndex(cast<GetElementPtrInst>(C.Ins))); - } - return false; -} - -// TODO: We currently implement an algorithm whose time complexity is linear in -// the number of existing candidates. However, we could do better by using -// ScopedHashTable. Specifically, while traversing the dominator tree, we could -// maintain all the candidates that dominate the basic block being traversed in -// a ScopedHashTable. This hash table is indexed by the base and the stride of -// a candidate. Therefore, finding the immediate basis of a candidate boils down -// to one hash-table look up. void StraightLineStrengthReduce::allocateCandidatesAndFindBasis( Candidate::Kind CT, const SCEV *B, ConstantInt *Idx, Value *S, Instruction *I) { - Candidate C(CT, B, Idx, S, I); - // SLSR can complicate an instruction in two cases: - // - // 1. If we can fold I into an addressing mode, computing I is likely free or - // takes only one instruction. - // - // 2. I is already in a simplest form. For example, when - // X = B + 8 * S - // Y = B + S, - // rewriting Y to X - 7 * S is probably a bad idea. + // Record the SCEV of S that we may use it as a variable delta. + // Ensure that we rewrite C with a existing IR that reproduces delta value. + + Candidate C(CT, B, Idx, S, I, getAndRecordSCEV(S)); + // If we can fold I into an addressing mode, computing I is likely free or + // takes only one instruction. So, we don't need to analyze or rewrite it. // - // In the above cases, we still add I to the candidate list so that I can be - // the basis of other candidates, but we leave I's basis blank so that I - // won't be rewritten. - if (!isFoldable(C, TTI, DL) && !isSimplestForm(C)) { - // Try to compute the immediate basis of C. - unsigned NumIterations = 0; - // Limit the scan radius to avoid running in quadratice time. - static const unsigned MaxNumIterations = 50; - for (auto Basis = Candidates.rbegin(); - Basis != Candidates.rend() && NumIterations < MaxNumIterations; - ++Basis, ++NumIterations) { - if (isBasisFor(*Basis, C)) { - C.Basis = &(*Basis); - break; - } + // Currently, this algorithm can at best optimize complex computations into + // a `variable +/* constant` form. However, some targets have stricter + // constraints on the their addressing mode. + // For example, a `variable + constant` can only be folded to an addressing + // mode if the constant falls within a certain range. + // So, we also check if the instruction is already high efficient enough + // for the strength reduction algorithm. + if (!isFoldable(C, TTI) && !C.isHighEfficiency()) { + setBasisAndDeltaFor(C); + + // Compress unnecessary rewrite to improve ILP + if (auto Res = compressPath(C, C.Basis)) { + C.Basis = Res.Cand; + C.DeltaKind = Res.DeltaKind; + C.Delta = Res.Delta; } } // Regardless of whether we find a basis for C, we need to push C to the // candidate list so that it can be the basis of other candidates. + LLVM_DEBUG(dbgs() << "Allocated Candidate: " << C << "\n"); Candidates.push_back(C); + RewriteCandidates[C.Ins].push_back(&Candidates.back()); + CandidateDict.add(Candidates.back()); } void StraightLineStrengthReduce::allocateCandidatesAndFindBasis( @@ -476,54 +1081,6 @@ void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForMul( } void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForGEP( - const SCEV *B, ConstantInt *Idx, Value *S, uint64_t ElementSize, - Instruction *I) { - // I = B + sext(Idx *nsw S) * ElementSize - // = B + (sext(Idx) * sext(S)) * ElementSize - // = B + (sext(Idx) * ElementSize) * sext(S) - // Casting to IntegerType is safe because we skipped vector GEPs. - IntegerType *PtrIdxTy = cast<IntegerType>(DL->getIndexType(I->getType())); - ConstantInt *ScaledIdx = ConstantInt::get( - PtrIdxTy, Idx->getSExtValue() * (int64_t)ElementSize, true); - allocateCandidatesAndFindBasis(Candidate::GEP, B, ScaledIdx, S, I); -} - -void StraightLineStrengthReduce::factorArrayIndex(Value *ArrayIdx, - const SCEV *Base, - uint64_t ElementSize, - GetElementPtrInst *GEP) { - // At least, ArrayIdx = ArrayIdx *nsw 1. - allocateCandidatesAndFindBasisForGEP( - Base, ConstantInt::get(cast<IntegerType>(ArrayIdx->getType()), 1), - ArrayIdx, ElementSize, GEP); - Value *LHS = nullptr; - ConstantInt *RHS = nullptr; - // One alternative is matching the SCEV of ArrayIdx instead of ArrayIdx - // itself. This would allow us to handle the shl case for free. However, - // matching SCEVs has two issues: - // - // 1. this would complicate rewriting because the rewriting procedure - // would have to translate SCEVs back to IR instructions. This translation - // is difficult when LHS is further evaluated to a composite SCEV. - // - // 2. ScalarEvolution is designed to be control-flow oblivious. It tends - // to strip nsw/nuw flags which are critical for SLSR to trace into - // sext'ed multiplication. - if (match(ArrayIdx, m_NSWMul(m_Value(LHS), m_ConstantInt(RHS)))) { - // SLSR is currently unsafe if i * S may overflow. - // GEP = Base + sext(LHS *nsw RHS) * ElementSize - allocateCandidatesAndFindBasisForGEP(Base, RHS, LHS, ElementSize, GEP); - } else if (match(ArrayIdx, m_NSWShl(m_Value(LHS), m_ConstantInt(RHS)))) { - // GEP = Base + sext(LHS <<nsw RHS) * ElementSize - // = Base + sext(LHS *nsw (1 << RHS)) * ElementSize - APInt One(RHS->getBitWidth(), 1); - ConstantInt *PowerOf2 = - ConstantInt::get(RHS->getContext(), One << RHS->getValue()); - allocateCandidatesAndFindBasisForGEP(Base, PowerOf2, LHS, ElementSize, GEP); - } -} - -void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForGEP( GetElementPtrInst *GEP) { // TODO: handle vector GEPs if (GEP->getType()->isVectorTy()) @@ -546,11 +1103,14 @@ void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForGEP( const SCEV *BaseExpr = SE->getGEPExpr(cast<GEPOperator>(GEP), IndexExprs); Value *ArrayIdx = GEP->getOperand(I); uint64_t ElementSize = GTI.getSequentialElementStride(*DL); + IntegerType *PtrIdxTy = cast<IntegerType>(DL->getIndexType(GEP->getType())); + ConstantInt *ElementSizeIdx = ConstantInt::get(PtrIdxTy, ElementSize, true); if (ArrayIdx->getType()->getIntegerBitWidth() <= DL->getIndexSizeInBits(GEP->getAddressSpace())) { // Skip factoring if ArrayIdx is wider than the index size, because // ArrayIdx is implicitly truncated to the index size. - factorArrayIndex(ArrayIdx, BaseExpr, ElementSize, GEP); + allocateCandidatesAndFindBasis(Candidate::GEP, BaseExpr, ElementSizeIdx, + ArrayIdx, GEP); } // When ArrayIdx is the sext of a value, we try to factor that value as // well. Handling this case is important because array indices are @@ -561,118 +1121,159 @@ void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForGEP( DL->getIndexSizeInBits(GEP->getAddressSpace())) { // Skip factoring if TruncatedArrayIdx is wider than the pointer size, // because TruncatedArrayIdx is implicitly truncated to the pointer size. - factorArrayIndex(TruncatedArrayIdx, BaseExpr, ElementSize, GEP); + allocateCandidatesAndFindBasis(Candidate::GEP, BaseExpr, ElementSizeIdx, + TruncatedArrayIdx, GEP); } IndexExprs[I - 1] = OrigIndexExpr; } } -// A helper function that unifies the bitwidth of A and B. -static void unifyBitWidth(APInt &A, APInt &B) { - if (A.getBitWidth() < B.getBitWidth()) - A = A.sext(B.getBitWidth()); - else if (A.getBitWidth() > B.getBitWidth()) - B = B.sext(A.getBitWidth()); -} - Value *StraightLineStrengthReduce::emitBump(const Candidate &Basis, const Candidate &C, IRBuilder<> &Builder, const DataLayout *DL) { - APInt Idx = C.Index->getValue(), BasisIdx = Basis.Index->getValue(); - unifyBitWidth(Idx, BasisIdx); - APInt IndexOffset = Idx - BasisIdx; - - // Compute Bump = C - Basis = (i' - i) * S. - // Common case 1: if (i' - i) is 1, Bump = S. - if (IndexOffset == 1) - return C.Stride; - // Common case 2: if (i' - i) is -1, Bump = -S. - if (IndexOffset.isAllOnes()) - return Builder.CreateNeg(C.Stride); - - // Otherwise, Bump = (i' - i) * sext/trunc(S). Note that (i' - i) and S may - // have different bit widths. - IntegerType *DeltaType = - IntegerType::get(Basis.Ins->getContext(), IndexOffset.getBitWidth()); - Value *ExtendedStride = Builder.CreateSExtOrTrunc(C.Stride, DeltaType); - if (IndexOffset.isPowerOf2()) { - // If (i' - i) is a power of 2, Bump = sext/trunc(S) << log(i' - i). - ConstantInt *Exponent = ConstantInt::get(DeltaType, IndexOffset.logBase2()); - return Builder.CreateShl(ExtendedStride, Exponent); + auto CreateMul = [&](Value *LHS, Value *RHS) { + if (ConstantInt *CR = dyn_cast<ConstantInt>(RHS)) { + const APInt &ConstRHS = CR->getValue(); + IntegerType *DeltaType = + IntegerType::get(C.Ins->getContext(), ConstRHS.getBitWidth()); + if (ConstRHS.isPowerOf2()) { + ConstantInt *Exponent = + ConstantInt::get(DeltaType, ConstRHS.logBase2()); + return Builder.CreateShl(LHS, Exponent); + } + if (ConstRHS.isNegatedPowerOf2()) { + ConstantInt *Exponent = + ConstantInt::get(DeltaType, (-ConstRHS).logBase2()); + return Builder.CreateNeg(Builder.CreateShl(LHS, Exponent)); + } + } + + return Builder.CreateMul(LHS, RHS); + }; + + Value *Delta = C.Delta; + // If Delta is 0, C is a fully redundant of C.Basis, + // just replace C.Ins with Basis.Ins + if (ConstantInt *CI = dyn_cast<ConstantInt>(Delta); + CI && CI->getValue().isZero()) + return nullptr; + + if (C.DeltaKind == Candidate::IndexDelta) { + APInt IndexDelta = cast<ConstantInt>(C.Delta)->getValue(); + // IndexDelta + // X = B + i * S + // Y = B + i` * S + // = B + (i + IndexDelta) * S + // = B + i * S + IndexDelta * S + // = X + IndexDelta * S + // Bump = (i' - i) * S + + // Common case 1: if (i' - i) is 1, Bump = S. + if (IndexDelta == 1) + return C.Stride; + // Common case 2: if (i' - i) is -1, Bump = -S. + if (IndexDelta.isAllOnes()) + return Builder.CreateNeg(C.Stride); + + IntegerType *DeltaType = + IntegerType::get(Basis.Ins->getContext(), IndexDelta.getBitWidth()); + Value *ExtendedStride = Builder.CreateSExtOrTrunc(C.Stride, DeltaType); + + return CreateMul(ExtendedStride, C.Delta); } - if (IndexOffset.isNegatedPowerOf2()) { - // If (i - i') is a power of 2, Bump = -sext/trunc(S) << log(i' - i). - ConstantInt *Exponent = - ConstantInt::get(DeltaType, (-IndexOffset).logBase2()); - return Builder.CreateNeg(Builder.CreateShl(ExtendedStride, Exponent)); + + assert(C.DeltaKind == Candidate::StrideDelta || + C.DeltaKind == Candidate::BaseDelta); + assert(C.CandidateKind != Candidate::Mul); + // StrideDelta + // X = B + i * S + // Y = B + i * S' + // = B + i * (S + StrideDelta) + // = B + i * S + i * StrideDelta + // = X + i * StrideDelta + // Bump = i * (S' - S) + // + // BaseDelta + // X = B + i * S + // Y = B' + i * S + // = (B + BaseDelta) + i * S + // = X + BaseDelta + // Bump = (B' - B). + Value *Bump = C.Delta; + if (C.DeltaKind == Candidate::StrideDelta) { + // If this value is consumed by a GEP, promote StrideDelta before doing + // StrideDelta * Index to ensure the same semantics as the original GEP. + if (C.CandidateKind == Candidate::GEP) { + auto *GEP = cast<GetElementPtrInst>(C.Ins); + Type *NewScalarIndexTy = + DL->getIndexType(GEP->getPointerOperandType()->getScalarType()); + Bump = Builder.CreateSExtOrTrunc(Bump, NewScalarIndexTy); + } + if (!C.Index->isOne()) { + Value *ExtendedIndex = + Builder.CreateSExtOrTrunc(C.Index, Bump->getType()); + Bump = CreateMul(Bump, ExtendedIndex); + } } - Constant *Delta = ConstantInt::get(DeltaType, IndexOffset); - return Builder.CreateMul(ExtendedStride, Delta); + return Bump; } -void StraightLineStrengthReduce::rewriteCandidateWithBasis( - const Candidate &C, const Candidate &Basis) { +void StraightLineStrengthReduce::rewriteCandidate(const Candidate &C) { if (!DebugCounter::shouldExecute(StraightLineStrengthReduceCounter)) return; - assert(C.CandidateKind == Basis.CandidateKind && C.Base == Basis.Base && - C.Stride == Basis.Stride); - // We run rewriteCandidateWithBasis on all candidates in a post-order, so the - // basis of a candidate cannot be unlinked before the candidate. - assert(Basis.Ins->getParent() != nullptr && "the basis is unlinked"); - - // An instruction can correspond to multiple candidates. Therefore, instead of - // simply deleting an instruction when we rewrite it, we mark its parent as - // nullptr (i.e. unlink it) so that we can skip the candidates whose - // instruction is already rewritten. - if (!C.Ins->getParent()) - return; + const Candidate &Basis = *C.Basis; + assert(C.Delta && C.CandidateKind == Basis.CandidateKind && + C.hasValidDelta(Basis)); IRBuilder<> Builder(C.Ins); Value *Bump = emitBump(Basis, C, Builder, DL); Value *Reduced = nullptr; // equivalent to but weaker than C.Ins - switch (C.CandidateKind) { - case Candidate::Add: - case Candidate::Mul: { - // C = Basis + Bump - Value *NegBump; - if (match(Bump, m_Neg(m_Value(NegBump)))) { - // If Bump is a neg instruction, emit C = Basis - (-Bump). - Reduced = Builder.CreateSub(Basis.Ins, NegBump); - // We only use the negative argument of Bump, and Bump itself may be - // trivially dead. - RecursivelyDeleteTriviallyDeadInstructions(Bump); - } else { - // It's tempting to preserve nsw on Bump and/or Reduced. However, it's - // usually unsound, e.g., - // - // X = (-2 +nsw 1) *nsw INT_MAX - // Y = (-2 +nsw 3) *nsw INT_MAX - // => - // Y = X + 2 * INT_MAX - // - // Neither + and * in the resultant expression are nsw. - Reduced = Builder.CreateAdd(Basis.Ins, Bump); + // If delta is 0, C is a fully redundant of Basis, and Bump is nullptr, + // just replace C.Ins with Basis.Ins + if (!Bump) + Reduced = Basis.Ins; + else { + switch (C.CandidateKind) { + case Candidate::Add: + case Candidate::Mul: { + // C = Basis + Bump + Value *NegBump; + if (match(Bump, m_Neg(m_Value(NegBump)))) { + // If Bump is a neg instruction, emit C = Basis - (-Bump). + Reduced = Builder.CreateSub(Basis.Ins, NegBump); + // We only use the negative argument of Bump, and Bump itself may be + // trivially dead. + RecursivelyDeleteTriviallyDeadInstructions(Bump); + } else { + // It's tempting to preserve nsw on Bump and/or Reduced. However, it's + // usually unsound, e.g., + // + // X = (-2 +nsw 1) *nsw INT_MAX + // Y = (-2 +nsw 3) *nsw INT_MAX + // => + // Y = X + 2 * INT_MAX + // + // Neither + and * in the resultant expression are nsw. + Reduced = Builder.CreateAdd(Basis.Ins, Bump); + } + break; } - break; - } - case Candidate::GEP: { - bool InBounds = cast<GetElementPtrInst>(C.Ins)->isInBounds(); - // C = (char *)Basis + Bump - Reduced = Builder.CreatePtrAdd(Basis.Ins, Bump, "", InBounds); - break; + case Candidate::GEP: { + bool InBounds = cast<GetElementPtrInst>(C.Ins)->isInBounds(); + // C = (char *)Basis + Bump + Reduced = Builder.CreatePtrAdd(Basis.Ins, Bump, "", InBounds); + break; + } + default: + llvm_unreachable("C.CandidateKind is invalid"); + }; + Reduced->takeName(C.Ins); } - default: - llvm_unreachable("C.CandidateKind is invalid"); - }; - Reduced->takeName(C.Ins); C.Ins->replaceAllUsesWith(Reduced); - // Unlink C.Ins so that we can skip other candidates also corresponding to - // C.Ins. The actual deletion is postponed to the end of runOnFunction. - C.Ins->removeFromParent(); - UnlinkedInstructions.push_back(C.Ins); + DeadInstructions.push_back(C.Ins); } bool StraightLineStrengthReduceLegacyPass::runOnFunction(Function &F) { @@ -686,33 +1287,42 @@ bool StraightLineStrengthReduceLegacyPass::runOnFunction(Function &F) { } bool StraightLineStrengthReduce::runOnFunction(Function &F) { + LLVM_DEBUG(dbgs() << "SLSR on Function: " << F.getName() << "\n"); // Traverse the dominator tree in the depth-first order. This order makes sure // all bases of a candidate are in Candidates when we process it. for (const auto Node : depth_first(DT)) for (auto &I : *(Node->getBlock())) allocateCandidatesAndFindBasis(&I); - // Rewrite candidates in the reverse depth-first order. This order makes sure - // a candidate being rewritten is not a basis for any other candidate. - while (!Candidates.empty()) { - const Candidate &C = Candidates.back(); - if (C.Basis != nullptr) { - rewriteCandidateWithBasis(C, *C.Basis); - } - Candidates.pop_back(); - } - - // Delete all unlink instructions. - for (auto *UnlinkedInst : UnlinkedInstructions) { - for (unsigned I = 0, E = UnlinkedInst->getNumOperands(); I != E; ++I) { - Value *Op = UnlinkedInst->getOperand(I); - UnlinkedInst->setOperand(I, nullptr); - RecursivelyDeleteTriviallyDeadInstructions(Op); - } - UnlinkedInst->deleteValue(); + // Build the dependency graph and sort candidate instructions from dependency + // roots to leaves + for (auto &C : Candidates) { + DependencyGraph.try_emplace(C.Ins); + addDependency(C, C.Basis); } - bool Ret = !UnlinkedInstructions.empty(); - UnlinkedInstructions.clear(); + sortCandidateInstructions(); + + // Rewrite candidates in the topological order that rewrites a Candidate + // always before rewriting its Basis + for (Instruction *I : reverse(SortedCandidateInsts)) + if (Candidate *C = pickRewriteCandidate(I)) + rewriteCandidate(*C); + + for (auto *DeadIns : DeadInstructions) + // A dead instruction may be another dead instruction's op, + // don't delete an instruction twice + if (DeadIns->getParent()) + RecursivelyDeleteTriviallyDeadInstructions(DeadIns); + + bool Ret = !DeadInstructions.empty(); + DeadInstructions.clear(); + DependencyGraph.clear(); + RewriteCandidates.clear(); + SortedCandidateInsts.clear(); + // First clear all references to candidates in the list + CandidateDict.clear(); + // Then destroy the list + Candidates.clear(); return Ret; } diff --git a/llvm/lib/Transforms/Scalar/StructurizeCFG.cpp b/llvm/lib/Transforms/Scalar/StructurizeCFG.cpp index 0f3978f..0a8f5ea 100644 --- a/llvm/lib/Transforms/Scalar/StructurizeCFG.cpp +++ b/llvm/lib/Transforms/Scalar/StructurizeCFG.cpp @@ -143,8 +143,8 @@ struct SubGraphTraits { class WrappedSuccIterator : public iterator_adaptor_base< WrappedSuccIterator, BaseSuccIterator, - typename std::iterator_traits<BaseSuccIterator>::iterator_category, - NodeRef, std::ptrdiff_t, NodeRef *, NodeRef> { + std::iterator_traits<BaseSuccIterator>::iterator_category, NodeRef, + std::ptrdiff_t, NodeRef *, NodeRef> { SmallDenseSet<RegionNode *> *Nodes; public: @@ -558,11 +558,10 @@ void StructurizeCFG::analyzeLoops(RegionNode *N) { } else { // Test for successors as back edge BasicBlock *BB = N->getNodeAs<BasicBlock>(); - BranchInst *Term = cast<BranchInst>(BB->getTerminator()); - - for (BasicBlock *Succ : Term->successors()) - if (Visited.count(Succ)) - Loops[Succ] = BB; + if (BranchInst *Term = dyn_cast<BranchInst>(BB->getTerminator())) + for (BasicBlock *Succ : Term->successors()) + if (Visited.count(Succ)) + Loops[Succ] = BB; } } @@ -594,7 +593,7 @@ void StructurizeCFG::gatherPredicates(RegionNode *N) { for (BasicBlock *P : predecessors(BB)) { // Ignore it if it's a branch from outside into our region entry - if (!ParentRegion->contains(P)) + if (!ParentRegion->contains(P) || !dyn_cast<BranchInst>(P->getTerminator())) continue; Region *R = RI->getRegionFor(P); @@ -1402,13 +1401,17 @@ bool StructurizeCFG::makeUniformRegion(Region *R, UniformityInfo &UA) { /// Run the transformation for each region found bool StructurizeCFG::run(Region *R, DominatorTree *DT, const TargetTransformInfo *TTI) { - if (R->isTopLevelRegion()) + // CallBr and its corresponding direct target blocks are for now ignored by + // this pass. This is not a limitation for the currently intended uses cases + // of callbr in the AMDGPU backend. + // Parent and child regions are not affected by this (current) restriction. + // See `llvm/test/Transforms/StructurizeCFG/callbr.ll` for details. + if (R->isTopLevelRegion() || isa<CallBrInst>(R->getEntry()->getTerminator())) return false; this->DT = DT; this->TTI = TTI; Func = R->getEntry()->getParent(); - assert(hasOnlySimpleTerminator(*Func) && "Unsupported block terminator."); ParentRegion = R; |
