diff options
Diffstat (limited to 'llvm/lib/Analysis')
-rwxr-xr-x | llvm/lib/Analysis/ConstantFolding.cpp | 60 | ||||
-rw-r--r-- | llvm/lib/Analysis/DXILResource.cpp | 47 | ||||
-rw-r--r-- | llvm/lib/Analysis/InstructionSimplify.cpp | 11 | ||||
-rw-r--r-- | llvm/lib/Analysis/LazyValueInfo.cpp | 10 | ||||
-rw-r--r-- | llvm/lib/Analysis/LoopInfo.cpp | 4 | ||||
-rw-r--r-- | llvm/lib/Analysis/MLInlineAdvisor.cpp | 76 | ||||
-rw-r--r-- | llvm/lib/Analysis/MemoryLocation.cpp | 4 | ||||
-rw-r--r-- | llvm/lib/Analysis/ScalarEvolution.cpp | 273 | ||||
-rw-r--r-- | llvm/lib/Analysis/StaticDataProfileInfo.cpp | 70 | ||||
-rw-r--r-- | llvm/lib/Analysis/ValueTracking.cpp | 5 |
10 files changed, 346 insertions, 214 deletions
diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp index 45c889c..e9e2e7d 100755 --- a/llvm/lib/Analysis/ConstantFolding.cpp +++ b/llvm/lib/Analysis/ConstantFolding.cpp @@ -2177,16 +2177,13 @@ Constant *constantFoldVectorReduce(Intrinsic::ID IID, Constant *Op) { return PoisonValue::get(VT->getElementType()); // TODO: Handle undef. - if (!isa<ConstantVector>(Op) && !isa<ConstantDataVector>(Op)) - return nullptr; - - auto *EltC = dyn_cast<ConstantInt>(Op->getAggregateElement(0U)); + auto *EltC = dyn_cast_or_null<ConstantInt>(Op->getAggregateElement(0U)); if (!EltC) return nullptr; APInt Acc = EltC->getValue(); for (unsigned I = 1, E = VT->getNumElements(); I != E; I++) { - if (!(EltC = dyn_cast<ConstantInt>(Op->getAggregateElement(I)))) + if (!(EltC = dyn_cast_or_null<ConstantInt>(Op->getAggregateElement(I)))) return nullptr; const APInt &X = EltC->getValue(); switch (IID) { @@ -3059,35 +3056,25 @@ static Constant *ConstantFoldScalarCall1(StringRef Name, Val = Val | Val << 1; return ConstantInt::get(Ty, Val); } - - default: - return nullptr; } } - switch (IntrinsicID) { - default: break; - case Intrinsic::vector_reduce_add: - case Intrinsic::vector_reduce_mul: - case Intrinsic::vector_reduce_and: - case Intrinsic::vector_reduce_or: - case Intrinsic::vector_reduce_xor: - case Intrinsic::vector_reduce_smin: - case Intrinsic::vector_reduce_smax: - case Intrinsic::vector_reduce_umin: - case Intrinsic::vector_reduce_umax: - if (Constant *C = constantFoldVectorReduce(IntrinsicID, Operands[0])) - return C; - break; - } - - // Support ConstantVector in case we have an Undef in the top. - if (isa<ConstantVector>(Operands[0]) || - isa<ConstantDataVector>(Operands[0]) || - isa<ConstantAggregateZero>(Operands[0])) { + if (Operands[0]->getType()->isVectorTy()) { auto *Op = cast<Constant>(Operands[0]); switch (IntrinsicID) { default: break; + case Intrinsic::vector_reduce_add: + case Intrinsic::vector_reduce_mul: + case Intrinsic::vector_reduce_and: + case Intrinsic::vector_reduce_or: + case Intrinsic::vector_reduce_xor: + case Intrinsic::vector_reduce_smin: + case Intrinsic::vector_reduce_smax: + case Intrinsic::vector_reduce_umin: + case Intrinsic::vector_reduce_umax: + if (Constant *C = constantFoldVectorReduce(IntrinsicID, Operands[0])) + return C; + break; case Intrinsic::x86_sse_cvtss2si: case Intrinsic::x86_sse_cvtss2si64: case Intrinsic::x86_sse2_cvtsd2si: @@ -3116,10 +3103,15 @@ static Constant *ConstantFoldScalarCall1(StringRef Name, case Intrinsic::wasm_alltrue: // Check each element individually unsigned E = cast<FixedVectorType>(Op->getType())->getNumElements(); - for (unsigned I = 0; I != E; ++I) - if (Constant *Elt = Op->getAggregateElement(I)) - if (Elt->isZeroValue()) - return ConstantInt::get(Ty, 0); + for (unsigned I = 0; I != E; ++I) { + Constant *Elt = Op->getAggregateElement(I); + // Return false as soon as we find a non-true element. + if (Elt && Elt->isZeroValue()) + return ConstantInt::get(Ty, 0); + // Bail as soon as we find an element we cannot prove to be true. + if (!Elt || !isa<ConstantInt>(Elt)) + return nullptr; + } return ConstantInt::get(Ty, 1); } @@ -4064,8 +4056,8 @@ static Constant *ConstantFoldFixedVectorCall( switch (IntrinsicID) { case Intrinsic::masked_load: { auto *SrcPtr = Operands[0]; - auto *Mask = Operands[2]; - auto *Passthru = Operands[3]; + auto *Mask = Operands[1]; + auto *Passthru = Operands[2]; Constant *VecData = ConstantFoldLoadFromConstPtr(SrcPtr, FVTy, DL); diff --git a/llvm/lib/Analysis/DXILResource.cpp b/llvm/lib/Analysis/DXILResource.cpp index b78cc03e..f9bf092 100644 --- a/llvm/lib/Analysis/DXILResource.cpp +++ b/llvm/lib/Analysis/DXILResource.cpp @@ -281,6 +281,38 @@ static StructType *getOrCreateElementStruct(Type *ElemType, StringRef Name) { return StructType::create(ElemType, Name); } +static Type *getTypeWithoutPadding(Type *Ty) { + // Recursively remove padding from structures. + if (auto *ST = dyn_cast<StructType>(Ty)) { + LLVMContext &Ctx = Ty->getContext(); + SmallVector<Type *> ElementTypes; + ElementTypes.reserve(ST->getNumElements()); + for (Type *ElTy : ST->elements()) { + if (isa<PaddingExtType>(ElTy)) + continue; + ElementTypes.push_back(getTypeWithoutPadding(ElTy)); + } + + // Handle explicitly padded cbuffer arrays like { [ n x paddedty ], ty } + if (ElementTypes.size() == 2) + if (auto *AT = dyn_cast<ArrayType>(ElementTypes[0])) + if (ElementTypes[1] == AT->getElementType()) + return ArrayType::get(ElementTypes[1], AT->getNumElements() + 1); + + // If we only have a single element, don't wrap it in a struct. + if (ElementTypes.size() == 1) + return ElementTypes[0]; + + return StructType::get(Ctx, ElementTypes, /*IsPacked=*/false); + } + // Arrays just need to have their element type adjusted. + if (auto *AT = dyn_cast<ArrayType>(Ty)) + return ArrayType::get(getTypeWithoutPadding(AT->getElementType()), + AT->getNumElements()); + // Anything else should be good as is. + return Ty; +} + StructType *ResourceTypeInfo::createElementStruct(StringRef CBufferName) { SmallString<64> TypeName; @@ -334,14 +366,21 @@ StructType *ResourceTypeInfo::createElementStruct(StringRef CBufferName) { } case ResourceKind::CBuffer: { auto *RTy = cast<CBufferExtType>(HandleTy); - LayoutExtType *LayoutType = cast<LayoutExtType>(RTy->getResourceType()); - StructType *Ty = cast<StructType>(LayoutType->getWrappedType()); SmallString<64> Name = getResourceKindName(Kind); if (!CBufferName.empty()) { Name.append("."); Name.append(CBufferName); } - return StructType::create(Ty->elements(), Name); + + // TODO: Remove this when we update the frontend to use explicit padding. + if (LayoutExtType *LayoutType = + dyn_cast<LayoutExtType>(RTy->getResourceType())) { + StructType *Ty = cast<StructType>(LayoutType->getWrappedType()); + return StructType::create(Ty->elements(), Name); + } + + return getOrCreateElementStruct( + getTypeWithoutPadding(RTy->getResourceType()), Name); } case ResourceKind::Sampler: { auto *RTy = cast<SamplerExtType>(HandleTy); @@ -454,10 +493,10 @@ uint32_t ResourceTypeInfo::getCBufferSize(const DataLayout &DL) const { Type *ElTy = cast<CBufferExtType>(HandleTy)->getResourceType(); + // TODO: Remove this when we update the frontend to use explicit padding. if (auto *LayoutTy = dyn_cast<LayoutExtType>(ElTy)) return LayoutTy->getSize(); - // TODO: What should we do with unannotated arrays? return DL.getTypeAllocSize(ElTy); } diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp index e08ef60..dc813f6 100644 --- a/llvm/lib/Analysis/InstructionSimplify.cpp +++ b/llvm/lib/Analysis/InstructionSimplify.cpp @@ -5440,9 +5440,10 @@ static Value *simplifyCastInst(unsigned CastOpc, Value *Op, Type *Ty, // ptrtoint (ptradd (Ptr, X - ptrtoint(Ptr))) -> X Value *Ptr, *X; - if (CastOpc == Instruction::PtrToInt && - match(Op, m_PtrAdd(m_Value(Ptr), - m_Sub(m_Value(X), m_PtrToInt(m_Deferred(Ptr))))) && + if ((CastOpc == Instruction::PtrToInt || CastOpc == Instruction::PtrToAddr) && + match(Op, + m_PtrAdd(m_Value(Ptr), + m_Sub(m_Value(X), m_PtrToIntOrAddr(m_Deferred(Ptr))))) && X->getType() == Ty && Ty == Q.DL.getIndexType(Ptr->getType())) return X; @@ -6987,8 +6988,8 @@ static Value *simplifyIntrinsic(CallBase *Call, Value *Callee, switch (IID) { case Intrinsic::masked_load: case Intrinsic::masked_gather: { - Value *MaskArg = Args[2]; - Value *PassthruArg = Args[3]; + Value *MaskArg = Args[1]; + Value *PassthruArg = Args[2]; // If the mask is all zeros or undef, the "passthru" argument is the result. if (maskIsAllZeroOrUndef(MaskArg)) return PassthruArg; diff --git a/llvm/lib/Analysis/LazyValueInfo.cpp b/llvm/lib/Analysis/LazyValueInfo.cpp index 0e5bc48..df75999 100644 --- a/llvm/lib/Analysis/LazyValueInfo.cpp +++ b/llvm/lib/Analysis/LazyValueInfo.cpp @@ -947,9 +947,8 @@ LazyValueInfoImpl::solveBlockValueSelect(SelectInst *SI, BasicBlock *BB) { /*UseBlockValue*/ false)); } - ValueLatticeElement Result = TrueVal; - Result.mergeIn(FalseVal); - return Result; + TrueVal.mergeIn(FalseVal); + return TrueVal; } std::optional<ConstantRange> @@ -1778,9 +1777,8 @@ ValueLatticeElement LazyValueInfoImpl::getValueInBlock(Value *V, BasicBlock *BB, assert(OptResult && "Value not available after solving"); } - ValueLatticeElement Result = *OptResult; - LLVM_DEBUG(dbgs() << " Result = " << Result << "\n"); - return Result; + LLVM_DEBUG(dbgs() << " Result = " << *OptResult << "\n"); + return *OptResult; } ValueLatticeElement LazyValueInfoImpl::getValueAt(Value *V, Instruction *CxtI) { diff --git a/llvm/lib/Analysis/LoopInfo.cpp b/llvm/lib/Analysis/LoopInfo.cpp index a8c3173..d84721b 100644 --- a/llvm/lib/Analysis/LoopInfo.cpp +++ b/llvm/lib/Analysis/LoopInfo.cpp @@ -986,8 +986,8 @@ PreservedAnalyses LoopPrinterPass::run(Function &F, return PreservedAnalyses::all(); } -void llvm::printLoop(Loop &L, raw_ostream &OS, const std::string &Banner) { - +void llvm::printLoop(const Loop &L, raw_ostream &OS, + const std::string &Banner) { if (forcePrintModuleIR()) { // handling -print-module-scope OS << Banner << " (loop: "; diff --git a/llvm/lib/Analysis/MLInlineAdvisor.cpp b/llvm/lib/Analysis/MLInlineAdvisor.cpp index f90717d..9a5ae2a 100644 --- a/llvm/lib/Analysis/MLInlineAdvisor.cpp +++ b/llvm/lib/Analysis/MLInlineAdvisor.cpp @@ -61,6 +61,9 @@ static cl::opt<SkipMLPolicyCriteria> SkipPolicy( static cl::opt<std::string> ModelSelector("ml-inliner-model-selector", cl::Hidden, cl::init("")); +static cl::opt<bool> StopImmediatelyForTest("ml-inliner-stop-immediately", + cl::Hidden); + #if defined(LLVM_HAVE_TF_AOT_INLINERSIZEMODEL) // codegen-ed file #include "InlinerSizeModel.h" // NOLINT @@ -214,6 +217,7 @@ MLInlineAdvisor::MLInlineAdvisor( return; } ModelRunner->switchContext(""); + ForceStop = StopImmediatelyForTest; } unsigned MLInlineAdvisor::getInitialFunctionLevel(const Function &F) const { @@ -320,32 +324,44 @@ void MLInlineAdvisor::onSuccessfulInlining(const MLInlineAdvice &Advice, FAM.invalidate(*Caller, PA); } Advice.updateCachedCallerFPI(FAM); - int64_t IRSizeAfter = - getIRSize(*Caller) + (CalleeWasDeleted ? 0 : Advice.CalleeIRSize); - CurrentIRSize += IRSizeAfter - (Advice.CallerIRSize + Advice.CalleeIRSize); + if (Caller == Callee) { + assert(!CalleeWasDeleted); + // We double-counted CallerAndCalleeEdges - since the caller and callee + // would be the same + assert(Advice.CallerAndCalleeEdges % 2 == 0); + CurrentIRSize += getIRSize(*Caller) - Advice.CallerIRSize; + EdgeCount += getCachedFPI(*Caller).DirectCallsToDefinedFunctions - + Advice.CallerAndCalleeEdges / 2; + // The NodeCount would stay the same. + } else { + int64_t IRSizeAfter = + getIRSize(*Caller) + (CalleeWasDeleted ? 0 : Advice.CalleeIRSize); + CurrentIRSize += IRSizeAfter - (Advice.CallerIRSize + Advice.CalleeIRSize); + + // We can delta-update module-wide features. We know the inlining only + // changed the caller, and maybe the callee (by deleting the latter). Nodes + // are simple to update. For edges, we 'forget' the edges that the caller + // and callee used to have before inlining, and add back what they currently + // have together. + int64_t NewCallerAndCalleeEdges = + getCachedFPI(*Caller).DirectCallsToDefinedFunctions; + + // A dead function's node is not actually removed from the call graph until + // the end of the call graph walk, but the node no longer belongs to any + // valid SCC. + if (CalleeWasDeleted) { + --NodeCount; + NodesInLastSCC.erase(CG.lookup(*Callee)); + DeadFunctions.insert(Callee); + } else { + NewCallerAndCalleeEdges += + getCachedFPI(*Callee).DirectCallsToDefinedFunctions; + } + EdgeCount += (NewCallerAndCalleeEdges - Advice.CallerAndCalleeEdges); + } if (CurrentIRSize > SizeIncreaseThreshold * InitialIRSize) ForceStop = true; - // We can delta-update module-wide features. We know the inlining only changed - // the caller, and maybe the callee (by deleting the latter). - // Nodes are simple to update. - // For edges, we 'forget' the edges that the caller and callee used to have - // before inlining, and add back what they currently have together. - int64_t NewCallerAndCalleeEdges = - getCachedFPI(*Caller).DirectCallsToDefinedFunctions; - - // A dead function's node is not actually removed from the call graph until - // the end of the call graph walk, but the node no longer belongs to any valid - // SCC. - if (CalleeWasDeleted) { - --NodeCount; - NodesInLastSCC.erase(CG.lookup(*Callee)); - DeadFunctions.insert(Callee); - } else { - NewCallerAndCalleeEdges += - getCachedFPI(*Callee).DirectCallsToDefinedFunctions; - } - EdgeCount += (NewCallerAndCalleeEdges - Advice.CallerAndCalleeEdges); assert(CurrentIRSize >= 0 && EdgeCount >= 0 && NodeCount >= 0); } @@ -379,9 +395,17 @@ std::unique_ptr<InlineAdvice> MLInlineAdvisor::getAdviceImpl(CallBase &CB) { auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(Caller); if (SkipPolicy == SkipMLPolicyCriteria::IfCallerIsNotCold) { - if (!PSI.isFunctionEntryCold(&Caller)) - return std::make_unique<InlineAdvice>(this, CB, ORE, - GetDefaultAdvice(CB)); + if (!PSI.isFunctionEntryCold(&Caller)) { + // Return a MLInlineAdvice, despite delegating to the default advice, + // because we need to keep track of the internal state. This is different + // from the other instances where we return a "default" InlineAdvice, + // which happen at points we won't come back to the MLAdvisor for + // decisions requiring that state. + return ForceStop ? std::make_unique<InlineAdvice>(this, CB, ORE, + GetDefaultAdvice(CB)) + : std::make_unique<MLInlineAdvice>(this, CB, ORE, + GetDefaultAdvice(CB)); + } } auto MandatoryKind = InlineAdvisor::getMandatoryKind(CB, FAM, ORE); // If this is a "never inline" case, there won't be any changes to internal diff --git a/llvm/lib/Analysis/MemoryLocation.cpp b/llvm/lib/Analysis/MemoryLocation.cpp index dcc5117..1c5f08e 100644 --- a/llvm/lib/Analysis/MemoryLocation.cpp +++ b/llvm/lib/Analysis/MemoryLocation.cpp @@ -245,7 +245,7 @@ MemoryLocation MemoryLocation::getForArgument(const CallBase *Call, assert(ArgIdx == 0 && "Invalid argument index"); auto *Ty = cast<VectorType>(II->getType()); - if (auto KnownType = getKnownTypeFromMaskedOp(II->getOperand(2), Ty)) + if (auto KnownType = getKnownTypeFromMaskedOp(II->getOperand(1), Ty)) return MemoryLocation(Arg, DL.getTypeStoreSize(*KnownType), AATags); return MemoryLocation( @@ -255,7 +255,7 @@ MemoryLocation MemoryLocation::getForArgument(const CallBase *Call, assert(ArgIdx == 1 && "Invalid argument index"); auto *Ty = cast<VectorType>(II->getArgOperand(0)->getType()); - if (auto KnownType = getKnownTypeFromMaskedOp(II->getOperand(3), Ty)) + if (auto KnownType = getKnownTypeFromMaskedOp(II->getOperand(2), Ty)) return MemoryLocation(Arg, DL.getTypeStoreSize(*KnownType), AATags); return MemoryLocation( diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index a64b93d..6f7dd79 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -1840,19 +1840,19 @@ const SCEV *ScalarEvolution::getZeroExtendExprImpl(const SCEV *Op, Type *Ty, // = zext((2^K * (trunc X to i{N-K}))<nuw>) to iM // = (2^K * (zext(trunc X to i{N-K}) to iM))<nuw>. // - if (SM->getNumOperands() == 2) - if (auto *MulLHS = dyn_cast<SCEVConstant>(SM->getOperand(0))) - if (MulLHS->getAPInt().isPowerOf2()) - if (auto *TruncRHS = dyn_cast<SCEVTruncateExpr>(SM->getOperand(1))) { - int NewTruncBits = getTypeSizeInBits(TruncRHS->getType()) - - MulLHS->getAPInt().logBase2(); - Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits); - return getMulExpr( - getZeroExtendExpr(MulLHS, Ty), - getZeroExtendExpr( - getTruncateExpr(TruncRHS->getOperand(), NewTruncTy), Ty), - SCEV::FlagNUW, Depth + 1); - } + const APInt *C; + const SCEV *TruncRHS; + if (match(SM, + m_scev_Mul(m_scev_APInt(C), m_scev_Trunc(m_SCEV(TruncRHS)))) && + C->isPowerOf2()) { + int NewTruncBits = + getTypeSizeInBits(SM->getOperand(1)->getType()) - C->logBase2(); + Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits); + return getMulExpr( + getZeroExtendExpr(SM->getOperand(0), Ty), + getZeroExtendExpr(getTruncateExpr(TruncRHS, NewTruncTy), Ty), + SCEV::FlagNUW, Depth + 1); + } } // zext(umin(x, y)) -> umin(zext(x), zext(y)) @@ -3144,20 +3144,19 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops, if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) { if (Ops.size() == 2) { // C1*(C2+V) -> C1*C2 + C1*V - if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) - // If any of Add's ops are Adds or Muls with a constant, apply this - // transformation as well. - // - // TODO: There are some cases where this transformation is not - // profitable; for example, Add = (C0 + X) * Y + Z. Maybe the scope of - // this transformation should be narrowed down. - if (Add->getNumOperands() == 2 && containsConstantInAddMulChain(Add)) { - const SCEV *LHS = getMulExpr(LHSC, Add->getOperand(0), - SCEV::FlagAnyWrap, Depth + 1); - const SCEV *RHS = getMulExpr(LHSC, Add->getOperand(1), - SCEV::FlagAnyWrap, Depth + 1); - return getAddExpr(LHS, RHS, SCEV::FlagAnyWrap, Depth + 1); - } + // If any of Add's ops are Adds or Muls with a constant, apply this + // transformation as well. + // + // TODO: There are some cases where this transformation is not + // profitable; for example, Add = (C0 + X) * Y + Z. Maybe the scope of + // this transformation should be narrowed down. + const SCEV *Op0, *Op1; + if (match(Ops[1], m_scev_Add(m_SCEV(Op0), m_SCEV(Op1))) && + containsConstantInAddMulChain(Ops[1])) { + const SCEV *LHS = getMulExpr(LHSC, Op0, SCEV::FlagAnyWrap, Depth + 1); + const SCEV *RHS = getMulExpr(LHSC, Op1, SCEV::FlagAnyWrap, Depth + 1); + return getAddExpr(LHS, RHS, SCEV::FlagAnyWrap, Depth + 1); + } if (Ops[0]->isAllOnesValue()) { // If we have a mul by -1 of an add, try distributing the -1 among the @@ -3578,20 +3577,12 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, } // ((-C + (C smax %x)) /u %x) evaluates to zero, for any positive constant C. - if (const auto *AE = dyn_cast<SCEVAddExpr>(LHS); - AE && AE->getNumOperands() == 2) { - if (const auto *VC = dyn_cast<SCEVConstant>(AE->getOperand(0))) { - const APInt &NegC = VC->getAPInt(); - if (NegC.isNegative() && !NegC.isMinSignedValue()) { - const auto *MME = dyn_cast<SCEVSMaxExpr>(AE->getOperand(1)); - if (MME && MME->getNumOperands() == 2 && - isa<SCEVConstant>(MME->getOperand(0)) && - cast<SCEVConstant>(MME->getOperand(0))->getAPInt() == -NegC && - MME->getOperand(1) == RHS) - return getZero(LHS->getType()); - } - } - } + const APInt *NegC, *C; + if (match(LHS, + m_scev_Add(m_scev_APInt(NegC), + m_scev_SMax(m_scev_APInt(C), m_scev_Specific(RHS)))) && + NegC->isNegative() && !NegC->isMinSignedValue() && *C == -*NegC) + return getZero(LHS->getType()); // TODO: Generalize to handle any common factors. // udiv (mul nuw a, vscale), (mul nuw b, vscale) --> udiv a, b @@ -4623,17 +4614,11 @@ const SCEV *ScalarEvolution::getNegativeSCEV(const SCEV *V, /// If Expr computes ~A, return A else return nullptr static const SCEV *MatchNotExpr(const SCEV *Expr) { - const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Expr); - if (!Add || Add->getNumOperands() != 2 || - !Add->getOperand(0)->isAllOnesValue()) - return nullptr; - - const SCEVMulExpr *AddRHS = dyn_cast<SCEVMulExpr>(Add->getOperand(1)); - if (!AddRHS || AddRHS->getNumOperands() != 2 || - !AddRHS->getOperand(0)->isAllOnesValue()) - return nullptr; - - return AddRHS->getOperand(1); + const SCEV *MulOp; + if (match(Expr, m_scev_Add(m_scev_AllOnes(), + m_scev_Mul(m_scev_AllOnes(), m_SCEV(MulOp))))) + return MulOp; + return nullptr; } /// Return a SCEV corresponding to ~V = -1-V @@ -10797,19 +10782,15 @@ static bool HasSameValue(const SCEV *A, const SCEV *B) { } static bool MatchBinarySub(const SCEV *S, const SCEV *&LHS, const SCEV *&RHS) { - const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S); - if (!Add || Add->getNumOperands() != 2) + const SCEV *Op0, *Op1; + if (!match(S, m_scev_Add(m_SCEV(Op0), m_SCEV(Op1)))) return false; - if (auto *ME = dyn_cast<SCEVMulExpr>(Add->getOperand(0)); - ME && ME->getNumOperands() == 2 && ME->getOperand(0)->isAllOnesValue()) { - LHS = Add->getOperand(1); - RHS = ME->getOperand(1); + if (match(Op0, m_scev_Mul(m_scev_AllOnes(), m_SCEV(RHS)))) { + LHS = Op1; return true; } - if (auto *ME = dyn_cast<SCEVMulExpr>(Add->getOperand(1)); - ME && ME->getNumOperands() == 2 && ME->getOperand(0)->isAllOnesValue()) { - LHS = Add->getOperand(0); - RHS = ME->getOperand(1); + if (match(Op1, m_scev_Mul(m_scev_AllOnes(), m_SCEV(RHS)))) { + LHS = Op0; return true; } return false; @@ -12172,13 +12153,10 @@ bool ScalarEvolution::isImpliedCondBalancedTypes( bool ScalarEvolution::splitBinaryAdd(const SCEV *Expr, const SCEV *&L, const SCEV *&R, SCEV::NoWrapFlags &Flags) { - const auto *AE = dyn_cast<SCEVAddExpr>(Expr); - if (!AE || AE->getNumOperands() != 2) + if (!match(Expr, m_scev_Add(m_SCEV(L), m_SCEV(R)))) return false; - L = AE->getOperand(0); - R = AE->getOperand(1); - Flags = AE->getNoWrapFlags(); + Flags = cast<SCEVAddExpr>(Expr)->getNoWrapFlags(); return true; } @@ -12220,12 +12198,11 @@ ScalarEvolution::computeConstantDifference(const SCEV *More, const SCEV *Less) { // Try to match a common constant multiply. auto MatchConstMul = [](const SCEV *S) -> std::optional<std::pair<const SCEV *, APInt>> { - auto *M = dyn_cast<SCEVMulExpr>(S); - if (!M || M->getNumOperands() != 2 || - !isa<SCEVConstant>(M->getOperand(0))) - return std::nullopt; - return { - {M->getOperand(1), cast<SCEVConstant>(M->getOperand(0))->getAPInt()}}; + const APInt *C; + const SCEV *Op; + if (match(S, m_scev_Mul(m_scev_APInt(C), m_SCEV(Op)))) + return {{Op, *C}}; + return std::nullopt; }; if (auto MatchedMore = MatchConstMul(More)) { if (auto MatchedLess = MatchConstMul(Less)) { @@ -15496,6 +15473,38 @@ void ScalarEvolution::LoopGuards::collectFromPHI( } } +// Return a new SCEV that modifies \p Expr to the closest number divides by +// \p Divisor and less or equal than Expr. For now, only handle constant +// Expr. +static const SCEV *getPreviousSCEVDivisibleByDivisor(const SCEV *Expr, + const APInt &DivisorVal, + ScalarEvolution &SE) { + const APInt *ExprVal; + if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() || + DivisorVal.isNonPositive()) + return Expr; + APInt Rem = ExprVal->urem(DivisorVal); + // return the SCEV: Expr - Expr % Divisor + return SE.getConstant(*ExprVal - Rem); +} + +// Return a new SCEV that modifies \p Expr to the closest number divides by +// \p Divisor and greater or equal than Expr. For now, only handle constant +// Expr. +static const SCEV *getNextSCEVDivisibleByDivisor(const SCEV *Expr, + const APInt &DivisorVal, + ScalarEvolution &SE) { + const APInt *ExprVal; + if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() || + DivisorVal.isNonPositive()) + return Expr; + APInt Rem = ExprVal->urem(DivisorVal); + if (Rem.isZero()) + return Expr; + // return the SCEV: Expr + Divisor - Expr % Divisor + return SE.getConstant(*ExprVal + DivisorVal - Rem); +} + void ScalarEvolution::LoopGuards::collectFromBlock( ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards, const BasicBlock *Block, const BasicBlock *Pred, @@ -15557,51 +15566,12 @@ void ScalarEvolution::LoopGuards::collectFromBlock( auto IsMinMaxSCEVWithNonNegativeConstant = [&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS, const SCEV *&RHS) { - if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) { - if (MinMax->getNumOperands() != 2) - return false; - if (auto *C = dyn_cast<SCEVConstant>(MinMax->getOperand(0))) { - if (C->getAPInt().isNegative()) - return false; - SCTy = MinMax->getSCEVType(); - LHS = MinMax->getOperand(0); - RHS = MinMax->getOperand(1); - return true; - } - } - return false; + const APInt *C; + SCTy = Expr->getSCEVType(); + return match(Expr, m_scev_MinMax(m_SCEV(LHS), m_SCEV(RHS))) && + match(LHS, m_scev_APInt(C)) && C->isNonNegative(); }; - // Return a new SCEV that modifies \p Expr to the closest number divides by - // \p Divisor and greater or equal than Expr. For now, only handle constant - // Expr. - auto GetNextSCEVDividesByDivisor = [&](const SCEV *Expr, - const APInt &DivisorVal) { - const APInt *ExprVal; - if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() || - DivisorVal.isNonPositive()) - return Expr; - APInt Rem = ExprVal->urem(DivisorVal); - if (Rem.isZero()) - return Expr; - // return the SCEV: Expr + Divisor - Expr % Divisor - return SE.getConstant(*ExprVal + DivisorVal - Rem); - }; - - // Return a new SCEV that modifies \p Expr to the closest number divides by - // \p Divisor and less or equal than Expr. For now, only handle constant - // Expr. - auto GetPreviousSCEVDividesByDivisor = [&](const SCEV *Expr, - const APInt &DivisorVal) { - const APInt *ExprVal; - if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() || - DivisorVal.isNonPositive()) - return Expr; - APInt Rem = ExprVal->urem(DivisorVal); - // return the SCEV: Expr - Expr % Divisor - return SE.getConstant(*ExprVal - Rem); - }; - // Apply divisibilty by \p Divisor on MinMaxExpr with constant values, // recursively. This is done by aligning up/down the constant value to the // Divisor. @@ -15623,8 +15593,9 @@ void ScalarEvolution::LoopGuards::collectFromBlock( assert(SE.isKnownNonNegative(MinMaxLHS) && "Expected non-negative operand!"); auto *DivisibleExpr = - IsMin ? GetPreviousSCEVDividesByDivisor(MinMaxLHS, DivisorVal) - : GetNextSCEVDividesByDivisor(MinMaxLHS, DivisorVal); + IsMin + ? getPreviousSCEVDivisibleByDivisor(MinMaxLHS, DivisorVal, SE) + : getNextSCEVDivisibleByDivisor(MinMaxLHS, DivisorVal, SE); SmallVector<const SCEV *> Ops = { ApplyDivisibiltyOnMinMaxExpr(MinMaxRHS, Divisor), DivisibleExpr}; return SE.getMinMaxExpr(SCTy, Ops); @@ -15701,21 +15672,21 @@ void ScalarEvolution::LoopGuards::collectFromBlock( [[fallthrough]]; case CmpInst::ICMP_SLT: { RHS = SE.getMinusSCEV(RHS, One); - RHS = GetPreviousSCEVDividesByDivisor(RHS, DividesBy); + RHS = getPreviousSCEVDivisibleByDivisor(RHS, DividesBy, SE); break; } case CmpInst::ICMP_UGT: case CmpInst::ICMP_SGT: RHS = SE.getAddExpr(RHS, One); - RHS = GetNextSCEVDividesByDivisor(RHS, DividesBy); + RHS = getNextSCEVDivisibleByDivisor(RHS, DividesBy, SE); break; case CmpInst::ICMP_ULE: case CmpInst::ICMP_SLE: - RHS = GetPreviousSCEVDividesByDivisor(RHS, DividesBy); + RHS = getPreviousSCEVDivisibleByDivisor(RHS, DividesBy, SE); break; case CmpInst::ICMP_UGE: case CmpInst::ICMP_SGE: - RHS = GetNextSCEVDividesByDivisor(RHS, DividesBy); + RHS = getNextSCEVDivisibleByDivisor(RHS, DividesBy, SE); break; default: break; @@ -15769,22 +15740,29 @@ void ScalarEvolution::LoopGuards::collectFromBlock( case CmpInst::ICMP_NE: if (match(RHS, m_scev_Zero())) { const SCEV *OneAlignedUp = - GetNextSCEVDividesByDivisor(One, DividesBy); + getNextSCEVDivisibleByDivisor(One, DividesBy, SE); To = SE.getUMaxExpr(FromRewritten, OneAlignedUp); } else { + // LHS != RHS can be rewritten as (LHS - RHS) = UMax(1, LHS - RHS), + // but creating the subtraction eagerly is expensive. Track the + // inequalities in a separate map, and materialize the rewrite lazily + // when encountering a suitable subtraction while re-writing. if (LHS->getType()->isPointerTy()) { LHS = SE.getLosslessPtrToIntExpr(LHS); RHS = SE.getLosslessPtrToIntExpr(RHS); if (isa<SCEVCouldNotCompute>(LHS) || isa<SCEVCouldNotCompute>(RHS)) break; } - auto AddSubRewrite = [&](const SCEV *A, const SCEV *B) { - const SCEV *Sub = SE.getMinusSCEV(A, B); - AddRewrite(Sub, Sub, - SE.getUMaxExpr(Sub, SE.getOne(From->getType()))); - }; - AddSubRewrite(LHS, RHS); - AddSubRewrite(RHS, LHS); + const SCEVConstant *C; + const SCEV *A, *B; + if (match(RHS, m_scev_Add(m_SCEVConstant(C), m_SCEV(A))) && + match(LHS, m_scev_Add(m_scev_Specific(C), m_SCEV(B)))) { + RHS = A; + LHS = B; + } + if (LHS > RHS) + std::swap(LHS, RHS); + Guards.NotEqual.insert({LHS, RHS}); continue; } break; @@ -15918,13 +15896,15 @@ const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const { class SCEVLoopGuardRewriter : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> { const DenseMap<const SCEV *, const SCEV *> ⤅ + const SmallDenseSet<std::pair<const SCEV *, const SCEV *>> ≠ SCEV::NoWrapFlags FlagMask = SCEV::FlagAnyWrap; public: SCEVLoopGuardRewriter(ScalarEvolution &SE, const ScalarEvolution::LoopGuards &Guards) - : SCEVRewriteVisitor(SE), Map(Guards.RewriteMap) { + : SCEVRewriteVisitor(SE), Map(Guards.RewriteMap), + NotEqual(Guards.NotEqual) { if (Guards.PreserveNUW) FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNUW); if (Guards.PreserveNSW) @@ -15979,14 +15959,39 @@ const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const { } const SCEV *visitAddExpr(const SCEVAddExpr *Expr) { + // Helper to check if S is a subtraction (A - B) where A != B, and if so, + // return UMax(S, 1). + auto RewriteSubtraction = [&](const SCEV *S) -> const SCEV * { + const SCEV *LHS, *RHS; + if (MatchBinarySub(S, LHS, RHS)) { + if (LHS > RHS) + std::swap(LHS, RHS); + if (NotEqual.contains({LHS, RHS})) { + const SCEV *OneAlignedUp = getNextSCEVDivisibleByDivisor( + SE.getOne(S->getType()), SE.getConstantMultiple(S), SE); + return SE.getUMaxExpr(OneAlignedUp, S); + } + } + return nullptr; + }; + + // Check if Expr itself is a subtraction pattern with guard info. + if (const SCEV *Rewritten = RewriteSubtraction(Expr)) + return Rewritten; + // Trip count expressions sometimes consist of adding 3 operands, i.e. // (Const + A + B). There may be guard info for A + B, and if so, apply // it. // TODO: Could more generally apply guards to Add sub-expressions. if (isa<SCEVConstant>(Expr->getOperand(0)) && Expr->getNumOperands() == 3) { - if (const SCEV *S = Map.lookup( - SE.getAddExpr(Expr->getOperand(1), Expr->getOperand(2)))) + const SCEV *Add = + SE.getAddExpr(Expr->getOperand(1), Expr->getOperand(2)); + if (const SCEV *Rewritten = RewriteSubtraction(Add)) + return SE.getAddExpr( + Expr->getOperand(0), Rewritten, + ScalarEvolution::maskFlags(Expr->getNoWrapFlags(), FlagMask)); + if (const SCEV *S = Map.lookup(Add)) return SE.getAddExpr(Expr->getOperand(0), S); } SmallVector<const SCEV *, 2> Operands; @@ -16021,7 +16026,7 @@ const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const { } }; - if (RewriteMap.empty()) + if (RewriteMap.empty() && NotEqual.empty()) return Expr; SCEVLoopGuardRewriter Rewriter(SE, *this); diff --git a/llvm/lib/Analysis/StaticDataProfileInfo.cpp b/llvm/lib/Analysis/StaticDataProfileInfo.cpp index e7f0b2c..61d4935 100644 --- a/llvm/lib/Analysis/StaticDataProfileInfo.cpp +++ b/llvm/lib/Analysis/StaticDataProfileInfo.cpp @@ -1,10 +1,14 @@ #include "llvm/Analysis/StaticDataProfileInfo.h" #include "llvm/Analysis/ProfileSummaryInfo.h" #include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" #include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/Module.h" #include "llvm/InitializePasses.h" #include "llvm/ProfileData/InstrProf.h" +#define DEBUG_TYPE "static-data-profile-info" + using namespace llvm; namespace llvm { @@ -79,6 +83,17 @@ StaticDataProfileInfo::getConstantHotnessUsingProfileCount( return StaticDataHotness::LukewarmOrUnknown; } +StaticDataProfileInfo::StaticDataHotness +StaticDataProfileInfo::getSectionHotnessUsingDataAccessProfile( + std::optional<StringRef> MaybeSectionPrefix) const { + if (!MaybeSectionPrefix) + return StaticDataHotness::LukewarmOrUnknown; + StringRef Prefix = *MaybeSectionPrefix; + assert((Prefix == "hot" || Prefix == "unlikely") && + "Expect section_prefix to be one of hot or unlikely"); + return Prefix == "hot" ? StaticDataHotness::Hot : StaticDataHotness::Cold; +} + StringRef StaticDataProfileInfo::hotnessToStr(StaticDataHotness Hotness) const { switch (Hotness) { case StaticDataHotness::Cold: @@ -101,13 +116,66 @@ StaticDataProfileInfo::getConstantProfileCount(const Constant *C) const { StringRef StaticDataProfileInfo::getConstantSectionPrefix( const Constant *C, const ProfileSummaryInfo *PSI) const { std::optional<uint64_t> Count = getConstantProfileCount(C); + +#ifndef NDEBUG + auto DbgPrintPrefix = [](StringRef Prefix) { + return Prefix.empty() ? "<empty>" : Prefix; + }; +#endif + + if (EnableDataAccessProf) { + // Module flag `HasDataAccessProf` is 1 -> empty section prefix means + // unknown hotness except for string literals. + if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(C); + GV && llvm::memprof::IsAnnotationOK(*GV) && + !GV->getName().starts_with(".str")) { + auto HotnessFromDataAccessProf = + getSectionHotnessUsingDataAccessProfile(GV->getSectionPrefix()); + + if (!Count) { + StringRef Prefix = hotnessToStr(HotnessFromDataAccessProf); + LLVM_DEBUG(dbgs() << GV->getName() << " has section prefix " + << DbgPrintPrefix(Prefix) + << ", solely from data access profiles\n"); + return Prefix; + } + + // Both data access profiles and PGO counters are available. Use the + // hotter one. + auto HotnessFromPGO = getConstantHotnessUsingProfileCount(C, PSI, *Count); + StaticDataHotness GlobalVarHotness = StaticDataHotness::LukewarmOrUnknown; + if (HotnessFromDataAccessProf == StaticDataHotness::Hot || + HotnessFromPGO == StaticDataHotness::Hot) { + GlobalVarHotness = StaticDataHotness::Hot; + } else if (HotnessFromDataAccessProf == + StaticDataHotness::LukewarmOrUnknown || + HotnessFromPGO == StaticDataHotness::LukewarmOrUnknown) { + GlobalVarHotness = StaticDataHotness::LukewarmOrUnknown; + } else { + GlobalVarHotness = StaticDataHotness::Cold; + } + StringRef Prefix = hotnessToStr(GlobalVarHotness); + LLVM_DEBUG( + dbgs() << GV->getName() << " has section prefix " + << DbgPrintPrefix(Prefix) + << ", the max from data access profiles as " + << DbgPrintPrefix(hotnessToStr(HotnessFromDataAccessProf)) + << " and PGO counters as " + << DbgPrintPrefix(hotnessToStr(HotnessFromPGO)) << "\n"); + return Prefix; + } + } if (!Count) return ""; return hotnessToStr(getConstantHotnessUsingProfileCount(C, PSI, *Count)); } bool StaticDataProfileInfoWrapperPass::doInitialization(Module &M) { - Info.reset(new StaticDataProfileInfo()); + bool EnableDataAccessProf = false; + if (auto *MD = mdconst::extract_or_null<ConstantInt>( + M.getModuleFlag("EnableDataAccessProf"))) + EnableDataAccessProf = MD->getZExtValue(); + Info.reset(new StaticDataProfileInfo(EnableDataAccessProf)); return false; } diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp index 9655c88..0a72076 100644 --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -7695,6 +7695,11 @@ static bool isGuaranteedNotToBeUndefOrPoison( } if (IsWellDefined) return true; + } else if (auto *Splat = isa<ShuffleVectorInst>(Opr) ? getSplatValue(Opr) + : nullptr) { + // For splats we only need to check the value being splatted. + if (OpCheck(Splat)) + return true; } else if (all_of(Opr->operands(), OpCheck)) return true; } |