aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/Utils/LoopUtils.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Transforms/Utils/LoopUtils.cpp')
-rw-r--r--llvm/lib/Transforms/Utils/LoopUtils.cpp83
1 files changed, 73 insertions, 10 deletions
diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp
index 6e0c195..1b2b371 100644
--- a/llvm/lib/Transforms/Utils/LoopUtils.cpp
+++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp
@@ -1628,42 +1628,92 @@ Loop *llvm::cloneLoop(Loop *L, Loop *PL, ValueToValueMapTy &VM,
struct PointerBounds {
TrackingVH<Value> Start;
TrackingVH<Value> End;
+ Value *StrideToCheck;
};
/// Expand code for the lower and upper bound of the pointer group \p CG
/// in \p TheLoop. \return the values for the bounds.
static PointerBounds expandBounds(const RuntimeCheckingPtrGroup *CG,
Loop *TheLoop, Instruction *Loc,
- SCEVExpander &Exp) {
+ SCEVExpander &Exp, bool HoistRuntimeChecks) {
LLVMContext &Ctx = Loc->getContext();
Type *PtrArithTy = PointerType::get(Ctx, CG->AddressSpace);
Value *Start = nullptr, *End = nullptr;
LLVM_DEBUG(dbgs() << "LAA: Adding RT check for range:\n");
- Start = Exp.expandCodeFor(CG->Low, PtrArithTy, Loc);
- End = Exp.expandCodeFor(CG->High, PtrArithTy, Loc);
+ const SCEV *Low = CG->Low, *High = CG->High, *Stride = nullptr;
+
+ // If the Low and High values are themselves loop-variant, then we may want
+ // to expand the range to include those covered by the outer loop as well.
+ // There is a trade-off here with the advantage being that creating checks
+ // using the expanded range permits the runtime memory checks to be hoisted
+ // out of the outer loop. This reduces the cost of entering the inner loop,
+ // which can be significant for low trip counts. The disadvantage is that
+ // there is a chance we may now never enter the vectorized inner loop,
+ // whereas using a restricted range check could have allowed us to enter at
+ // least once. This is why the behaviour is not currently the default and is
+ // controlled by the parameter 'HoistRuntimeChecks'.
+ if (HoistRuntimeChecks && TheLoop->getParentLoop() &&
+ isa<SCEVAddRecExpr>(High) && isa<SCEVAddRecExpr>(Low)) {
+ auto *HighAR = cast<SCEVAddRecExpr>(High);
+ auto *LowAR = cast<SCEVAddRecExpr>(Low);
+ const Loop *OuterLoop = TheLoop->getParentLoop();
+ const SCEV *Recur = LowAR->getStepRecurrence(*Exp.getSE());
+ if (Recur == HighAR->getStepRecurrence(*Exp.getSE()) &&
+ HighAR->getLoop() == OuterLoop && LowAR->getLoop() == OuterLoop) {
+ BasicBlock *OuterLoopLatch = OuterLoop->getLoopLatch();
+ const SCEV *OuterExitCount =
+ Exp.getSE()->getExitCount(OuterLoop, OuterLoopLatch);
+ if (!isa<SCEVCouldNotCompute>(OuterExitCount) &&
+ OuterExitCount->getType()->isIntegerTy()) {
+ const SCEV *NewHigh = cast<SCEVAddRecExpr>(High)->evaluateAtIteration(
+ OuterExitCount, *Exp.getSE());
+ if (!isa<SCEVCouldNotCompute>(NewHigh)) {
+ LLVM_DEBUG(dbgs() << "LAA: Expanded RT check for range to include "
+ "outer loop in order to permit hoisting\n");
+ High = NewHigh;
+ Low = cast<SCEVAddRecExpr>(Low)->getStart();
+ // If there is a possibility that the stride is negative then we have
+ // to generate extra checks to ensure the stride is positive.
+ if (!Exp.getSE()->isKnownNonNegative(Recur)) {
+ Stride = Recur;
+ LLVM_DEBUG(dbgs() << "LAA: ... but need to check stride is "
+ "positive: "
+ << *Stride << '\n');
+ }
+ }
+ }
+ }
+ }
+
+ Start = Exp.expandCodeFor(Low, PtrArithTy, Loc);
+ End = Exp.expandCodeFor(High, PtrArithTy, Loc);
if (CG->NeedsFreeze) {
IRBuilder<> Builder(Loc);
Start = Builder.CreateFreeze(Start, Start->getName() + ".fr");
End = Builder.CreateFreeze(End, End->getName() + ".fr");
}
- LLVM_DEBUG(dbgs() << "Start: " << *CG->Low << " End: " << *CG->High << "\n");
- return {Start, End};
+ Value *StrideVal =
+ Stride ? Exp.expandCodeFor(Stride, Type::getInt64Ty(Ctx), Loc) : nullptr;
+ LLVM_DEBUG(dbgs() << "Start: " << *Low << " End: " << *High << "\n");
+ return {Start, End, StrideVal};
}
/// Turns a collection of checks into a collection of expanded upper and
/// lower bounds for both pointers in the check.
static SmallVector<std::pair<PointerBounds, PointerBounds>, 4>
expandBounds(const SmallVectorImpl<RuntimePointerCheck> &PointerChecks, Loop *L,
- Instruction *Loc, SCEVExpander &Exp) {
+ Instruction *Loc, SCEVExpander &Exp, bool HoistRuntimeChecks) {
SmallVector<std::pair<PointerBounds, PointerBounds>, 4> ChecksWithBounds;
// Here we're relying on the SCEV Expander's cache to only emit code for the
// same bounds once.
transform(PointerChecks, std::back_inserter(ChecksWithBounds),
[&](const RuntimePointerCheck &Check) {
- PointerBounds First = expandBounds(Check.first, L, Loc, Exp),
- Second = expandBounds(Check.second, L, Loc, Exp);
+ PointerBounds First = expandBounds(Check.first, L, Loc, Exp,
+ HoistRuntimeChecks),
+ Second = expandBounds(Check.second, L, Loc, Exp,
+ HoistRuntimeChecks);
return std::make_pair(First, Second);
});
@@ -1673,10 +1723,11 @@ expandBounds(const SmallVectorImpl<RuntimePointerCheck> &PointerChecks, Loop *L,
Value *llvm::addRuntimeChecks(
Instruction *Loc, Loop *TheLoop,
const SmallVectorImpl<RuntimePointerCheck> &PointerChecks,
- SCEVExpander &Exp) {
+ SCEVExpander &Exp, bool HoistRuntimeChecks) {
// TODO: Move noalias annotation code from LoopVersioning here and share with LV if possible.
// TODO: Pass RtPtrChecking instead of PointerChecks and SE separately, if possible
- auto ExpandedChecks = expandBounds(PointerChecks, TheLoop, Loc, Exp);
+ auto ExpandedChecks =
+ expandBounds(PointerChecks, TheLoop, Loc, Exp, HoistRuntimeChecks);
LLVMContext &Ctx = Loc->getContext();
IRBuilder<InstSimplifyFolder> ChkBuilder(Ctx,
@@ -1707,6 +1758,18 @@ Value *llvm::addRuntimeChecks(
Value *Cmp0 = ChkBuilder.CreateICmpULT(A.Start, B.End, "bound0");
Value *Cmp1 = ChkBuilder.CreateICmpULT(B.Start, A.End, "bound1");
Value *IsConflict = ChkBuilder.CreateAnd(Cmp0, Cmp1, "found.conflict");
+ if (A.StrideToCheck) {
+ Value *IsNegativeStride = ChkBuilder.CreateICmpSLT(
+ A.StrideToCheck, ConstantInt::get(A.StrideToCheck->getType(), 0),
+ "stride.check");
+ IsConflict = ChkBuilder.CreateOr(IsConflict, IsNegativeStride);
+ }
+ if (B.StrideToCheck) {
+ Value *IsNegativeStride = ChkBuilder.CreateICmpSLT(
+ B.StrideToCheck, ConstantInt::get(B.StrideToCheck->getType(), 0),
+ "stride.check");
+ IsConflict = ChkBuilder.CreateOr(IsConflict, IsNegativeStride);
+ }
if (MemoryRuntimeCheck) {
IsConflict =
ChkBuilder.CreateOr(MemoryRuntimeCheck, IsConflict, "conflict.rdx");