diff options
Diffstat (limited to 'llvm/lib/Transforms/Utils/LoopUtils.cpp')
-rw-r--r-- | llvm/lib/Transforms/Utils/LoopUtils.cpp | 83 |
1 files changed, 73 insertions, 10 deletions
diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp index 6e0c195..1b2b371 100644 --- a/llvm/lib/Transforms/Utils/LoopUtils.cpp +++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp @@ -1628,42 +1628,92 @@ Loop *llvm::cloneLoop(Loop *L, Loop *PL, ValueToValueMapTy &VM, struct PointerBounds { TrackingVH<Value> Start; TrackingVH<Value> End; + Value *StrideToCheck; }; /// Expand code for the lower and upper bound of the pointer group \p CG /// in \p TheLoop. \return the values for the bounds. static PointerBounds expandBounds(const RuntimeCheckingPtrGroup *CG, Loop *TheLoop, Instruction *Loc, - SCEVExpander &Exp) { + SCEVExpander &Exp, bool HoistRuntimeChecks) { LLVMContext &Ctx = Loc->getContext(); Type *PtrArithTy = PointerType::get(Ctx, CG->AddressSpace); Value *Start = nullptr, *End = nullptr; LLVM_DEBUG(dbgs() << "LAA: Adding RT check for range:\n"); - Start = Exp.expandCodeFor(CG->Low, PtrArithTy, Loc); - End = Exp.expandCodeFor(CG->High, PtrArithTy, Loc); + const SCEV *Low = CG->Low, *High = CG->High, *Stride = nullptr; + + // If the Low and High values are themselves loop-variant, then we may want + // to expand the range to include those covered by the outer loop as well. + // There is a trade-off here with the advantage being that creating checks + // using the expanded range permits the runtime memory checks to be hoisted + // out of the outer loop. This reduces the cost of entering the inner loop, + // which can be significant for low trip counts. The disadvantage is that + // there is a chance we may now never enter the vectorized inner loop, + // whereas using a restricted range check could have allowed us to enter at + // least once. This is why the behaviour is not currently the default and is + // controlled by the parameter 'HoistRuntimeChecks'. + if (HoistRuntimeChecks && TheLoop->getParentLoop() && + isa<SCEVAddRecExpr>(High) && isa<SCEVAddRecExpr>(Low)) { + auto *HighAR = cast<SCEVAddRecExpr>(High); + auto *LowAR = cast<SCEVAddRecExpr>(Low); + const Loop *OuterLoop = TheLoop->getParentLoop(); + const SCEV *Recur = LowAR->getStepRecurrence(*Exp.getSE()); + if (Recur == HighAR->getStepRecurrence(*Exp.getSE()) && + HighAR->getLoop() == OuterLoop && LowAR->getLoop() == OuterLoop) { + BasicBlock *OuterLoopLatch = OuterLoop->getLoopLatch(); + const SCEV *OuterExitCount = + Exp.getSE()->getExitCount(OuterLoop, OuterLoopLatch); + if (!isa<SCEVCouldNotCompute>(OuterExitCount) && + OuterExitCount->getType()->isIntegerTy()) { + const SCEV *NewHigh = cast<SCEVAddRecExpr>(High)->evaluateAtIteration( + OuterExitCount, *Exp.getSE()); + if (!isa<SCEVCouldNotCompute>(NewHigh)) { + LLVM_DEBUG(dbgs() << "LAA: Expanded RT check for range to include " + "outer loop in order to permit hoisting\n"); + High = NewHigh; + Low = cast<SCEVAddRecExpr>(Low)->getStart(); + // If there is a possibility that the stride is negative then we have + // to generate extra checks to ensure the stride is positive. + if (!Exp.getSE()->isKnownNonNegative(Recur)) { + Stride = Recur; + LLVM_DEBUG(dbgs() << "LAA: ... but need to check stride is " + "positive: " + << *Stride << '\n'); + } + } + } + } + } + + Start = Exp.expandCodeFor(Low, PtrArithTy, Loc); + End = Exp.expandCodeFor(High, PtrArithTy, Loc); if (CG->NeedsFreeze) { IRBuilder<> Builder(Loc); Start = Builder.CreateFreeze(Start, Start->getName() + ".fr"); End = Builder.CreateFreeze(End, End->getName() + ".fr"); } - LLVM_DEBUG(dbgs() << "Start: " << *CG->Low << " End: " << *CG->High << "\n"); - return {Start, End}; + Value *StrideVal = + Stride ? Exp.expandCodeFor(Stride, Type::getInt64Ty(Ctx), Loc) : nullptr; + LLVM_DEBUG(dbgs() << "Start: " << *Low << " End: " << *High << "\n"); + return {Start, End, StrideVal}; } /// Turns a collection of checks into a collection of expanded upper and /// lower bounds for both pointers in the check. static SmallVector<std::pair<PointerBounds, PointerBounds>, 4> expandBounds(const SmallVectorImpl<RuntimePointerCheck> &PointerChecks, Loop *L, - Instruction *Loc, SCEVExpander &Exp) { + Instruction *Loc, SCEVExpander &Exp, bool HoistRuntimeChecks) { SmallVector<std::pair<PointerBounds, PointerBounds>, 4> ChecksWithBounds; // Here we're relying on the SCEV Expander's cache to only emit code for the // same bounds once. transform(PointerChecks, std::back_inserter(ChecksWithBounds), [&](const RuntimePointerCheck &Check) { - PointerBounds First = expandBounds(Check.first, L, Loc, Exp), - Second = expandBounds(Check.second, L, Loc, Exp); + PointerBounds First = expandBounds(Check.first, L, Loc, Exp, + HoistRuntimeChecks), + Second = expandBounds(Check.second, L, Loc, Exp, + HoistRuntimeChecks); return std::make_pair(First, Second); }); @@ -1673,10 +1723,11 @@ expandBounds(const SmallVectorImpl<RuntimePointerCheck> &PointerChecks, Loop *L, Value *llvm::addRuntimeChecks( Instruction *Loc, Loop *TheLoop, const SmallVectorImpl<RuntimePointerCheck> &PointerChecks, - SCEVExpander &Exp) { + SCEVExpander &Exp, bool HoistRuntimeChecks) { // TODO: Move noalias annotation code from LoopVersioning here and share with LV if possible. // TODO: Pass RtPtrChecking instead of PointerChecks and SE separately, if possible - auto ExpandedChecks = expandBounds(PointerChecks, TheLoop, Loc, Exp); + auto ExpandedChecks = + expandBounds(PointerChecks, TheLoop, Loc, Exp, HoistRuntimeChecks); LLVMContext &Ctx = Loc->getContext(); IRBuilder<InstSimplifyFolder> ChkBuilder(Ctx, @@ -1707,6 +1758,18 @@ Value *llvm::addRuntimeChecks( Value *Cmp0 = ChkBuilder.CreateICmpULT(A.Start, B.End, "bound0"); Value *Cmp1 = ChkBuilder.CreateICmpULT(B.Start, A.End, "bound1"); Value *IsConflict = ChkBuilder.CreateAnd(Cmp0, Cmp1, "found.conflict"); + if (A.StrideToCheck) { + Value *IsNegativeStride = ChkBuilder.CreateICmpSLT( + A.StrideToCheck, ConstantInt::get(A.StrideToCheck->getType(), 0), + "stride.check"); + IsConflict = ChkBuilder.CreateOr(IsConflict, IsNegativeStride); + } + if (B.StrideToCheck) { + Value *IsNegativeStride = ChkBuilder.CreateICmpSLT( + B.StrideToCheck, ConstantInt::get(B.StrideToCheck->getType(), 0), + "stride.check"); + IsConflict = ChkBuilder.CreateOr(IsConflict, IsNegativeStride); + } if (MemoryRuntimeCheck) { IsConflict = ChkBuilder.CreateOr(MemoryRuntimeCheck, IsConflict, "conflict.rdx"); |