aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/Utils/LoopUtils.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Transforms/Utils/LoopUtils.cpp')
-rw-r--r--llvm/lib/Transforms/Utils/LoopUtils.cpp149
1 files changed, 149 insertions, 0 deletions
diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp
index a5fbdb5..5f54685 100644
--- a/llvm/lib/Transforms/Utils/LoopUtils.cpp
+++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp
@@ -1534,3 +1534,152 @@ Loop *llvm::cloneLoop(Loop *L, Loop *PL, ValueToValueMapTy &VM,
return &New;
}
+
+/// IR Values for the lower and upper bounds of a pointer evolution. We
+/// need to use value-handles because SCEV expansion can invalidate previously
+/// expanded values. Thus expansion of a pointer can invalidate the bounds for
+/// a previous one.
+struct PointerBounds {
+ TrackingVH<Value> Start;
+ TrackingVH<Value> End;
+};
+
+/// Expand code for the lower and upper bound of the pointer group \p CG
+/// in \p TheLoop. \return the values for the bounds.
+static PointerBounds expandBounds(const RuntimeCheckingPtrGroup *CG,
+ Loop *TheLoop, Instruction *Loc,
+ SCEVExpander &Exp, ScalarEvolution *SE) {
+ // TODO: Add helper to retrieve pointers to CG.
+ Value *Ptr = CG->RtCheck.Pointers[CG->Members[0]].PointerValue;
+ const SCEV *Sc = SE->getSCEV(Ptr);
+
+ unsigned AS = Ptr->getType()->getPointerAddressSpace();
+ LLVMContext &Ctx = Loc->getContext();
+
+ // Use this type for pointer arithmetic.
+ Type *PtrArithTy = Type::getInt8PtrTy(Ctx, AS);
+
+ if (SE->isLoopInvariant(Sc, TheLoop)) {
+ LLVM_DEBUG(dbgs() << "LAA: Adding RT check for a loop invariant ptr:"
+ << *Ptr << "\n");
+ // Ptr could be in the loop body. If so, expand a new one at the correct
+ // location.
+ Instruction *Inst = dyn_cast<Instruction>(Ptr);
+ Value *NewPtr = (Inst && TheLoop->contains(Inst))
+ ? Exp.expandCodeFor(Sc, PtrArithTy, Loc)
+ : Ptr;
+ // We must return a half-open range, which means incrementing Sc.
+ const SCEV *ScPlusOne = SE->getAddExpr(Sc, SE->getOne(PtrArithTy));
+ Value *NewPtrPlusOne = Exp.expandCodeFor(ScPlusOne, PtrArithTy, Loc);
+ return {NewPtr, NewPtrPlusOne};
+ } else {
+ Value *Start = nullptr, *End = nullptr;
+ LLVM_DEBUG(dbgs() << "LAA: Adding RT check for range:\n");
+ Start = Exp.expandCodeFor(CG->Low, PtrArithTy, Loc);
+ End = Exp.expandCodeFor(CG->High, PtrArithTy, Loc);
+ LLVM_DEBUG(dbgs() << "Start: " << *CG->Low << " End: " << *CG->High
+ << "\n");
+ return {Start, End};
+ }
+}
+
+/// Turns a collection of checks into a collection of expanded upper and
+/// lower bounds for both pointers in the check.
+static SmallVector<std::pair<PointerBounds, PointerBounds>, 4>
+expandBounds(const SmallVectorImpl<RuntimePointerCheck> &PointerChecks, Loop *L,
+ Instruction *Loc, ScalarEvolution *SE, SCEVExpander &Exp) {
+ SmallVector<std::pair<PointerBounds, PointerBounds>, 4> ChecksWithBounds;
+
+ // Here we're relying on the SCEV Expander's cache to only emit code for the
+ // same bounds once.
+ transform(PointerChecks, std::back_inserter(ChecksWithBounds),
+ [&](const RuntimePointerCheck &Check) {
+ PointerBounds First = expandBounds(Check.first, L, Loc, Exp, SE),
+ Second =
+ expandBounds(Check.second, L, Loc, Exp, SE);
+ return std::make_pair(First, Second);
+ });
+
+ return ChecksWithBounds;
+}
+
+std::pair<Instruction *, Instruction *> llvm::addRuntimeChecks(
+ Instruction *Loc, Loop *TheLoop,
+ const SmallVectorImpl<RuntimePointerCheck> &PointerChecks,
+ ScalarEvolution *SE) {
+ // TODO: Move noalias annotation code from LoopVersioning here and share with LV if possible.
+ // TODO: Pass RtPtrChecking instead of PointerChecks and SE separately, if possible
+ const DataLayout &DL = TheLoop->getHeader()->getModule()->getDataLayout();
+ SCEVExpander Exp(*SE, DL, "induction");
+ auto ExpandedChecks = expandBounds(PointerChecks, TheLoop, Loc, SE, Exp);
+
+ LLVMContext &Ctx = Loc->getContext();
+ Instruction *FirstInst = nullptr;
+ IRBuilder<> ChkBuilder(Loc);
+ // Our instructions might fold to a constant.
+ Value *MemoryRuntimeCheck = nullptr;
+
+ // FIXME: this helper is currently a duplicate of the one in
+ // LoopVectorize.cpp.
+ auto GetFirstInst = [](Instruction *FirstInst, Value *V,
+ Instruction *Loc) -> Instruction * {
+ if (FirstInst)
+ return FirstInst;
+ if (Instruction *I = dyn_cast<Instruction>(V))
+ return I->getParent() == Loc->getParent() ? I : nullptr;
+ return nullptr;
+ };
+
+ for (const auto &Check : ExpandedChecks) {
+ const PointerBounds &A = Check.first, &B = Check.second;
+ // Check if two pointers (A and B) conflict where conflict is computed as:
+ // start(A) <= end(B) && start(B) <= end(A)
+ unsigned AS0 = A.Start->getType()->getPointerAddressSpace();
+ unsigned AS1 = B.Start->getType()->getPointerAddressSpace();
+
+ assert((AS0 == B.End->getType()->getPointerAddressSpace()) &&
+ (AS1 == A.End->getType()->getPointerAddressSpace()) &&
+ "Trying to bounds check pointers with different address spaces");
+
+ Type *PtrArithTy0 = Type::getInt8PtrTy(Ctx, AS0);
+ Type *PtrArithTy1 = Type::getInt8PtrTy(Ctx, AS1);
+
+ Value *Start0 = ChkBuilder.CreateBitCast(A.Start, PtrArithTy0, "bc");
+ Value *Start1 = ChkBuilder.CreateBitCast(B.Start, PtrArithTy1, "bc");
+ Value *End0 = ChkBuilder.CreateBitCast(A.End, PtrArithTy1, "bc");
+ Value *End1 = ChkBuilder.CreateBitCast(B.End, PtrArithTy0, "bc");
+
+ // [A|B].Start points to the first accessed byte under base [A|B].
+ // [A|B].End points to the last accessed byte, plus one.
+ // There is no conflict when the intervals are disjoint:
+ // NoConflict = (B.Start >= A.End) || (A.Start >= B.End)
+ //
+ // bound0 = (B.Start < A.End)
+ // bound1 = (A.Start < B.End)
+ // IsConflict = bound0 & bound1
+ Value *Cmp0 = ChkBuilder.CreateICmpULT(Start0, End1, "bound0");
+ FirstInst = GetFirstInst(FirstInst, Cmp0, Loc);
+ Value *Cmp1 = ChkBuilder.CreateICmpULT(Start1, End0, "bound1");
+ FirstInst = GetFirstInst(FirstInst, Cmp1, Loc);
+ Value *IsConflict = ChkBuilder.CreateAnd(Cmp0, Cmp1, "found.conflict");
+ FirstInst = GetFirstInst(FirstInst, IsConflict, Loc);
+ if (MemoryRuntimeCheck) {
+ IsConflict =
+ ChkBuilder.CreateOr(MemoryRuntimeCheck, IsConflict, "conflict.rdx");
+ FirstInst = GetFirstInst(FirstInst, IsConflict, Loc);
+ }
+ MemoryRuntimeCheck = IsConflict;
+ }
+
+ if (!MemoryRuntimeCheck)
+ return std::make_pair(nullptr, nullptr);
+
+ // We have to do this trickery because the IRBuilder might fold the check to a
+ // constant expression in which case there is no Instruction anchored in a
+ // the block.
+ Instruction *Check =
+ BinaryOperator::CreateAnd(MemoryRuntimeCheck, ConstantInt::getTrue(Ctx));
+ ChkBuilder.Insert(Check, "memcheck.conflict");
+ FirstInst = GetFirstInst(FirstInst, Check, Loc);
+ return std::make_pair(FirstInst, Check);
+}