aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Analysis
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Analysis')
-rwxr-xr-xllvm/lib/Analysis/ConstantFolding.cpp4
-rw-r--r--llvm/lib/Analysis/DependenceAnalysis.cpp294
-rw-r--r--llvm/lib/Analysis/InstructionSimplify.cpp30
-rw-r--r--llvm/lib/Analysis/LazyValueInfo.cpp10
-rw-r--r--llvm/lib/Analysis/MLInlineAdvisor.cpp58
-rw-r--r--llvm/lib/Analysis/MemoryLocation.cpp4
-rw-r--r--llvm/lib/Analysis/ScalarEvolution.cpp133
-rw-r--r--llvm/lib/Analysis/ValueTracking.cpp5
8 files changed, 440 insertions, 98 deletions
diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp
index a5ba197..e9e2e7d 100755
--- a/llvm/lib/Analysis/ConstantFolding.cpp
+++ b/llvm/lib/Analysis/ConstantFolding.cpp
@@ -4056,8 +4056,8 @@ static Constant *ConstantFoldFixedVectorCall(
switch (IntrinsicID) {
case Intrinsic::masked_load: {
auto *SrcPtr = Operands[0];
- auto *Mask = Operands[2];
- auto *Passthru = Operands[3];
+ auto *Mask = Operands[1];
+ auto *Passthru = Operands[2];
Constant *VecData = ConstantFoldLoadFromConstPtr(SrcPtr, FVTy, DL);
diff --git a/llvm/lib/Analysis/DependenceAnalysis.cpp b/llvm/lib/Analysis/DependenceAnalysis.cpp
index 805b682..0a8c2f8 100644
--- a/llvm/lib/Analysis/DependenceAnalysis.cpp
+++ b/llvm/lib/Analysis/DependenceAnalysis.cpp
@@ -128,6 +128,18 @@ static cl::opt<bool> RunSIVRoutinesOnly(
"The purpose is mainly to exclude the influence of those routines "
"in regression tests for SIV routines."));
+// TODO: This flag is disabled by default because it is still under development.
+// Enable it or delete this flag when the feature is ready.
+static cl::opt<bool> EnableMonotonicityCheck(
+ "da-enable-monotonicity-check", cl::init(false), cl::Hidden,
+ cl::desc("Check if the subscripts are monotonic. If it's not, dependence "
+ "is reported as unknown."));
+
+static cl::opt<bool> DumpMonotonicityReport(
+ "da-dump-monotonicity-report", cl::init(false), cl::Hidden,
+ cl::desc(
+ "When printing analysis, dump the results of monotonicity checks."));
+
//===----------------------------------------------------------------------===//
// basics
@@ -177,13 +189,196 @@ void DependenceAnalysisWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
AU.addRequiredTransitive<LoopInfoWrapperPass>();
}
+namespace {
+
+/// The property of monotonicity of a SCEV. To define the monotonicity, assume
+/// a SCEV defined within N-nested loops. Let i_k denote the iteration number
+/// of the k-th loop. Then we can regard the SCEV as an N-ary function:
+///
+/// F(i_1, i_2, ..., i_N)
+///
+/// The domain of i_k is the closed range [0, BTC_k], where BTC_k is the
+/// backedge-taken count of the k-th loop.
+///
+/// A function F is said to be "monotonically increasing with respect to the
+/// k-th loop" if x <= y implies the following condition:
+///
+/// F(i_1, ..., i_{k-1}, x, i_{k+1}, ..., i_N) <=
+/// F(i_1, ..., i_{k-1}, y, i_{k+1}, ..., i_N)
+///
+/// where i_1, ..., i_{k-1}, i_{k+1}, ..., i_N, x, and y are elements of their
+/// respective domains.
+///
+/// Likewise F is "monotonically decreasing with respect to the k-th loop"
+/// if x <= y implies
+///
+/// F(i_1, ..., i_{k-1}, x, i_{k+1}, ..., i_N) >=
+/// F(i_1, ..., i_{k-1}, y, i_{k+1}, ..., i_N)
+///
+/// A function F that is monotonically increasing or decreasing with respect to
+/// the k-th loop is simply called "monotonic with respect to k-th loop".
+///
+/// A function F is said to be "multivariate monotonic" when it is monotonic
+/// with respect to all of the N loops.
+///
+/// Since integer comparison can be either signed or unsigned, we need to
+/// distinguish monotonicity in the signed sense from that in the unsigned
+/// sense. Note that the inequality "x <= y" merely indicates loop progression
+/// and is not affected by the difference between signed and unsigned order.
+///
+/// Currently we only consider monotonicity in a signed sense.
+enum class SCEVMonotonicityType {
+ /// We don't know anything about the monotonicity of the SCEV.
+ Unknown,
+
+ /// The SCEV is loop-invariant with respect to the outermost loop. In other
+ /// words, the function F corresponding to the SCEV is a constant function.
+ Invariant,
+
+ /// The function F corresponding to the SCEV is multivariate monotonic in a
+ /// signed sense. Note that the multivariate monotonic function may also be a
+ /// constant function. The order employed in the definition of monotonicity
+ /// is not strict order.
+ MultivariateSignedMonotonic,
+};
+
+struct SCEVMonotonicity {
+ SCEVMonotonicity(SCEVMonotonicityType Type,
+ const SCEV *FailurePoint = nullptr);
+
+ SCEVMonotonicityType getType() const { return Type; }
+
+ const SCEV *getFailurePoint() const { return FailurePoint; }
+
+ bool isUnknown() const { return Type == SCEVMonotonicityType::Unknown; }
+
+ void print(raw_ostream &OS, unsigned Depth) const;
+
+private:
+ SCEVMonotonicityType Type;
+
+ /// The subexpression that caused Unknown. Mainly for debugging purpose.
+ const SCEV *FailurePoint;
+};
+
+/// Check the monotonicity of a SCEV. Since dependence tests (SIV, MIV, etc.)
+/// assume that subscript expressions are (multivariate) monotonic, we need to
+/// verify this property before applying those tests. Violating this assumption
+/// may cause them to produce incorrect results.
+struct SCEVMonotonicityChecker
+ : public SCEVVisitor<SCEVMonotonicityChecker, SCEVMonotonicity> {
+
+ SCEVMonotonicityChecker(ScalarEvolution *SE) : SE(SE) {}
+
+ /// Check the monotonicity of \p Expr. \p Expr must be integer type. If \p
+ /// OutermostLoop is not null, \p Expr must be defined in \p OutermostLoop or
+ /// one of its nested loops.
+ SCEVMonotonicity checkMonotonicity(const SCEV *Expr,
+ const Loop *OutermostLoop);
+
+private:
+ ScalarEvolution *SE;
+
+ /// The outermost loop that DA is analyzing.
+ const Loop *OutermostLoop;
+
+ /// A helper to classify \p Expr as either Invariant or Unknown.
+ SCEVMonotonicity invariantOrUnknown(const SCEV *Expr);
+
+ /// Return true if \p Expr is loop-invariant with respect to the outermost
+ /// loop.
+ bool isLoopInvariant(const SCEV *Expr) const;
+
+ /// A helper to create an Unknown SCEVMonotonicity.
+ SCEVMonotonicity createUnknown(const SCEV *FailurePoint) {
+ return SCEVMonotonicity(SCEVMonotonicityType::Unknown, FailurePoint);
+ }
+
+ SCEVMonotonicity visitAddRecExpr(const SCEVAddRecExpr *Expr);
+
+ SCEVMonotonicity visitConstant(const SCEVConstant *) {
+ return SCEVMonotonicity(SCEVMonotonicityType::Invariant);
+ }
+ SCEVMonotonicity visitVScale(const SCEVVScale *) {
+ return SCEVMonotonicity(SCEVMonotonicityType::Invariant);
+ }
+
+ // TODO: Handle more cases.
+ SCEVMonotonicity visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
+ return invariantOrUnknown(Expr);
+ }
+ SCEVMonotonicity visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
+ return invariantOrUnknown(Expr);
+ }
+ SCEVMonotonicity visitAddExpr(const SCEVAddExpr *Expr) {
+ return invariantOrUnknown(Expr);
+ }
+ SCEVMonotonicity visitMulExpr(const SCEVMulExpr *Expr) {
+ return invariantOrUnknown(Expr);
+ }
+ SCEVMonotonicity visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) {
+ return invariantOrUnknown(Expr);
+ }
+ SCEVMonotonicity visitTruncateExpr(const SCEVTruncateExpr *Expr) {
+ return invariantOrUnknown(Expr);
+ }
+ SCEVMonotonicity visitUDivExpr(const SCEVUDivExpr *Expr) {
+ return invariantOrUnknown(Expr);
+ }
+ SCEVMonotonicity visitSMaxExpr(const SCEVSMaxExpr *Expr) {
+ return invariantOrUnknown(Expr);
+ }
+ SCEVMonotonicity visitUMaxExpr(const SCEVUMaxExpr *Expr) {
+ return invariantOrUnknown(Expr);
+ }
+ SCEVMonotonicity visitSMinExpr(const SCEVSMinExpr *Expr) {
+ return invariantOrUnknown(Expr);
+ }
+ SCEVMonotonicity visitUMinExpr(const SCEVUMinExpr *Expr) {
+ return invariantOrUnknown(Expr);
+ }
+ SCEVMonotonicity visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) {
+ return invariantOrUnknown(Expr);
+ }
+ SCEVMonotonicity visitUnknown(const SCEVUnknown *Expr) {
+ return invariantOrUnknown(Expr);
+ }
+ SCEVMonotonicity visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
+ return invariantOrUnknown(Expr);
+ }
+
+ friend struct SCEVVisitor<SCEVMonotonicityChecker, SCEVMonotonicity>;
+};
+
+} // anonymous namespace
+
// Used to test the dependence analyzer.
// Looks through the function, noting instructions that may access memory.
// Calls depends() on every possible pair and prints out the result.
// Ignores all other instructions.
static void dumpExampleDependence(raw_ostream &OS, DependenceInfo *DA,
- ScalarEvolution &SE, bool NormalizeResults) {
+ ScalarEvolution &SE, LoopInfo &LI,
+ bool NormalizeResults) {
auto *F = DA->getFunction();
+
+ if (DumpMonotonicityReport) {
+ SCEVMonotonicityChecker Checker(&SE);
+ OS << "Monotonicity check:\n";
+ for (Instruction &Inst : instructions(F)) {
+ if (!isa<LoadInst>(Inst) && !isa<StoreInst>(Inst))
+ continue;
+ Value *Ptr = getLoadStorePointerOperand(&Inst);
+ const Loop *L = LI.getLoopFor(Inst.getParent());
+ const SCEV *PtrSCEV = SE.getSCEVAtScope(Ptr, L);
+ const SCEV *AccessFn = SE.removePointerBase(PtrSCEV);
+ SCEVMonotonicity Mon = Checker.checkMonotonicity(AccessFn, L);
+ OS.indent(2) << "Inst: " << Inst << "\n";
+ OS.indent(4) << "Expr: " << *AccessFn << "\n";
+ Mon.print(OS, 4);
+ }
+ OS << "\n";
+ }
+
for (inst_iterator SrcI = inst_begin(F), SrcE = inst_end(F); SrcI != SrcE;
++SrcI) {
if (SrcI->mayReadOrWriteMemory()) {
@@ -235,7 +430,8 @@ static void dumpExampleDependence(raw_ostream &OS, DependenceInfo *DA,
void DependenceAnalysisWrapperPass::print(raw_ostream &OS,
const Module *) const {
dumpExampleDependence(
- OS, info.get(), getAnalysis<ScalarEvolutionWrapperPass>().getSE(), false);
+ OS, info.get(), getAnalysis<ScalarEvolutionWrapperPass>().getSE(),
+ getAnalysis<LoopInfoWrapperPass>().getLoopInfo(), false);
}
PreservedAnalyses
@@ -244,7 +440,7 @@ DependenceAnalysisPrinterPass::run(Function &F, FunctionAnalysisManager &FAM) {
<< "':\n";
dumpExampleDependence(OS, &FAM.getResult<DependenceAnalysis>(F),
FAM.getResult<ScalarEvolutionAnalysis>(F),
- NormalizeResults);
+ FAM.getResult<LoopAnalysis>(F), NormalizeResults);
return PreservedAnalyses::all();
}
@@ -671,6 +867,81 @@ bool DependenceInfo::intersectConstraints(Constraint *X, const Constraint *Y) {
}
//===----------------------------------------------------------------------===//
+// SCEVMonotonicity
+
+SCEVMonotonicity::SCEVMonotonicity(SCEVMonotonicityType Type,
+ const SCEV *FailurePoint)
+ : Type(Type), FailurePoint(FailurePoint) {
+ assert(
+ ((Type == SCEVMonotonicityType::Unknown) == (FailurePoint != nullptr)) &&
+ "FailurePoint must be provided iff Type is Unknown");
+}
+
+void SCEVMonotonicity::print(raw_ostream &OS, unsigned Depth) const {
+ OS.indent(Depth) << "Monotonicity: ";
+ switch (Type) {
+ case SCEVMonotonicityType::Unknown:
+ assert(FailurePoint && "FailurePoint must be provided for Unknown");
+ OS << "Unknown\n";
+ OS.indent(Depth) << "Reason: " << *FailurePoint << "\n";
+ break;
+ case SCEVMonotonicityType::Invariant:
+ OS << "Invariant\n";
+ break;
+ case SCEVMonotonicityType::MultivariateSignedMonotonic:
+ OS << "MultivariateSignedMonotonic\n";
+ break;
+ }
+}
+
+bool SCEVMonotonicityChecker::isLoopInvariant(const SCEV *Expr) const {
+ return !OutermostLoop || SE->isLoopInvariant(Expr, OutermostLoop);
+}
+
+SCEVMonotonicity SCEVMonotonicityChecker::invariantOrUnknown(const SCEV *Expr) {
+ if (isLoopInvariant(Expr))
+ return SCEVMonotonicity(SCEVMonotonicityType::Invariant);
+ return createUnknown(Expr);
+}
+
+SCEVMonotonicity
+SCEVMonotonicityChecker::checkMonotonicity(const SCEV *Expr,
+ const Loop *OutermostLoop) {
+ assert(Expr->getType()->isIntegerTy() && "Expr must be integer type");
+ this->OutermostLoop = OutermostLoop;
+ return visit(Expr);
+}
+
+/// We only care about an affine AddRec at the moment. For an affine AddRec,
+/// the monotonicity can be inferred from its nowrap property. For example, let
+/// X and Y be loop-invariant, and assume Y is non-negative. An AddRec
+/// {X,+.Y}<nsw> implies:
+///
+/// X <=s (X + Y) <=s ((X + Y) + Y) <=s ...
+///
+/// Thus, we can conclude that the AddRec is monotonically increasing with
+/// respect to the associated loop in a signed sense. The similar reasoning
+/// applies when Y is non-positive, leading to a monotonically decreasing
+/// AddRec.
+SCEVMonotonicity
+SCEVMonotonicityChecker::visitAddRecExpr(const SCEVAddRecExpr *Expr) {
+ if (!Expr->isAffine() || !Expr->hasNoSignedWrap())
+ return createUnknown(Expr);
+
+ const SCEV *Start = Expr->getStart();
+ const SCEV *Step = Expr->getStepRecurrence(*SE);
+
+ SCEVMonotonicity StartMon = visit(Start);
+ if (StartMon.isUnknown())
+ return StartMon;
+
+ if (!isLoopInvariant(Step))
+ return createUnknown(Expr);
+
+ return SCEVMonotonicity(SCEVMonotonicityType::MultivariateSignedMonotonic);
+}
+
+//===----------------------------------------------------------------------===//
// DependenceInfo methods
// For debugging purposes. Dumps a dependence to OS.
@@ -3488,10 +3759,19 @@ bool DependenceInfo::tryDelinearize(Instruction *Src, Instruction *Dst,
// resize Pair to contain as many pairs of subscripts as the delinearization
// has found, and then initialize the pairs following the delinearization.
Pair.resize(Size);
+ SCEVMonotonicityChecker MonChecker(SE);
+ const Loop *OutermostLoop = SrcLoop ? SrcLoop->getOutermostLoop() : nullptr;
for (int I = 0; I < Size; ++I) {
Pair[I].Src = SrcSubscripts[I];
Pair[I].Dst = DstSubscripts[I];
unifySubscriptType(&Pair[I]);
+
+ if (EnableMonotonicityCheck) {
+ if (MonChecker.checkMonotonicity(Pair[I].Src, OutermostLoop).isUnknown())
+ return false;
+ if (MonChecker.checkMonotonicity(Pair[I].Dst, OutermostLoop).isUnknown())
+ return false;
+ }
}
return true;
@@ -3824,6 +4104,14 @@ DependenceInfo::depends(Instruction *Src, Instruction *Dst,
Pair[0].Src = SrcEv;
Pair[0].Dst = DstEv;
+ SCEVMonotonicityChecker MonChecker(SE);
+ const Loop *OutermostLoop = SrcLoop ? SrcLoop->getOutermostLoop() : nullptr;
+ if (EnableMonotonicityCheck)
+ if (MonChecker.checkMonotonicity(Pair[0].Src, OutermostLoop).isUnknown() ||
+ MonChecker.checkMonotonicity(Pair[0].Dst, OutermostLoop).isUnknown())
+ return std::make_unique<Dependence>(Src, Dst,
+ SCEVUnionPredicate(Assume, *SE));
+
if (Delinearize) {
if (tryDelinearize(Src, Dst, Pair)) {
LLVM_DEBUG(dbgs() << " delinearized\n");
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index e08ef60..8da51d0 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -5106,32 +5106,33 @@ static Value *simplifyGEPInst(Type *SrcTy, Value *Ptr,
return Ptr;
// The following transforms are only safe if the ptrtoint cast
- // doesn't truncate the pointers.
- if (Indices[0]->getType()->getScalarSizeInBits() ==
- Q.DL.getPointerSizeInBits(AS)) {
+ // doesn't truncate the address of the pointers. The non-address bits
+ // must be the same, as the underlying objects are the same.
+ if (Indices[0]->getType()->getScalarSizeInBits() >=
+ Q.DL.getAddressSizeInBits(AS)) {
auto CanSimplify = [GEPTy, &P, Ptr]() -> bool {
return P->getType() == GEPTy &&
getUnderlyingObject(P) == getUnderlyingObject(Ptr);
};
// getelementptr V, (sub P, V) -> P if P points to a type of size 1.
if (TyAllocSize == 1 &&
- match(Indices[0],
- m_Sub(m_PtrToInt(m_Value(P)), m_PtrToInt(m_Specific(Ptr)))) &&
+ match(Indices[0], m_Sub(m_PtrToIntOrAddr(m_Value(P)),
+ m_PtrToIntOrAddr(m_Specific(Ptr)))) &&
CanSimplify())
return P;
// getelementptr V, (ashr (sub P, V), C) -> P if P points to a type of
// size 1 << C.
- if (match(Indices[0], m_AShr(m_Sub(m_PtrToInt(m_Value(P)),
- m_PtrToInt(m_Specific(Ptr))),
+ if (match(Indices[0], m_AShr(m_Sub(m_PtrToIntOrAddr(m_Value(P)),
+ m_PtrToIntOrAddr(m_Specific(Ptr))),
m_ConstantInt(C))) &&
TyAllocSize == 1ULL << C && CanSimplify())
return P;
// getelementptr V, (sdiv (sub P, V), C) -> P if P points to a type of
// size C.
- if (match(Indices[0], m_SDiv(m_Sub(m_PtrToInt(m_Value(P)),
- m_PtrToInt(m_Specific(Ptr))),
+ if (match(Indices[0], m_SDiv(m_Sub(m_PtrToIntOrAddr(m_Value(P)),
+ m_PtrToIntOrAddr(m_Specific(Ptr))),
m_SpecificInt(TyAllocSize))) &&
CanSimplify())
return P;
@@ -5440,9 +5441,10 @@ static Value *simplifyCastInst(unsigned CastOpc, Value *Op, Type *Ty,
// ptrtoint (ptradd (Ptr, X - ptrtoint(Ptr))) -> X
Value *Ptr, *X;
- if (CastOpc == Instruction::PtrToInt &&
- match(Op, m_PtrAdd(m_Value(Ptr),
- m_Sub(m_Value(X), m_PtrToInt(m_Deferred(Ptr))))) &&
+ if ((CastOpc == Instruction::PtrToInt || CastOpc == Instruction::PtrToAddr) &&
+ match(Op,
+ m_PtrAdd(m_Value(Ptr),
+ m_Sub(m_Value(X), m_PtrToIntOrAddr(m_Deferred(Ptr))))) &&
X->getType() == Ty && Ty == Q.DL.getIndexType(Ptr->getType()))
return X;
@@ -6987,8 +6989,8 @@ static Value *simplifyIntrinsic(CallBase *Call, Value *Callee,
switch (IID) {
case Intrinsic::masked_load:
case Intrinsic::masked_gather: {
- Value *MaskArg = Args[2];
- Value *PassthruArg = Args[3];
+ Value *MaskArg = Args[1];
+ Value *PassthruArg = Args[2];
// If the mask is all zeros or undef, the "passthru" argument is the result.
if (maskIsAllZeroOrUndef(MaskArg))
return PassthruArg;
diff --git a/llvm/lib/Analysis/LazyValueInfo.cpp b/llvm/lib/Analysis/LazyValueInfo.cpp
index 0e5bc48..df75999 100644
--- a/llvm/lib/Analysis/LazyValueInfo.cpp
+++ b/llvm/lib/Analysis/LazyValueInfo.cpp
@@ -947,9 +947,8 @@ LazyValueInfoImpl::solveBlockValueSelect(SelectInst *SI, BasicBlock *BB) {
/*UseBlockValue*/ false));
}
- ValueLatticeElement Result = TrueVal;
- Result.mergeIn(FalseVal);
- return Result;
+ TrueVal.mergeIn(FalseVal);
+ return TrueVal;
}
std::optional<ConstantRange>
@@ -1778,9 +1777,8 @@ ValueLatticeElement LazyValueInfoImpl::getValueInBlock(Value *V, BasicBlock *BB,
assert(OptResult && "Value not available after solving");
}
- ValueLatticeElement Result = *OptResult;
- LLVM_DEBUG(dbgs() << " Result = " << Result << "\n");
- return Result;
+ LLVM_DEBUG(dbgs() << " Result = " << *OptResult << "\n");
+ return *OptResult;
}
ValueLatticeElement LazyValueInfoImpl::getValueAt(Value *V, Instruction *CxtI) {
diff --git a/llvm/lib/Analysis/MLInlineAdvisor.cpp b/llvm/lib/Analysis/MLInlineAdvisor.cpp
index 1d1a5560..9a5ae2a 100644
--- a/llvm/lib/Analysis/MLInlineAdvisor.cpp
+++ b/llvm/lib/Analysis/MLInlineAdvisor.cpp
@@ -324,32 +324,44 @@ void MLInlineAdvisor::onSuccessfulInlining(const MLInlineAdvice &Advice,
FAM.invalidate(*Caller, PA);
}
Advice.updateCachedCallerFPI(FAM);
- int64_t IRSizeAfter =
- getIRSize(*Caller) + (CalleeWasDeleted ? 0 : Advice.CalleeIRSize);
- CurrentIRSize += IRSizeAfter - (Advice.CallerIRSize + Advice.CalleeIRSize);
+ if (Caller == Callee) {
+ assert(!CalleeWasDeleted);
+ // We double-counted CallerAndCalleeEdges - since the caller and callee
+ // would be the same
+ assert(Advice.CallerAndCalleeEdges % 2 == 0);
+ CurrentIRSize += getIRSize(*Caller) - Advice.CallerIRSize;
+ EdgeCount += getCachedFPI(*Caller).DirectCallsToDefinedFunctions -
+ Advice.CallerAndCalleeEdges / 2;
+ // The NodeCount would stay the same.
+ } else {
+ int64_t IRSizeAfter =
+ getIRSize(*Caller) + (CalleeWasDeleted ? 0 : Advice.CalleeIRSize);
+ CurrentIRSize += IRSizeAfter - (Advice.CallerIRSize + Advice.CalleeIRSize);
+
+ // We can delta-update module-wide features. We know the inlining only
+ // changed the caller, and maybe the callee (by deleting the latter). Nodes
+ // are simple to update. For edges, we 'forget' the edges that the caller
+ // and callee used to have before inlining, and add back what they currently
+ // have together.
+ int64_t NewCallerAndCalleeEdges =
+ getCachedFPI(*Caller).DirectCallsToDefinedFunctions;
+
+ // A dead function's node is not actually removed from the call graph until
+ // the end of the call graph walk, but the node no longer belongs to any
+ // valid SCC.
+ if (CalleeWasDeleted) {
+ --NodeCount;
+ NodesInLastSCC.erase(CG.lookup(*Callee));
+ DeadFunctions.insert(Callee);
+ } else {
+ NewCallerAndCalleeEdges +=
+ getCachedFPI(*Callee).DirectCallsToDefinedFunctions;
+ }
+ EdgeCount += (NewCallerAndCalleeEdges - Advice.CallerAndCalleeEdges);
+ }
if (CurrentIRSize > SizeIncreaseThreshold * InitialIRSize)
ForceStop = true;
- // We can delta-update module-wide features. We know the inlining only changed
- // the caller, and maybe the callee (by deleting the latter).
- // Nodes are simple to update.
- // For edges, we 'forget' the edges that the caller and callee used to have
- // before inlining, and add back what they currently have together.
- int64_t NewCallerAndCalleeEdges =
- getCachedFPI(*Caller).DirectCallsToDefinedFunctions;
-
- // A dead function's node is not actually removed from the call graph until
- // the end of the call graph walk, but the node no longer belongs to any valid
- // SCC.
- if (CalleeWasDeleted) {
- --NodeCount;
- NodesInLastSCC.erase(CG.lookup(*Callee));
- DeadFunctions.insert(Callee);
- } else {
- NewCallerAndCalleeEdges +=
- getCachedFPI(*Callee).DirectCallsToDefinedFunctions;
- }
- EdgeCount += (NewCallerAndCalleeEdges - Advice.CallerAndCalleeEdges);
assert(CurrentIRSize >= 0 && EdgeCount >= 0 && NodeCount >= 0);
}
diff --git a/llvm/lib/Analysis/MemoryLocation.cpp b/llvm/lib/Analysis/MemoryLocation.cpp
index dcc5117..1c5f08e 100644
--- a/llvm/lib/Analysis/MemoryLocation.cpp
+++ b/llvm/lib/Analysis/MemoryLocation.cpp
@@ -245,7 +245,7 @@ MemoryLocation MemoryLocation::getForArgument(const CallBase *Call,
assert(ArgIdx == 0 && "Invalid argument index");
auto *Ty = cast<VectorType>(II->getType());
- if (auto KnownType = getKnownTypeFromMaskedOp(II->getOperand(2), Ty))
+ if (auto KnownType = getKnownTypeFromMaskedOp(II->getOperand(1), Ty))
return MemoryLocation(Arg, DL.getTypeStoreSize(*KnownType), AATags);
return MemoryLocation(
@@ -255,7 +255,7 @@ MemoryLocation MemoryLocation::getForArgument(const CallBase *Call,
assert(ArgIdx == 1 && "Invalid argument index");
auto *Ty = cast<VectorType>(II->getArgOperand(0)->getType());
- if (auto KnownType = getKnownTypeFromMaskedOp(II->getOperand(3), Ty))
+ if (auto KnownType = getKnownTypeFromMaskedOp(II->getOperand(2), Ty))
return MemoryLocation(Arg, DL.getTypeStoreSize(*KnownType), AATags);
return MemoryLocation(
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index e06b095..6f7dd79 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -15473,6 +15473,38 @@ void ScalarEvolution::LoopGuards::collectFromPHI(
}
}
+// Return a new SCEV that modifies \p Expr to the closest number divides by
+// \p Divisor and less or equal than Expr. For now, only handle constant
+// Expr.
+static const SCEV *getPreviousSCEVDivisibleByDivisor(const SCEV *Expr,
+ const APInt &DivisorVal,
+ ScalarEvolution &SE) {
+ const APInt *ExprVal;
+ if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() ||
+ DivisorVal.isNonPositive())
+ return Expr;
+ APInt Rem = ExprVal->urem(DivisorVal);
+ // return the SCEV: Expr - Expr % Divisor
+ return SE.getConstant(*ExprVal - Rem);
+}
+
+// Return a new SCEV that modifies \p Expr to the closest number divides by
+// \p Divisor and greater or equal than Expr. For now, only handle constant
+// Expr.
+static const SCEV *getNextSCEVDivisibleByDivisor(const SCEV *Expr,
+ const APInt &DivisorVal,
+ ScalarEvolution &SE) {
+ const APInt *ExprVal;
+ if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() ||
+ DivisorVal.isNonPositive())
+ return Expr;
+ APInt Rem = ExprVal->urem(DivisorVal);
+ if (Rem.isZero())
+ return Expr;
+ // return the SCEV: Expr + Divisor - Expr % Divisor
+ return SE.getConstant(*ExprVal + DivisorVal - Rem);
+}
+
void ScalarEvolution::LoopGuards::collectFromBlock(
ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
const BasicBlock *Block, const BasicBlock *Pred,
@@ -15540,36 +15572,6 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
match(LHS, m_scev_APInt(C)) && C->isNonNegative();
};
- // Return a new SCEV that modifies \p Expr to the closest number divides by
- // \p Divisor and greater or equal than Expr. For now, only handle constant
- // Expr.
- auto GetNextSCEVDividesByDivisor = [&](const SCEV *Expr,
- const APInt &DivisorVal) {
- const APInt *ExprVal;
- if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() ||
- DivisorVal.isNonPositive())
- return Expr;
- APInt Rem = ExprVal->urem(DivisorVal);
- if (Rem.isZero())
- return Expr;
- // return the SCEV: Expr + Divisor - Expr % Divisor
- return SE.getConstant(*ExprVal + DivisorVal - Rem);
- };
-
- // Return a new SCEV that modifies \p Expr to the closest number divides by
- // \p Divisor and less or equal than Expr. For now, only handle constant
- // Expr.
- auto GetPreviousSCEVDividesByDivisor = [&](const SCEV *Expr,
- const APInt &DivisorVal) {
- const APInt *ExprVal;
- if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() ||
- DivisorVal.isNonPositive())
- return Expr;
- APInt Rem = ExprVal->urem(DivisorVal);
- // return the SCEV: Expr - Expr % Divisor
- return SE.getConstant(*ExprVal - Rem);
- };
-
// Apply divisibilty by \p Divisor on MinMaxExpr with constant values,
// recursively. This is done by aligning up/down the constant value to the
// Divisor.
@@ -15591,8 +15593,9 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
assert(SE.isKnownNonNegative(MinMaxLHS) &&
"Expected non-negative operand!");
auto *DivisibleExpr =
- IsMin ? GetPreviousSCEVDividesByDivisor(MinMaxLHS, DivisorVal)
- : GetNextSCEVDividesByDivisor(MinMaxLHS, DivisorVal);
+ IsMin
+ ? getPreviousSCEVDivisibleByDivisor(MinMaxLHS, DivisorVal, SE)
+ : getNextSCEVDivisibleByDivisor(MinMaxLHS, DivisorVal, SE);
SmallVector<const SCEV *> Ops = {
ApplyDivisibiltyOnMinMaxExpr(MinMaxRHS, Divisor), DivisibleExpr};
return SE.getMinMaxExpr(SCTy, Ops);
@@ -15669,21 +15672,21 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
[[fallthrough]];
case CmpInst::ICMP_SLT: {
RHS = SE.getMinusSCEV(RHS, One);
- RHS = GetPreviousSCEVDividesByDivisor(RHS, DividesBy);
+ RHS = getPreviousSCEVDivisibleByDivisor(RHS, DividesBy, SE);
break;
}
case CmpInst::ICMP_UGT:
case CmpInst::ICMP_SGT:
RHS = SE.getAddExpr(RHS, One);
- RHS = GetNextSCEVDividesByDivisor(RHS, DividesBy);
+ RHS = getNextSCEVDivisibleByDivisor(RHS, DividesBy, SE);
break;
case CmpInst::ICMP_ULE:
case CmpInst::ICMP_SLE:
- RHS = GetPreviousSCEVDividesByDivisor(RHS, DividesBy);
+ RHS = getPreviousSCEVDivisibleByDivisor(RHS, DividesBy, SE);
break;
case CmpInst::ICMP_UGE:
case CmpInst::ICMP_SGE:
- RHS = GetNextSCEVDividesByDivisor(RHS, DividesBy);
+ RHS = getNextSCEVDivisibleByDivisor(RHS, DividesBy, SE);
break;
default:
break;
@@ -15737,22 +15740,29 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
case CmpInst::ICMP_NE:
if (match(RHS, m_scev_Zero())) {
const SCEV *OneAlignedUp =
- GetNextSCEVDividesByDivisor(One, DividesBy);
+ getNextSCEVDivisibleByDivisor(One, DividesBy, SE);
To = SE.getUMaxExpr(FromRewritten, OneAlignedUp);
} else {
+ // LHS != RHS can be rewritten as (LHS - RHS) = UMax(1, LHS - RHS),
+ // but creating the subtraction eagerly is expensive. Track the
+ // inequalities in a separate map, and materialize the rewrite lazily
+ // when encountering a suitable subtraction while re-writing.
if (LHS->getType()->isPointerTy()) {
LHS = SE.getLosslessPtrToIntExpr(LHS);
RHS = SE.getLosslessPtrToIntExpr(RHS);
if (isa<SCEVCouldNotCompute>(LHS) || isa<SCEVCouldNotCompute>(RHS))
break;
}
- auto AddSubRewrite = [&](const SCEV *A, const SCEV *B) {
- const SCEV *Sub = SE.getMinusSCEV(A, B);
- AddRewrite(Sub, Sub,
- SE.getUMaxExpr(Sub, SE.getOne(From->getType())));
- };
- AddSubRewrite(LHS, RHS);
- AddSubRewrite(RHS, LHS);
+ const SCEVConstant *C;
+ const SCEV *A, *B;
+ if (match(RHS, m_scev_Add(m_SCEVConstant(C), m_SCEV(A))) &&
+ match(LHS, m_scev_Add(m_scev_Specific(C), m_SCEV(B)))) {
+ RHS = A;
+ LHS = B;
+ }
+ if (LHS > RHS)
+ std::swap(LHS, RHS);
+ Guards.NotEqual.insert({LHS, RHS});
continue;
}
break;
@@ -15886,13 +15896,15 @@ const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const {
class SCEVLoopGuardRewriter
: public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
const DenseMap<const SCEV *, const SCEV *> &Map;
+ const SmallDenseSet<std::pair<const SCEV *, const SCEV *>> &NotEqual;
SCEV::NoWrapFlags FlagMask = SCEV::FlagAnyWrap;
public:
SCEVLoopGuardRewriter(ScalarEvolution &SE,
const ScalarEvolution::LoopGuards &Guards)
- : SCEVRewriteVisitor(SE), Map(Guards.RewriteMap) {
+ : SCEVRewriteVisitor(SE), Map(Guards.RewriteMap),
+ NotEqual(Guards.NotEqual) {
if (Guards.PreserveNUW)
FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNUW);
if (Guards.PreserveNSW)
@@ -15947,14 +15959,39 @@ const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const {
}
const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
+ // Helper to check if S is a subtraction (A - B) where A != B, and if so,
+ // return UMax(S, 1).
+ auto RewriteSubtraction = [&](const SCEV *S) -> const SCEV * {
+ const SCEV *LHS, *RHS;
+ if (MatchBinarySub(S, LHS, RHS)) {
+ if (LHS > RHS)
+ std::swap(LHS, RHS);
+ if (NotEqual.contains({LHS, RHS})) {
+ const SCEV *OneAlignedUp = getNextSCEVDivisibleByDivisor(
+ SE.getOne(S->getType()), SE.getConstantMultiple(S), SE);
+ return SE.getUMaxExpr(OneAlignedUp, S);
+ }
+ }
+ return nullptr;
+ };
+
+ // Check if Expr itself is a subtraction pattern with guard info.
+ if (const SCEV *Rewritten = RewriteSubtraction(Expr))
+ return Rewritten;
+
// Trip count expressions sometimes consist of adding 3 operands, i.e.
// (Const + A + B). There may be guard info for A + B, and if so, apply
// it.
// TODO: Could more generally apply guards to Add sub-expressions.
if (isa<SCEVConstant>(Expr->getOperand(0)) &&
Expr->getNumOperands() == 3) {
- if (const SCEV *S = Map.lookup(
- SE.getAddExpr(Expr->getOperand(1), Expr->getOperand(2))))
+ const SCEV *Add =
+ SE.getAddExpr(Expr->getOperand(1), Expr->getOperand(2));
+ if (const SCEV *Rewritten = RewriteSubtraction(Add))
+ return SE.getAddExpr(
+ Expr->getOperand(0), Rewritten,
+ ScalarEvolution::maskFlags(Expr->getNoWrapFlags(), FlagMask));
+ if (const SCEV *S = Map.lookup(Add))
return SE.getAddExpr(Expr->getOperand(0), S);
}
SmallVector<const SCEV *, 2> Operands;
@@ -15989,7 +16026,7 @@ const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const {
}
};
- if (RewriteMap.empty())
+ if (RewriteMap.empty() && NotEqual.empty())
return Expr;
SCEVLoopGuardRewriter Rewriter(SE, *this);
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 9655c88..0a72076 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -7695,6 +7695,11 @@ static bool isGuaranteedNotToBeUndefOrPoison(
}
if (IsWellDefined)
return true;
+ } else if (auto *Splat = isa<ShuffleVectorInst>(Opr) ? getSplatValue(Opr)
+ : nullptr) {
+ // For splats we only need to check the value being splatted.
+ if (OpCheck(Splat))
+ return true;
} else if (all_of(Opr->operands(), OpCheck))
return true;
}