diff options
Diffstat (limited to 'llvm/lib/Transforms/Utils')
27 files changed, 1116 insertions, 490 deletions
diff --git a/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp b/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp index 11db0ec..b0c0408 100644 --- a/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp +++ b/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp @@ -92,6 +92,15 @@ emptyAndDetachBlock(BasicBlock *BB, "applying corresponding DTU updates."); } +bool llvm::HasLoopOrEntryConvergenceToken(const BasicBlock *BB) { + for (const Instruction &I : *BB) { + const ConvergenceControlInst *CCI = dyn_cast<ConvergenceControlInst>(&I); + if (CCI && (CCI->isLoop() || CCI->isEntry())) + return true; + } + return false; +} + void llvm::detachDeadBlocks(ArrayRef<BasicBlock *> BBs, SmallVectorImpl<DominatorTree::UpdateType> *Updates, bool KeepOneInputPHIs) { @@ -259,6 +268,13 @@ bool llvm::MergeBlockIntoPredecessor(BasicBlock *BB, DomTreeUpdater *DTU, if (llvm::is_contained(PN.incoming_values(), &PN)) return false; + // Don't break if both the basic block and the predecessor contain loop or + // entry convergent intrinsics, since there may only be one convergence token + // per block. + if (HasLoopOrEntryConvergenceToken(BB) && + HasLoopOrEntryConvergenceToken(PredBB)) + return false; + LLVM_DEBUG(dbgs() << "Merging: " << BB->getName() << " into " << PredBB->getName() << "\n"); @@ -739,9 +755,11 @@ BasicBlock *llvm::SplitCallBrEdge(BasicBlock *CallBrBlock, BasicBlock *Succ, updateCycleLoopInfo<CycleInfo, Cycle>(CI, CallBrBlock, CallBrTarget, Succ); if (DTU) { DTU->applyUpdates({{DominatorTree::Insert, CallBrBlock, CallBrTarget}}); - if (DTU->getDomTree().dominates(CallBrBlock, Succ)) - DTU->applyUpdates({{DominatorTree::Delete, CallBrBlock, Succ}, - {DominatorTree::Insert, CallBrTarget, Succ}}); + if (DTU->getDomTree().dominates(CallBrBlock, Succ)) { + if (!is_contained(successors(CallBrBlock), Succ)) + DTU->applyUpdates({{DominatorTree::Delete, CallBrBlock, Succ}}); + DTU->applyUpdates({{DominatorTree::Insert, CallBrTarget, Succ}}); + } } return CallBrTarget; diff --git a/llvm/lib/Transforms/Utils/BreakCriticalEdges.cpp b/llvm/lib/Transforms/Utils/BreakCriticalEdges.cpp index 42b1fdf..8aa8aa2 100644 --- a/llvm/lib/Transforms/Utils/BreakCriticalEdges.cpp +++ b/llvm/lib/Transforms/Utils/BreakCriticalEdges.cpp @@ -39,36 +39,36 @@ using namespace llvm; STATISTIC(NumBroken, "Number of blocks inserted"); namespace { - struct BreakCriticalEdges : public FunctionPass { - static char ID; // Pass identification, replacement for typeid - BreakCriticalEdges() : FunctionPass(ID) { - initializeBreakCriticalEdgesPass(*PassRegistry::getPassRegistry()); - } +struct BreakCriticalEdges : public FunctionPass { + static char ID; // Pass identification, replacement for typeid + BreakCriticalEdges() : FunctionPass(ID) { + initializeBreakCriticalEdgesPass(*PassRegistry::getPassRegistry()); + } - bool runOnFunction(Function &F) override { - auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>(); - auto *DT = DTWP ? &DTWP->getDomTree() : nullptr; + bool runOnFunction(Function &F) override { + auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>(); + auto *DT = DTWP ? &DTWP->getDomTree() : nullptr; - auto *PDTWP = getAnalysisIfAvailable<PostDominatorTreeWrapperPass>(); - auto *PDT = PDTWP ? &PDTWP->getPostDomTree() : nullptr; + auto *PDTWP = getAnalysisIfAvailable<PostDominatorTreeWrapperPass>(); + auto *PDT = PDTWP ? &PDTWP->getPostDomTree() : nullptr; - auto *LIWP = getAnalysisIfAvailable<LoopInfoWrapperPass>(); - auto *LI = LIWP ? &LIWP->getLoopInfo() : nullptr; - unsigned N = - SplitAllCriticalEdges(F, CriticalEdgeSplittingOptions(DT, LI, nullptr, PDT)); - NumBroken += N; - return N > 0; - } + auto *LIWP = getAnalysisIfAvailable<LoopInfoWrapperPass>(); + auto *LI = LIWP ? &LIWP->getLoopInfo() : nullptr; + unsigned N = SplitAllCriticalEdges( + F, CriticalEdgeSplittingOptions(DT, LI, nullptr, PDT)); + NumBroken += N; + return N > 0; + } - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addPreserved<DominatorTreeWrapperPass>(); - AU.addPreserved<LoopInfoWrapperPass>(); + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addPreserved<DominatorTreeWrapperPass>(); + AU.addPreserved<LoopInfoWrapperPass>(); - // No loop canonicalization guarantees are broken by this pass. - AU.addPreservedID(LoopSimplifyID); - } - }; -} + // No loop canonicalization guarantees are broken by this pass. + AU.addPreservedID(LoopSimplifyID); + } +}; +} // namespace char BreakCriticalEdges::ID = 0; INITIALIZE_PASS(BreakCriticalEdges, "break-crit-edges", @@ -76,6 +76,7 @@ INITIALIZE_PASS(BreakCriticalEdges, "break-crit-edges", // Publicly exposed interface to pass... char &llvm::BreakCriticalEdgesID = BreakCriticalEdges::ID; + FunctionPass *llvm::createBreakCriticalEdgesPass() { return new BreakCriticalEdges(); } diff --git a/llvm/lib/Transforms/Utils/BuildLibCalls.cpp b/llvm/lib/Transforms/Utils/BuildLibCalls.cpp index 573a781..a245b94 100644 --- a/llvm/lib/Transforms/Utils/BuildLibCalls.cpp +++ b/llvm/lib/Transforms/Utils/BuildLibCalls.cpp @@ -1227,9 +1227,6 @@ bool llvm::inferNonMandatoryLibFuncAttrs(Function &F, case LibFunc_atanhf: case LibFunc_atanhl: case LibFunc_atanl: - case LibFunc_ceil: - case LibFunc_ceilf: - case LibFunc_ceill: case LibFunc_cos: case LibFunc_cosh: case LibFunc_coshf: @@ -1283,6 +1280,12 @@ bool llvm::inferNonMandatoryLibFuncAttrs(Function &F, case LibFunc_ilogbl: case LibFunc_logf: case LibFunc_logl: + case LibFunc_nextafter: + case LibFunc_nextafterf: + case LibFunc_nextafterl: + case LibFunc_nexttoward: + case LibFunc_nexttowardf: + case LibFunc_nexttowardl: case LibFunc_pow: case LibFunc_powf: case LibFunc_powl: @@ -1292,9 +1295,6 @@ bool llvm::inferNonMandatoryLibFuncAttrs(Function &F, case LibFunc_rint: case LibFunc_rintf: case LibFunc_rintl: - case LibFunc_round: - case LibFunc_roundf: - case LibFunc_roundl: case LibFunc_scalbln: case LibFunc_scalblnf: case LibFunc_scalblnl: @@ -1332,6 +1332,9 @@ bool llvm::inferNonMandatoryLibFuncAttrs(Function &F, case LibFunc_copysign: case LibFunc_copysignf: case LibFunc_copysignl: + case LibFunc_ceil: + case LibFunc_ceilf: + case LibFunc_ceill: case LibFunc_fabs: case LibFunc_fabsf: case LibFunc_fabsl: @@ -1350,11 +1353,23 @@ bool llvm::inferNonMandatoryLibFuncAttrs(Function &F, case LibFunc_fmin: case LibFunc_fminf: case LibFunc_fminl: + case LibFunc_fmaximum_num: + case LibFunc_fmaximum_numf: + case LibFunc_fmaximum_numl: + case LibFunc_fminimum_num: + case LibFunc_fminimum_numf: + case LibFunc_fminimum_numl: case LibFunc_labs: case LibFunc_llabs: case LibFunc_nearbyint: case LibFunc_nearbyintf: case LibFunc_nearbyintl: + case LibFunc_round: + case LibFunc_roundf: + case LibFunc_roundl: + case LibFunc_roundeven: + case LibFunc_roundevenf: + case LibFunc_roundevenl: case LibFunc_toascii: case LibFunc_trunc: case LibFunc_truncf: diff --git a/llvm/lib/Transforms/Utils/BypassSlowDivision.cpp b/llvm/lib/Transforms/Utils/BypassSlowDivision.cpp index 7343c79..66d8fea 100644 --- a/llvm/lib/Transforms/Utils/BypassSlowDivision.cpp +++ b/llvm/lib/Transforms/Utils/BypassSlowDivision.cpp @@ -40,22 +40,22 @@ using namespace llvm; namespace { - struct QuotRemPair { - Value *Quotient; - Value *Remainder; - - QuotRemPair(Value *InQuotient, Value *InRemainder) - : Quotient(InQuotient), Remainder(InRemainder) {} - }; - - /// A quotient and remainder, plus a BB from which they logically "originate". - /// If you use Quotient or Remainder in a Phi node, you should use BB as its - /// corresponding predecessor. - struct QuotRemWithBB { - BasicBlock *BB = nullptr; - Value *Quotient = nullptr; - Value *Remainder = nullptr; - }; +struct QuotRemPair { + Value *Quotient; + Value *Remainder; + + QuotRemPair(Value *InQuotient, Value *InRemainder) + : Quotient(InQuotient), Remainder(InRemainder) {} +}; + +/// A quotient and remainder, plus a BB from which they logically "originate". +/// If you use Quotient or Remainder in a Phi node, you should use BB as its +/// corresponding predecessor. +struct QuotRemWithBB { + BasicBlock *BB = nullptr; + Value *Quotient = nullptr; + Value *Remainder = nullptr; +}; using DivCacheTy = DenseMap<DivRemMapKey, QuotRemPair>; using BypassWidthsTy = DenseMap<unsigned, unsigned>; @@ -335,10 +335,10 @@ Value *FastDivInsertionTask::insertOperandRuntimeCheck(Value *Op1, Value *Op2) { else OrV = Op1 ? Op1 : Op2; - // BitMask is inverted to check if the operands are - // larger than the bypass type - uint64_t BitMask = ~BypassType->getBitMask(); - Value *AndV = Builder.CreateAnd(OrV, BitMask); + // Check whether the operands are larger than the bypass type. + Value *AndV = Builder.CreateAnd( + OrV, APInt::getBitsSetFrom(OrV->getType()->getIntegerBitWidth(), + BypassType->getBitWidth())); // Compare operand values Value *ZeroV = ConstantInt::getSigned(getSlowType(), 0); diff --git a/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/llvm/lib/Transforms/Utils/CodeExtractor.cpp index 5ba6f95f..0ca1fa2 100644 --- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp +++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp @@ -63,7 +63,6 @@ #include <cstdint> #include <iterator> #include <map> -#include <utility> #include <vector> using namespace llvm; @@ -933,6 +932,7 @@ Function *CodeExtractor::constructFunctionDeclaration( case Attribute::CoroDestroyOnlyWhenComplete: case Attribute::CoroElideSafe: case Attribute::NoDivergenceSource: + case Attribute::NoCreateUndefOrPoison: continue; // Those attributes should be safe to propagate to the extracted function. case Attribute::AlwaysInline: diff --git a/llvm/lib/Transforms/Utils/DebugSSAUpdater.cpp b/llvm/lib/Transforms/Utils/DebugSSAUpdater.cpp index c0e7609..cceabd8 100644 --- a/llvm/lib/Transforms/Utils/DebugSSAUpdater.cpp +++ b/llvm/lib/Transforms/Utils/DebugSSAUpdater.cpp @@ -291,7 +291,6 @@ void DbgValueRangeTable::addVariable(Function *F, DebugVariableAggregate DVA) { // We don't have a single location for the variable's entire scope, so instead // we must now perform a liveness analysis to create a location list. - DenseMap<BasicBlock *, DbgValueDef> LiveInMap; SmallVector<DbgSSAPhi *> HypotheticalPHIs; DebugSSAUpdater SSAUpdater(&HypotheticalPHIs); SSAUpdater.initialize(); diff --git a/llvm/lib/Transforms/Utils/DeclareRuntimeLibcalls.cpp b/llvm/lib/Transforms/Utils/DeclareRuntimeLibcalls.cpp index 0642d51..94e8a33 100644 --- a/llvm/lib/Transforms/Utils/DeclareRuntimeLibcalls.cpp +++ b/llvm/lib/Transforms/Utils/DeclareRuntimeLibcalls.cpp @@ -11,27 +11,70 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Utils/DeclareRuntimeLibcalls.h" +#include "llvm/Analysis/RuntimeLibcallInfo.h" #include "llvm/IR/Module.h" #include "llvm/IR/RuntimeLibcalls.h" using namespace llvm; +static void mergeAttributes(LLVMContext &Ctx, const Module &M, + const DataLayout &DL, const Triple &TT, + Function *Func, FunctionType *FuncTy, + AttributeList FuncAttrs) { + AttributeList OldAttrs = Func->getAttributes(); + AttributeList NewAttrs = OldAttrs; + + { + AttrBuilder OldBuilder(Ctx, OldAttrs.getFnAttrs()); + AttrBuilder NewBuilder(Ctx, FuncAttrs.getFnAttrs()); + OldBuilder.merge(NewBuilder); + NewAttrs = NewAttrs.addFnAttributes(Ctx, OldBuilder); + } + + { + AttrBuilder OldBuilder(Ctx, OldAttrs.getRetAttrs()); + AttrBuilder NewBuilder(Ctx, FuncAttrs.getRetAttrs()); + OldBuilder.merge(NewBuilder); + NewAttrs = NewAttrs.addRetAttributes(Ctx, OldBuilder); + } + + for (unsigned I = 0, E = FuncTy->getNumParams(); I != E; ++I) { + AttrBuilder OldBuilder(Ctx, OldAttrs.getParamAttrs(I)); + AttrBuilder NewBuilder(Ctx, FuncAttrs.getParamAttrs(I)); + OldBuilder.merge(NewBuilder); + NewAttrs = NewAttrs.addParamAttributes(Ctx, I, OldBuilder); + } + + Func->setAttributes(NewAttrs); +} + PreservedAnalyses DeclareRuntimeLibcallsPass::run(Module &M, ModuleAnalysisManager &MAM) { - RTLIB::RuntimeLibcallsInfo RTLCI(M.getTargetTriple()); + const RTLIB::RuntimeLibcallsInfo &RTLCI = + MAM.getResult<RuntimeLibraryAnalysis>(M); + LLVMContext &Ctx = M.getContext(); + const DataLayout &DL = M.getDataLayout(); + const Triple &TT = M.getTargetTriple(); - for (RTLIB::LibcallImpl Impl : RTLCI.getLibcallImpls()) { - if (Impl == RTLIB::Unsupported) + for (RTLIB::LibcallImpl Impl : RTLIB::libcall_impls()) { + if (!RTLCI.isAvailable(Impl)) continue; - // TODO: Declare with correct type, calling convention, and attributes. + auto [FuncTy, FuncAttrs] = RTLCI.getFunctionTy(Ctx, TT, DL, Impl); - FunctionType *FuncTy = - FunctionType::get(Type::getVoidTy(Ctx), {}, /*IsVarArgs=*/true); + // TODO: Declare with correct type, calling convention, and attributes. + if (!FuncTy) + FuncTy = FunctionType::get(Type::getVoidTy(Ctx), {}, /*IsVarArgs=*/true); StringRef FuncName = RTLCI.getLibcallImplName(Impl); - M.getOrInsertFunction(FuncName, FuncTy); + + Function *Func = + cast<Function>(M.getOrInsertFunction(FuncName, FuncTy).getCallee()); + if (Func->getFunctionType() == FuncTy) { + mergeAttributes(Ctx, M, DL, TT, Func, FuncTy, FuncAttrs); + Func->setCallingConv(RTLCI.getLibcallImplCallingConv(Impl)); + } } return PreservedAnalyses::none(); diff --git a/llvm/lib/Transforms/Utils/Local.cpp b/llvm/lib/Transforms/Utils/Local.cpp index 46f2903..f7842a2 100644 --- a/llvm/lib/Transforms/Utils/Local.cpp +++ b/llvm/lib/Transforms/Utils/Local.cpp @@ -3416,7 +3416,11 @@ DIExpression *llvm::getExpressionForConstant(DIBuilder &DIB, const Constant &C, // Create integer constant expression. auto createIntegerExpression = [&DIB](const Constant &CV) -> DIExpression * { const APInt &API = cast<ConstantInt>(&CV)->getValue(); - std::optional<int64_t> InitIntOpt = API.trySExtValue(); + std::optional<int64_t> InitIntOpt; + if (API.getBitWidth() == 1) + InitIntOpt = API.tryZExtValue(); + else + InitIntOpt = API.trySExtValue(); return InitIntOpt ? DIB.createConstantValueExpression( static_cast<uint64_t>(*InitIntOpt)) : nullptr; @@ -3880,6 +3884,12 @@ bool llvm::canReplaceOperandWithVariable(const Instruction *I, unsigned OpIdx) { if (Op->isSwiftError()) return false; + // Protected pointer field loads/stores should be paired with the intrinsic + // to avoid unnecessary address escapes. + if (auto *II = dyn_cast<IntrinsicInst>(Op)) + if (II->getIntrinsicID() == Intrinsic::protected_field_ptr) + return false; + // Cannot replace alloca argument with phi/select. if (I->isLifetimeStartOrEnd()) return false; diff --git a/llvm/lib/Transforms/Utils/LoopPeel.cpp b/llvm/lib/Transforms/Utils/LoopPeel.cpp index e1dcaa85..174a21a 100644 --- a/llvm/lib/Transforms/Utils/LoopPeel.cpp +++ b/llvm/lib/Transforms/Utils/LoopPeel.cpp @@ -54,6 +54,7 @@ using namespace llvm::SCEVPatternMatch; STATISTIC(NumPeeled, "Number of loops peeled"); STATISTIC(NumPeeledEnd, "Number of loops peeled from end"); +namespace llvm { static cl::opt<unsigned> UnrollPeelCount( "unroll-peel-count", cl::Hidden, cl::desc("Set the unroll peeling count, for testing purposes")); @@ -87,6 +88,9 @@ static cl::opt<bool> EnablePeelingForIV( static const char *PeeledCountMetaData = "llvm.loop.peeled.count"; +extern cl::opt<bool> ProfcheckDisableMetadataFixes; +} // namespace llvm + // Check whether we are capable of peeling this loop. bool llvm::canPeel(const Loop *L) { // Make sure the loop is in simplified form @@ -415,9 +419,9 @@ std::optional<unsigned> PhiAnalyzer::calculateIterationsToPeel() { // the remainder loop after peeling. The load must also be used (transitively) // by an exit condition. Returns the number of iterations to peel off (at the // moment either 0 or 1). -static unsigned peelToTurnInvariantLoadsDerefencebale(Loop &L, - DominatorTree &DT, - AssumptionCache *AC) { +static unsigned peelToTurnInvariantLoadsDereferenceable(Loop &L, + DominatorTree &DT, + AssumptionCache *AC) { // Skip loops with a single exiting block, because there should be no benefit // for the heuristic below. if (L.getExitingBlock()) @@ -812,7 +816,7 @@ void llvm::computePeelCount(Loop *L, unsigned LoopSize, DesiredPeelCount = std::max(DesiredPeelCount, CountToEliminateCmps); if (DesiredPeelCount == 0) - DesiredPeelCount = peelToTurnInvariantLoadsDerefencebale(*L, DT, AC); + DesiredPeelCount = peelToTurnInvariantLoadsDereferenceable(*L, DT, AC); if (DesiredPeelCount > 0) { DesiredPeelCount = std::min(DesiredPeelCount, MaxPeelCount); @@ -1179,8 +1183,34 @@ bool llvm::peelLoop(Loop *L, unsigned PeelCount, bool PeelLast, LoopInfo *LI, // If the original loop may only execute a single iteration we need to // insert a trip count check and skip the original loop with the last - // iteration peeled off if necessary. - if (!SE->isKnownNonZero(BTC)) { + // iteration peeled off if necessary. Either way, we must update branch + // weights to maintain the loop body frequency. + if (SE->isKnownNonZero(BTC)) { + // We have just proven that, when reached, the original loop always + // executes at least two iterations. Thus, we unconditionally execute + // both the remaining loop's initial iteration and the peeled iteration. + // But that increases the latter's frequency above its frequency in the + // original loop. To maintain the total frequency, we compensate by + // decreasing the remaining loop body's frequency to indicate one less + // iteration. + // + // We use this formula to convert probability to/from frequency: + // Sum(i=0..inf)(P^i) = 1/(1-P) = Freq. + if (BranchProbability P = getLoopProbability(L); !P.isUnknown()) { + // Trying to subtract one from an infinite loop is pointless, and our + // formulas then produce division by zero, so skip that case. + if (BranchProbability ExitP = P.getCompl(); !ExitP.isZero()) { + double Freq = 1 / ExitP.toDouble(); + // No branch weights can produce a frequency of less than one given + // the initial iteration, and our formulas produce a negative + // probability if we try. + assert(Freq >= 1.0 && "expected freq >= 1 due to initial iteration"); + double NewFreq = std::max(Freq - 1, 1.0); + setLoopProbability( + L, BranchProbability::getBranchProbability(1 - 1 / NewFreq)); + } + } + } else { NewPreHeader = SplitEdge(PreHeader, Header, &DT, LI); SCEVExpander Expander(*SE, Latch->getDataLayout(), "loop-peel"); @@ -1190,7 +1220,24 @@ bool llvm::peelLoop(Loop *L, unsigned PeelCount, bool PeelLast, LoopInfo *LI, IRBuilder<> B(PreHeaderBR); Value *Cond = B.CreateICmpNE(BTCValue, ConstantInt::get(BTCValue->getType(), 0)); - B.CreateCondBr(Cond, NewPreHeader, InsertTop); + auto *BI = B.CreateCondBr(Cond, NewPreHeader, InsertTop); + SmallVector<uint32_t> Weights; + auto *OrigLatchBr = Latch->getTerminator(); + auto HasBranchWeights = !ProfcheckDisableMetadataFixes && + extractBranchWeights(*OrigLatchBr, Weights); + if (HasBranchWeights) { + // The probability that the new guard skips the loop to execute just one + // iteration is the original loop's probability of exiting at the latch + // after any iteration. That should maintain the original loop body + // frequency. Upon arriving at the loop, due to the guard, the + // probability of reaching iteration i of the new loop is the + // probability of reaching iteration i+1 of the original loop. The + // probability of reaching the peeled iteration is 1, which is the + // probability of reaching iteration 0 of the original loop. + if (L->getExitBlock() == OrigLatchBr->getSuccessor(0)) + std::swap(Weights[0], Weights[1]); + setBranchWeights(*BI, Weights, /*IsExpected=*/false); + } PreHeaderBR->eraseFromParent(); // PreHeader now dominates InsertTop. diff --git a/llvm/lib/Transforms/Utils/LoopSimplify.cpp b/llvm/lib/Transforms/Utils/LoopSimplify.cpp index 61ffb49..8da6a980 100644 --- a/llvm/lib/Transforms/Utils/LoopSimplify.cpp +++ b/llvm/lib/Transforms/Utils/LoopSimplify.cpp @@ -378,7 +378,7 @@ static BasicBlock *insertUniqueBackedgeBlock(Loop *L, BasicBlock *Preheader, if (P != Preheader) BackedgeBlocks.push_back(P); } - // Create and insert the new backedge block... + // Create and insert the new backedge block. BasicBlock *BEBlock = BasicBlock::Create(Header->getContext(), Header->getName() + ".backedge", F); BranchInst *BETerminator = BranchInst::Create(Header, BEBlock); @@ -737,39 +737,39 @@ bool llvm::simplifyLoop(Loop *L, DominatorTree *DT, LoopInfo *LI, } namespace { - struct LoopSimplify : public FunctionPass { - static char ID; // Pass identification, replacement for typeid - LoopSimplify() : FunctionPass(ID) { - initializeLoopSimplifyPass(*PassRegistry::getPassRegistry()); - } +struct LoopSimplify : public FunctionPass { + static char ID; // Pass identification, replacement for typeid + LoopSimplify() : FunctionPass(ID) { + initializeLoopSimplifyPass(*PassRegistry::getPassRegistry()); + } - bool runOnFunction(Function &F) override; + bool runOnFunction(Function &F) override; - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<AssumptionCacheTracker>(); + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<AssumptionCacheTracker>(); - // We need loop information to identify the loops... - AU.addRequired<DominatorTreeWrapperPass>(); - AU.addPreserved<DominatorTreeWrapperPass>(); + // We need loop information to identify the loops. + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addPreserved<DominatorTreeWrapperPass>(); - AU.addRequired<LoopInfoWrapperPass>(); - AU.addPreserved<LoopInfoWrapperPass>(); + AU.addRequired<LoopInfoWrapperPass>(); + AU.addPreserved<LoopInfoWrapperPass>(); - AU.addPreserved<BasicAAWrapperPass>(); - AU.addPreserved<AAResultsWrapperPass>(); - AU.addPreserved<GlobalsAAWrapperPass>(); - AU.addPreserved<ScalarEvolutionWrapperPass>(); - AU.addPreserved<SCEVAAWrapperPass>(); - AU.addPreservedID(LCSSAID); - AU.addPreservedID(BreakCriticalEdgesID); // No critical edges added. - AU.addPreserved<BranchProbabilityInfoWrapperPass>(); - AU.addPreserved<MemorySSAWrapperPass>(); - } + AU.addPreserved<BasicAAWrapperPass>(); + AU.addPreserved<AAResultsWrapperPass>(); + AU.addPreserved<GlobalsAAWrapperPass>(); + AU.addPreserved<ScalarEvolutionWrapperPass>(); + AU.addPreserved<SCEVAAWrapperPass>(); + AU.addPreservedID(LCSSAID); + AU.addPreservedID(BreakCriticalEdgesID); // No critical edges added. + AU.addPreserved<BranchProbabilityInfoWrapperPass>(); + AU.addPreserved<MemorySSAWrapperPass>(); + } - /// verifyAnalysis() - Verify LoopSimplifyForm's guarantees. - void verifyAnalysis() const override; - }; -} + /// verifyAnalysis() - Verify LoopSimplifyForm's guarantees. + void verifyAnalysis() const override; +}; +} // namespace char LoopSimplify::ID = 0; INITIALIZE_PASS_BEGIN(LoopSimplify, "loop-simplify", @@ -780,12 +780,12 @@ INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) INITIALIZE_PASS_END(LoopSimplify, "loop-simplify", "Canonicalize natural loops", false, false) -// Publicly exposed interface to pass... +// Publicly exposed interface to pass. char &llvm::LoopSimplifyID = LoopSimplify::ID; Pass *llvm::createLoopSimplifyPass() { return new LoopSimplify(); } /// runOnFunction - Run down all loops in the CFG (recursively, but we could do -/// it in any convenient order) inserting preheaders... +/// it in any convenient order) inserting preheaders. /// bool LoopSimplify::runOnFunction(Function &F) { bool Changed = false; diff --git a/llvm/lib/Transforms/Utils/LoopUnroll.cpp b/llvm/lib/Transforms/Utils/LoopUnroll.cpp index 94dfd3a..0f25639 100644 --- a/llvm/lib/Transforms/Utils/LoopUnroll.cpp +++ b/llvm/lib/Transforms/Utils/LoopUnroll.cpp @@ -66,7 +66,6 @@ #include "llvm/Transforms/Utils/ValueMapper.h" #include <assert.h> #include <numeric> -#include <type_traits> #include <vector> namespace llvm { @@ -1094,6 +1093,7 @@ llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI, if (!RdxResult) { RdxResult = PartialReductions.front(); IRBuilder Builder(ExitBlock, ExitBlock->getFirstNonPHIIt()); + Builder.setFastMathFlags(Reductions.begin()->second.getFastMathFlags()); RecurKind RK = Reductions.begin()->second.getRecurrenceKind(); for (Instruction *RdxPart : drop_begin(PartialReductions)) { RdxResult = Builder.CreateBinOp( @@ -1254,16 +1254,19 @@ llvm::canParallelizeReductionWhenUnrolling(PHINode &Phi, Loop *L, /*DemandedBits=*/nullptr, /*AC=*/nullptr, /*DT=*/nullptr, SE)) return std::nullopt; + if (RdxDesc.hasUsesOutsideReductionChain()) + return std::nullopt; RecurKind RK = RdxDesc.getRecurrenceKind(); // Skip unsupported reductions. - // TODO: Handle additional reductions, including FP and min-max - // reductions. - if (!RecurrenceDescriptor::isIntegerRecurrenceKind(RK) || - RecurrenceDescriptor::isAnyOfRecurrenceKind(RK) || + // TODO: Handle additional reductions, including min-max reductions. + if (RecurrenceDescriptor::isAnyOfRecurrenceKind(RK) || RecurrenceDescriptor::isFindIVRecurrenceKind(RK) || RecurrenceDescriptor::isMinMaxRecurrenceKind(RK)) return std::nullopt; + if (RdxDesc.hasExactFPMath()) + return std::nullopt; + if (RdxDesc.IntermediateStore) return std::nullopt; diff --git a/llvm/lib/Transforms/Utils/LoopUnrollAndJam.cpp b/llvm/lib/Transforms/Utils/LoopUnrollAndJam.cpp index ca90bb6..1e614bd 100644 --- a/llvm/lib/Transforms/Utils/LoopUnrollAndJam.cpp +++ b/llvm/lib/Transforms/Utils/LoopUnrollAndJam.cpp @@ -53,7 +53,6 @@ #include "llvm/Transforms/Utils/ValueMapper.h" #include <assert.h> #include <memory> -#include <type_traits> #include <vector> using namespace llvm; diff --git a/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp b/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp index 1e8f6cc..7de8683 100644 --- a/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp +++ b/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp @@ -202,6 +202,27 @@ static void ConnectProlog(Loop *L, Value *BECount, unsigned Count, /// probability of executing at least one more iteration? static BranchProbability probOfNextInRemainder(BranchProbability OriginalLoopProb, unsigned N) { + // OriginalLoopProb == 1 would produce a division by zero in the calculation + // below. The problem is that case indicates an always infinite loop, but a + // remainder loop cannot be calculated at run time if the original loop is + // infinite as infinity % UnrollCount is undefined. We then choose + // probabilities indicating that all remainder loop iterations will always + // execute. + // + // Currently, the remainder loop here is an epilogue, which cannot be reached + // if the original loop is infinite, so the aforementioned choice is + // arbitrary. + // + // FIXME: Branch weights still need to be fixed in the case of prologues + // (issue #135812). In that case, the aforementioned choice seems reasonable + // for the goal of maintaining the original loop's block frequencies. That + // is, an infinite loop's initial iterations are not skipped, and the prologue + // loop body might have unique blocks that execute a finite number of times + // if, for example, the original loop body contains conditionals like i < + // UnrollCount. + if (OriginalLoopProb == BranchProbability::getOne()) + return BranchProbability::getOne(); + // Each of these variables holds the original loop's probability that the // number of iterations it will execute is some m in the specified range. BranchProbability ProbOne = OriginalLoopProb; // 1 <= m @@ -474,16 +495,13 @@ static Loop *CloneLoopBlocks(Loop *L, Value *NewIter, BranchProbability ProbReaching = BranchProbability::getOne(); for (unsigned N = Count - 2; N >= 1; --N) { ProbReaching *= probOfNextInRemainder(OriginalLoopProb, N); - FreqRemIters += double(ProbReaching.getNumerator()) / - ProbReaching.getDenominator(); + FreqRemIters += ProbReaching.toDouble(); } } // Solve for the loop probability that would produce that frequency. // Sum(i=0..inf)(Prob^i) = 1/(1-Prob) = FreqRemIters. - double ProbDouble = 1 - 1 / FreqRemIters; - BranchProbability Prob = BranchProbability::getBranchProbability( - std::round(ProbDouble * BranchProbability::getDenominator()), - BranchProbability::getDenominator()); + BranchProbability Prob = + BranchProbability::getBranchProbability(1 - 1 / FreqRemIters); setBranchProbability(RemainderLoopLatch, Prob, /*ForFirstTarget=*/true); } NewIdx->addIncoming(Zero, InsertTop); diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp index 6e60b94..8e2a4f8 100644 --- a/llvm/lib/Transforms/Utils/LoopUtils.cpp +++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp @@ -913,11 +913,26 @@ llvm::getLoopEstimatedTripCount(Loop *L, // Return the estimated trip count from metadata unless the metadata is // missing or has no value. + // + // Some passes set llvm.loop.estimated_trip_count to 0. For example, after + // peeling 10 or more iterations from a loop with an estimated trip count of + // 10, llvm.loop.estimated_trip_count becomes 0 on the remaining loop. It + // indicates that, each time execution reaches the peeled iterations, + // execution is estimated to exit them without reaching the remaining loop's + // header. + // + // Even if the probability of reaching a loop's header is low, if it is + // reached, it is the start of an iteration. Consequently, some passes + // historically assume that llvm::getLoopEstimatedTripCount always returns a + // positive count or std::nullopt. Thus, return std::nullopt when + // llvm.loop.estimated_trip_count is 0. if (auto TC = getOptionalIntLoopAttribute(L, LLVMLoopEstimatedTripCount)) { LLVM_DEBUG(dbgs() << "getLoopEstimatedTripCount: " << LLVMLoopEstimatedTripCount << " metadata has trip " - << "count of " << *TC << " for " << DbgLoop(L) << "\n"); - return TC; + << "count of " << *TC + << (*TC == 0 ? " (returning std::nullopt)" : "") + << " for " << DbgLoop(L) << "\n"); + return *TC == 0 ? std::nullopt : std::optional(*TC); } // Estimate the trip count from latch branch weights. diff --git a/llvm/lib/Transforms/Utils/LoopVersioning.cpp b/llvm/lib/Transforms/Utils/LoopVersioning.cpp index ec2e6c1..9c8b6ef 100644 --- a/llvm/lib/Transforms/Utils/LoopVersioning.cpp +++ b/llvm/lib/Transforms/Utils/LoopVersioning.cpp @@ -23,6 +23,7 @@ #include "llvm/IR/Dominators.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/PassManager.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/Support/CommandLine.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" @@ -109,8 +110,12 @@ void LoopVersioning::versionLoop( // Insert the conditional branch based on the result of the memchecks. Instruction *OrigTerm = RuntimeCheckBB->getTerminator(); Builder.SetInsertPoint(OrigTerm); - Builder.CreateCondBr(RuntimeCheck, NonVersionedLoop->getLoopPreheader(), - VersionedLoop->getLoopPreheader()); + auto *BI = + Builder.CreateCondBr(RuntimeCheck, NonVersionedLoop->getLoopPreheader(), + VersionedLoop->getLoopPreheader()); + // We don't know what the probability of executing the versioned vs the + // unversioned variants is. + setExplicitlyUnknownBranchWeightsIfProfiled(*BI, DEBUG_TYPE); OrigTerm->eraseFromParent(); // The loops merge in the original exit block. This is now dominated by the diff --git a/llvm/lib/Transforms/Utils/LowerMemIntrinsics.cpp b/llvm/lib/Transforms/Utils/LowerMemIntrinsics.cpp index 18b0f61..4ab99ed 100644 --- a/llvm/lib/Transforms/Utils/LowerMemIntrinsics.cpp +++ b/llvm/lib/Transforms/Utils/LowerMemIntrinsics.cpp @@ -21,6 +21,218 @@ using namespace llvm; +/// \returns \p Len urem \p OpSize, checking for optimization opportunities. +/// \p OpSizeVal must be the integer value of the \c ConstantInt \p OpSize. +static Value *getRuntimeLoopRemainder(IRBuilderBase &B, Value *Len, + Value *OpSize, unsigned OpSizeVal) { + // For powers of 2, we can and by (OpSizeVal - 1) instead of using urem. + if (isPowerOf2_32(OpSizeVal)) + return B.CreateAnd(Len, OpSizeVal - 1); + return B.CreateURem(Len, OpSize); +} + +/// \returns (\p Len udiv \p OpSize) mul \p OpSize, checking for optimization +/// opportunities. +/// If \p RTLoopRemainder is provided, it must be the result of +/// \c getRuntimeLoopRemainder() with the same arguments. +static Value *getRuntimeLoopUnits(IRBuilderBase &B, Value *Len, Value *OpSize, + unsigned OpSizeVal, + Value *RTLoopRemainder = nullptr) { + if (!RTLoopRemainder) + RTLoopRemainder = getRuntimeLoopRemainder(B, Len, OpSize, OpSizeVal); + return B.CreateSub(Len, RTLoopRemainder); +} + +namespace { +/// Container for the return values of insertLoopExpansion. +struct LoopExpansionInfo { + /// The instruction at the end of the main loop body. + Instruction *MainLoopIP = nullptr; + + /// The unit index in the main loop body. + Value *MainLoopIndex = nullptr; + + /// The instruction at the end of the residual loop body. Can be nullptr if no + /// residual is required. + Instruction *ResidualLoopIP = nullptr; + + /// The unit index in the residual loop body. Can be nullptr if no residual is + /// required. + Value *ResidualLoopIndex = nullptr; +}; +} // namespace + +/// Insert the control flow and loop counters for a memcpy/memset loop +/// expansion. +/// +/// This function inserts IR corresponding to the following C code before +/// \p InsertBefore: +/// \code +/// LoopUnits = (Len / MainLoopStep) * MainLoopStep; +/// ResidualUnits = Len - LoopUnits; +/// MainLoopIndex = 0; +/// if (LoopUnits > 0) { +/// do { +/// // MainLoopIP +/// MainLoopIndex += MainLoopStep; +/// } while (MainLoopIndex < LoopUnits); +/// } +/// for (size_t i = 0; i < ResidualUnits; i += ResidualLoopStep) { +/// ResidualLoopIndex = LoopUnits + i; +/// // ResidualLoopIP +/// } +/// \endcode +/// +/// \p MainLoopStep and \p ResidualLoopStep determine by how many "units" the +/// loop index is increased in each iteration of the main and residual loops, +/// respectively. In most cases, the "unit" will be bytes, but larger units are +/// useful for lowering memset.pattern. +/// +/// The computation of \c LoopUnits and \c ResidualUnits is performed at compile +/// time if \p Len is a \c ConstantInt. +/// The second (residual) loop is omitted if \p ResidualLoopStep is 0 or equal +/// to \p MainLoopStep. +/// The generated \c MainLoopIP, \c MainLoopIndex, \c ResidualLoopIP, and +/// \c ResidualLoopIndex are returned in a \c LoopExpansionInfo object. +static LoopExpansionInfo insertLoopExpansion(Instruction *InsertBefore, + Value *Len, unsigned MainLoopStep, + unsigned ResidualLoopStep, + StringRef BBNamePrefix) { + assert((ResidualLoopStep == 0 || MainLoopStep % ResidualLoopStep == 0) && + "ResidualLoopStep must divide MainLoopStep if specified"); + assert(ResidualLoopStep <= MainLoopStep && + "ResidualLoopStep cannot be larger than MainLoopStep"); + assert(MainLoopStep > 0 && "MainLoopStep must be non-zero"); + LoopExpansionInfo LEI; + BasicBlock *PreLoopBB = InsertBefore->getParent(); + BasicBlock *PostLoopBB = PreLoopBB->splitBasicBlock( + InsertBefore, BBNamePrefix + "-post-expansion"); + Function *ParentFunc = PreLoopBB->getParent(); + LLVMContext &Ctx = PreLoopBB->getContext(); + IRBuilder<> PreLoopBuilder(PreLoopBB->getTerminator()); + + // Calculate the main loop trip count and remaining units to cover after the + // loop. + Type *LenType = Len->getType(); + IntegerType *ILenType = cast<IntegerType>(LenType); + ConstantInt *CIMainLoopStep = ConstantInt::get(ILenType, MainLoopStep); + + Value *LoopUnits = Len; + Value *ResidualUnits = nullptr; + // We can make a conditional branch unconditional if we know that the + // MainLoop must be executed at least once. + bool MustTakeMainLoop = false; + if (MainLoopStep != 1) { + if (auto *CLen = dyn_cast<ConstantInt>(Len)) { + uint64_t TotalUnits = CLen->getZExtValue(); + uint64_t LoopEndCount = alignDown(TotalUnits, MainLoopStep); + uint64_t ResidualCount = TotalUnits - LoopEndCount; + LoopUnits = ConstantInt::get(LenType, LoopEndCount); + ResidualUnits = ConstantInt::get(LenType, ResidualCount); + MustTakeMainLoop = LoopEndCount > 0; + // As an optimization, we could skip generating the residual loop if + // ResidualCount is known to be 0. However, current uses of this function + // don't request a residual loop if the length is constant (they generate + // a (potentially empty) sequence of loads and stores instead), so this + // optimization would have no effect here. + } else { + ResidualUnits = getRuntimeLoopRemainder(PreLoopBuilder, Len, + CIMainLoopStep, MainLoopStep); + LoopUnits = getRuntimeLoopUnits(PreLoopBuilder, Len, CIMainLoopStep, + MainLoopStep, ResidualUnits); + } + } else if (auto *CLen = dyn_cast<ConstantInt>(Len)) { + MustTakeMainLoop = CLen->getZExtValue() > 0; + } + + BasicBlock *MainLoopBB = BasicBlock::Create( + Ctx, BBNamePrefix + "-expansion-main-body", ParentFunc, PostLoopBB); + IRBuilder<> LoopBuilder(MainLoopBB); + + PHINode *LoopIndex = LoopBuilder.CreatePHI(LenType, 2, "loop-index"); + LEI.MainLoopIndex = LoopIndex; + LoopIndex->addIncoming(ConstantInt::get(LenType, 0U), PreLoopBB); + + Value *NewIndex = + LoopBuilder.CreateAdd(LoopIndex, ConstantInt::get(LenType, MainLoopStep)); + LoopIndex->addIncoming(NewIndex, MainLoopBB); + + // One argument of the addition is a loop-variant PHI, so it must be an + // Instruction (i.e., it cannot be a Constant). + LEI.MainLoopIP = cast<Instruction>(NewIndex); + + if (ResidualLoopStep > 0 && ResidualLoopStep < MainLoopStep) { + // Loop body for the residual accesses. + BasicBlock *ResLoopBB = + BasicBlock::Create(Ctx, BBNamePrefix + "-expansion-residual-body", + PreLoopBB->getParent(), PostLoopBB); + // BB to check if the residual loop is needed. + BasicBlock *ResidualCondBB = + BasicBlock::Create(Ctx, BBNamePrefix + "-expansion-residual-cond", + PreLoopBB->getParent(), ResLoopBB); + + // Enter the MainLoop unless no main loop iteration is required. + ConstantInt *Zero = ConstantInt::get(ILenType, 0U); + if (MustTakeMainLoop) + PreLoopBuilder.CreateBr(MainLoopBB); + else + PreLoopBuilder.CreateCondBr(PreLoopBuilder.CreateICmpNE(LoopUnits, Zero), + MainLoopBB, ResidualCondBB); + PreLoopBB->getTerminator()->eraseFromParent(); + + // Stay in the MainLoop until we have handled all the LoopUnits. Then go to + // the residual condition BB. + LoopBuilder.CreateCondBr(LoopBuilder.CreateICmpULT(NewIndex, LoopUnits), + MainLoopBB, ResidualCondBB); + + // Determine if we need to branch to the residual loop or bypass it. + IRBuilder<> RCBuilder(ResidualCondBB); + RCBuilder.CreateCondBr(RCBuilder.CreateICmpNE(ResidualUnits, Zero), + ResLoopBB, PostLoopBB); + + IRBuilder<> ResBuilder(ResLoopBB); + PHINode *ResidualIndex = + ResBuilder.CreatePHI(LenType, 2, "residual-loop-index"); + ResidualIndex->addIncoming(Zero, ResidualCondBB); + + // Add the offset at the end of the main loop to the loop counter of the + // residual loop to get the proper index. + Value *FullOffset = ResBuilder.CreateAdd(LoopUnits, ResidualIndex); + LEI.ResidualLoopIndex = FullOffset; + + Value *ResNewIndex = ResBuilder.CreateAdd( + ResidualIndex, ConstantInt::get(LenType, ResidualLoopStep)); + ResidualIndex->addIncoming(ResNewIndex, ResLoopBB); + + // One argument of the addition is a loop-variant PHI, so it must be an + // Instruction (i.e., it cannot be a Constant). + LEI.ResidualLoopIP = cast<Instruction>(ResNewIndex); + + // Stay in the residual loop until all ResidualUnits are handled. + ResBuilder.CreateCondBr( + ResBuilder.CreateICmpULT(ResNewIndex, ResidualUnits), ResLoopBB, + PostLoopBB); + } else { + // There is no need for a residual loop after the main loop. We do however + // need to patch up the control flow by creating the terminators for the + // preloop block and the main loop. + + // Enter the MainLoop unless no main loop iteration is required. + if (MustTakeMainLoop) { + PreLoopBuilder.CreateBr(MainLoopBB); + } else { + ConstantInt *Zero = ConstantInt::get(ILenType, 0U); + PreLoopBuilder.CreateCondBr(PreLoopBuilder.CreateICmpNE(LoopUnits, Zero), + MainLoopBB, PostLoopBB); + } + PreLoopBB->getTerminator()->eraseFromParent(); + // Stay in the MainLoop until we have handled all the LoopUnits. + LoopBuilder.CreateCondBr(LoopBuilder.CreateICmpULT(NewIndex, LoopUnits), + MainLoopBB, PostLoopBB); + } + return LEI; +} + void llvm::createMemCpyLoopKnownSize( Instruction *InsertBefore, Value *SrcAddr, Value *DstAddr, ConstantInt *CopyLen, Align SrcAlign, Align DstAlign, bool SrcIsVolatile, @@ -31,7 +243,6 @@ void llvm::createMemCpyLoopKnownSize( return; BasicBlock *PreLoopBB = InsertBefore->getParent(); - BasicBlock *PostLoopBB = nullptr; Function *ParentFunc = PreLoopBB->getParent(); LLVMContext &Ctx = PreLoopBB->getContext(); const DataLayout &DL = ParentFunc->getDataLayout(); @@ -56,37 +267,32 @@ void llvm::createMemCpyLoopKnownSize( uint64_t LoopEndCount = alignDown(CopyLen->getZExtValue(), LoopOpSize); + // Skip the loop expansion entirely if the loop would never be taken. if (LoopEndCount != 0) { - // Split - PostLoopBB = PreLoopBB->splitBasicBlock(InsertBefore, "memcpy-split"); - BasicBlock *LoopBB = - BasicBlock::Create(Ctx, "load-store-loop", ParentFunc, PostLoopBB); - PreLoopBB->getTerminator()->setSuccessor(0, LoopBB); - - IRBuilder<> PLBuilder(PreLoopBB->getTerminator()); + LoopExpansionInfo LEI = insertLoopExpansion(InsertBefore, CopyLen, + LoopOpSize, 0, "static-memcpy"); + // Fill MainLoopBB + IRBuilder<> MainLoopBuilder(LEI.MainLoopIP); Align PartDstAlign(commonAlignment(DstAlign, LoopOpSize)); Align PartSrcAlign(commonAlignment(SrcAlign, LoopOpSize)); - IRBuilder<> LoopBuilder(LoopBB); - PHINode *LoopIndex = LoopBuilder.CreatePHI(TypeOfCopyLen, 2, "loop-index"); - LoopIndex->addIncoming(ConstantInt::get(TypeOfCopyLen, 0U), PreLoopBB); - // Loop Body - // If we used LoopOpType as GEP element type, we would iterate over the // buffers in TypeStoreSize strides while copying TypeAllocSize bytes, i.e., // we would miss bytes if TypeStoreSize != TypeAllocSize. Therefore, use // byte offsets computed from the TypeStoreSize. - Value *SrcGEP = LoopBuilder.CreateInBoundsGEP(Int8Type, SrcAddr, LoopIndex); - LoadInst *Load = LoopBuilder.CreateAlignedLoad(LoopOpType, SrcGEP, - PartSrcAlign, SrcIsVolatile); + Value *SrcGEP = + MainLoopBuilder.CreateInBoundsGEP(Int8Type, SrcAddr, LEI.MainLoopIndex); + LoadInst *Load = MainLoopBuilder.CreateAlignedLoad( + LoopOpType, SrcGEP, PartSrcAlign, SrcIsVolatile); if (!CanOverlap) { // Set alias scope for loads. Load->setMetadata(LLVMContext::MD_alias_scope, MDNode::get(Ctx, NewScope)); } - Value *DstGEP = LoopBuilder.CreateInBoundsGEP(Int8Type, DstAddr, LoopIndex); - StoreInst *Store = LoopBuilder.CreateAlignedStore( + Value *DstGEP = + MainLoopBuilder.CreateInBoundsGEP(Int8Type, DstAddr, LEI.MainLoopIndex); + StoreInst *Store = MainLoopBuilder.CreateAlignedStore( Load, DstGEP, PartDstAlign, DstIsVolatile); if (!CanOverlap) { // Indicate that stores don't overlap loads. @@ -96,96 +302,63 @@ void llvm::createMemCpyLoopKnownSize( Load->setAtomic(AtomicOrdering::Unordered); Store->setAtomic(AtomicOrdering::Unordered); } - Value *NewIndex = LoopBuilder.CreateAdd( - LoopIndex, ConstantInt::get(TypeOfCopyLen, LoopOpSize)); - LoopIndex->addIncoming(NewIndex, LoopBB); - - // Create the loop branch condition. - Constant *LoopEndCI = ConstantInt::get(TypeOfCopyLen, LoopEndCount); - LoopBuilder.CreateCondBr(LoopBuilder.CreateICmpULT(NewIndex, LoopEndCI), - LoopBB, PostLoopBB); + assert(!LEI.ResidualLoopIP && !LEI.ResidualLoopIndex && + "No residual loop was requested"); } + // Copy the remaining bytes with straight-line code. uint64_t BytesCopied = LoopEndCount; uint64_t RemainingBytes = CopyLen->getZExtValue() - BytesCopied; - if (RemainingBytes) { - BasicBlock::iterator InsertIt = PostLoopBB ? PostLoopBB->getFirstNonPHIIt() - : InsertBefore->getIterator(); - IRBuilder<> RBuilder(InsertIt->getParent(), InsertIt); + if (RemainingBytes == 0) + return; - SmallVector<Type *, 5> RemainingOps; - TTI.getMemcpyLoopResidualLoweringType(RemainingOps, Ctx, RemainingBytes, - SrcAS, DstAS, SrcAlign, DstAlign, - AtomicElementSize); + IRBuilder<> RBuilder(InsertBefore); + SmallVector<Type *, 5> RemainingOps; + TTI.getMemcpyLoopResidualLoweringType(RemainingOps, Ctx, RemainingBytes, + SrcAS, DstAS, SrcAlign, DstAlign, + AtomicElementSize); - for (auto *OpTy : RemainingOps) { - Align PartSrcAlign(commonAlignment(SrcAlign, BytesCopied)); - Align PartDstAlign(commonAlignment(DstAlign, BytesCopied)); - - unsigned OperandSize = DL.getTypeStoreSize(OpTy); - assert( - (!AtomicElementSize || OperandSize % *AtomicElementSize == 0) && - "Atomic memcpy lowering is not supported for selected operand size"); - - Value *SrcGEP = RBuilder.CreateInBoundsGEP( - Int8Type, SrcAddr, ConstantInt::get(TypeOfCopyLen, BytesCopied)); - LoadInst *Load = - RBuilder.CreateAlignedLoad(OpTy, SrcGEP, PartSrcAlign, SrcIsVolatile); - if (!CanOverlap) { - // Set alias scope for loads. - Load->setMetadata(LLVMContext::MD_alias_scope, - MDNode::get(Ctx, NewScope)); - } - Value *DstGEP = RBuilder.CreateInBoundsGEP( - Int8Type, DstAddr, ConstantInt::get(TypeOfCopyLen, BytesCopied)); - StoreInst *Store = RBuilder.CreateAlignedStore(Load, DstGEP, PartDstAlign, - DstIsVolatile); - if (!CanOverlap) { - // Indicate that stores don't overlap loads. - Store->setMetadata(LLVMContext::MD_noalias, MDNode::get(Ctx, NewScope)); - } - if (AtomicElementSize) { - Load->setAtomic(AtomicOrdering::Unordered); - Store->setAtomic(AtomicOrdering::Unordered); - } - BytesCopied += OperandSize; + for (auto *OpTy : RemainingOps) { + Align PartSrcAlign(commonAlignment(SrcAlign, BytesCopied)); + Align PartDstAlign(commonAlignment(DstAlign, BytesCopied)); + + unsigned OperandSize = DL.getTypeStoreSize(OpTy); + assert((!AtomicElementSize || OperandSize % *AtomicElementSize == 0) && + "Atomic memcpy lowering is not supported for selected operand size"); + + Value *SrcGEP = RBuilder.CreateInBoundsGEP( + Int8Type, SrcAddr, ConstantInt::get(TypeOfCopyLen, BytesCopied)); + LoadInst *Load = + RBuilder.CreateAlignedLoad(OpTy, SrcGEP, PartSrcAlign, SrcIsVolatile); + if (!CanOverlap) { + // Set alias scope for loads. + Load->setMetadata(LLVMContext::MD_alias_scope, + MDNode::get(Ctx, NewScope)); + } + Value *DstGEP = RBuilder.CreateInBoundsGEP( + Int8Type, DstAddr, ConstantInt::get(TypeOfCopyLen, BytesCopied)); + StoreInst *Store = + RBuilder.CreateAlignedStore(Load, DstGEP, PartDstAlign, DstIsVolatile); + if (!CanOverlap) { + // Indicate that stores don't overlap loads. + Store->setMetadata(LLVMContext::MD_noalias, MDNode::get(Ctx, NewScope)); } + if (AtomicElementSize) { + Load->setAtomic(AtomicOrdering::Unordered); + Store->setAtomic(AtomicOrdering::Unordered); + } + BytesCopied += OperandSize; } assert(BytesCopied == CopyLen->getZExtValue() && "Bytes copied should match size in the call!"); } -// \returns \p Len urem \p OpSize, checking for optimization opportunities. -static Value *getRuntimeLoopRemainder(const DataLayout &DL, IRBuilderBase &B, - Value *Len, Value *OpSize, - unsigned OpSizeVal) { - // For powers of 2, we can and by (OpSizeVal - 1) instead of using urem. - if (isPowerOf2_32(OpSizeVal)) - return B.CreateAnd(Len, OpSizeVal - 1); - return B.CreateURem(Len, OpSize); -} - -// \returns (\p Len udiv \p OpSize) mul \p OpSize, checking for optimization -// opportunities. -// If RTLoopRemainder is provided, it must be the result of -// getRuntimeLoopRemainder() with the same arguments. -static Value *getRuntimeLoopBytes(const DataLayout &DL, IRBuilderBase &B, - Value *Len, Value *OpSize, unsigned OpSizeVal, - Value *RTLoopRemainder = nullptr) { - if (!RTLoopRemainder) - RTLoopRemainder = getRuntimeLoopRemainder(DL, B, Len, OpSize, OpSizeVal); - return B.CreateSub(Len, RTLoopRemainder); -} - void llvm::createMemCpyLoopUnknownSize( Instruction *InsertBefore, Value *SrcAddr, Value *DstAddr, Value *CopyLen, Align SrcAlign, Align DstAlign, bool SrcIsVolatile, bool DstIsVolatile, bool CanOverlap, const TargetTransformInfo &TTI, std::optional<uint32_t> AtomicElementSize) { BasicBlock *PreLoopBB = InsertBefore->getParent(); - BasicBlock *PostLoopBB = - PreLoopBB->splitBasicBlock(InsertBefore, "post-loop-memcpy-expansion"); - Function *ParentFunc = PreLoopBB->getParent(); const DataLayout &DL = ParentFunc->getDataLayout(); LLVMContext &Ctx = PreLoopBB->getContext(); @@ -205,50 +378,39 @@ void llvm::createMemCpyLoopUnknownSize( assert((!AtomicElementSize || LoopOpSize % *AtomicElementSize == 0) && "Atomic memcpy lowering is not supported for selected operand size"); - IRBuilder<> PLBuilder(PreLoopBB->getTerminator()); - - // Calculate the loop trip count, and remaining bytes to copy after the loop. - Type *CopyLenType = CopyLen->getType(); - IntegerType *ILengthType = dyn_cast<IntegerType>(CopyLenType); - assert(ILengthType && - "expected size argument to memcpy to be an integer type!"); Type *Int8Type = Type::getInt8Ty(Ctx); - bool LoopOpIsInt8 = LoopOpType == Int8Type; - ConstantInt *CILoopOpSize = ConstantInt::get(ILengthType, LoopOpSize); - Value *RuntimeLoopBytes = CopyLen; - Value *RuntimeResidualBytes = nullptr; - if (!LoopOpIsInt8) { - RuntimeResidualBytes = getRuntimeLoopRemainder(DL, PLBuilder, CopyLen, - CILoopOpSize, LoopOpSize); - RuntimeLoopBytes = getRuntimeLoopBytes(DL, PLBuilder, CopyLen, CILoopOpSize, - LoopOpSize, RuntimeResidualBytes); - } + Type *ResidualLoopOpType = AtomicElementSize + ? Type::getIntNTy(Ctx, *AtomicElementSize * 8) + : Int8Type; + unsigned ResidualLoopOpSize = DL.getTypeStoreSize(ResidualLoopOpType); + assert(ResidualLoopOpSize == (AtomicElementSize ? *AtomicElementSize : 1) && + "Store size is expected to match type size"); - BasicBlock *LoopBB = - BasicBlock::Create(Ctx, "loop-memcpy-expansion", ParentFunc, PostLoopBB); - IRBuilder<> LoopBuilder(LoopBB); + LoopExpansionInfo LEI = insertLoopExpansion( + InsertBefore, CopyLen, LoopOpSize, ResidualLoopOpSize, "dynamic-memcpy"); + // Fill MainLoopBB + IRBuilder<> MainLoopBuilder(LEI.MainLoopIP); Align PartSrcAlign(commonAlignment(SrcAlign, LoopOpSize)); Align PartDstAlign(commonAlignment(DstAlign, LoopOpSize)); - PHINode *LoopIndex = LoopBuilder.CreatePHI(CopyLenType, 2, "loop-index"); - LoopIndex->addIncoming(ConstantInt::get(CopyLenType, 0U), PreLoopBB); - // If we used LoopOpType as GEP element type, we would iterate over the // buffers in TypeStoreSize strides while copying TypeAllocSize bytes, i.e., // we would miss bytes if TypeStoreSize != TypeAllocSize. Therefore, use byte // offsets computed from the TypeStoreSize. - Value *SrcGEP = LoopBuilder.CreateInBoundsGEP(Int8Type, SrcAddr, LoopIndex); - LoadInst *Load = LoopBuilder.CreateAlignedLoad(LoopOpType, SrcGEP, - PartSrcAlign, SrcIsVolatile); + Value *SrcGEP = + MainLoopBuilder.CreateInBoundsGEP(Int8Type, SrcAddr, LEI.MainLoopIndex); + LoadInst *Load = MainLoopBuilder.CreateAlignedLoad( + LoopOpType, SrcGEP, PartSrcAlign, SrcIsVolatile); if (!CanOverlap) { // Set alias scope for loads. Load->setMetadata(LLVMContext::MD_alias_scope, MDNode::get(Ctx, NewScope)); } - Value *DstGEP = LoopBuilder.CreateInBoundsGEP(Int8Type, DstAddr, LoopIndex); - StoreInst *Store = - LoopBuilder.CreateAlignedStore(Load, DstGEP, PartDstAlign, DstIsVolatile); + Value *DstGEP = + MainLoopBuilder.CreateInBoundsGEP(Int8Type, DstAddr, LEI.MainLoopIndex); + StoreInst *Store = MainLoopBuilder.CreateAlignedStore( + Load, DstGEP, PartDstAlign, DstIsVolatile); if (!CanOverlap) { // Indicate that stores don't overlap loads. Store->setMetadata(LLVMContext::MD_noalias, MDNode::get(Ctx, NewScope)); @@ -257,95 +419,35 @@ void llvm::createMemCpyLoopUnknownSize( Load->setAtomic(AtomicOrdering::Unordered); Store->setAtomic(AtomicOrdering::Unordered); } - Value *NewIndex = LoopBuilder.CreateAdd( - LoopIndex, ConstantInt::get(CopyLenType, LoopOpSize)); - LoopIndex->addIncoming(NewIndex, LoopBB); - - bool RequiresResidual = - !LoopOpIsInt8 && !(AtomicElementSize && LoopOpSize == AtomicElementSize); - if (RequiresResidual) { - Type *ResLoopOpType = AtomicElementSize - ? Type::getIntNTy(Ctx, *AtomicElementSize * 8) - : Int8Type; - unsigned ResLoopOpSize = DL.getTypeStoreSize(ResLoopOpType); - assert((ResLoopOpSize == AtomicElementSize ? *AtomicElementSize : 1) && - "Store size is expected to match type size"); - - Align ResSrcAlign(commonAlignment(PartSrcAlign, ResLoopOpSize)); - Align ResDstAlign(commonAlignment(PartDstAlign, ResLoopOpSize)); - - // Loop body for the residual copy. - BasicBlock *ResLoopBB = BasicBlock::Create( - Ctx, "loop-memcpy-residual", PreLoopBB->getParent(), PostLoopBB); - // Residual loop header. - BasicBlock *ResHeaderBB = BasicBlock::Create( - Ctx, "loop-memcpy-residual-header", PreLoopBB->getParent(), nullptr); - - // Need to update the pre-loop basic block to branch to the correct place. - // branch to the main loop if the count is non-zero, branch to the residual - // loop if the copy size is smaller then 1 iteration of the main loop but - // non-zero and finally branch to after the residual loop if the memcpy - // size is zero. - ConstantInt *Zero = ConstantInt::get(ILengthType, 0U); - PLBuilder.CreateCondBr(PLBuilder.CreateICmpNE(RuntimeLoopBytes, Zero), - LoopBB, ResHeaderBB); - PreLoopBB->getTerminator()->eraseFromParent(); - LoopBuilder.CreateCondBr( - LoopBuilder.CreateICmpULT(NewIndex, RuntimeLoopBytes), LoopBB, - ResHeaderBB); - - // Determine if we need to branch to the residual loop or bypass it. - IRBuilder<> RHBuilder(ResHeaderBB); - RHBuilder.CreateCondBr(RHBuilder.CreateICmpNE(RuntimeResidualBytes, Zero), - ResLoopBB, PostLoopBB); + // Fill ResidualLoopBB. + if (!LEI.ResidualLoopIP) + return; - // Copy the residual with single byte load/store loop. - IRBuilder<> ResBuilder(ResLoopBB); - PHINode *ResidualIndex = - ResBuilder.CreatePHI(CopyLenType, 2, "residual-loop-index"); - ResidualIndex->addIncoming(Zero, ResHeaderBB); + Align ResSrcAlign(commonAlignment(PartSrcAlign, ResidualLoopOpSize)); + Align ResDstAlign(commonAlignment(PartDstAlign, ResidualLoopOpSize)); - Value *FullOffset = ResBuilder.CreateAdd(RuntimeLoopBytes, ResidualIndex); - Value *SrcGEP = ResBuilder.CreateInBoundsGEP(Int8Type, SrcAddr, FullOffset); - LoadInst *Load = ResBuilder.CreateAlignedLoad(ResLoopOpType, SrcGEP, - ResSrcAlign, SrcIsVolatile); - if (!CanOverlap) { - // Set alias scope for loads. - Load->setMetadata(LLVMContext::MD_alias_scope, - MDNode::get(Ctx, NewScope)); - } - Value *DstGEP = ResBuilder.CreateInBoundsGEP(Int8Type, DstAddr, FullOffset); - StoreInst *Store = - ResBuilder.CreateAlignedStore(Load, DstGEP, ResDstAlign, DstIsVolatile); - if (!CanOverlap) { - // Indicate that stores don't overlap loads. - Store->setMetadata(LLVMContext::MD_noalias, MDNode::get(Ctx, NewScope)); - } - if (AtomicElementSize) { - Load->setAtomic(AtomicOrdering::Unordered); - Store->setAtomic(AtomicOrdering::Unordered); - } - Value *ResNewIndex = ResBuilder.CreateAdd( - ResidualIndex, ConstantInt::get(CopyLenType, ResLoopOpSize)); - ResidualIndex->addIncoming(ResNewIndex, ResLoopBB); - - // Create the loop branch condition. - ResBuilder.CreateCondBr( - ResBuilder.CreateICmpULT(ResNewIndex, RuntimeResidualBytes), ResLoopBB, - PostLoopBB); - } else { - // In this case the loop operand type was a byte, and there is no need for a - // residual loop to copy the remaining memory after the main loop. - // We do however need to patch up the control flow by creating the - // terminators for the preloop block and the memcpy loop. - ConstantInt *Zero = ConstantInt::get(ILengthType, 0U); - PLBuilder.CreateCondBr(PLBuilder.CreateICmpNE(RuntimeLoopBytes, Zero), - LoopBB, PostLoopBB); - PreLoopBB->getTerminator()->eraseFromParent(); - LoopBuilder.CreateCondBr( - LoopBuilder.CreateICmpULT(NewIndex, RuntimeLoopBytes), LoopBB, - PostLoopBB); + IRBuilder<> ResLoopBuilder(LEI.ResidualLoopIP); + Value *ResSrcGEP = ResLoopBuilder.CreateInBoundsGEP(Int8Type, SrcAddr, + LEI.ResidualLoopIndex); + LoadInst *ResLoad = ResLoopBuilder.CreateAlignedLoad( + ResidualLoopOpType, ResSrcGEP, ResSrcAlign, SrcIsVolatile); + if (!CanOverlap) { + // Set alias scope for loads. + ResLoad->setMetadata(LLVMContext::MD_alias_scope, + MDNode::get(Ctx, NewScope)); + } + Value *ResDstGEP = ResLoopBuilder.CreateInBoundsGEP(Int8Type, DstAddr, + LEI.ResidualLoopIndex); + StoreInst *ResStore = ResLoopBuilder.CreateAlignedStore( + ResLoad, ResDstGEP, ResDstAlign, DstIsVolatile); + if (!CanOverlap) { + // Indicate that stores don't overlap loads. + ResStore->setMetadata(LLVMContext::MD_noalias, MDNode::get(Ctx, NewScope)); + } + if (AtomicElementSize) { + ResLoad->setAtomic(AtomicOrdering::Unordered); + ResStore->setAtomic(AtomicOrdering::Unordered); } } @@ -439,9 +541,9 @@ static void createMemMoveLoopUnknownSize(Instruction *InsertBefore, Value *RuntimeLoopRemainder = nullptr; Value *SkipResidualCondition = nullptr; if (RequiresResidual) { - RuntimeLoopRemainder = getRuntimeLoopRemainder(DL, PLBuilder, CopyLen, - CILoopOpSize, LoopOpSize); - RuntimeLoopBytes = getRuntimeLoopBytes(DL, PLBuilder, CopyLen, CILoopOpSize, + RuntimeLoopRemainder = + getRuntimeLoopRemainder(PLBuilder, CopyLen, CILoopOpSize, LoopOpSize); + RuntimeLoopBytes = getRuntimeLoopUnits(PLBuilder, CopyLen, CILoopOpSize, LoopOpSize, RuntimeLoopRemainder); SkipResidualCondition = PLBuilder.CreateICmpEQ(RuntimeLoopRemainder, Zero, "skip_residual"); diff --git a/llvm/lib/Transforms/Utils/ModuleUtils.cpp b/llvm/lib/Transforms/Utils/ModuleUtils.cpp index 596849e..63a2349 100644 --- a/llvm/lib/Transforms/Utils/ModuleUtils.cpp +++ b/llvm/lib/Transforms/Utils/ModuleUtils.cpp @@ -11,17 +11,17 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Utils/ModuleUtils.h" -#include "llvm/Analysis/VectorUtils.h" #include "llvm/ADT/SmallString.h" +#include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Module.h" #include "llvm/Support/Casting.h" +#include "llvm/Support/Hash.h" #include "llvm/Support/MD5.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/Support/xxhash.h" using namespace llvm; @@ -208,10 +208,16 @@ void llvm::setKCFIType(Module &M, Function &F, StringRef MangledType) { std::string Type = MangledType.str(); if (M.getModuleFlag("cfi-normalize-integers")) Type += ".normalized"; + + // Determine which hash algorithm to use + auto *MD = dyn_cast_or_null<MDString>(M.getModuleFlag("kcfi-hash")); + KCFIHashAlgorithm Algorithm = + parseKCFIHashAlgorithm(MD ? MD->getString() : ""); + F.setMetadata(LLVMContext::MD_kcfi_type, MDNode::get(Ctx, MDB.createConstant(ConstantInt::get( Type::getInt32Ty(Ctx), - static_cast<uint32_t>(xxHash64(Type)))))); + getKCFITypeID(Type, Algorithm))))); // If the module was compiled with -fpatchable-function-entry, ensure // we use the same patchable-function-prefix. if (auto *MD = mdconst::extract_or_null<ConstantInt>( diff --git a/llvm/lib/Transforms/Utils/ProfileVerify.cpp b/llvm/lib/Transforms/Utils/ProfileVerify.cpp index c578b4b..69e03f0 100644 --- a/llvm/lib/Transforms/Utils/ProfileVerify.cpp +++ b/llvm/lib/Transforms/Utils/ProfileVerify.cpp @@ -11,13 +11,19 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/Analysis/BranchProbabilityInfo.h" #include "llvm/IR/Analysis.h" +#include "llvm/IR/Constants.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" +#include "llvm/IR/GlobalValue.h" +#include "llvm/IR/GlobalVariable.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/MDBuilder.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PassManager.h" #include "llvm/IR/ProfDataUtils.h" #include "llvm/Support/BranchProbability.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" using namespace llvm; @@ -65,11 +71,26 @@ public: ProfileInjector(Function &F, FunctionAnalysisManager &FAM) : F(F), FAM(FAM) {} bool inject(); }; + +bool isAsmOnly(const Function &F) { + if (!F.hasFnAttribute(Attribute::AttrKind::Naked)) + return false; + for (const auto &BB : F) + for (const auto &I : drop_end(BB.instructionsWithoutDebug())) { + const auto *CB = dyn_cast<CallBase>(&I); + if (!CB || !CB->isInlineAsm()) + return false; + } + return true; +} } // namespace // FIXME: currently this injects only for terminators. Select isn't yet // supported. bool ProfileInjector::inject() { + // skip purely asm functions + if (isAsmOnly(F)) + return false; // Get whatever branch probability info can be derived from the given IR - // whether it has or not metadata. The main intention for this pass is to // ensure that other passes don't drop or "forget" to update MD_prof. We do @@ -102,9 +123,14 @@ bool ProfileInjector::inject() { for (auto &BB : F) { if (AnnotateSelect) { for (auto &I : BB) { - if (isa<SelectInst>(I) && !I.getMetadata(LLVMContext::MD_prof)) + if (auto *SI = dyn_cast<SelectInst>(&I)) { + if (SI->getCondition()->getType()->isVectorTy()) + continue; + if (I.getMetadata(LLVMContext::MD_prof)) + continue; setBranchWeights(I, {SelectTrueWeight, SelectFalseWeight}, /*IsExpected=*/false); + } } } auto *Term = getTerminatorBenefitingFromMDProf(BB); @@ -169,8 +195,44 @@ PreservedAnalyses ProfileInjectorPass::run(Function &F, return PreservedAnalyses::none(); } +PreservedAnalyses ProfileVerifierPass::run(Module &M, + ModuleAnalysisManager &MAM) { + auto PopulateIgnoreList = [&](StringRef GVName) { + if (const auto *CT = M.getGlobalVariable(GVName)) + if (const auto *CA = + dyn_cast_if_present<ConstantArray>(CT->getInitializer())) + for (const auto &Elt : CA->operands()) + if (const auto *CS = dyn_cast<ConstantStruct>(Elt)) + if (CS->getNumOperands() >= 2 && CS->getOperand(1)) + if (const auto *F = dyn_cast<Function>( + CS->getOperand(1)->stripPointerCasts())) + IgnoreList.insert(F); + }; + PopulateIgnoreList("llvm.global_ctors"); + PopulateIgnoreList("llvm.global_dtors"); + + // expose the function-level run as public through a wrapper, so we can use + // pass manager mechanisms dealing with declarations and with composing the + // returned PreservedAnalyses values. + struct Wrapper : PassInfoMixin<Wrapper> { + ProfileVerifierPass &PVP; + PreservedAnalyses run(Function &F, FunctionAnalysisManager &FAM) { + return PVP.run(F, FAM); + } + explicit Wrapper(ProfileVerifierPass &PVP) : PVP(PVP) {} + }; + + return createModuleToFunctionPassAdaptor(Wrapper(*this)).run(M, MAM); +} + PreservedAnalyses ProfileVerifierPass::run(Function &F, FunctionAnalysisManager &FAM) { + // skip purely asm functions + if (isAsmOnly(F)) + return PreservedAnalyses::all(); + if (IgnoreList.contains(&F)) + return PreservedAnalyses::all(); + const auto EntryCount = F.getEntryCount(/*AllowSynthetic=*/true); if (!EntryCount) { auto *MD = F.getMetadata(LLVMContext::MD_prof); @@ -185,9 +247,14 @@ PreservedAnalyses ProfileVerifierPass::run(Function &F, for (const auto &BB : F) { if (AnnotateSelect) { for (const auto &I : BB) - if (isa<SelectInst>(I) && !I.getMetadata(LLVMContext::MD_prof)) + if (auto *SI = dyn_cast<SelectInst>(&I)) { + if (SI->getCondition()->getType()->isVectorTy()) + continue; + if (I.getMetadata(LLVMContext::MD_prof)) + continue; F.getContext().emitError( "Profile verification failed: select annotation missing"); + } } if (const auto *Term = ProfileInjector::getTerminatorBenefitingFromMDProf(BB)) diff --git a/llvm/lib/Transforms/Utils/SCCPSolver.cpp b/llvm/lib/Transforms/Utils/SCCPSolver.cpp index 4947d03..021bf06 100644 --- a/llvm/lib/Transforms/Utils/SCCPSolver.cpp +++ b/llvm/lib/Transforms/Utils/SCCPSolver.cpp @@ -2098,6 +2098,38 @@ void SCCPInstVisitor::handleCallResult(CallBase &CB) { return (void)mergeInValue(ValueState[II], II, ValueLatticeElement::getRange(Result)); } + if (II->getIntrinsicID() == Intrinsic::experimental_get_vector_length) { + Value *CountArg = II->getArgOperand(0); + Value *VF = II->getArgOperand(1); + bool Scalable = cast<ConstantInt>(II->getArgOperand(2))->isOne(); + + // Computation happens in the larger type. + unsigned BitWidth = std::max(CountArg->getType()->getScalarSizeInBits(), + VF->getType()->getScalarSizeInBits()); + + ConstantRange Count = getValueState(CountArg) + .asConstantRange(CountArg->getType(), false) + .zeroExtend(BitWidth); + ConstantRange MaxLanes = getValueState(VF) + .asConstantRange(VF->getType(), false) + .zeroExtend(BitWidth); + if (Scalable) + MaxLanes = + MaxLanes.multiply(getVScaleRange(II->getFunction(), BitWidth)); + + // The result is always less than both Count and MaxLanes. + ConstantRange Result( + APInt::getZero(BitWidth), + APIntOps::umin(Count.getUpper(), MaxLanes.getUpper())); + + // If Count <= MaxLanes, getvectorlength(Count, MaxLanes) = Count + if (Count.icmp(CmpInst::ICMP_ULE, MaxLanes)) + Result = Count; + + Result = Result.truncate(II->getType()->getScalarSizeInBits()); + return (void)mergeInValue(ValueState[II], II, + ValueLatticeElement::getRange(Result)); + } if (ConstantRange::isIntrinsicSupported(II->getIntrinsicID())) { // Compute result range for intrinsics supported by ConstantRange. diff --git a/llvm/lib/Transforms/Utils/SSAUpdaterBulk.cpp b/llvm/lib/Transforms/Utils/SSAUpdaterBulk.cpp index fb39fdd..c5bd056 100644 --- a/llvm/lib/Transforms/Utils/SSAUpdaterBulk.cpp +++ b/llvm/lib/Transforms/Utils/SSAUpdaterBulk.cpp @@ -267,6 +267,8 @@ static bool replaceIfIdentical(PHINode &PHI, PHINode &ReplPHI) { return true; } +namespace llvm { + bool EliminateNewDuplicatePHINodes(BasicBlock *BB, BasicBlock::phi_iterator FirstExistingPN) { assert(!PHIAreRefEachOther(make_range(BB->phis().begin(), FirstExistingPN))); @@ -293,6 +295,8 @@ bool EliminateNewDuplicatePHINodes(BasicBlock *BB, return Changed; } +} // end namespace llvm + static void deduplicatePass(ArrayRef<PHINode *> Worklist) { SmallDenseMap<BasicBlock *, unsigned> BBs; for (PHINode *PHI : Worklist) { diff --git a/llvm/lib/Transforms/Utils/SampleProfileInference.cpp b/llvm/lib/Transforms/Utils/SampleProfileInference.cpp index 53bcaa6..934d158 100644 --- a/llvm/lib/Transforms/Utils/SampleProfileInference.cpp +++ b/llvm/lib/Transforms/Utils/SampleProfileInference.cpp @@ -1174,8 +1174,6 @@ std::pair<int64_t, int64_t> assignJumpCosts(const ProfiParams &Params, else CostInc = Params.CostJumpUnknownInc; CostDec = 0; - } else { - assert(Jump.Weight > 0 && "found zero-weight jump with a positive weight"); } return std::make_pair(CostInc, CostDec); } diff --git a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp index 9035e58..54e26b2 100644 --- a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp +++ b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp @@ -2208,17 +2208,6 @@ Value *SCEVExpander::generateOverflowCheck(const SCEVAddRecExpr *AR, // negative. If Step is known to be positive or negative, only create // either 1. or 2. auto ComputeEndCheck = [&]() -> Value * { - // Checking <u 0 is always false, if (Step * trunc ExitCount) does not wrap. - // TODO: Predicates that can be proven true/false should be discarded when - // the predicates are created, not late during expansion. - if (!Signed && Start->isZero() && SE.isKnownPositive(Step) && - DstBits < SrcBits && - ExitCount == SE.getZeroExtendExpr(SE.getTruncateExpr(ExitCount, ARTy), - ExitCount->getType()) && - SE.willNotOverflow(Instruction::Mul, Signed, Step, - SE.getTruncateExpr(ExitCount, ARTy))) - return ConstantInt::getFalse(Loc->getContext()); - // Get the backedge taken count and truncate or extended to the AR type. Value *TruncTripCount = Builder.CreateZExtOrTrunc(TripCountVal, Ty); diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp index cbc604e..66b8c69 100644 --- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp @@ -302,11 +302,14 @@ class SimplifyCFGOpt { bool tryToSimplifyUncondBranchWithICmpInIt(ICmpInst *ICI, IRBuilder<> &Builder); - + bool tryToSimplifyUncondBranchWithICmpSelectInIt(ICmpInst *ICI, + SelectInst *Select, + IRBuilder<> &Builder); bool hoistCommonCodeFromSuccessors(Instruction *TI, bool AllInstsEqOnly); bool hoistSuccIdenticalTerminatorToSwitchOrIf( Instruction *TI, Instruction *I1, - SmallVectorImpl<Instruction *> &OtherSuccTIs); + SmallVectorImpl<Instruction *> &OtherSuccTIs, + ArrayRef<BasicBlock *> UniqueSuccessors); bool speculativelyExecuteBB(BranchInst *BI, BasicBlock *ThenBB); bool simplifyTerminatorOnSelect(Instruction *OldTerm, Value *Cond, BasicBlock *TrueBB, BasicBlock *FalseBB, @@ -778,8 +781,10 @@ private: return false; // Add all values from the range to the set - for (APInt Tmp = Span.getLower(); Tmp != Span.getUpper(); ++Tmp) + APInt Tmp = Span.getLower(); + do Vals.push_back(ConstantInt::get(I->getContext(), Tmp)); + while (++Tmp != Span.getUpper()); UsedICmps++; return true; @@ -1867,10 +1872,13 @@ bool SimplifyCFGOpt::hoistCommonCodeFromSuccessors(Instruction *TI, // If either of the blocks has it's address taken, then we can't do this fold, // because the code we'd hoist would no longer run when we jump into the block // by it's address. - for (auto *Succ : successors(BB)) { + SmallSetVector<BasicBlock *, 4> UniqueSuccessors(from_range, successors(BB)); + for (auto *Succ : UniqueSuccessors) { if (Succ->hasAddressTaken()) return false; - if (Succ->getSinglePredecessor()) + // Use getUniquePredecessor instead of getSinglePredecessor to support + // multi-cases successors in switch. + if (Succ->getUniquePredecessor()) continue; // If Succ has >1 predecessors, continue to check if the Succ contains only // one `unreachable` inst. Since executing `unreachable` inst is an UB, we @@ -1883,7 +1891,7 @@ bool SimplifyCFGOpt::hoistCommonCodeFromSuccessors(Instruction *TI, // The second of pair is a SkipFlags bitmask. using SuccIterPair = std::pair<BasicBlock::iterator, unsigned>; SmallVector<SuccIterPair, 8> SuccIterPairs; - for (auto *Succ : successors(BB)) { + for (auto *Succ : UniqueSuccessors) { BasicBlock::iterator SuccItr = Succ->begin(); if (isa<PHINode>(*SuccItr)) return false; @@ -1894,19 +1902,20 @@ bool SimplifyCFGOpt::hoistCommonCodeFromSuccessors(Instruction *TI, // Check if all instructions in the successor blocks match. This allows // hoisting all instructions and removing the blocks we are hoisting from, // so does not add any new instructions. - SmallVector<BasicBlock *> Succs = to_vector(successors(BB)); + // Check if sizes and terminators of all successors match. - bool AllSame = none_of(Succs, [&Succs](BasicBlock *Succ) { - Instruction *Term0 = Succs[0]->getTerminator(); - Instruction *Term = Succ->getTerminator(); - return !Term->isSameOperationAs(Term0) || - !equal(Term->operands(), Term0->operands()) || - Succs[0]->size() != Succ->size(); - }); + bool AllSame = + none_of(UniqueSuccessors, [&UniqueSuccessors](BasicBlock *Succ) { + Instruction *Term0 = UniqueSuccessors[0]->getTerminator(); + Instruction *Term = Succ->getTerminator(); + return !Term->isSameOperationAs(Term0) || + !equal(Term->operands(), Term0->operands()) || + UniqueSuccessors[0]->size() != Succ->size(); + }); if (!AllSame) return false; if (AllSame) { - LockstepReverseIterator<true> LRI(Succs); + LockstepReverseIterator<true> LRI(UniqueSuccessors.getArrayRef()); while (LRI.isValid()) { Instruction *I0 = (*LRI)[0]; if (any_of(*LRI, [I0](Instruction *I) { @@ -1970,7 +1979,8 @@ bool SimplifyCFGOpt::hoistCommonCodeFromSuccessors(Instruction *TI, return Changed; } - return hoistSuccIdenticalTerminatorToSwitchOrIf(TI, I1, OtherInsts) || + return hoistSuccIdenticalTerminatorToSwitchOrIf( + TI, I1, OtherInsts, UniqueSuccessors.getArrayRef()) || Changed; } @@ -2043,7 +2053,8 @@ bool SimplifyCFGOpt::hoistCommonCodeFromSuccessors(Instruction *TI, bool SimplifyCFGOpt::hoistSuccIdenticalTerminatorToSwitchOrIf( Instruction *TI, Instruction *I1, - SmallVectorImpl<Instruction *> &OtherSuccTIs) { + SmallVectorImpl<Instruction *> &OtherSuccTIs, + ArrayRef<BasicBlock *> UniqueSuccessors) { auto *BI = dyn_cast<BranchInst>(TI); @@ -2157,9 +2168,12 @@ bool SimplifyCFGOpt::hoistSuccIdenticalTerminatorToSwitchOrIf( Updates.push_back({DominatorTree::Insert, TIParent, Succ}); } - if (DTU) - for (BasicBlock *Succ : successors(TI)) + if (DTU) { + // TI might be a switch with multi-cases destination, so we need to care for + // the duplication of successors. + for (BasicBlock *Succ : UniqueSuccessors) Updates.push_back({DominatorTree::Delete, TIParent, Succ}); + } eraseTerminatorAndDCECond(TI); if (DTU) @@ -5021,16 +5035,65 @@ bool SimplifyCFGOpt::simplifyIndirectBrOnSelect(IndirectBrInst *IBI, /// the PHI, merging the third icmp into the switch. bool SimplifyCFGOpt::tryToSimplifyUncondBranchWithICmpInIt( ICmpInst *ICI, IRBuilder<> &Builder) { + // Select == nullptr means we assume that there is a hidden no-op select + // instruction of `_ = select %icmp, true, false` after `%icmp = icmp ...` + return tryToSimplifyUncondBranchWithICmpSelectInIt(ICI, nullptr, Builder); +} + +/// Similar to tryToSimplifyUncondBranchWithICmpInIt, but handle a more generic +/// case. This is called when we find an icmp instruction (a seteq/setne with a +/// constant) and its following select instruction as the only TWO instructions +/// in a block that ends with an uncond branch. We are looking for a very +/// specific pattern that occurs when " +/// if (A == 1) return C1; +/// if (A == 2) return C2; +/// if (A < 3) return C3; +/// return C4; +/// " gets simplified. In this case, we merge the first two "branches of icmp" +/// into a switch, but then the default value goes to an uncond block with a lt +/// icmp and select in it, as InstCombine can not simplify "A < 3" as "A == 2". +/// After SimplifyCFG and other subsequent optimizations (e.g., SCCP), we might +/// get something like: +/// +/// case1: +/// switch i8 %A, label %DEFAULT [ i8 0, label %end i8 1, label %case2 ] +/// case2: +/// br label %end +/// DEFAULT: +/// %tmp = icmp eq i8 %A, 2 +/// %val = select i1 %tmp, i8 C3, i8 C4 +/// br label %end +/// end: +/// _ = phi i8 [ C1, %case1 ], [ C2, %case2 ], [ %val, %DEFAULT ] +/// +/// We prefer to split the edge to 'end' so that there are TWO entries of V3/V4 +/// to the PHI, merging the icmp & select into the switch, as follows: +/// +/// case1: +/// switch i8 %A, label %DEFAULT [ +/// i8 0, label %end +/// i8 1, label %case2 +/// i8 2, label %case3 +/// ] +/// case2: +/// br label %end +/// case3: +/// br label %end +/// DEFAULT: +/// br label %end +/// end: +/// _ = phi i8 [ C1, %case1 ], [ C2, %case2 ], [ C3, %case2 ], [ C4, %DEFAULT] +bool SimplifyCFGOpt::tryToSimplifyUncondBranchWithICmpSelectInIt( + ICmpInst *ICI, SelectInst *Select, IRBuilder<> &Builder) { BasicBlock *BB = ICI->getParent(); - // If the block has any PHIs in it or the icmp has multiple uses, it is too - // complex. - if (isa<PHINode>(BB->begin()) || !ICI->hasOneUse()) + // If the block has any PHIs in it or the icmp/select has multiple uses, it is + // too complex. + /// TODO: support multi-phis in succ BB of select's BB. + if (isa<PHINode>(BB->begin()) || !ICI->hasOneUse() || + (Select && !Select->hasOneUse())) return false; - Value *V = ICI->getOperand(0); - ConstantInt *Cst = cast<ConstantInt>(ICI->getOperand(1)); - // The pattern we're looking for is where our only predecessor is a switch on // 'V' and this block is the default case for the switch. In this case we can // fold the compared value into the switch to simplify things. @@ -5038,8 +5101,36 @@ bool SimplifyCFGOpt::tryToSimplifyUncondBranchWithICmpInIt( if (!Pred || !isa<SwitchInst>(Pred->getTerminator())) return false; + Value *IcmpCond; + ConstantInt *NewCaseVal; + CmpPredicate Predicate; + + // Match icmp X, C + if (!match(ICI, + m_ICmp(Predicate, m_Value(IcmpCond), m_ConstantInt(NewCaseVal)))) + return false; + + Value *SelectCond, *SelectTrueVal, *SelectFalseVal; + Instruction *User; + if (!Select) { + // If Select == nullptr, we can assume that there is a hidden no-op select + // just after icmp + SelectCond = ICI; + SelectTrueVal = Builder.getTrue(); + SelectFalseVal = Builder.getFalse(); + User = ICI->user_back(); + } else { + SelectCond = Select->getCondition(); + // Check if the select condition is the same as the icmp condition. + if (SelectCond != ICI) + return false; + SelectTrueVal = Select->getTrueValue(); + SelectFalseVal = Select->getFalseValue(); + User = Select->user_back(); + } + SwitchInst *SI = cast<SwitchInst>(Pred->getTerminator()); - if (SI->getCondition() != V) + if (SI->getCondition() != IcmpCond) return false; // If BB is reachable on a non-default case, then we simply know the value of @@ -5061,9 +5152,9 @@ bool SimplifyCFGOpt::tryToSimplifyUncondBranchWithICmpInIt( // Ok, the block is reachable from the default dest. If the constant we're // comparing exists in one of the other edges, then we can constant fold ICI // and zap it. - if (SI->findCaseValue(Cst) != SI->case_default()) { + if (SI->findCaseValue(NewCaseVal) != SI->case_default()) { Value *V; - if (ICI->getPredicate() == ICmpInst::ICMP_EQ) + if (Predicate == ICmpInst::ICMP_EQ) V = ConstantInt::getFalse(BB->getContext()); else V = ConstantInt::getTrue(BB->getContext()); @@ -5074,25 +5165,30 @@ bool SimplifyCFGOpt::tryToSimplifyUncondBranchWithICmpInIt( return requestResimplify(); } - // The use of the icmp has to be in the 'end' block, by the only PHI node in + // The use of the select has to be in the 'end' block, by the only PHI node in // the block. BasicBlock *SuccBlock = BB->getTerminator()->getSuccessor(0); - PHINode *PHIUse = dyn_cast<PHINode>(ICI->user_back()); + PHINode *PHIUse = dyn_cast<PHINode>(User); if (PHIUse == nullptr || PHIUse != &SuccBlock->front() || isa<PHINode>(++BasicBlock::iterator(PHIUse))) return false; - // If the icmp is a SETEQ, then the default dest gets false, the new edge gets - // true in the PHI. - Constant *DefaultCst = ConstantInt::getTrue(BB->getContext()); - Constant *NewCst = ConstantInt::getFalse(BB->getContext()); + // If the icmp is a SETEQ, then the default dest gets SelectFalseVal, the new + // edge gets SelectTrueVal in the PHI. + Value *DefaultCst = SelectFalseVal; + Value *NewCst = SelectTrueVal; - if (ICI->getPredicate() == ICmpInst::ICMP_EQ) + if (ICI->getPredicate() == ICmpInst::ICMP_NE) std::swap(DefaultCst, NewCst); - // Replace ICI (which is used by the PHI for the default value) with true or - // false depending on if it is EQ or NE. - ICI->replaceAllUsesWith(DefaultCst); + // Replace Select (which is used by the PHI for the default value) with + // SelectFalseVal or SelectTrueVal depending on if ICI is EQ or NE. + if (Select) { + Select->replaceAllUsesWith(DefaultCst); + Select->eraseFromParent(); + } else { + ICI->replaceAllUsesWith(DefaultCst); + } ICI->eraseFromParent(); SmallVector<DominatorTree::UpdateType, 2> Updates; @@ -5109,7 +5205,7 @@ bool SimplifyCFGOpt::tryToSimplifyUncondBranchWithICmpInIt( NewW = ((uint64_t(*W0) + 1) >> 1); SIW.setSuccessorWeight(0, *NewW); } - SIW.addCase(Cst, NewBB, NewW); + SIW.addCase(NewCaseVal, NewBB, NewW); if (DTU) Updates.push_back({DominatorTree::Insert, Pred, NewBB}); } @@ -5212,8 +5308,7 @@ bool SimplifyCFGOpt::simplifyBranchOnICmpChain(BranchInst *BI, // We don't have any info about this condition. auto *Br = TrueWhenEqual ? Builder.CreateCondBr(ExtraCase, EdgeBB, NewBB) : Builder.CreateCondBr(ExtraCase, NewBB, EdgeBB); - setExplicitlyUnknownBranchWeightsIfProfiled(*Br, *NewBB->getParent(), - DEBUG_TYPE); + setExplicitlyUnknownBranchWeightsIfProfiled(*Br, DEBUG_TYPE); OldTI->eraseFromParent(); @@ -6020,6 +6115,8 @@ static bool eliminateDeadSwitchCases(SwitchInst *SI, DomTreeUpdater *DTU, const DataLayout &DL) { Value *Cond = SI->getCondition(); KnownBits Known = computeKnownBits(Cond, DL, AC, SI); + SmallPtrSet<const Constant *, 4> KnownValues; + bool IsKnownValuesValid = collectPossibleValues(Cond, KnownValues, 4); // We can also eliminate cases by determining that their values are outside of // the limited range of the condition based on how many significant (non-sign) @@ -6039,15 +6136,18 @@ static bool eliminateDeadSwitchCases(SwitchInst *SI, DomTreeUpdater *DTU, UniqueSuccessors.push_back(Successor); ++It->second; } - const APInt &CaseVal = Case.getCaseValue()->getValue(); + ConstantInt *CaseC = Case.getCaseValue(); + const APInt &CaseVal = CaseC->getValue(); if (Known.Zero.intersects(CaseVal) || !Known.One.isSubsetOf(CaseVal) || - (CaseVal.getSignificantBits() > MaxSignificantBitsInCond)) { - DeadCases.push_back(Case.getCaseValue()); + (CaseVal.getSignificantBits() > MaxSignificantBitsInCond) || + (IsKnownValuesValid && !KnownValues.contains(CaseC))) { + DeadCases.push_back(CaseC); if (DTU) --NumPerSuccessorCases[Successor]; LLVM_DEBUG(dbgs() << "SimplifyCFG: switch case " << CaseVal << " is dead.\n"); - } + } else if (IsKnownValuesValid) + KnownValues.erase(CaseC); } // If we can prove that the cases must cover all possible values, the @@ -6058,33 +6158,41 @@ static bool eliminateDeadSwitchCases(SwitchInst *SI, DomTreeUpdater *DTU, const unsigned NumUnknownBits = Known.getBitWidth() - (Known.Zero | Known.One).popcount(); assert(NumUnknownBits <= Known.getBitWidth()); - if (HasDefault && DeadCases.empty() && - NumUnknownBits < 64 /* avoid overflow */) { - uint64_t AllNumCases = 1ULL << NumUnknownBits; - if (SI->getNumCases() == AllNumCases) { + if (HasDefault && DeadCases.empty()) { + if (IsKnownValuesValid && all_of(KnownValues, IsaPred<UndefValue>)) { createUnreachableSwitchDefault(SI, DTU); return true; } - // When only one case value is missing, replace default with that case. - // Eliminating the default branch will provide more opportunities for - // optimization, such as lookup tables. - if (SI->getNumCases() == AllNumCases - 1) { - assert(NumUnknownBits > 1 && "Should be canonicalized to a branch"); - IntegerType *CondTy = cast<IntegerType>(Cond->getType()); - if (CondTy->getIntegerBitWidth() > 64 || - !DL.fitsInLegalInteger(CondTy->getIntegerBitWidth())) - return false; - uint64_t MissingCaseVal = 0; - for (const auto &Case : SI->cases()) - MissingCaseVal ^= Case.getCaseValue()->getValue().getLimitedValue(); - auto *MissingCase = - cast<ConstantInt>(ConstantInt::get(Cond->getType(), MissingCaseVal)); - SwitchInstProfUpdateWrapper SIW(*SI); - SIW.addCase(MissingCase, SI->getDefaultDest(), SIW.getSuccessorWeight(0)); - createUnreachableSwitchDefault(SI, DTU, /*RemoveOrigDefaultBlock*/ false); - SIW.setSuccessorWeight(0, 0); - return true; + if (NumUnknownBits < 64 /* avoid overflow */) { + uint64_t AllNumCases = 1ULL << NumUnknownBits; + if (SI->getNumCases() == AllNumCases) { + createUnreachableSwitchDefault(SI, DTU); + return true; + } + // When only one case value is missing, replace default with that case. + // Eliminating the default branch will provide more opportunities for + // optimization, such as lookup tables. + if (SI->getNumCases() == AllNumCases - 1) { + assert(NumUnknownBits > 1 && "Should be canonicalized to a branch"); + IntegerType *CondTy = cast<IntegerType>(Cond->getType()); + if (CondTy->getIntegerBitWidth() > 64 || + !DL.fitsInLegalInteger(CondTy->getIntegerBitWidth())) + return false; + + uint64_t MissingCaseVal = 0; + for (const auto &Case : SI->cases()) + MissingCaseVal ^= Case.getCaseValue()->getValue().getLimitedValue(); + auto *MissingCase = cast<ConstantInt>( + ConstantInt::get(Cond->getType(), MissingCaseVal)); + SwitchInstProfUpdateWrapper SIW(*SI); + SIW.addCase(MissingCase, SI->getDefaultDest(), + SIW.getSuccessorWeight(0)); + createUnreachableSwitchDefault(SI, DTU, + /*RemoveOrigDefaultBlock*/ false); + SIW.setSuccessorWeight(0, 0); + return true; + } } } @@ -7570,6 +7678,81 @@ static bool reduceSwitchRange(SwitchInst *SI, IRBuilder<> &Builder, return true; } +/// Tries to transform the switch when the condition is umin with a constant. +/// In that case, the default branch can be replaced by the constant's branch. +/// This method also removes dead cases when the simplification cannot replace +/// the default branch. +/// +/// For example: +/// switch(umin(a, 3)) { +/// case 0: +/// case 1: +/// case 2: +/// case 3: +/// case 4: +/// // ... +/// default: +/// unreachable +/// } +/// +/// Transforms into: +/// +/// switch(a) { +/// case 0: +/// case 1: +/// case 2: +/// default: +/// // This is case 3 +/// } +static bool simplifySwitchWhenUMin(SwitchInst *SI, DomTreeUpdater *DTU) { + Value *A; + ConstantInt *Constant; + + if (!match(SI->getCondition(), m_UMin(m_Value(A), m_ConstantInt(Constant)))) + return false; + + SmallVector<DominatorTree::UpdateType> Updates; + SwitchInstProfUpdateWrapper SIW(*SI); + BasicBlock *BB = SIW->getParent(); + + // Dead cases are removed even when the simplification fails. + // A case is dead when its value is higher than the Constant. + for (auto I = SI->case_begin(), E = SI->case_end(); I != E;) { + if (!I->getCaseValue()->getValue().ugt(Constant->getValue())) { + ++I; + continue; + } + BasicBlock *DeadCaseBB = I->getCaseSuccessor(); + DeadCaseBB->removePredecessor(BB); + Updates.push_back({DominatorTree::Delete, BB, DeadCaseBB}); + I = SIW->removeCase(I); + E = SIW->case_end(); + } + + auto Case = SI->findCaseValue(Constant); + // If the case value is not found, `findCaseValue` returns the default case. + // In this scenario, since there is no explicit `case 3:`, the simplification + // fails. The simplification also fails when the switch’s default destination + // is reachable. + if (!SI->defaultDestUnreachable() || Case == SI->case_default()) { + if (DTU) + DTU->applyUpdates(Updates); + return !Updates.empty(); + } + + BasicBlock *Unreachable = SI->getDefaultDest(); + SIW.replaceDefaultDest(Case); + SIW.removeCase(Case); + SIW->setCondition(A); + + Updates.push_back({DominatorTree::Delete, BB, Unreachable}); + + if (DTU) + DTU->applyUpdates(Updates); + + return true; +} + /// Tries to transform switch of powers of two to reduce switch range. /// For example, switch like: /// switch (C) { case 1: case 2: case 64: case 128: } @@ -7642,19 +7825,24 @@ static bool simplifySwitchOfPowersOfTwo(SwitchInst *SI, IRBuilder<> &Builder, // label. The other is those powers of 2 that don't appear in the case // statement. We don't know the distribution of the values coming in, so // the safest is to split 50-50 the original probability to `default`. - uint64_t OrigDenominator = sum_of(map_range( - Weights, [](const auto &V) { return static_cast<uint64_t>(V); })); + uint64_t OrigDenominator = + sum_of(map_range(Weights, StaticCastTo<uint64_t>)); SmallVector<uint64_t> NewWeights(2); NewWeights[1] = Weights[0] / 2; NewWeights[0] = OrigDenominator - NewWeights[1]; setFittedBranchWeights(*BI, NewWeights, /*IsExpected=*/false); - - // For the original switch, we reduce the weight of the default by the - // amount by which the previous branch contributes to getting to default, - // and then make sure the remaining weights have the same relative ratio - // wrt eachother. + // The probability of executing the default block stays constant. It was + // p_d = Weights[0] / OrigDenominator + // we rewrite as W/D + // We want to find the probability of the default branch of the switch + // statement. Let's call it X. We have W/D = W/2D + X * (1-W/2D) + // i.e. the original probability is the probability we go to the default + // branch from the BI branch, or we take the default branch on the SI. + // Meaning X = W / (2D - W), or (W/2) / (D - W/2) + // This matches using W/2 for the default branch probability numerator and + // D-W/2 as the denominator. + Weights[0] = NewWeights[1]; uint64_t CasesDenominator = OrigDenominator - Weights[0]; - Weights[0] /= 2; for (auto &W : drop_begin(Weights)) W = NewWeights[0] * static_cast<double>(W) / CasesDenominator; @@ -8037,6 +8225,9 @@ bool SimplifyCFGOpt::simplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) { if (simplifyDuplicateSwitchArms(SI, DTU)) return requestResimplify(); + if (simplifySwitchWhenUMin(SI, DTU)) + return requestResimplify(); + return false; } @@ -8205,13 +8396,18 @@ bool SimplifyCFGOpt::simplifyUncondBranch(BranchInst *BI, // If the only instruction in the block is a seteq/setne comparison against a // constant, try to simplify the block. - if (ICmpInst *ICI = dyn_cast<ICmpInst>(I)) + if (ICmpInst *ICI = dyn_cast<ICmpInst>(I)) { if (ICI->isEquality() && isa<ConstantInt>(ICI->getOperand(1))) { ++I; if (I->isTerminator() && tryToSimplifyUncondBranchWithICmpInIt(ICI, Builder)) return true; + if (isa<SelectInst>(I) && I->getNextNode()->isTerminator() && + tryToSimplifyUncondBranchWithICmpSelectInIt(ICI, cast<SelectInst>(I), + Builder)) + return true; } + } // See if we can merge an empty landing pad block with another which is // equivalent. diff --git a/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp b/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp index 43264cc..61acf3a 100644 --- a/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp @@ -43,7 +43,9 @@ STATISTIC( STATISTIC( NumSimplifiedSRem, "Number of IV signed remainder operations converted to unsigned remainder"); -STATISTIC(NumElimCmp , "Number of IV comparisons eliminated"); +STATISTIC(NumElimCmp, "Number of IV comparisons eliminated"); +STATISTIC(NumInvariantCmp, "Number of IV comparisons made loop invariant"); +STATISTIC(NumSameSign, "Number of IV comparisons with new samesign flags"); namespace { /// This is a utility for simplifying induction variables @@ -275,25 +277,33 @@ void SimplifyIndvar::eliminateIVComparison(ICmpInst *ICmp, ICmp->replaceAllUsesWith(ConstantInt::getBool(ICmp->getContext(), *Ev)); DeadInsts.emplace_back(ICmp); LLVM_DEBUG(dbgs() << "INDVARS: Eliminated comparison: " << *ICmp << '\n'); - } else if (makeIVComparisonInvariant(ICmp, IVOperand)) { - // fallthrough to end of function - } else if (ICmpInst::isSigned(OriginalPred) && - SE->isKnownNonNegative(S) && SE->isKnownNonNegative(X)) { - // If we were unable to make anything above, all we can is to canonicalize - // the comparison hoping that it will open the doors for other - // optimizations. If we find out that we compare two non-negative values, - // we turn the instruction's predicate to its unsigned version. Note that - // we cannot rely on Pred here unless we check if we have swapped it. + ++NumElimCmp; + Changed = true; + return; + } + + if (makeIVComparisonInvariant(ICmp, IVOperand)) { + ++NumInvariantCmp; + Changed = true; + return; + } + + if ((ICmpInst::isSigned(OriginalPred) || + (ICmpInst::isUnsigned(OriginalPred) && !ICmp->hasSameSign())) && + SE->haveSameSign(S, X)) { + // Set the samesign flag on the compare if legal, and canonicalize to + // the unsigned variant (for signed compares) hoping that it will open + // the doors for other optimizations. Note that we cannot rely on Pred + // here unless we check if we have swapped it. assert(ICmp->getPredicate() == OriginalPred && "Predicate changed?"); - LLVM_DEBUG(dbgs() << "INDVARS: Turn to unsigned comparison: " << *ICmp + LLVM_DEBUG(dbgs() << "INDVARS: Marking comparison samesign: " << *ICmp << '\n'); ICmp->setPredicate(ICmpInst::getUnsignedPredicate(OriginalPred)); ICmp->setSameSign(); - } else + NumSameSign++; + Changed = true; return; - - ++NumElimCmp; - Changed = true; + } } bool SimplifyIndvar::eliminateSDiv(BinaryOperator *SDiv) { diff --git a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp index 4a15659..c3537f5 100644 --- a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp @@ -577,8 +577,8 @@ Value *LibCallSimplifier::optimizeStrCmp(CallInst *CI, IRBuilderBase &B) { // strcmp(x, y) -> cnst (if both x and y are constant strings) if (HasStr1 && HasStr2) - return ConstantInt::get(CI->getType(), - std::clamp(Str1.compare(Str2), -1, 1)); + return ConstantInt::getSigned(CI->getType(), + std::clamp(Str1.compare(Str2), -1, 1)); if (HasStr1 && Str1.empty()) // strcmp("", x) -> -*x return B.CreateNeg(B.CreateZExt( @@ -657,8 +657,8 @@ Value *LibCallSimplifier::optimizeStrNCmp(CallInst *CI, IRBuilderBase &B) { // Avoid truncating the 64-bit Length to 32 bits in ILP32. StringRef SubStr1 = substr(Str1, Length); StringRef SubStr2 = substr(Str2, Length); - return ConstantInt::get(CI->getType(), - std::clamp(SubStr1.compare(SubStr2), -1, 1)); + return ConstantInt::getSigned(CI->getType(), + std::clamp(SubStr1.compare(SubStr2), -1, 1)); } if (HasStr1 && Str1.empty()) // strncmp("", x, n) -> -*x @@ -1534,7 +1534,7 @@ static Value *optimizeMemCmpVarSize(CallInst *CI, Value *LHS, Value *RHS, int IRes = UChar(LStr[Pos]) < UChar(RStr[Pos]) ? -1 : 1; Value *MaxSize = ConstantInt::get(Size->getType(), Pos); Value *Cmp = B.CreateICmp(ICmpInst::ICMP_ULE, Size, MaxSize); - Value *Res = ConstantInt::get(CI->getType(), IRes); + Value *Res = ConstantInt::getSigned(CI->getType(), IRes); return B.CreateSelect(Cmp, Zero, Res); } @@ -1806,119 +1806,124 @@ Value *LibCallSimplifier::optimizeNew(CallInst *CI, IRBuilderBase &B, // better to replace the hinted call with a non hinted call, to avoid the // extra parameter and the if condition check of the hint value in the // allocator. This can be considered in the future. + Value *NewCall = nullptr; switch (Func) { case LibFunc_Znwm12__hot_cold_t: if (OptimizeExistingHotColdNew) - return emitHotColdNew(CI->getArgOperand(0), B, TLI, - LibFunc_Znwm12__hot_cold_t, HotCold); + NewCall = emitHotColdNew(CI->getArgOperand(0), B, TLI, + LibFunc_Znwm12__hot_cold_t, HotCold); break; case LibFunc_Znwm: - return emitHotColdNew(CI->getArgOperand(0), B, TLI, - LibFunc_Znwm12__hot_cold_t, HotCold); + NewCall = emitHotColdNew(CI->getArgOperand(0), B, TLI, + LibFunc_Znwm12__hot_cold_t, HotCold); break; case LibFunc_Znam12__hot_cold_t: if (OptimizeExistingHotColdNew) - return emitHotColdNew(CI->getArgOperand(0), B, TLI, - LibFunc_Znam12__hot_cold_t, HotCold); + NewCall = emitHotColdNew(CI->getArgOperand(0), B, TLI, + LibFunc_Znam12__hot_cold_t, HotCold); break; case LibFunc_Znam: - return emitHotColdNew(CI->getArgOperand(0), B, TLI, - LibFunc_Znam12__hot_cold_t, HotCold); + NewCall = emitHotColdNew(CI->getArgOperand(0), B, TLI, + LibFunc_Znam12__hot_cold_t, HotCold); break; case LibFunc_ZnwmRKSt9nothrow_t12__hot_cold_t: if (OptimizeExistingHotColdNew) - return emitHotColdNewNoThrow( + NewCall = emitHotColdNewNoThrow( CI->getArgOperand(0), CI->getArgOperand(1), B, TLI, LibFunc_ZnwmRKSt9nothrow_t12__hot_cold_t, HotCold); break; case LibFunc_ZnwmRKSt9nothrow_t: - return emitHotColdNewNoThrow(CI->getArgOperand(0), CI->getArgOperand(1), B, - TLI, LibFunc_ZnwmRKSt9nothrow_t12__hot_cold_t, - HotCold); + NewCall = emitHotColdNewNoThrow( + CI->getArgOperand(0), CI->getArgOperand(1), B, TLI, + LibFunc_ZnwmRKSt9nothrow_t12__hot_cold_t, HotCold); break; case LibFunc_ZnamRKSt9nothrow_t12__hot_cold_t: if (OptimizeExistingHotColdNew) - return emitHotColdNewNoThrow( + NewCall = emitHotColdNewNoThrow( CI->getArgOperand(0), CI->getArgOperand(1), B, TLI, LibFunc_ZnamRKSt9nothrow_t12__hot_cold_t, HotCold); break; case LibFunc_ZnamRKSt9nothrow_t: - return emitHotColdNewNoThrow(CI->getArgOperand(0), CI->getArgOperand(1), B, - TLI, LibFunc_ZnamRKSt9nothrow_t12__hot_cold_t, - HotCold); + NewCall = emitHotColdNewNoThrow( + CI->getArgOperand(0), CI->getArgOperand(1), B, TLI, + LibFunc_ZnamRKSt9nothrow_t12__hot_cold_t, HotCold); break; case LibFunc_ZnwmSt11align_val_t12__hot_cold_t: if (OptimizeExistingHotColdNew) - return emitHotColdNewAligned( + NewCall = emitHotColdNewAligned( CI->getArgOperand(0), CI->getArgOperand(1), B, TLI, LibFunc_ZnwmSt11align_val_t12__hot_cold_t, HotCold); break; case LibFunc_ZnwmSt11align_val_t: - return emitHotColdNewAligned(CI->getArgOperand(0), CI->getArgOperand(1), B, - TLI, LibFunc_ZnwmSt11align_val_t12__hot_cold_t, - HotCold); + NewCall = emitHotColdNewAligned( + CI->getArgOperand(0), CI->getArgOperand(1), B, TLI, + LibFunc_ZnwmSt11align_val_t12__hot_cold_t, HotCold); break; case LibFunc_ZnamSt11align_val_t12__hot_cold_t: if (OptimizeExistingHotColdNew) - return emitHotColdNewAligned( + NewCall = emitHotColdNewAligned( CI->getArgOperand(0), CI->getArgOperand(1), B, TLI, LibFunc_ZnamSt11align_val_t12__hot_cold_t, HotCold); break; case LibFunc_ZnamSt11align_val_t: - return emitHotColdNewAligned(CI->getArgOperand(0), CI->getArgOperand(1), B, - TLI, LibFunc_ZnamSt11align_val_t12__hot_cold_t, - HotCold); + NewCall = emitHotColdNewAligned( + CI->getArgOperand(0), CI->getArgOperand(1), B, TLI, + LibFunc_ZnamSt11align_val_t12__hot_cold_t, HotCold); break; case LibFunc_ZnwmSt11align_val_tRKSt9nothrow_t12__hot_cold_t: if (OptimizeExistingHotColdNew) - return emitHotColdNewAlignedNoThrow( + NewCall = emitHotColdNewAlignedNoThrow( CI->getArgOperand(0), CI->getArgOperand(1), CI->getArgOperand(2), B, TLI, LibFunc_ZnwmSt11align_val_tRKSt9nothrow_t12__hot_cold_t, HotCold); break; case LibFunc_ZnwmSt11align_val_tRKSt9nothrow_t: - return emitHotColdNewAlignedNoThrow( + NewCall = emitHotColdNewAlignedNoThrow( CI->getArgOperand(0), CI->getArgOperand(1), CI->getArgOperand(2), B, TLI, LibFunc_ZnwmSt11align_val_tRKSt9nothrow_t12__hot_cold_t, HotCold); break; case LibFunc_ZnamSt11align_val_tRKSt9nothrow_t12__hot_cold_t: if (OptimizeExistingHotColdNew) - return emitHotColdNewAlignedNoThrow( + NewCall = emitHotColdNewAlignedNoThrow( CI->getArgOperand(0), CI->getArgOperand(1), CI->getArgOperand(2), B, TLI, LibFunc_ZnamSt11align_val_tRKSt9nothrow_t12__hot_cold_t, HotCold); break; case LibFunc_ZnamSt11align_val_tRKSt9nothrow_t: - return emitHotColdNewAlignedNoThrow( + NewCall = emitHotColdNewAlignedNoThrow( CI->getArgOperand(0), CI->getArgOperand(1), CI->getArgOperand(2), B, TLI, LibFunc_ZnamSt11align_val_tRKSt9nothrow_t12__hot_cold_t, HotCold); break; case LibFunc_size_returning_new: - return emitHotColdSizeReturningNew(CI->getArgOperand(0), B, TLI, - LibFunc_size_returning_new_hot_cold, - HotCold); + NewCall = emitHotColdSizeReturningNew(CI->getArgOperand(0), B, TLI, + LibFunc_size_returning_new_hot_cold, + HotCold); break; case LibFunc_size_returning_new_hot_cold: if (OptimizeExistingHotColdNew) - return emitHotColdSizeReturningNew(CI->getArgOperand(0), B, TLI, - LibFunc_size_returning_new_hot_cold, - HotCold); + NewCall = emitHotColdSizeReturningNew(CI->getArgOperand(0), B, TLI, + LibFunc_size_returning_new_hot_cold, + HotCold); break; case LibFunc_size_returning_new_aligned: - return emitHotColdSizeReturningNewAligned( + NewCall = emitHotColdSizeReturningNewAligned( CI->getArgOperand(0), CI->getArgOperand(1), B, TLI, LibFunc_size_returning_new_aligned_hot_cold, HotCold); break; case LibFunc_size_returning_new_aligned_hot_cold: if (OptimizeExistingHotColdNew) - return emitHotColdSizeReturningNewAligned( + NewCall = emitHotColdSizeReturningNewAligned( CI->getArgOperand(0), CI->getArgOperand(1), B, TLI, LibFunc_size_returning_new_aligned_hot_cold, HotCold); break; default: return nullptr; } - return nullptr; + + if (auto *NewCI = dyn_cast_or_null<Instruction>(NewCall)) + NewCI->copyMetadata(*CI); + + return NewCall; } //===----------------------------------------------------------------------===// @@ -2538,6 +2543,30 @@ Value *LibCallSimplifier::optimizeFMinFMax(CallInst *CI, IRBuilderBase &B) { CI->getArgOperand(1), FMF)); } +Value *LibCallSimplifier::optimizeFMinimumnumFMaximumnum(CallInst *CI, + IRBuilderBase &B) { + Module *M = CI->getModule(); + + // If we can shrink the call to a float function rather than a double + // function, do that first. + Function *Callee = CI->getCalledFunction(); + StringRef Name = Callee->getName(); + if ((Name == "fminimum_num" || Name == "fmaximum_num") && + hasFloatVersion(M, Name)) + if (Value *Ret = optimizeBinaryDoubleFP(CI, B, TLI)) + return Ret; + + // The new fminimum_num/fmaximum_num functions, unlike fmin/fmax, *are* + // sensitive to the sign of zero, so we don't change the fast-math flags like + // we did for those. + + Intrinsic::ID IID = Callee->getName().starts_with("fminimum_num") + ? Intrinsic::minimumnum + : Intrinsic::maximumnum; + return copyFlags(*CI, B.CreateBinaryIntrinsic(IID, CI->getArgOperand(0), + CI->getArgOperand(1), CI)); +} + Value *LibCallSimplifier::optimizeLog(CallInst *Log, IRBuilderBase &B) { Function *LogFn = Log->getCalledFunction(); StringRef LogNm = LogFn->getName(); @@ -2921,7 +2950,7 @@ Value *LibCallSimplifier::optimizeTrigInversionPairs(CallInst *CI, .Case("asinh", LibFunc_sinh) .Case("asinhf", LibFunc_sinhf) .Case("asinhl", LibFunc_sinhl) - .Default(NumLibFuncs); // Used as error value + .Default(NotLibFunc); // Used as error value if (Func == inverseFunc) Ret = OpC->getArgOperand(0); } @@ -3154,7 +3183,7 @@ Value *LibCallSimplifier::optimizeRemquo(CallInst *CI, IRBuilderBase &B) { return nullptr; B.CreateAlignedStore( - ConstantInt::get(B.getIntNTy(IntBW), QuotInt.getExtValue()), + ConstantInt::getSigned(B.getIntNTy(IntBW), QuotInt.getExtValue()), CI->getArgOperand(2), CI->getParamAlign(2)); return ConstantFP::get(CI->getType(), Rem); } @@ -4118,6 +4147,13 @@ Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI, case LibFunc_fmax: case LibFunc_fmaxl: return optimizeFMinFMax(CI, Builder); + case LibFunc_fminimum_numf: + case LibFunc_fminimum_num: + case LibFunc_fminimum_numl: + case LibFunc_fmaximum_numf: + case LibFunc_fmaximum_num: + case LibFunc_fmaximum_numl: + return optimizeFMinimumnumFMaximumnum(CI, Builder); case LibFunc_cabs: case LibFunc_cabsf: case LibFunc_cabsl: diff --git a/llvm/lib/Transforms/Utils/UnifyLoopExits.cpp b/llvm/lib/Transforms/Utils/UnifyLoopExits.cpp index 94c5c170..e86ab13 100644 --- a/llvm/lib/Transforms/Utils/UnifyLoopExits.cpp +++ b/llvm/lib/Transforms/Utils/UnifyLoopExits.cpp @@ -158,6 +158,7 @@ static bool unifyLoopExits(DominatorTree &DT, LoopInfo &LI, Loop *L) { SmallVector<BasicBlock *, 8> CallBrTargetBlocksToFix; // Redirect exiting edges through a control flow hub. ControlFlowHub CHub; + bool Changed = false; for (unsigned I = 0; I < ExitingBlocks.size(); ++I) { BasicBlock *BB = ExitingBlocks[I]; @@ -182,6 +183,10 @@ static bool unifyLoopExits(DominatorTree &DT, LoopInfo &LI, Loop *L) { bool UpdatedLI = false; BasicBlock *NewSucc = SplitCallBrEdge(BB, Succ, J, &DTU, nullptr, &LI, &UpdatedLI); + // SplitCallBrEdge modifies the CFG because it creates an intermediate + // block. So we need to set the changed flag no matter what the + // ControlFlowHub is going to do later. + Changed = true; // Even if CallBr and Succ do not have a common parent loop, we need to // add the new target block to the parent loop of the current loop. if (!UpdatedLI) @@ -207,6 +212,7 @@ static bool unifyLoopExits(DominatorTree &DT, LoopInfo &LI, Loop *L) { bool ChangedCFG; std::tie(LoopExitBlock, ChangedCFG) = CHub.finalize( &DTU, GuardBlocks, "loop.exit", MaxBooleansInControlFlowHub.getValue()); + ChangedCFG |= Changed; if (!ChangedCFG) return false; diff --git a/llvm/lib/Transforms/Utils/ValueMapper.cpp b/llvm/lib/Transforms/Utils/ValueMapper.cpp index 8d8a60b..6e36006 100644 --- a/llvm/lib/Transforms/Utils/ValueMapper.cpp +++ b/llvm/lib/Transforms/Utils/ValueMapper.cpp @@ -77,7 +77,7 @@ struct WorklistEntry { }; struct AppendingGVTy { GlobalVariable *GV; - Constant *InitPrefix; + GlobalVariable *OldGV; }; struct AliasOrIFuncTy { GlobalValue *GV; @@ -162,7 +162,7 @@ public: void scheduleMapGlobalInitializer(GlobalVariable &GV, Constant &Init, unsigned MCID); - void scheduleMapAppendingVariable(GlobalVariable &GV, Constant *InitPrefix, + void scheduleMapAppendingVariable(GlobalVariable &GV, GlobalVariable *OldGV, bool IsOldCtorDtor, ArrayRef<Constant *> NewMembers, unsigned MCID); @@ -173,7 +173,7 @@ public: void flush(); private: - void mapAppendingVariable(GlobalVariable &GV, Constant *InitPrefix, + void mapAppendingVariable(GlobalVariable &GV, GlobalVariable *OldGV, bool IsOldCtorDtor, ArrayRef<Constant *> NewMembers); @@ -526,8 +526,9 @@ Value *Mapper::mapValue(const Value *V) { if (isa<ConstantVector>(C)) return getVM()[V] = ConstantVector::get(Ops); if (isa<ConstantPtrAuth>(C)) - return getVM()[V] = ConstantPtrAuth::get(Ops[0], cast<ConstantInt>(Ops[1]), - cast<ConstantInt>(Ops[2]), Ops[3]); + return getVM()[V] = + ConstantPtrAuth::get(Ops[0], cast<ConstantInt>(Ops[1]), + cast<ConstantInt>(Ops[2]), Ops[3], Ops[4]); // If this is a no-operand constant, it must be because the type was remapped. if (isa<PoisonValue>(C)) return getVM()[V] = PoisonValue::get(NewTy); @@ -944,7 +945,7 @@ void Mapper::flush() { drop_begin(AppendingInits, PrefixSize)); AppendingInits.resize(PrefixSize); mapAppendingVariable(*E.Data.AppendingGV.GV, - E.Data.AppendingGV.InitPrefix, + E.Data.AppendingGV.OldGV, E.AppendingGVIsOldCtorDtor, ArrayRef(NewInits)); break; } @@ -1094,15 +1095,21 @@ void Mapper::remapFunction(Function &F) { } } -void Mapper::mapAppendingVariable(GlobalVariable &GV, Constant *InitPrefix, +void Mapper::mapAppendingVariable(GlobalVariable &GV, GlobalVariable *OldGV, bool IsOldCtorDtor, ArrayRef<Constant *> NewMembers) { + Constant *InitPrefix = + (OldGV && !OldGV->isDeclaration()) ? OldGV->getInitializer() : nullptr; + SmallVector<Constant *, 16> Elements; if (InitPrefix) { unsigned NumElements = cast<ArrayType>(InitPrefix->getType())->getNumElements(); for (unsigned I = 0; I != NumElements; ++I) Elements.push_back(InitPrefix->getAggregateElement(I)); + OldGV->setInitializer(nullptr); + if (InitPrefix->hasUseList() && InitPrefix->use_empty()) + InitPrefix->destroyConstant(); } PointerType *VoidPtrTy; @@ -1148,7 +1155,7 @@ void Mapper::scheduleMapGlobalInitializer(GlobalVariable &GV, Constant &Init, } void Mapper::scheduleMapAppendingVariable(GlobalVariable &GV, - Constant *InitPrefix, + GlobalVariable *OldGV, bool IsOldCtorDtor, ArrayRef<Constant *> NewMembers, unsigned MCID) { @@ -1159,7 +1166,7 @@ void Mapper::scheduleMapAppendingVariable(GlobalVariable &GV, WE.Kind = WorklistEntry::MapAppendingVar; WE.MCID = MCID; WE.Data.AppendingGV.GV = &GV; - WE.Data.AppendingGV.InitPrefix = InitPrefix; + WE.Data.AppendingGV.OldGV = OldGV; WE.AppendingGVIsOldCtorDtor = IsOldCtorDtor; WE.AppendingGVNumNewMembers = NewMembers.size(); Worklist.push_back(WE); @@ -1282,12 +1289,12 @@ void ValueMapper::scheduleMapGlobalInitializer(GlobalVariable &GV, } void ValueMapper::scheduleMapAppendingVariable(GlobalVariable &GV, - Constant *InitPrefix, + GlobalVariable *OldGV, bool IsOldCtorDtor, ArrayRef<Constant *> NewMembers, unsigned MCID) { getAsMapper(pImpl)->scheduleMapAppendingVariable( - GV, InitPrefix, IsOldCtorDtor, NewMembers, MCID); + GV, OldGV, IsOldCtorDtor, NewMembers, MCID); } void ValueMapper::scheduleMapGlobalAlias(GlobalAlias &GA, Constant &Aliasee, |
