diff options
Diffstat (limited to 'llvm/lib/Analysis')
-rwxr-xr-x | llvm/lib/Analysis/ConstantFolding.cpp | 4 | ||||
-rw-r--r-- | llvm/lib/Analysis/DependenceAnalysis.cpp | 294 | ||||
-rw-r--r-- | llvm/lib/Analysis/InstructionSimplify.cpp | 30 | ||||
-rw-r--r-- | llvm/lib/Analysis/LazyValueInfo.cpp | 10 | ||||
-rw-r--r-- | llvm/lib/Analysis/MLInlineAdvisor.cpp | 58 | ||||
-rw-r--r-- | llvm/lib/Analysis/MemoryLocation.cpp | 4 | ||||
-rw-r--r-- | llvm/lib/Analysis/ScalarEvolution.cpp | 133 | ||||
-rw-r--r-- | llvm/lib/Analysis/ValueTracking.cpp | 5 |
8 files changed, 440 insertions, 98 deletions
diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp index a5ba197..e9e2e7d 100755 --- a/llvm/lib/Analysis/ConstantFolding.cpp +++ b/llvm/lib/Analysis/ConstantFolding.cpp @@ -4056,8 +4056,8 @@ static Constant *ConstantFoldFixedVectorCall( switch (IntrinsicID) { case Intrinsic::masked_load: { auto *SrcPtr = Operands[0]; - auto *Mask = Operands[2]; - auto *Passthru = Operands[3]; + auto *Mask = Operands[1]; + auto *Passthru = Operands[2]; Constant *VecData = ConstantFoldLoadFromConstPtr(SrcPtr, FVTy, DL); diff --git a/llvm/lib/Analysis/DependenceAnalysis.cpp b/llvm/lib/Analysis/DependenceAnalysis.cpp index 805b682..0a8c2f8 100644 --- a/llvm/lib/Analysis/DependenceAnalysis.cpp +++ b/llvm/lib/Analysis/DependenceAnalysis.cpp @@ -128,6 +128,18 @@ static cl::opt<bool> RunSIVRoutinesOnly( "The purpose is mainly to exclude the influence of those routines " "in regression tests for SIV routines.")); +// TODO: This flag is disabled by default because it is still under development. +// Enable it or delete this flag when the feature is ready. +static cl::opt<bool> EnableMonotonicityCheck( + "da-enable-monotonicity-check", cl::init(false), cl::Hidden, + cl::desc("Check if the subscripts are monotonic. If it's not, dependence " + "is reported as unknown.")); + +static cl::opt<bool> DumpMonotonicityReport( + "da-dump-monotonicity-report", cl::init(false), cl::Hidden, + cl::desc( + "When printing analysis, dump the results of monotonicity checks.")); + //===----------------------------------------------------------------------===// // basics @@ -177,13 +189,196 @@ void DependenceAnalysisWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { AU.addRequiredTransitive<LoopInfoWrapperPass>(); } +namespace { + +/// The property of monotonicity of a SCEV. To define the monotonicity, assume +/// a SCEV defined within N-nested loops. Let i_k denote the iteration number +/// of the k-th loop. Then we can regard the SCEV as an N-ary function: +/// +/// F(i_1, i_2, ..., i_N) +/// +/// The domain of i_k is the closed range [0, BTC_k], where BTC_k is the +/// backedge-taken count of the k-th loop. +/// +/// A function F is said to be "monotonically increasing with respect to the +/// k-th loop" if x <= y implies the following condition: +/// +/// F(i_1, ..., i_{k-1}, x, i_{k+1}, ..., i_N) <= +/// F(i_1, ..., i_{k-1}, y, i_{k+1}, ..., i_N) +/// +/// where i_1, ..., i_{k-1}, i_{k+1}, ..., i_N, x, and y are elements of their +/// respective domains. +/// +/// Likewise F is "monotonically decreasing with respect to the k-th loop" +/// if x <= y implies +/// +/// F(i_1, ..., i_{k-1}, x, i_{k+1}, ..., i_N) >= +/// F(i_1, ..., i_{k-1}, y, i_{k+1}, ..., i_N) +/// +/// A function F that is monotonically increasing or decreasing with respect to +/// the k-th loop is simply called "monotonic with respect to k-th loop". +/// +/// A function F is said to be "multivariate monotonic" when it is monotonic +/// with respect to all of the N loops. +/// +/// Since integer comparison can be either signed or unsigned, we need to +/// distinguish monotonicity in the signed sense from that in the unsigned +/// sense. Note that the inequality "x <= y" merely indicates loop progression +/// and is not affected by the difference between signed and unsigned order. +/// +/// Currently we only consider monotonicity in a signed sense. +enum class SCEVMonotonicityType { + /// We don't know anything about the monotonicity of the SCEV. + Unknown, + + /// The SCEV is loop-invariant with respect to the outermost loop. In other + /// words, the function F corresponding to the SCEV is a constant function. + Invariant, + + /// The function F corresponding to the SCEV is multivariate monotonic in a + /// signed sense. Note that the multivariate monotonic function may also be a + /// constant function. The order employed in the definition of monotonicity + /// is not strict order. + MultivariateSignedMonotonic, +}; + +struct SCEVMonotonicity { + SCEVMonotonicity(SCEVMonotonicityType Type, + const SCEV *FailurePoint = nullptr); + + SCEVMonotonicityType getType() const { return Type; } + + const SCEV *getFailurePoint() const { return FailurePoint; } + + bool isUnknown() const { return Type == SCEVMonotonicityType::Unknown; } + + void print(raw_ostream &OS, unsigned Depth) const; + +private: + SCEVMonotonicityType Type; + + /// The subexpression that caused Unknown. Mainly for debugging purpose. + const SCEV *FailurePoint; +}; + +/// Check the monotonicity of a SCEV. Since dependence tests (SIV, MIV, etc.) +/// assume that subscript expressions are (multivariate) monotonic, we need to +/// verify this property before applying those tests. Violating this assumption +/// may cause them to produce incorrect results. +struct SCEVMonotonicityChecker + : public SCEVVisitor<SCEVMonotonicityChecker, SCEVMonotonicity> { + + SCEVMonotonicityChecker(ScalarEvolution *SE) : SE(SE) {} + + /// Check the monotonicity of \p Expr. \p Expr must be integer type. If \p + /// OutermostLoop is not null, \p Expr must be defined in \p OutermostLoop or + /// one of its nested loops. + SCEVMonotonicity checkMonotonicity(const SCEV *Expr, + const Loop *OutermostLoop); + +private: + ScalarEvolution *SE; + + /// The outermost loop that DA is analyzing. + const Loop *OutermostLoop; + + /// A helper to classify \p Expr as either Invariant or Unknown. + SCEVMonotonicity invariantOrUnknown(const SCEV *Expr); + + /// Return true if \p Expr is loop-invariant with respect to the outermost + /// loop. + bool isLoopInvariant(const SCEV *Expr) const; + + /// A helper to create an Unknown SCEVMonotonicity. + SCEVMonotonicity createUnknown(const SCEV *FailurePoint) { + return SCEVMonotonicity(SCEVMonotonicityType::Unknown, FailurePoint); + } + + SCEVMonotonicity visitAddRecExpr(const SCEVAddRecExpr *Expr); + + SCEVMonotonicity visitConstant(const SCEVConstant *) { + return SCEVMonotonicity(SCEVMonotonicityType::Invariant); + } + SCEVMonotonicity visitVScale(const SCEVVScale *) { + return SCEVMonotonicity(SCEVMonotonicityType::Invariant); + } + + // TODO: Handle more cases. + SCEVMonotonicity visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { + return invariantOrUnknown(Expr); + } + SCEVMonotonicity visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { + return invariantOrUnknown(Expr); + } + SCEVMonotonicity visitAddExpr(const SCEVAddExpr *Expr) { + return invariantOrUnknown(Expr); + } + SCEVMonotonicity visitMulExpr(const SCEVMulExpr *Expr) { + return invariantOrUnknown(Expr); + } + SCEVMonotonicity visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { + return invariantOrUnknown(Expr); + } + SCEVMonotonicity visitTruncateExpr(const SCEVTruncateExpr *Expr) { + return invariantOrUnknown(Expr); + } + SCEVMonotonicity visitUDivExpr(const SCEVUDivExpr *Expr) { + return invariantOrUnknown(Expr); + } + SCEVMonotonicity visitSMaxExpr(const SCEVSMaxExpr *Expr) { + return invariantOrUnknown(Expr); + } + SCEVMonotonicity visitUMaxExpr(const SCEVUMaxExpr *Expr) { + return invariantOrUnknown(Expr); + } + SCEVMonotonicity visitSMinExpr(const SCEVSMinExpr *Expr) { + return invariantOrUnknown(Expr); + } + SCEVMonotonicity visitUMinExpr(const SCEVUMinExpr *Expr) { + return invariantOrUnknown(Expr); + } + SCEVMonotonicity visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) { + return invariantOrUnknown(Expr); + } + SCEVMonotonicity visitUnknown(const SCEVUnknown *Expr) { + return invariantOrUnknown(Expr); + } + SCEVMonotonicity visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { + return invariantOrUnknown(Expr); + } + + friend struct SCEVVisitor<SCEVMonotonicityChecker, SCEVMonotonicity>; +}; + +} // anonymous namespace + // Used to test the dependence analyzer. // Looks through the function, noting instructions that may access memory. // Calls depends() on every possible pair and prints out the result. // Ignores all other instructions. static void dumpExampleDependence(raw_ostream &OS, DependenceInfo *DA, - ScalarEvolution &SE, bool NormalizeResults) { + ScalarEvolution &SE, LoopInfo &LI, + bool NormalizeResults) { auto *F = DA->getFunction(); + + if (DumpMonotonicityReport) { + SCEVMonotonicityChecker Checker(&SE); + OS << "Monotonicity check:\n"; + for (Instruction &Inst : instructions(F)) { + if (!isa<LoadInst>(Inst) && !isa<StoreInst>(Inst)) + continue; + Value *Ptr = getLoadStorePointerOperand(&Inst); + const Loop *L = LI.getLoopFor(Inst.getParent()); + const SCEV *PtrSCEV = SE.getSCEVAtScope(Ptr, L); + const SCEV *AccessFn = SE.removePointerBase(PtrSCEV); + SCEVMonotonicity Mon = Checker.checkMonotonicity(AccessFn, L); + OS.indent(2) << "Inst: " << Inst << "\n"; + OS.indent(4) << "Expr: " << *AccessFn << "\n"; + Mon.print(OS, 4); + } + OS << "\n"; + } + for (inst_iterator SrcI = inst_begin(F), SrcE = inst_end(F); SrcI != SrcE; ++SrcI) { if (SrcI->mayReadOrWriteMemory()) { @@ -235,7 +430,8 @@ static void dumpExampleDependence(raw_ostream &OS, DependenceInfo *DA, void DependenceAnalysisWrapperPass::print(raw_ostream &OS, const Module *) const { dumpExampleDependence( - OS, info.get(), getAnalysis<ScalarEvolutionWrapperPass>().getSE(), false); + OS, info.get(), getAnalysis<ScalarEvolutionWrapperPass>().getSE(), + getAnalysis<LoopInfoWrapperPass>().getLoopInfo(), false); } PreservedAnalyses @@ -244,7 +440,7 @@ DependenceAnalysisPrinterPass::run(Function &F, FunctionAnalysisManager &FAM) { << "':\n"; dumpExampleDependence(OS, &FAM.getResult<DependenceAnalysis>(F), FAM.getResult<ScalarEvolutionAnalysis>(F), - NormalizeResults); + FAM.getResult<LoopAnalysis>(F), NormalizeResults); return PreservedAnalyses::all(); } @@ -671,6 +867,81 @@ bool DependenceInfo::intersectConstraints(Constraint *X, const Constraint *Y) { } //===----------------------------------------------------------------------===// +// SCEVMonotonicity + +SCEVMonotonicity::SCEVMonotonicity(SCEVMonotonicityType Type, + const SCEV *FailurePoint) + : Type(Type), FailurePoint(FailurePoint) { + assert( + ((Type == SCEVMonotonicityType::Unknown) == (FailurePoint != nullptr)) && + "FailurePoint must be provided iff Type is Unknown"); +} + +void SCEVMonotonicity::print(raw_ostream &OS, unsigned Depth) const { + OS.indent(Depth) << "Monotonicity: "; + switch (Type) { + case SCEVMonotonicityType::Unknown: + assert(FailurePoint && "FailurePoint must be provided for Unknown"); + OS << "Unknown\n"; + OS.indent(Depth) << "Reason: " << *FailurePoint << "\n"; + break; + case SCEVMonotonicityType::Invariant: + OS << "Invariant\n"; + break; + case SCEVMonotonicityType::MultivariateSignedMonotonic: + OS << "MultivariateSignedMonotonic\n"; + break; + } +} + +bool SCEVMonotonicityChecker::isLoopInvariant(const SCEV *Expr) const { + return !OutermostLoop || SE->isLoopInvariant(Expr, OutermostLoop); +} + +SCEVMonotonicity SCEVMonotonicityChecker::invariantOrUnknown(const SCEV *Expr) { + if (isLoopInvariant(Expr)) + return SCEVMonotonicity(SCEVMonotonicityType::Invariant); + return createUnknown(Expr); +} + +SCEVMonotonicity +SCEVMonotonicityChecker::checkMonotonicity(const SCEV *Expr, + const Loop *OutermostLoop) { + assert(Expr->getType()->isIntegerTy() && "Expr must be integer type"); + this->OutermostLoop = OutermostLoop; + return visit(Expr); +} + +/// We only care about an affine AddRec at the moment. For an affine AddRec, +/// the monotonicity can be inferred from its nowrap property. For example, let +/// X and Y be loop-invariant, and assume Y is non-negative. An AddRec +/// {X,+.Y}<nsw> implies: +/// +/// X <=s (X + Y) <=s ((X + Y) + Y) <=s ... +/// +/// Thus, we can conclude that the AddRec is monotonically increasing with +/// respect to the associated loop in a signed sense. The similar reasoning +/// applies when Y is non-positive, leading to a monotonically decreasing +/// AddRec. +SCEVMonotonicity +SCEVMonotonicityChecker::visitAddRecExpr(const SCEVAddRecExpr *Expr) { + if (!Expr->isAffine() || !Expr->hasNoSignedWrap()) + return createUnknown(Expr); + + const SCEV *Start = Expr->getStart(); + const SCEV *Step = Expr->getStepRecurrence(*SE); + + SCEVMonotonicity StartMon = visit(Start); + if (StartMon.isUnknown()) + return StartMon; + + if (!isLoopInvariant(Step)) + return createUnknown(Expr); + + return SCEVMonotonicity(SCEVMonotonicityType::MultivariateSignedMonotonic); +} + +//===----------------------------------------------------------------------===// // DependenceInfo methods // For debugging purposes. Dumps a dependence to OS. @@ -3488,10 +3759,19 @@ bool DependenceInfo::tryDelinearize(Instruction *Src, Instruction *Dst, // resize Pair to contain as many pairs of subscripts as the delinearization // has found, and then initialize the pairs following the delinearization. Pair.resize(Size); + SCEVMonotonicityChecker MonChecker(SE); + const Loop *OutermostLoop = SrcLoop ? SrcLoop->getOutermostLoop() : nullptr; for (int I = 0; I < Size; ++I) { Pair[I].Src = SrcSubscripts[I]; Pair[I].Dst = DstSubscripts[I]; unifySubscriptType(&Pair[I]); + + if (EnableMonotonicityCheck) { + if (MonChecker.checkMonotonicity(Pair[I].Src, OutermostLoop).isUnknown()) + return false; + if (MonChecker.checkMonotonicity(Pair[I].Dst, OutermostLoop).isUnknown()) + return false; + } } return true; @@ -3824,6 +4104,14 @@ DependenceInfo::depends(Instruction *Src, Instruction *Dst, Pair[0].Src = SrcEv; Pair[0].Dst = DstEv; + SCEVMonotonicityChecker MonChecker(SE); + const Loop *OutermostLoop = SrcLoop ? SrcLoop->getOutermostLoop() : nullptr; + if (EnableMonotonicityCheck) + if (MonChecker.checkMonotonicity(Pair[0].Src, OutermostLoop).isUnknown() || + MonChecker.checkMonotonicity(Pair[0].Dst, OutermostLoop).isUnknown()) + return std::make_unique<Dependence>(Src, Dst, + SCEVUnionPredicate(Assume, *SE)); + if (Delinearize) { if (tryDelinearize(Src, Dst, Pair)) { LLVM_DEBUG(dbgs() << " delinearized\n"); diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp index e08ef60..8da51d0 100644 --- a/llvm/lib/Analysis/InstructionSimplify.cpp +++ b/llvm/lib/Analysis/InstructionSimplify.cpp @@ -5106,32 +5106,33 @@ static Value *simplifyGEPInst(Type *SrcTy, Value *Ptr, return Ptr; // The following transforms are only safe if the ptrtoint cast - // doesn't truncate the pointers. - if (Indices[0]->getType()->getScalarSizeInBits() == - Q.DL.getPointerSizeInBits(AS)) { + // doesn't truncate the address of the pointers. The non-address bits + // must be the same, as the underlying objects are the same. + if (Indices[0]->getType()->getScalarSizeInBits() >= + Q.DL.getAddressSizeInBits(AS)) { auto CanSimplify = [GEPTy, &P, Ptr]() -> bool { return P->getType() == GEPTy && getUnderlyingObject(P) == getUnderlyingObject(Ptr); }; // getelementptr V, (sub P, V) -> P if P points to a type of size 1. if (TyAllocSize == 1 && - match(Indices[0], - m_Sub(m_PtrToInt(m_Value(P)), m_PtrToInt(m_Specific(Ptr)))) && + match(Indices[0], m_Sub(m_PtrToIntOrAddr(m_Value(P)), + m_PtrToIntOrAddr(m_Specific(Ptr)))) && CanSimplify()) return P; // getelementptr V, (ashr (sub P, V), C) -> P if P points to a type of // size 1 << C. - if (match(Indices[0], m_AShr(m_Sub(m_PtrToInt(m_Value(P)), - m_PtrToInt(m_Specific(Ptr))), + if (match(Indices[0], m_AShr(m_Sub(m_PtrToIntOrAddr(m_Value(P)), + m_PtrToIntOrAddr(m_Specific(Ptr))), m_ConstantInt(C))) && TyAllocSize == 1ULL << C && CanSimplify()) return P; // getelementptr V, (sdiv (sub P, V), C) -> P if P points to a type of // size C. - if (match(Indices[0], m_SDiv(m_Sub(m_PtrToInt(m_Value(P)), - m_PtrToInt(m_Specific(Ptr))), + if (match(Indices[0], m_SDiv(m_Sub(m_PtrToIntOrAddr(m_Value(P)), + m_PtrToIntOrAddr(m_Specific(Ptr))), m_SpecificInt(TyAllocSize))) && CanSimplify()) return P; @@ -5440,9 +5441,10 @@ static Value *simplifyCastInst(unsigned CastOpc, Value *Op, Type *Ty, // ptrtoint (ptradd (Ptr, X - ptrtoint(Ptr))) -> X Value *Ptr, *X; - if (CastOpc == Instruction::PtrToInt && - match(Op, m_PtrAdd(m_Value(Ptr), - m_Sub(m_Value(X), m_PtrToInt(m_Deferred(Ptr))))) && + if ((CastOpc == Instruction::PtrToInt || CastOpc == Instruction::PtrToAddr) && + match(Op, + m_PtrAdd(m_Value(Ptr), + m_Sub(m_Value(X), m_PtrToIntOrAddr(m_Deferred(Ptr))))) && X->getType() == Ty && Ty == Q.DL.getIndexType(Ptr->getType())) return X; @@ -6987,8 +6989,8 @@ static Value *simplifyIntrinsic(CallBase *Call, Value *Callee, switch (IID) { case Intrinsic::masked_load: case Intrinsic::masked_gather: { - Value *MaskArg = Args[2]; - Value *PassthruArg = Args[3]; + Value *MaskArg = Args[1]; + Value *PassthruArg = Args[2]; // If the mask is all zeros or undef, the "passthru" argument is the result. if (maskIsAllZeroOrUndef(MaskArg)) return PassthruArg; diff --git a/llvm/lib/Analysis/LazyValueInfo.cpp b/llvm/lib/Analysis/LazyValueInfo.cpp index 0e5bc48..df75999 100644 --- a/llvm/lib/Analysis/LazyValueInfo.cpp +++ b/llvm/lib/Analysis/LazyValueInfo.cpp @@ -947,9 +947,8 @@ LazyValueInfoImpl::solveBlockValueSelect(SelectInst *SI, BasicBlock *BB) { /*UseBlockValue*/ false)); } - ValueLatticeElement Result = TrueVal; - Result.mergeIn(FalseVal); - return Result; + TrueVal.mergeIn(FalseVal); + return TrueVal; } std::optional<ConstantRange> @@ -1778,9 +1777,8 @@ ValueLatticeElement LazyValueInfoImpl::getValueInBlock(Value *V, BasicBlock *BB, assert(OptResult && "Value not available after solving"); } - ValueLatticeElement Result = *OptResult; - LLVM_DEBUG(dbgs() << " Result = " << Result << "\n"); - return Result; + LLVM_DEBUG(dbgs() << " Result = " << *OptResult << "\n"); + return *OptResult; } ValueLatticeElement LazyValueInfoImpl::getValueAt(Value *V, Instruction *CxtI) { diff --git a/llvm/lib/Analysis/MLInlineAdvisor.cpp b/llvm/lib/Analysis/MLInlineAdvisor.cpp index 1d1a5560..9a5ae2a 100644 --- a/llvm/lib/Analysis/MLInlineAdvisor.cpp +++ b/llvm/lib/Analysis/MLInlineAdvisor.cpp @@ -324,32 +324,44 @@ void MLInlineAdvisor::onSuccessfulInlining(const MLInlineAdvice &Advice, FAM.invalidate(*Caller, PA); } Advice.updateCachedCallerFPI(FAM); - int64_t IRSizeAfter = - getIRSize(*Caller) + (CalleeWasDeleted ? 0 : Advice.CalleeIRSize); - CurrentIRSize += IRSizeAfter - (Advice.CallerIRSize + Advice.CalleeIRSize); + if (Caller == Callee) { + assert(!CalleeWasDeleted); + // We double-counted CallerAndCalleeEdges - since the caller and callee + // would be the same + assert(Advice.CallerAndCalleeEdges % 2 == 0); + CurrentIRSize += getIRSize(*Caller) - Advice.CallerIRSize; + EdgeCount += getCachedFPI(*Caller).DirectCallsToDefinedFunctions - + Advice.CallerAndCalleeEdges / 2; + // The NodeCount would stay the same. + } else { + int64_t IRSizeAfter = + getIRSize(*Caller) + (CalleeWasDeleted ? 0 : Advice.CalleeIRSize); + CurrentIRSize += IRSizeAfter - (Advice.CallerIRSize + Advice.CalleeIRSize); + + // We can delta-update module-wide features. We know the inlining only + // changed the caller, and maybe the callee (by deleting the latter). Nodes + // are simple to update. For edges, we 'forget' the edges that the caller + // and callee used to have before inlining, and add back what they currently + // have together. + int64_t NewCallerAndCalleeEdges = + getCachedFPI(*Caller).DirectCallsToDefinedFunctions; + + // A dead function's node is not actually removed from the call graph until + // the end of the call graph walk, but the node no longer belongs to any + // valid SCC. + if (CalleeWasDeleted) { + --NodeCount; + NodesInLastSCC.erase(CG.lookup(*Callee)); + DeadFunctions.insert(Callee); + } else { + NewCallerAndCalleeEdges += + getCachedFPI(*Callee).DirectCallsToDefinedFunctions; + } + EdgeCount += (NewCallerAndCalleeEdges - Advice.CallerAndCalleeEdges); + } if (CurrentIRSize > SizeIncreaseThreshold * InitialIRSize) ForceStop = true; - // We can delta-update module-wide features. We know the inlining only changed - // the caller, and maybe the callee (by deleting the latter). - // Nodes are simple to update. - // For edges, we 'forget' the edges that the caller and callee used to have - // before inlining, and add back what they currently have together. - int64_t NewCallerAndCalleeEdges = - getCachedFPI(*Caller).DirectCallsToDefinedFunctions; - - // A dead function's node is not actually removed from the call graph until - // the end of the call graph walk, but the node no longer belongs to any valid - // SCC. - if (CalleeWasDeleted) { - --NodeCount; - NodesInLastSCC.erase(CG.lookup(*Callee)); - DeadFunctions.insert(Callee); - } else { - NewCallerAndCalleeEdges += - getCachedFPI(*Callee).DirectCallsToDefinedFunctions; - } - EdgeCount += (NewCallerAndCalleeEdges - Advice.CallerAndCalleeEdges); assert(CurrentIRSize >= 0 && EdgeCount >= 0 && NodeCount >= 0); } diff --git a/llvm/lib/Analysis/MemoryLocation.cpp b/llvm/lib/Analysis/MemoryLocation.cpp index dcc5117..1c5f08e 100644 --- a/llvm/lib/Analysis/MemoryLocation.cpp +++ b/llvm/lib/Analysis/MemoryLocation.cpp @@ -245,7 +245,7 @@ MemoryLocation MemoryLocation::getForArgument(const CallBase *Call, assert(ArgIdx == 0 && "Invalid argument index"); auto *Ty = cast<VectorType>(II->getType()); - if (auto KnownType = getKnownTypeFromMaskedOp(II->getOperand(2), Ty)) + if (auto KnownType = getKnownTypeFromMaskedOp(II->getOperand(1), Ty)) return MemoryLocation(Arg, DL.getTypeStoreSize(*KnownType), AATags); return MemoryLocation( @@ -255,7 +255,7 @@ MemoryLocation MemoryLocation::getForArgument(const CallBase *Call, assert(ArgIdx == 1 && "Invalid argument index"); auto *Ty = cast<VectorType>(II->getArgOperand(0)->getType()); - if (auto KnownType = getKnownTypeFromMaskedOp(II->getOperand(3), Ty)) + if (auto KnownType = getKnownTypeFromMaskedOp(II->getOperand(2), Ty)) return MemoryLocation(Arg, DL.getTypeStoreSize(*KnownType), AATags); return MemoryLocation( diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index e06b095..6f7dd79 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -15473,6 +15473,38 @@ void ScalarEvolution::LoopGuards::collectFromPHI( } } +// Return a new SCEV that modifies \p Expr to the closest number divides by +// \p Divisor and less or equal than Expr. For now, only handle constant +// Expr. +static const SCEV *getPreviousSCEVDivisibleByDivisor(const SCEV *Expr, + const APInt &DivisorVal, + ScalarEvolution &SE) { + const APInt *ExprVal; + if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() || + DivisorVal.isNonPositive()) + return Expr; + APInt Rem = ExprVal->urem(DivisorVal); + // return the SCEV: Expr - Expr % Divisor + return SE.getConstant(*ExprVal - Rem); +} + +// Return a new SCEV that modifies \p Expr to the closest number divides by +// \p Divisor and greater or equal than Expr. For now, only handle constant +// Expr. +static const SCEV *getNextSCEVDivisibleByDivisor(const SCEV *Expr, + const APInt &DivisorVal, + ScalarEvolution &SE) { + const APInt *ExprVal; + if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() || + DivisorVal.isNonPositive()) + return Expr; + APInt Rem = ExprVal->urem(DivisorVal); + if (Rem.isZero()) + return Expr; + // return the SCEV: Expr + Divisor - Expr % Divisor + return SE.getConstant(*ExprVal + DivisorVal - Rem); +} + void ScalarEvolution::LoopGuards::collectFromBlock( ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards, const BasicBlock *Block, const BasicBlock *Pred, @@ -15540,36 +15572,6 @@ void ScalarEvolution::LoopGuards::collectFromBlock( match(LHS, m_scev_APInt(C)) && C->isNonNegative(); }; - // Return a new SCEV that modifies \p Expr to the closest number divides by - // \p Divisor and greater or equal than Expr. For now, only handle constant - // Expr. - auto GetNextSCEVDividesByDivisor = [&](const SCEV *Expr, - const APInt &DivisorVal) { - const APInt *ExprVal; - if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() || - DivisorVal.isNonPositive()) - return Expr; - APInt Rem = ExprVal->urem(DivisorVal); - if (Rem.isZero()) - return Expr; - // return the SCEV: Expr + Divisor - Expr % Divisor - return SE.getConstant(*ExprVal + DivisorVal - Rem); - }; - - // Return a new SCEV that modifies \p Expr to the closest number divides by - // \p Divisor and less or equal than Expr. For now, only handle constant - // Expr. - auto GetPreviousSCEVDividesByDivisor = [&](const SCEV *Expr, - const APInt &DivisorVal) { - const APInt *ExprVal; - if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() || - DivisorVal.isNonPositive()) - return Expr; - APInt Rem = ExprVal->urem(DivisorVal); - // return the SCEV: Expr - Expr % Divisor - return SE.getConstant(*ExprVal - Rem); - }; - // Apply divisibilty by \p Divisor on MinMaxExpr with constant values, // recursively. This is done by aligning up/down the constant value to the // Divisor. @@ -15591,8 +15593,9 @@ void ScalarEvolution::LoopGuards::collectFromBlock( assert(SE.isKnownNonNegative(MinMaxLHS) && "Expected non-negative operand!"); auto *DivisibleExpr = - IsMin ? GetPreviousSCEVDividesByDivisor(MinMaxLHS, DivisorVal) - : GetNextSCEVDividesByDivisor(MinMaxLHS, DivisorVal); + IsMin + ? getPreviousSCEVDivisibleByDivisor(MinMaxLHS, DivisorVal, SE) + : getNextSCEVDivisibleByDivisor(MinMaxLHS, DivisorVal, SE); SmallVector<const SCEV *> Ops = { ApplyDivisibiltyOnMinMaxExpr(MinMaxRHS, Divisor), DivisibleExpr}; return SE.getMinMaxExpr(SCTy, Ops); @@ -15669,21 +15672,21 @@ void ScalarEvolution::LoopGuards::collectFromBlock( [[fallthrough]]; case CmpInst::ICMP_SLT: { RHS = SE.getMinusSCEV(RHS, One); - RHS = GetPreviousSCEVDividesByDivisor(RHS, DividesBy); + RHS = getPreviousSCEVDivisibleByDivisor(RHS, DividesBy, SE); break; } case CmpInst::ICMP_UGT: case CmpInst::ICMP_SGT: RHS = SE.getAddExpr(RHS, One); - RHS = GetNextSCEVDividesByDivisor(RHS, DividesBy); + RHS = getNextSCEVDivisibleByDivisor(RHS, DividesBy, SE); break; case CmpInst::ICMP_ULE: case CmpInst::ICMP_SLE: - RHS = GetPreviousSCEVDividesByDivisor(RHS, DividesBy); + RHS = getPreviousSCEVDivisibleByDivisor(RHS, DividesBy, SE); break; case CmpInst::ICMP_UGE: case CmpInst::ICMP_SGE: - RHS = GetNextSCEVDividesByDivisor(RHS, DividesBy); + RHS = getNextSCEVDivisibleByDivisor(RHS, DividesBy, SE); break; default: break; @@ -15737,22 +15740,29 @@ void ScalarEvolution::LoopGuards::collectFromBlock( case CmpInst::ICMP_NE: if (match(RHS, m_scev_Zero())) { const SCEV *OneAlignedUp = - GetNextSCEVDividesByDivisor(One, DividesBy); + getNextSCEVDivisibleByDivisor(One, DividesBy, SE); To = SE.getUMaxExpr(FromRewritten, OneAlignedUp); } else { + // LHS != RHS can be rewritten as (LHS - RHS) = UMax(1, LHS - RHS), + // but creating the subtraction eagerly is expensive. Track the + // inequalities in a separate map, and materialize the rewrite lazily + // when encountering a suitable subtraction while re-writing. if (LHS->getType()->isPointerTy()) { LHS = SE.getLosslessPtrToIntExpr(LHS); RHS = SE.getLosslessPtrToIntExpr(RHS); if (isa<SCEVCouldNotCompute>(LHS) || isa<SCEVCouldNotCompute>(RHS)) break; } - auto AddSubRewrite = [&](const SCEV *A, const SCEV *B) { - const SCEV *Sub = SE.getMinusSCEV(A, B); - AddRewrite(Sub, Sub, - SE.getUMaxExpr(Sub, SE.getOne(From->getType()))); - }; - AddSubRewrite(LHS, RHS); - AddSubRewrite(RHS, LHS); + const SCEVConstant *C; + const SCEV *A, *B; + if (match(RHS, m_scev_Add(m_SCEVConstant(C), m_SCEV(A))) && + match(LHS, m_scev_Add(m_scev_Specific(C), m_SCEV(B)))) { + RHS = A; + LHS = B; + } + if (LHS > RHS) + std::swap(LHS, RHS); + Guards.NotEqual.insert({LHS, RHS}); continue; } break; @@ -15886,13 +15896,15 @@ const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const { class SCEVLoopGuardRewriter : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> { const DenseMap<const SCEV *, const SCEV *> ⤅ + const SmallDenseSet<std::pair<const SCEV *, const SCEV *>> ≠ SCEV::NoWrapFlags FlagMask = SCEV::FlagAnyWrap; public: SCEVLoopGuardRewriter(ScalarEvolution &SE, const ScalarEvolution::LoopGuards &Guards) - : SCEVRewriteVisitor(SE), Map(Guards.RewriteMap) { + : SCEVRewriteVisitor(SE), Map(Guards.RewriteMap), + NotEqual(Guards.NotEqual) { if (Guards.PreserveNUW) FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNUW); if (Guards.PreserveNSW) @@ -15947,14 +15959,39 @@ const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const { } const SCEV *visitAddExpr(const SCEVAddExpr *Expr) { + // Helper to check if S is a subtraction (A - B) where A != B, and if so, + // return UMax(S, 1). + auto RewriteSubtraction = [&](const SCEV *S) -> const SCEV * { + const SCEV *LHS, *RHS; + if (MatchBinarySub(S, LHS, RHS)) { + if (LHS > RHS) + std::swap(LHS, RHS); + if (NotEqual.contains({LHS, RHS})) { + const SCEV *OneAlignedUp = getNextSCEVDivisibleByDivisor( + SE.getOne(S->getType()), SE.getConstantMultiple(S), SE); + return SE.getUMaxExpr(OneAlignedUp, S); + } + } + return nullptr; + }; + + // Check if Expr itself is a subtraction pattern with guard info. + if (const SCEV *Rewritten = RewriteSubtraction(Expr)) + return Rewritten; + // Trip count expressions sometimes consist of adding 3 operands, i.e. // (Const + A + B). There may be guard info for A + B, and if so, apply // it. // TODO: Could more generally apply guards to Add sub-expressions. if (isa<SCEVConstant>(Expr->getOperand(0)) && Expr->getNumOperands() == 3) { - if (const SCEV *S = Map.lookup( - SE.getAddExpr(Expr->getOperand(1), Expr->getOperand(2)))) + const SCEV *Add = + SE.getAddExpr(Expr->getOperand(1), Expr->getOperand(2)); + if (const SCEV *Rewritten = RewriteSubtraction(Add)) + return SE.getAddExpr( + Expr->getOperand(0), Rewritten, + ScalarEvolution::maskFlags(Expr->getNoWrapFlags(), FlagMask)); + if (const SCEV *S = Map.lookup(Add)) return SE.getAddExpr(Expr->getOperand(0), S); } SmallVector<const SCEV *, 2> Operands; @@ -15989,7 +16026,7 @@ const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const { } }; - if (RewriteMap.empty()) + if (RewriteMap.empty() && NotEqual.empty()) return Expr; SCEVLoopGuardRewriter Rewriter(SE, *this); diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp index 9655c88..0a72076 100644 --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -7695,6 +7695,11 @@ static bool isGuaranteedNotToBeUndefOrPoison( } if (IsWellDefined) return true; + } else if (auto *Splat = isa<ShuffleVectorInst>(Opr) ? getSplatValue(Opr) + : nullptr) { + // For splats we only need to check the value being splatted. + if (OpCheck(Splat)) + return true; } else if (all_of(Opr->operands(), OpCheck)) return true; } |