diff options
author | Philip Reames <preames@rivosinc.com> | 2023-05-11 08:10:49 -0700 |
---|---|---|
committer | Philip Reames <listmail@philipreames.com> | 2023-05-11 08:32:56 -0700 |
commit | d5b840131223f2ffef4e48ca769ad1eb7bb1869a (patch) | |
tree | 58e5c8eaaf69561cc630c0bf830ed39df4f7bad1 /llvm/lib/Analysis/LoopAccessAnalysis.cpp | |
parent | 28dc5f4cdd5e552c87ec72d39bd0f9d7378ab747 (diff) | |
download | llvm-d5b840131223f2ffef4e48ca769ad1eb7bb1869a.zip llvm-d5b840131223f2ffef4e48ca769ad1eb7bb1869a.tar.gz llvm-d5b840131223f2ffef4e48ca769ad1eb7bb1869a.tar.bz2 |
[LAA/LV] Simplify stride speculation logic [NFC]
The existing code makes it hard to tell that collectStridedAccess is really about identifying some loop invariant SCEV which is *profitable* to speculate is equal to one. The odd dual usage structure of Value and SCEV confuses this point.
We could choose to loosen the profitability analysis if desired. I'm not proposing doing so at this time as it exposes too many cases where the speculation is unprofitable.
Differential Revision: https://reviews.llvm.org/D147750
Diffstat (limited to 'llvm/lib/Analysis/LoopAccessAnalysis.cpp')
-rw-r--r-- | llvm/lib/Analysis/LoopAccessAnalysis.cpp | 34 |
1 files changed, 19 insertions, 15 deletions
diff --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp index 351e090..7ac51c3 100644 --- a/llvm/lib/Analysis/LoopAccessAnalysis.cpp +++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp @@ -154,21 +154,19 @@ Value *llvm::stripIntegerCast(Value *V) { } const SCEV *llvm::replaceSymbolicStrideSCEV(PredicatedScalarEvolution &PSE, - const ValueToValueMap &PtrToStride, + const DenseMap<Value *, const SCEV *> &PtrToStride, Value *Ptr) { const SCEV *OrigSCEV = PSE.getSCEV(Ptr); // If there is an entry in the map return the SCEV of the pointer with the // symbolic stride replaced by one. - ValueToValueMap::const_iterator SI = PtrToStride.find(Ptr); + DenseMap<Value *, const SCEV *>::const_iterator SI = PtrToStride.find(Ptr); if (SI == PtrToStride.end()) // For a non-symbolic stride, just return the original expression. return OrigSCEV; - Value *StrideVal = stripIntegerCast(SI->second); - ScalarEvolution *SE = PSE.getSE(); - const SCEV *StrideSCEV = SE->getSCEV(StrideVal); + const SCEV *StrideSCEV = SI->second; assert(isa<SCEVUnknown>(StrideSCEV) && "shouldn't be in map"); const auto *CT = SE->getOne(StrideSCEV->getType()); @@ -658,7 +656,7 @@ public: /// the bounds of the pointer. bool createCheckForAccess(RuntimePointerChecking &RtCheck, MemAccessInfo Access, Type *AccessTy, - const ValueToValueMap &Strides, + const DenseMap<Value *, const SCEV *> &Strides, DenseMap<Value *, unsigned> &DepSetId, Loop *TheLoop, unsigned &RunningDepId, unsigned ASId, bool ShouldCheckStride, bool Assume); @@ -669,7 +667,7 @@ public: /// Returns true if we need no check or if we do and we can generate them /// (i.e. the pointers have computable bounds). bool canCheckPtrAtRT(RuntimePointerChecking &RtCheck, ScalarEvolution *SE, - Loop *TheLoop, const ValueToValueMap &Strides, + Loop *TheLoop, const DenseMap<Value *, const SCEV *> &Strides, Value *&UncomputablePtr, bool ShouldCheckWrap = false); /// Goes over all memory accesses, checks whether a RT check is needed @@ -764,7 +762,7 @@ static bool hasComputableBounds(PredicatedScalarEvolution &PSE, Value *Ptr, /// Check whether a pointer address cannot wrap. static bool isNoWrap(PredicatedScalarEvolution &PSE, - const ValueToValueMap &Strides, Value *Ptr, Type *AccessTy, + const DenseMap<Value *, const SCEV *> &Strides, Value *Ptr, Type *AccessTy, Loop *L) { const SCEV *PtrScev = PSE.getSCEV(Ptr); if (PSE.getSE()->isLoopInvariant(PtrScev, L)) @@ -957,7 +955,7 @@ static void findForkedSCEVs( static SmallVector<PointerIntPair<const SCEV *, 1, bool>> findForkedPointer(PredicatedScalarEvolution &PSE, - const ValueToValueMap &StridesMap, Value *Ptr, + const DenseMap<Value *, const SCEV *> &StridesMap, Value *Ptr, const Loop *L) { ScalarEvolution *SE = PSE.getSE(); assert(SE->isSCEVable(Ptr->getType()) && "Value is not SCEVable!"); @@ -982,7 +980,7 @@ findForkedPointer(PredicatedScalarEvolution &PSE, bool AccessAnalysis::createCheckForAccess(RuntimePointerChecking &RtCheck, MemAccessInfo Access, Type *AccessTy, - const ValueToValueMap &StridesMap, + const DenseMap<Value *, const SCEV *> &StridesMap, DenseMap<Value *, unsigned> &DepSetId, Loop *TheLoop, unsigned &RunningDepId, unsigned ASId, bool ShouldCheckWrap, @@ -1043,7 +1041,7 @@ bool AccessAnalysis::createCheckForAccess(RuntimePointerChecking &RtCheck, bool AccessAnalysis::canCheckPtrAtRT(RuntimePointerChecking &RtCheck, ScalarEvolution *SE, Loop *TheLoop, - const ValueToValueMap &StridesMap, + const DenseMap<Value *, const SCEV *> &StridesMap, Value *&UncomputablePtr, bool ShouldCheckWrap) { // Find pointers with computable bounds. We are going to use this information // to place a runtime bound check. @@ -1373,7 +1371,7 @@ static bool isNoWrapAddRec(Value *Ptr, const SCEVAddRecExpr *AR, std::optional<int64_t> llvm::getPtrStride(PredicatedScalarEvolution &PSE, Type *AccessTy, Value *Ptr, const Loop *Lp, - const ValueToValueMap &StridesMap, + const DenseMap<Value *, const SCEV *> &StridesMap, bool Assume, bool ShouldCheckWrap) { Type *Ty = Ptr->getType(); assert(Ty->isPointerTy() && "Unexpected non-ptr"); @@ -1822,7 +1820,7 @@ static bool areStridedAccessesIndependent(uint64_t Distance, uint64_t Stride, MemoryDepChecker::Dependence::DepType MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx, const MemAccessInfo &B, unsigned BIdx, - const ValueToValueMap &Strides) { + const DenseMap<Value *, const SCEV *> &Strides) { assert (AIdx < BIdx && "Must pass arguments in program order"); auto [APtr, AIsWrite] = A; @@ -2016,7 +2014,7 @@ MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx, bool MemoryDepChecker::areDepsSafe(DepCandidates &AccessSets, MemAccessInfoList &CheckDeps, - const ValueToValueMap &Strides) { + const DenseMap<Value *, const SCEV *> &Strides) { MaxSafeDepDistBytes = -1; SmallPtrSet<MemAccessInfo, 8> Visited; @@ -2691,6 +2689,12 @@ void LoopAccessInfo::collectStridedAccess(Value *MemAccess) { if (!Ptr) return; + // Note: getStrideFromPointer is a *profitability* heuristic. We + // could broaden the scope of values returned here - to anything + // which happens to be loop invariant and contributes to the + // computation of an interesting IV - but we chose not to as we + // don't have a cost model here, and broadening the scope exposes + // far too many unprofitable cases. Value *Stride = getStrideFromPointer(Ptr, PSE->getSE(), TheLoop); if (!Stride) return; @@ -2746,7 +2750,7 @@ void LoopAccessInfo::collectStridedAccess(Value *MemAccess) { } LLVM_DEBUG(dbgs() << "LAA: Found a strided access that we can version.\n"); - SymbolicStrides[Ptr] = Stride; + SymbolicStrides[Ptr] = StrideExpr; StrideSet.insert(Stride); } |