diff options
Diffstat (limited to 'llvm/lib/Transforms/Utils/LoopUtils.cpp')
-rw-r--r-- | llvm/lib/Transforms/Utils/LoopUtils.cpp | 149 |
1 files changed, 149 insertions, 0 deletions
diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp index a5fbdb5..5f54685 100644 --- a/llvm/lib/Transforms/Utils/LoopUtils.cpp +++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp @@ -1534,3 +1534,152 @@ Loop *llvm::cloneLoop(Loop *L, Loop *PL, ValueToValueMapTy &VM, return &New; } + +/// IR Values for the lower and upper bounds of a pointer evolution. We +/// need to use value-handles because SCEV expansion can invalidate previously +/// expanded values. Thus expansion of a pointer can invalidate the bounds for +/// a previous one. +struct PointerBounds { + TrackingVH<Value> Start; + TrackingVH<Value> End; +}; + +/// Expand code for the lower and upper bound of the pointer group \p CG +/// in \p TheLoop. \return the values for the bounds. +static PointerBounds expandBounds(const RuntimeCheckingPtrGroup *CG, + Loop *TheLoop, Instruction *Loc, + SCEVExpander &Exp, ScalarEvolution *SE) { + // TODO: Add helper to retrieve pointers to CG. + Value *Ptr = CG->RtCheck.Pointers[CG->Members[0]].PointerValue; + const SCEV *Sc = SE->getSCEV(Ptr); + + unsigned AS = Ptr->getType()->getPointerAddressSpace(); + LLVMContext &Ctx = Loc->getContext(); + + // Use this type for pointer arithmetic. + Type *PtrArithTy = Type::getInt8PtrTy(Ctx, AS); + + if (SE->isLoopInvariant(Sc, TheLoop)) { + LLVM_DEBUG(dbgs() << "LAA: Adding RT check for a loop invariant ptr:" + << *Ptr << "\n"); + // Ptr could be in the loop body. If so, expand a new one at the correct + // location. + Instruction *Inst = dyn_cast<Instruction>(Ptr); + Value *NewPtr = (Inst && TheLoop->contains(Inst)) + ? Exp.expandCodeFor(Sc, PtrArithTy, Loc) + : Ptr; + // We must return a half-open range, which means incrementing Sc. + const SCEV *ScPlusOne = SE->getAddExpr(Sc, SE->getOne(PtrArithTy)); + Value *NewPtrPlusOne = Exp.expandCodeFor(ScPlusOne, PtrArithTy, Loc); + return {NewPtr, NewPtrPlusOne}; + } else { + Value *Start = nullptr, *End = nullptr; + LLVM_DEBUG(dbgs() << "LAA: Adding RT check for range:\n"); + Start = Exp.expandCodeFor(CG->Low, PtrArithTy, Loc); + End = Exp.expandCodeFor(CG->High, PtrArithTy, Loc); + LLVM_DEBUG(dbgs() << "Start: " << *CG->Low << " End: " << *CG->High + << "\n"); + return {Start, End}; + } +} + +/// Turns a collection of checks into a collection of expanded upper and +/// lower bounds for both pointers in the check. +static SmallVector<std::pair<PointerBounds, PointerBounds>, 4> +expandBounds(const SmallVectorImpl<RuntimePointerCheck> &PointerChecks, Loop *L, + Instruction *Loc, ScalarEvolution *SE, SCEVExpander &Exp) { + SmallVector<std::pair<PointerBounds, PointerBounds>, 4> ChecksWithBounds; + + // Here we're relying on the SCEV Expander's cache to only emit code for the + // same bounds once. + transform(PointerChecks, std::back_inserter(ChecksWithBounds), + [&](const RuntimePointerCheck &Check) { + PointerBounds First = expandBounds(Check.first, L, Loc, Exp, SE), + Second = + expandBounds(Check.second, L, Loc, Exp, SE); + return std::make_pair(First, Second); + }); + + return ChecksWithBounds; +} + +std::pair<Instruction *, Instruction *> llvm::addRuntimeChecks( + Instruction *Loc, Loop *TheLoop, + const SmallVectorImpl<RuntimePointerCheck> &PointerChecks, + ScalarEvolution *SE) { + // TODO: Move noalias annotation code from LoopVersioning here and share with LV if possible. + // TODO: Pass RtPtrChecking instead of PointerChecks and SE separately, if possible + const DataLayout &DL = TheLoop->getHeader()->getModule()->getDataLayout(); + SCEVExpander Exp(*SE, DL, "induction"); + auto ExpandedChecks = expandBounds(PointerChecks, TheLoop, Loc, SE, Exp); + + LLVMContext &Ctx = Loc->getContext(); + Instruction *FirstInst = nullptr; + IRBuilder<> ChkBuilder(Loc); + // Our instructions might fold to a constant. + Value *MemoryRuntimeCheck = nullptr; + + // FIXME: this helper is currently a duplicate of the one in + // LoopVectorize.cpp. + auto GetFirstInst = [](Instruction *FirstInst, Value *V, + Instruction *Loc) -> Instruction * { + if (FirstInst) + return FirstInst; + if (Instruction *I = dyn_cast<Instruction>(V)) + return I->getParent() == Loc->getParent() ? I : nullptr; + return nullptr; + }; + + for (const auto &Check : ExpandedChecks) { + const PointerBounds &A = Check.first, &B = Check.second; + // Check if two pointers (A and B) conflict where conflict is computed as: + // start(A) <= end(B) && start(B) <= end(A) + unsigned AS0 = A.Start->getType()->getPointerAddressSpace(); + unsigned AS1 = B.Start->getType()->getPointerAddressSpace(); + + assert((AS0 == B.End->getType()->getPointerAddressSpace()) && + (AS1 == A.End->getType()->getPointerAddressSpace()) && + "Trying to bounds check pointers with different address spaces"); + + Type *PtrArithTy0 = Type::getInt8PtrTy(Ctx, AS0); + Type *PtrArithTy1 = Type::getInt8PtrTy(Ctx, AS1); + + Value *Start0 = ChkBuilder.CreateBitCast(A.Start, PtrArithTy0, "bc"); + Value *Start1 = ChkBuilder.CreateBitCast(B.Start, PtrArithTy1, "bc"); + Value *End0 = ChkBuilder.CreateBitCast(A.End, PtrArithTy1, "bc"); + Value *End1 = ChkBuilder.CreateBitCast(B.End, PtrArithTy0, "bc"); + + // [A|B].Start points to the first accessed byte under base [A|B]. + // [A|B].End points to the last accessed byte, plus one. + // There is no conflict when the intervals are disjoint: + // NoConflict = (B.Start >= A.End) || (A.Start >= B.End) + // + // bound0 = (B.Start < A.End) + // bound1 = (A.Start < B.End) + // IsConflict = bound0 & bound1 + Value *Cmp0 = ChkBuilder.CreateICmpULT(Start0, End1, "bound0"); + FirstInst = GetFirstInst(FirstInst, Cmp0, Loc); + Value *Cmp1 = ChkBuilder.CreateICmpULT(Start1, End0, "bound1"); + FirstInst = GetFirstInst(FirstInst, Cmp1, Loc); + Value *IsConflict = ChkBuilder.CreateAnd(Cmp0, Cmp1, "found.conflict"); + FirstInst = GetFirstInst(FirstInst, IsConflict, Loc); + if (MemoryRuntimeCheck) { + IsConflict = + ChkBuilder.CreateOr(MemoryRuntimeCheck, IsConflict, "conflict.rdx"); + FirstInst = GetFirstInst(FirstInst, IsConflict, Loc); + } + MemoryRuntimeCheck = IsConflict; + } + + if (!MemoryRuntimeCheck) + return std::make_pair(nullptr, nullptr); + + // We have to do this trickery because the IRBuilder might fold the check to a + // constant expression in which case there is no Instruction anchored in a + // the block. + Instruction *Check = + BinaryOperator::CreateAnd(MemoryRuntimeCheck, ConstantInt::getTrue(Ctx)); + ChkBuilder.Insert(Check, "memcheck.conflict"); + FirstInst = GetFirstInst(FirstInst, Check, Loc); + return std::make_pair(FirstInst, Check); +} |