aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Analysis
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Analysis')
-rwxr-xr-xllvm/lib/Analysis/ConstantFolding.cpp60
-rw-r--r--llvm/lib/Analysis/DXILResource.cpp47
-rw-r--r--llvm/lib/Analysis/InstructionSimplify.cpp11
-rw-r--r--llvm/lib/Analysis/LazyValueInfo.cpp10
-rw-r--r--llvm/lib/Analysis/LoopInfo.cpp4
-rw-r--r--llvm/lib/Analysis/MLInlineAdvisor.cpp76
-rw-r--r--llvm/lib/Analysis/MemoryLocation.cpp4
-rw-r--r--llvm/lib/Analysis/ScalarEvolution.cpp273
-rw-r--r--llvm/lib/Analysis/StaticDataProfileInfo.cpp70
-rw-r--r--llvm/lib/Analysis/ValueTracking.cpp5
10 files changed, 346 insertions, 214 deletions
diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp
index 45c889c..e9e2e7d 100755
--- a/llvm/lib/Analysis/ConstantFolding.cpp
+++ b/llvm/lib/Analysis/ConstantFolding.cpp
@@ -2177,16 +2177,13 @@ Constant *constantFoldVectorReduce(Intrinsic::ID IID, Constant *Op) {
return PoisonValue::get(VT->getElementType());
// TODO: Handle undef.
- if (!isa<ConstantVector>(Op) && !isa<ConstantDataVector>(Op))
- return nullptr;
-
- auto *EltC = dyn_cast<ConstantInt>(Op->getAggregateElement(0U));
+ auto *EltC = dyn_cast_or_null<ConstantInt>(Op->getAggregateElement(0U));
if (!EltC)
return nullptr;
APInt Acc = EltC->getValue();
for (unsigned I = 1, E = VT->getNumElements(); I != E; I++) {
- if (!(EltC = dyn_cast<ConstantInt>(Op->getAggregateElement(I))))
+ if (!(EltC = dyn_cast_or_null<ConstantInt>(Op->getAggregateElement(I))))
return nullptr;
const APInt &X = EltC->getValue();
switch (IID) {
@@ -3059,35 +3056,25 @@ static Constant *ConstantFoldScalarCall1(StringRef Name,
Val = Val | Val << 1;
return ConstantInt::get(Ty, Val);
}
-
- default:
- return nullptr;
}
}
- switch (IntrinsicID) {
- default: break;
- case Intrinsic::vector_reduce_add:
- case Intrinsic::vector_reduce_mul:
- case Intrinsic::vector_reduce_and:
- case Intrinsic::vector_reduce_or:
- case Intrinsic::vector_reduce_xor:
- case Intrinsic::vector_reduce_smin:
- case Intrinsic::vector_reduce_smax:
- case Intrinsic::vector_reduce_umin:
- case Intrinsic::vector_reduce_umax:
- if (Constant *C = constantFoldVectorReduce(IntrinsicID, Operands[0]))
- return C;
- break;
- }
-
- // Support ConstantVector in case we have an Undef in the top.
- if (isa<ConstantVector>(Operands[0]) ||
- isa<ConstantDataVector>(Operands[0]) ||
- isa<ConstantAggregateZero>(Operands[0])) {
+ if (Operands[0]->getType()->isVectorTy()) {
auto *Op = cast<Constant>(Operands[0]);
switch (IntrinsicID) {
default: break;
+ case Intrinsic::vector_reduce_add:
+ case Intrinsic::vector_reduce_mul:
+ case Intrinsic::vector_reduce_and:
+ case Intrinsic::vector_reduce_or:
+ case Intrinsic::vector_reduce_xor:
+ case Intrinsic::vector_reduce_smin:
+ case Intrinsic::vector_reduce_smax:
+ case Intrinsic::vector_reduce_umin:
+ case Intrinsic::vector_reduce_umax:
+ if (Constant *C = constantFoldVectorReduce(IntrinsicID, Operands[0]))
+ return C;
+ break;
case Intrinsic::x86_sse_cvtss2si:
case Intrinsic::x86_sse_cvtss2si64:
case Intrinsic::x86_sse2_cvtsd2si:
@@ -3116,10 +3103,15 @@ static Constant *ConstantFoldScalarCall1(StringRef Name,
case Intrinsic::wasm_alltrue:
// Check each element individually
unsigned E = cast<FixedVectorType>(Op->getType())->getNumElements();
- for (unsigned I = 0; I != E; ++I)
- if (Constant *Elt = Op->getAggregateElement(I))
- if (Elt->isZeroValue())
- return ConstantInt::get(Ty, 0);
+ for (unsigned I = 0; I != E; ++I) {
+ Constant *Elt = Op->getAggregateElement(I);
+ // Return false as soon as we find a non-true element.
+ if (Elt && Elt->isZeroValue())
+ return ConstantInt::get(Ty, 0);
+ // Bail as soon as we find an element we cannot prove to be true.
+ if (!Elt || !isa<ConstantInt>(Elt))
+ return nullptr;
+ }
return ConstantInt::get(Ty, 1);
}
@@ -4064,8 +4056,8 @@ static Constant *ConstantFoldFixedVectorCall(
switch (IntrinsicID) {
case Intrinsic::masked_load: {
auto *SrcPtr = Operands[0];
- auto *Mask = Operands[2];
- auto *Passthru = Operands[3];
+ auto *Mask = Operands[1];
+ auto *Passthru = Operands[2];
Constant *VecData = ConstantFoldLoadFromConstPtr(SrcPtr, FVTy, DL);
diff --git a/llvm/lib/Analysis/DXILResource.cpp b/llvm/lib/Analysis/DXILResource.cpp
index b78cc03e..f9bf092 100644
--- a/llvm/lib/Analysis/DXILResource.cpp
+++ b/llvm/lib/Analysis/DXILResource.cpp
@@ -281,6 +281,38 @@ static StructType *getOrCreateElementStruct(Type *ElemType, StringRef Name) {
return StructType::create(ElemType, Name);
}
+static Type *getTypeWithoutPadding(Type *Ty) {
+ // Recursively remove padding from structures.
+ if (auto *ST = dyn_cast<StructType>(Ty)) {
+ LLVMContext &Ctx = Ty->getContext();
+ SmallVector<Type *> ElementTypes;
+ ElementTypes.reserve(ST->getNumElements());
+ for (Type *ElTy : ST->elements()) {
+ if (isa<PaddingExtType>(ElTy))
+ continue;
+ ElementTypes.push_back(getTypeWithoutPadding(ElTy));
+ }
+
+ // Handle explicitly padded cbuffer arrays like { [ n x paddedty ], ty }
+ if (ElementTypes.size() == 2)
+ if (auto *AT = dyn_cast<ArrayType>(ElementTypes[0]))
+ if (ElementTypes[1] == AT->getElementType())
+ return ArrayType::get(ElementTypes[1], AT->getNumElements() + 1);
+
+ // If we only have a single element, don't wrap it in a struct.
+ if (ElementTypes.size() == 1)
+ return ElementTypes[0];
+
+ return StructType::get(Ctx, ElementTypes, /*IsPacked=*/false);
+ }
+ // Arrays just need to have their element type adjusted.
+ if (auto *AT = dyn_cast<ArrayType>(Ty))
+ return ArrayType::get(getTypeWithoutPadding(AT->getElementType()),
+ AT->getNumElements());
+ // Anything else should be good as is.
+ return Ty;
+}
+
StructType *ResourceTypeInfo::createElementStruct(StringRef CBufferName) {
SmallString<64> TypeName;
@@ -334,14 +366,21 @@ StructType *ResourceTypeInfo::createElementStruct(StringRef CBufferName) {
}
case ResourceKind::CBuffer: {
auto *RTy = cast<CBufferExtType>(HandleTy);
- LayoutExtType *LayoutType = cast<LayoutExtType>(RTy->getResourceType());
- StructType *Ty = cast<StructType>(LayoutType->getWrappedType());
SmallString<64> Name = getResourceKindName(Kind);
if (!CBufferName.empty()) {
Name.append(".");
Name.append(CBufferName);
}
- return StructType::create(Ty->elements(), Name);
+
+ // TODO: Remove this when we update the frontend to use explicit padding.
+ if (LayoutExtType *LayoutType =
+ dyn_cast<LayoutExtType>(RTy->getResourceType())) {
+ StructType *Ty = cast<StructType>(LayoutType->getWrappedType());
+ return StructType::create(Ty->elements(), Name);
+ }
+
+ return getOrCreateElementStruct(
+ getTypeWithoutPadding(RTy->getResourceType()), Name);
}
case ResourceKind::Sampler: {
auto *RTy = cast<SamplerExtType>(HandleTy);
@@ -454,10 +493,10 @@ uint32_t ResourceTypeInfo::getCBufferSize(const DataLayout &DL) const {
Type *ElTy = cast<CBufferExtType>(HandleTy)->getResourceType();
+ // TODO: Remove this when we update the frontend to use explicit padding.
if (auto *LayoutTy = dyn_cast<LayoutExtType>(ElTy))
return LayoutTy->getSize();
- // TODO: What should we do with unannotated arrays?
return DL.getTypeAllocSize(ElTy);
}
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index e08ef60..dc813f6 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -5440,9 +5440,10 @@ static Value *simplifyCastInst(unsigned CastOpc, Value *Op, Type *Ty,
// ptrtoint (ptradd (Ptr, X - ptrtoint(Ptr))) -> X
Value *Ptr, *X;
- if (CastOpc == Instruction::PtrToInt &&
- match(Op, m_PtrAdd(m_Value(Ptr),
- m_Sub(m_Value(X), m_PtrToInt(m_Deferred(Ptr))))) &&
+ if ((CastOpc == Instruction::PtrToInt || CastOpc == Instruction::PtrToAddr) &&
+ match(Op,
+ m_PtrAdd(m_Value(Ptr),
+ m_Sub(m_Value(X), m_PtrToIntOrAddr(m_Deferred(Ptr))))) &&
X->getType() == Ty && Ty == Q.DL.getIndexType(Ptr->getType()))
return X;
@@ -6987,8 +6988,8 @@ static Value *simplifyIntrinsic(CallBase *Call, Value *Callee,
switch (IID) {
case Intrinsic::masked_load:
case Intrinsic::masked_gather: {
- Value *MaskArg = Args[2];
- Value *PassthruArg = Args[3];
+ Value *MaskArg = Args[1];
+ Value *PassthruArg = Args[2];
// If the mask is all zeros or undef, the "passthru" argument is the result.
if (maskIsAllZeroOrUndef(MaskArg))
return PassthruArg;
diff --git a/llvm/lib/Analysis/LazyValueInfo.cpp b/llvm/lib/Analysis/LazyValueInfo.cpp
index 0e5bc48..df75999 100644
--- a/llvm/lib/Analysis/LazyValueInfo.cpp
+++ b/llvm/lib/Analysis/LazyValueInfo.cpp
@@ -947,9 +947,8 @@ LazyValueInfoImpl::solveBlockValueSelect(SelectInst *SI, BasicBlock *BB) {
/*UseBlockValue*/ false));
}
- ValueLatticeElement Result = TrueVal;
- Result.mergeIn(FalseVal);
- return Result;
+ TrueVal.mergeIn(FalseVal);
+ return TrueVal;
}
std::optional<ConstantRange>
@@ -1778,9 +1777,8 @@ ValueLatticeElement LazyValueInfoImpl::getValueInBlock(Value *V, BasicBlock *BB,
assert(OptResult && "Value not available after solving");
}
- ValueLatticeElement Result = *OptResult;
- LLVM_DEBUG(dbgs() << " Result = " << Result << "\n");
- return Result;
+ LLVM_DEBUG(dbgs() << " Result = " << *OptResult << "\n");
+ return *OptResult;
}
ValueLatticeElement LazyValueInfoImpl::getValueAt(Value *V, Instruction *CxtI) {
diff --git a/llvm/lib/Analysis/LoopInfo.cpp b/llvm/lib/Analysis/LoopInfo.cpp
index a8c3173..d84721b 100644
--- a/llvm/lib/Analysis/LoopInfo.cpp
+++ b/llvm/lib/Analysis/LoopInfo.cpp
@@ -986,8 +986,8 @@ PreservedAnalyses LoopPrinterPass::run(Function &F,
return PreservedAnalyses::all();
}
-void llvm::printLoop(Loop &L, raw_ostream &OS, const std::string &Banner) {
-
+void llvm::printLoop(const Loop &L, raw_ostream &OS,
+ const std::string &Banner) {
if (forcePrintModuleIR()) {
// handling -print-module-scope
OS << Banner << " (loop: ";
diff --git a/llvm/lib/Analysis/MLInlineAdvisor.cpp b/llvm/lib/Analysis/MLInlineAdvisor.cpp
index f90717d..9a5ae2a 100644
--- a/llvm/lib/Analysis/MLInlineAdvisor.cpp
+++ b/llvm/lib/Analysis/MLInlineAdvisor.cpp
@@ -61,6 +61,9 @@ static cl::opt<SkipMLPolicyCriteria> SkipPolicy(
static cl::opt<std::string> ModelSelector("ml-inliner-model-selector",
cl::Hidden, cl::init(""));
+static cl::opt<bool> StopImmediatelyForTest("ml-inliner-stop-immediately",
+ cl::Hidden);
+
#if defined(LLVM_HAVE_TF_AOT_INLINERSIZEMODEL)
// codegen-ed file
#include "InlinerSizeModel.h" // NOLINT
@@ -214,6 +217,7 @@ MLInlineAdvisor::MLInlineAdvisor(
return;
}
ModelRunner->switchContext("");
+ ForceStop = StopImmediatelyForTest;
}
unsigned MLInlineAdvisor::getInitialFunctionLevel(const Function &F) const {
@@ -320,32 +324,44 @@ void MLInlineAdvisor::onSuccessfulInlining(const MLInlineAdvice &Advice,
FAM.invalidate(*Caller, PA);
}
Advice.updateCachedCallerFPI(FAM);
- int64_t IRSizeAfter =
- getIRSize(*Caller) + (CalleeWasDeleted ? 0 : Advice.CalleeIRSize);
- CurrentIRSize += IRSizeAfter - (Advice.CallerIRSize + Advice.CalleeIRSize);
+ if (Caller == Callee) {
+ assert(!CalleeWasDeleted);
+ // We double-counted CallerAndCalleeEdges - since the caller and callee
+ // would be the same
+ assert(Advice.CallerAndCalleeEdges % 2 == 0);
+ CurrentIRSize += getIRSize(*Caller) - Advice.CallerIRSize;
+ EdgeCount += getCachedFPI(*Caller).DirectCallsToDefinedFunctions -
+ Advice.CallerAndCalleeEdges / 2;
+ // The NodeCount would stay the same.
+ } else {
+ int64_t IRSizeAfter =
+ getIRSize(*Caller) + (CalleeWasDeleted ? 0 : Advice.CalleeIRSize);
+ CurrentIRSize += IRSizeAfter - (Advice.CallerIRSize + Advice.CalleeIRSize);
+
+ // We can delta-update module-wide features. We know the inlining only
+ // changed the caller, and maybe the callee (by deleting the latter). Nodes
+ // are simple to update. For edges, we 'forget' the edges that the caller
+ // and callee used to have before inlining, and add back what they currently
+ // have together.
+ int64_t NewCallerAndCalleeEdges =
+ getCachedFPI(*Caller).DirectCallsToDefinedFunctions;
+
+ // A dead function's node is not actually removed from the call graph until
+ // the end of the call graph walk, but the node no longer belongs to any
+ // valid SCC.
+ if (CalleeWasDeleted) {
+ --NodeCount;
+ NodesInLastSCC.erase(CG.lookup(*Callee));
+ DeadFunctions.insert(Callee);
+ } else {
+ NewCallerAndCalleeEdges +=
+ getCachedFPI(*Callee).DirectCallsToDefinedFunctions;
+ }
+ EdgeCount += (NewCallerAndCalleeEdges - Advice.CallerAndCalleeEdges);
+ }
if (CurrentIRSize > SizeIncreaseThreshold * InitialIRSize)
ForceStop = true;
- // We can delta-update module-wide features. We know the inlining only changed
- // the caller, and maybe the callee (by deleting the latter).
- // Nodes are simple to update.
- // For edges, we 'forget' the edges that the caller and callee used to have
- // before inlining, and add back what they currently have together.
- int64_t NewCallerAndCalleeEdges =
- getCachedFPI(*Caller).DirectCallsToDefinedFunctions;
-
- // A dead function's node is not actually removed from the call graph until
- // the end of the call graph walk, but the node no longer belongs to any valid
- // SCC.
- if (CalleeWasDeleted) {
- --NodeCount;
- NodesInLastSCC.erase(CG.lookup(*Callee));
- DeadFunctions.insert(Callee);
- } else {
- NewCallerAndCalleeEdges +=
- getCachedFPI(*Callee).DirectCallsToDefinedFunctions;
- }
- EdgeCount += (NewCallerAndCalleeEdges - Advice.CallerAndCalleeEdges);
assert(CurrentIRSize >= 0 && EdgeCount >= 0 && NodeCount >= 0);
}
@@ -379,9 +395,17 @@ std::unique_ptr<InlineAdvice> MLInlineAdvisor::getAdviceImpl(CallBase &CB) {
auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(Caller);
if (SkipPolicy == SkipMLPolicyCriteria::IfCallerIsNotCold) {
- if (!PSI.isFunctionEntryCold(&Caller))
- return std::make_unique<InlineAdvice>(this, CB, ORE,
- GetDefaultAdvice(CB));
+ if (!PSI.isFunctionEntryCold(&Caller)) {
+ // Return a MLInlineAdvice, despite delegating to the default advice,
+ // because we need to keep track of the internal state. This is different
+ // from the other instances where we return a "default" InlineAdvice,
+ // which happen at points we won't come back to the MLAdvisor for
+ // decisions requiring that state.
+ return ForceStop ? std::make_unique<InlineAdvice>(this, CB, ORE,
+ GetDefaultAdvice(CB))
+ : std::make_unique<MLInlineAdvice>(this, CB, ORE,
+ GetDefaultAdvice(CB));
+ }
}
auto MandatoryKind = InlineAdvisor::getMandatoryKind(CB, FAM, ORE);
// If this is a "never inline" case, there won't be any changes to internal
diff --git a/llvm/lib/Analysis/MemoryLocation.cpp b/llvm/lib/Analysis/MemoryLocation.cpp
index dcc5117..1c5f08e 100644
--- a/llvm/lib/Analysis/MemoryLocation.cpp
+++ b/llvm/lib/Analysis/MemoryLocation.cpp
@@ -245,7 +245,7 @@ MemoryLocation MemoryLocation::getForArgument(const CallBase *Call,
assert(ArgIdx == 0 && "Invalid argument index");
auto *Ty = cast<VectorType>(II->getType());
- if (auto KnownType = getKnownTypeFromMaskedOp(II->getOperand(2), Ty))
+ if (auto KnownType = getKnownTypeFromMaskedOp(II->getOperand(1), Ty))
return MemoryLocation(Arg, DL.getTypeStoreSize(*KnownType), AATags);
return MemoryLocation(
@@ -255,7 +255,7 @@ MemoryLocation MemoryLocation::getForArgument(const CallBase *Call,
assert(ArgIdx == 1 && "Invalid argument index");
auto *Ty = cast<VectorType>(II->getArgOperand(0)->getType());
- if (auto KnownType = getKnownTypeFromMaskedOp(II->getOperand(3), Ty))
+ if (auto KnownType = getKnownTypeFromMaskedOp(II->getOperand(2), Ty))
return MemoryLocation(Arg, DL.getTypeStoreSize(*KnownType), AATags);
return MemoryLocation(
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index a64b93d..6f7dd79 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -1840,19 +1840,19 @@ const SCEV *ScalarEvolution::getZeroExtendExprImpl(const SCEV *Op, Type *Ty,
// = zext((2^K * (trunc X to i{N-K}))<nuw>) to iM
// = (2^K * (zext(trunc X to i{N-K}) to iM))<nuw>.
//
- if (SM->getNumOperands() == 2)
- if (auto *MulLHS = dyn_cast<SCEVConstant>(SM->getOperand(0)))
- if (MulLHS->getAPInt().isPowerOf2())
- if (auto *TruncRHS = dyn_cast<SCEVTruncateExpr>(SM->getOperand(1))) {
- int NewTruncBits = getTypeSizeInBits(TruncRHS->getType()) -
- MulLHS->getAPInt().logBase2();
- Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits);
- return getMulExpr(
- getZeroExtendExpr(MulLHS, Ty),
- getZeroExtendExpr(
- getTruncateExpr(TruncRHS->getOperand(), NewTruncTy), Ty),
- SCEV::FlagNUW, Depth + 1);
- }
+ const APInt *C;
+ const SCEV *TruncRHS;
+ if (match(SM,
+ m_scev_Mul(m_scev_APInt(C), m_scev_Trunc(m_SCEV(TruncRHS)))) &&
+ C->isPowerOf2()) {
+ int NewTruncBits =
+ getTypeSizeInBits(SM->getOperand(1)->getType()) - C->logBase2();
+ Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits);
+ return getMulExpr(
+ getZeroExtendExpr(SM->getOperand(0), Ty),
+ getZeroExtendExpr(getTruncateExpr(TruncRHS, NewTruncTy), Ty),
+ SCEV::FlagNUW, Depth + 1);
+ }
}
// zext(umin(x, y)) -> umin(zext(x), zext(y))
@@ -3144,20 +3144,19 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops,
if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
if (Ops.size() == 2) {
// C1*(C2+V) -> C1*C2 + C1*V
- if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1]))
- // If any of Add's ops are Adds or Muls with a constant, apply this
- // transformation as well.
- //
- // TODO: There are some cases where this transformation is not
- // profitable; for example, Add = (C0 + X) * Y + Z. Maybe the scope of
- // this transformation should be narrowed down.
- if (Add->getNumOperands() == 2 && containsConstantInAddMulChain(Add)) {
- const SCEV *LHS = getMulExpr(LHSC, Add->getOperand(0),
- SCEV::FlagAnyWrap, Depth + 1);
- const SCEV *RHS = getMulExpr(LHSC, Add->getOperand(1),
- SCEV::FlagAnyWrap, Depth + 1);
- return getAddExpr(LHS, RHS, SCEV::FlagAnyWrap, Depth + 1);
- }
+ // If any of Add's ops are Adds or Muls with a constant, apply this
+ // transformation as well.
+ //
+ // TODO: There are some cases where this transformation is not
+ // profitable; for example, Add = (C0 + X) * Y + Z. Maybe the scope of
+ // this transformation should be narrowed down.
+ const SCEV *Op0, *Op1;
+ if (match(Ops[1], m_scev_Add(m_SCEV(Op0), m_SCEV(Op1))) &&
+ containsConstantInAddMulChain(Ops[1])) {
+ const SCEV *LHS = getMulExpr(LHSC, Op0, SCEV::FlagAnyWrap, Depth + 1);
+ const SCEV *RHS = getMulExpr(LHSC, Op1, SCEV::FlagAnyWrap, Depth + 1);
+ return getAddExpr(LHS, RHS, SCEV::FlagAnyWrap, Depth + 1);
+ }
if (Ops[0]->isAllOnesValue()) {
// If we have a mul by -1 of an add, try distributing the -1 among the
@@ -3578,20 +3577,12 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS,
}
// ((-C + (C smax %x)) /u %x) evaluates to zero, for any positive constant C.
- if (const auto *AE = dyn_cast<SCEVAddExpr>(LHS);
- AE && AE->getNumOperands() == 2) {
- if (const auto *VC = dyn_cast<SCEVConstant>(AE->getOperand(0))) {
- const APInt &NegC = VC->getAPInt();
- if (NegC.isNegative() && !NegC.isMinSignedValue()) {
- const auto *MME = dyn_cast<SCEVSMaxExpr>(AE->getOperand(1));
- if (MME && MME->getNumOperands() == 2 &&
- isa<SCEVConstant>(MME->getOperand(0)) &&
- cast<SCEVConstant>(MME->getOperand(0))->getAPInt() == -NegC &&
- MME->getOperand(1) == RHS)
- return getZero(LHS->getType());
- }
- }
- }
+ const APInt *NegC, *C;
+ if (match(LHS,
+ m_scev_Add(m_scev_APInt(NegC),
+ m_scev_SMax(m_scev_APInt(C), m_scev_Specific(RHS)))) &&
+ NegC->isNegative() && !NegC->isMinSignedValue() && *C == -*NegC)
+ return getZero(LHS->getType());
// TODO: Generalize to handle any common factors.
// udiv (mul nuw a, vscale), (mul nuw b, vscale) --> udiv a, b
@@ -4623,17 +4614,11 @@ const SCEV *ScalarEvolution::getNegativeSCEV(const SCEV *V,
/// If Expr computes ~A, return A else return nullptr
static const SCEV *MatchNotExpr(const SCEV *Expr) {
- const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Expr);
- if (!Add || Add->getNumOperands() != 2 ||
- !Add->getOperand(0)->isAllOnesValue())
- return nullptr;
-
- const SCEVMulExpr *AddRHS = dyn_cast<SCEVMulExpr>(Add->getOperand(1));
- if (!AddRHS || AddRHS->getNumOperands() != 2 ||
- !AddRHS->getOperand(0)->isAllOnesValue())
- return nullptr;
-
- return AddRHS->getOperand(1);
+ const SCEV *MulOp;
+ if (match(Expr, m_scev_Add(m_scev_AllOnes(),
+ m_scev_Mul(m_scev_AllOnes(), m_SCEV(MulOp)))))
+ return MulOp;
+ return nullptr;
}
/// Return a SCEV corresponding to ~V = -1-V
@@ -10797,19 +10782,15 @@ static bool HasSameValue(const SCEV *A, const SCEV *B) {
}
static bool MatchBinarySub(const SCEV *S, const SCEV *&LHS, const SCEV *&RHS) {
- const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S);
- if (!Add || Add->getNumOperands() != 2)
+ const SCEV *Op0, *Op1;
+ if (!match(S, m_scev_Add(m_SCEV(Op0), m_SCEV(Op1))))
return false;
- if (auto *ME = dyn_cast<SCEVMulExpr>(Add->getOperand(0));
- ME && ME->getNumOperands() == 2 && ME->getOperand(0)->isAllOnesValue()) {
- LHS = Add->getOperand(1);
- RHS = ME->getOperand(1);
+ if (match(Op0, m_scev_Mul(m_scev_AllOnes(), m_SCEV(RHS)))) {
+ LHS = Op1;
return true;
}
- if (auto *ME = dyn_cast<SCEVMulExpr>(Add->getOperand(1));
- ME && ME->getNumOperands() == 2 && ME->getOperand(0)->isAllOnesValue()) {
- LHS = Add->getOperand(0);
- RHS = ME->getOperand(1);
+ if (match(Op1, m_scev_Mul(m_scev_AllOnes(), m_SCEV(RHS)))) {
+ LHS = Op0;
return true;
}
return false;
@@ -12172,13 +12153,10 @@ bool ScalarEvolution::isImpliedCondBalancedTypes(
bool ScalarEvolution::splitBinaryAdd(const SCEV *Expr,
const SCEV *&L, const SCEV *&R,
SCEV::NoWrapFlags &Flags) {
- const auto *AE = dyn_cast<SCEVAddExpr>(Expr);
- if (!AE || AE->getNumOperands() != 2)
+ if (!match(Expr, m_scev_Add(m_SCEV(L), m_SCEV(R))))
return false;
- L = AE->getOperand(0);
- R = AE->getOperand(1);
- Flags = AE->getNoWrapFlags();
+ Flags = cast<SCEVAddExpr>(Expr)->getNoWrapFlags();
return true;
}
@@ -12220,12 +12198,11 @@ ScalarEvolution::computeConstantDifference(const SCEV *More, const SCEV *Less) {
// Try to match a common constant multiply.
auto MatchConstMul =
[](const SCEV *S) -> std::optional<std::pair<const SCEV *, APInt>> {
- auto *M = dyn_cast<SCEVMulExpr>(S);
- if (!M || M->getNumOperands() != 2 ||
- !isa<SCEVConstant>(M->getOperand(0)))
- return std::nullopt;
- return {
- {M->getOperand(1), cast<SCEVConstant>(M->getOperand(0))->getAPInt()}};
+ const APInt *C;
+ const SCEV *Op;
+ if (match(S, m_scev_Mul(m_scev_APInt(C), m_SCEV(Op))))
+ return {{Op, *C}};
+ return std::nullopt;
};
if (auto MatchedMore = MatchConstMul(More)) {
if (auto MatchedLess = MatchConstMul(Less)) {
@@ -15496,6 +15473,38 @@ void ScalarEvolution::LoopGuards::collectFromPHI(
}
}
+// Return a new SCEV that modifies \p Expr to the closest number divides by
+// \p Divisor and less or equal than Expr. For now, only handle constant
+// Expr.
+static const SCEV *getPreviousSCEVDivisibleByDivisor(const SCEV *Expr,
+ const APInt &DivisorVal,
+ ScalarEvolution &SE) {
+ const APInt *ExprVal;
+ if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() ||
+ DivisorVal.isNonPositive())
+ return Expr;
+ APInt Rem = ExprVal->urem(DivisorVal);
+ // return the SCEV: Expr - Expr % Divisor
+ return SE.getConstant(*ExprVal - Rem);
+}
+
+// Return a new SCEV that modifies \p Expr to the closest number divides by
+// \p Divisor and greater or equal than Expr. For now, only handle constant
+// Expr.
+static const SCEV *getNextSCEVDivisibleByDivisor(const SCEV *Expr,
+ const APInt &DivisorVal,
+ ScalarEvolution &SE) {
+ const APInt *ExprVal;
+ if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() ||
+ DivisorVal.isNonPositive())
+ return Expr;
+ APInt Rem = ExprVal->urem(DivisorVal);
+ if (Rem.isZero())
+ return Expr;
+ // return the SCEV: Expr + Divisor - Expr % Divisor
+ return SE.getConstant(*ExprVal + DivisorVal - Rem);
+}
+
void ScalarEvolution::LoopGuards::collectFromBlock(
ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
const BasicBlock *Block, const BasicBlock *Pred,
@@ -15557,51 +15566,12 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
auto IsMinMaxSCEVWithNonNegativeConstant =
[&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS,
const SCEV *&RHS) {
- if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
- if (MinMax->getNumOperands() != 2)
- return false;
- if (auto *C = dyn_cast<SCEVConstant>(MinMax->getOperand(0))) {
- if (C->getAPInt().isNegative())
- return false;
- SCTy = MinMax->getSCEVType();
- LHS = MinMax->getOperand(0);
- RHS = MinMax->getOperand(1);
- return true;
- }
- }
- return false;
+ const APInt *C;
+ SCTy = Expr->getSCEVType();
+ return match(Expr, m_scev_MinMax(m_SCEV(LHS), m_SCEV(RHS))) &&
+ match(LHS, m_scev_APInt(C)) && C->isNonNegative();
};
- // Return a new SCEV that modifies \p Expr to the closest number divides by
- // \p Divisor and greater or equal than Expr. For now, only handle constant
- // Expr.
- auto GetNextSCEVDividesByDivisor = [&](const SCEV *Expr,
- const APInt &DivisorVal) {
- const APInt *ExprVal;
- if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() ||
- DivisorVal.isNonPositive())
- return Expr;
- APInt Rem = ExprVal->urem(DivisorVal);
- if (Rem.isZero())
- return Expr;
- // return the SCEV: Expr + Divisor - Expr % Divisor
- return SE.getConstant(*ExprVal + DivisorVal - Rem);
- };
-
- // Return a new SCEV that modifies \p Expr to the closest number divides by
- // \p Divisor and less or equal than Expr. For now, only handle constant
- // Expr.
- auto GetPreviousSCEVDividesByDivisor = [&](const SCEV *Expr,
- const APInt &DivisorVal) {
- const APInt *ExprVal;
- if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() ||
- DivisorVal.isNonPositive())
- return Expr;
- APInt Rem = ExprVal->urem(DivisorVal);
- // return the SCEV: Expr - Expr % Divisor
- return SE.getConstant(*ExprVal - Rem);
- };
-
// Apply divisibilty by \p Divisor on MinMaxExpr with constant values,
// recursively. This is done by aligning up/down the constant value to the
// Divisor.
@@ -15623,8 +15593,9 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
assert(SE.isKnownNonNegative(MinMaxLHS) &&
"Expected non-negative operand!");
auto *DivisibleExpr =
- IsMin ? GetPreviousSCEVDividesByDivisor(MinMaxLHS, DivisorVal)
- : GetNextSCEVDividesByDivisor(MinMaxLHS, DivisorVal);
+ IsMin
+ ? getPreviousSCEVDivisibleByDivisor(MinMaxLHS, DivisorVal, SE)
+ : getNextSCEVDivisibleByDivisor(MinMaxLHS, DivisorVal, SE);
SmallVector<const SCEV *> Ops = {
ApplyDivisibiltyOnMinMaxExpr(MinMaxRHS, Divisor), DivisibleExpr};
return SE.getMinMaxExpr(SCTy, Ops);
@@ -15701,21 +15672,21 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
[[fallthrough]];
case CmpInst::ICMP_SLT: {
RHS = SE.getMinusSCEV(RHS, One);
- RHS = GetPreviousSCEVDividesByDivisor(RHS, DividesBy);
+ RHS = getPreviousSCEVDivisibleByDivisor(RHS, DividesBy, SE);
break;
}
case CmpInst::ICMP_UGT:
case CmpInst::ICMP_SGT:
RHS = SE.getAddExpr(RHS, One);
- RHS = GetNextSCEVDividesByDivisor(RHS, DividesBy);
+ RHS = getNextSCEVDivisibleByDivisor(RHS, DividesBy, SE);
break;
case CmpInst::ICMP_ULE:
case CmpInst::ICMP_SLE:
- RHS = GetPreviousSCEVDividesByDivisor(RHS, DividesBy);
+ RHS = getPreviousSCEVDivisibleByDivisor(RHS, DividesBy, SE);
break;
case CmpInst::ICMP_UGE:
case CmpInst::ICMP_SGE:
- RHS = GetNextSCEVDividesByDivisor(RHS, DividesBy);
+ RHS = getNextSCEVDivisibleByDivisor(RHS, DividesBy, SE);
break;
default:
break;
@@ -15769,22 +15740,29 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
case CmpInst::ICMP_NE:
if (match(RHS, m_scev_Zero())) {
const SCEV *OneAlignedUp =
- GetNextSCEVDividesByDivisor(One, DividesBy);
+ getNextSCEVDivisibleByDivisor(One, DividesBy, SE);
To = SE.getUMaxExpr(FromRewritten, OneAlignedUp);
} else {
+ // LHS != RHS can be rewritten as (LHS - RHS) = UMax(1, LHS - RHS),
+ // but creating the subtraction eagerly is expensive. Track the
+ // inequalities in a separate map, and materialize the rewrite lazily
+ // when encountering a suitable subtraction while re-writing.
if (LHS->getType()->isPointerTy()) {
LHS = SE.getLosslessPtrToIntExpr(LHS);
RHS = SE.getLosslessPtrToIntExpr(RHS);
if (isa<SCEVCouldNotCompute>(LHS) || isa<SCEVCouldNotCompute>(RHS))
break;
}
- auto AddSubRewrite = [&](const SCEV *A, const SCEV *B) {
- const SCEV *Sub = SE.getMinusSCEV(A, B);
- AddRewrite(Sub, Sub,
- SE.getUMaxExpr(Sub, SE.getOne(From->getType())));
- };
- AddSubRewrite(LHS, RHS);
- AddSubRewrite(RHS, LHS);
+ const SCEVConstant *C;
+ const SCEV *A, *B;
+ if (match(RHS, m_scev_Add(m_SCEVConstant(C), m_SCEV(A))) &&
+ match(LHS, m_scev_Add(m_scev_Specific(C), m_SCEV(B)))) {
+ RHS = A;
+ LHS = B;
+ }
+ if (LHS > RHS)
+ std::swap(LHS, RHS);
+ Guards.NotEqual.insert({LHS, RHS});
continue;
}
break;
@@ -15918,13 +15896,15 @@ const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const {
class SCEVLoopGuardRewriter
: public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
const DenseMap<const SCEV *, const SCEV *> &Map;
+ const SmallDenseSet<std::pair<const SCEV *, const SCEV *>> &NotEqual;
SCEV::NoWrapFlags FlagMask = SCEV::FlagAnyWrap;
public:
SCEVLoopGuardRewriter(ScalarEvolution &SE,
const ScalarEvolution::LoopGuards &Guards)
- : SCEVRewriteVisitor(SE), Map(Guards.RewriteMap) {
+ : SCEVRewriteVisitor(SE), Map(Guards.RewriteMap),
+ NotEqual(Guards.NotEqual) {
if (Guards.PreserveNUW)
FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNUW);
if (Guards.PreserveNSW)
@@ -15979,14 +15959,39 @@ const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const {
}
const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
+ // Helper to check if S is a subtraction (A - B) where A != B, and if so,
+ // return UMax(S, 1).
+ auto RewriteSubtraction = [&](const SCEV *S) -> const SCEV * {
+ const SCEV *LHS, *RHS;
+ if (MatchBinarySub(S, LHS, RHS)) {
+ if (LHS > RHS)
+ std::swap(LHS, RHS);
+ if (NotEqual.contains({LHS, RHS})) {
+ const SCEV *OneAlignedUp = getNextSCEVDivisibleByDivisor(
+ SE.getOne(S->getType()), SE.getConstantMultiple(S), SE);
+ return SE.getUMaxExpr(OneAlignedUp, S);
+ }
+ }
+ return nullptr;
+ };
+
+ // Check if Expr itself is a subtraction pattern with guard info.
+ if (const SCEV *Rewritten = RewriteSubtraction(Expr))
+ return Rewritten;
+
// Trip count expressions sometimes consist of adding 3 operands, i.e.
// (Const + A + B). There may be guard info for A + B, and if so, apply
// it.
// TODO: Could more generally apply guards to Add sub-expressions.
if (isa<SCEVConstant>(Expr->getOperand(0)) &&
Expr->getNumOperands() == 3) {
- if (const SCEV *S = Map.lookup(
- SE.getAddExpr(Expr->getOperand(1), Expr->getOperand(2))))
+ const SCEV *Add =
+ SE.getAddExpr(Expr->getOperand(1), Expr->getOperand(2));
+ if (const SCEV *Rewritten = RewriteSubtraction(Add))
+ return SE.getAddExpr(
+ Expr->getOperand(0), Rewritten,
+ ScalarEvolution::maskFlags(Expr->getNoWrapFlags(), FlagMask));
+ if (const SCEV *S = Map.lookup(Add))
return SE.getAddExpr(Expr->getOperand(0), S);
}
SmallVector<const SCEV *, 2> Operands;
@@ -16021,7 +16026,7 @@ const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const {
}
};
- if (RewriteMap.empty())
+ if (RewriteMap.empty() && NotEqual.empty())
return Expr;
SCEVLoopGuardRewriter Rewriter(SE, *this);
diff --git a/llvm/lib/Analysis/StaticDataProfileInfo.cpp b/llvm/lib/Analysis/StaticDataProfileInfo.cpp
index e7f0b2c..61d4935 100644
--- a/llvm/lib/Analysis/StaticDataProfileInfo.cpp
+++ b/llvm/lib/Analysis/StaticDataProfileInfo.cpp
@@ -1,10 +1,14 @@
#include "llvm/Analysis/StaticDataProfileInfo.h"
#include "llvm/Analysis/ProfileSummaryInfo.h"
#include "llvm/IR/Constant.h"
+#include "llvm/IR/Constants.h"
#include "llvm/IR/GlobalVariable.h"
+#include "llvm/IR/Module.h"
#include "llvm/InitializePasses.h"
#include "llvm/ProfileData/InstrProf.h"
+#define DEBUG_TYPE "static-data-profile-info"
+
using namespace llvm;
namespace llvm {
@@ -79,6 +83,17 @@ StaticDataProfileInfo::getConstantHotnessUsingProfileCount(
return StaticDataHotness::LukewarmOrUnknown;
}
+StaticDataProfileInfo::StaticDataHotness
+StaticDataProfileInfo::getSectionHotnessUsingDataAccessProfile(
+ std::optional<StringRef> MaybeSectionPrefix) const {
+ if (!MaybeSectionPrefix)
+ return StaticDataHotness::LukewarmOrUnknown;
+ StringRef Prefix = *MaybeSectionPrefix;
+ assert((Prefix == "hot" || Prefix == "unlikely") &&
+ "Expect section_prefix to be one of hot or unlikely");
+ return Prefix == "hot" ? StaticDataHotness::Hot : StaticDataHotness::Cold;
+}
+
StringRef StaticDataProfileInfo::hotnessToStr(StaticDataHotness Hotness) const {
switch (Hotness) {
case StaticDataHotness::Cold:
@@ -101,13 +116,66 @@ StaticDataProfileInfo::getConstantProfileCount(const Constant *C) const {
StringRef StaticDataProfileInfo::getConstantSectionPrefix(
const Constant *C, const ProfileSummaryInfo *PSI) const {
std::optional<uint64_t> Count = getConstantProfileCount(C);
+
+#ifndef NDEBUG
+ auto DbgPrintPrefix = [](StringRef Prefix) {
+ return Prefix.empty() ? "<empty>" : Prefix;
+ };
+#endif
+
+ if (EnableDataAccessProf) {
+ // Module flag `HasDataAccessProf` is 1 -> empty section prefix means
+ // unknown hotness except for string literals.
+ if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(C);
+ GV && llvm::memprof::IsAnnotationOK(*GV) &&
+ !GV->getName().starts_with(".str")) {
+ auto HotnessFromDataAccessProf =
+ getSectionHotnessUsingDataAccessProfile(GV->getSectionPrefix());
+
+ if (!Count) {
+ StringRef Prefix = hotnessToStr(HotnessFromDataAccessProf);
+ LLVM_DEBUG(dbgs() << GV->getName() << " has section prefix "
+ << DbgPrintPrefix(Prefix)
+ << ", solely from data access profiles\n");
+ return Prefix;
+ }
+
+ // Both data access profiles and PGO counters are available. Use the
+ // hotter one.
+ auto HotnessFromPGO = getConstantHotnessUsingProfileCount(C, PSI, *Count);
+ StaticDataHotness GlobalVarHotness = StaticDataHotness::LukewarmOrUnknown;
+ if (HotnessFromDataAccessProf == StaticDataHotness::Hot ||
+ HotnessFromPGO == StaticDataHotness::Hot) {
+ GlobalVarHotness = StaticDataHotness::Hot;
+ } else if (HotnessFromDataAccessProf ==
+ StaticDataHotness::LukewarmOrUnknown ||
+ HotnessFromPGO == StaticDataHotness::LukewarmOrUnknown) {
+ GlobalVarHotness = StaticDataHotness::LukewarmOrUnknown;
+ } else {
+ GlobalVarHotness = StaticDataHotness::Cold;
+ }
+ StringRef Prefix = hotnessToStr(GlobalVarHotness);
+ LLVM_DEBUG(
+ dbgs() << GV->getName() << " has section prefix "
+ << DbgPrintPrefix(Prefix)
+ << ", the max from data access profiles as "
+ << DbgPrintPrefix(hotnessToStr(HotnessFromDataAccessProf))
+ << " and PGO counters as "
+ << DbgPrintPrefix(hotnessToStr(HotnessFromPGO)) << "\n");
+ return Prefix;
+ }
+ }
if (!Count)
return "";
return hotnessToStr(getConstantHotnessUsingProfileCount(C, PSI, *Count));
}
bool StaticDataProfileInfoWrapperPass::doInitialization(Module &M) {
- Info.reset(new StaticDataProfileInfo());
+ bool EnableDataAccessProf = false;
+ if (auto *MD = mdconst::extract_or_null<ConstantInt>(
+ M.getModuleFlag("EnableDataAccessProf")))
+ EnableDataAccessProf = MD->getZExtValue();
+ Info.reset(new StaticDataProfileInfo(EnableDataAccessProf));
return false;
}
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 9655c88..0a72076 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -7695,6 +7695,11 @@ static bool isGuaranteedNotToBeUndefOrPoison(
}
if (IsWellDefined)
return true;
+ } else if (auto *Splat = isa<ShuffleVectorInst>(Opr) ? getSplatValue(Opr)
+ : nullptr) {
+ // For splats we only need to check the value being splatted.
+ if (OpCheck(Splat))
+ return true;
} else if (all_of(Opr->operands(), OpCheck))
return true;
}