diff options
Diffstat (limited to 'llvm/lib/Analysis/ScalarEvolution.cpp')
-rw-r--r-- | llvm/lib/Analysis/ScalarEvolution.cpp | 66 |
1 files changed, 43 insertions, 23 deletions
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index f997b19..b074294 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -271,6 +271,9 @@ void SCEV::print(raw_ostream &OS) const { case scConstant: cast<SCEVConstant>(this)->getValue()->printAsOperand(OS, false); return; + case scVScale: + OS << "vscale"; + return; case scPtrToInt: { const SCEVPtrToIntExpr *PtrToInt = cast<SCEVPtrToIntExpr>(this); const SCEV *Op = PtrToInt->getOperand(); @@ -366,17 +369,9 @@ void SCEV::print(raw_ostream &OS) const { OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")"; return; } - case scUnknown: { - const SCEVUnknown *U = cast<SCEVUnknown>(this); - if (U->isVScale()) { - OS << "vscale"; - return; - } - - // Otherwise just print it normally. - U->getValue()->printAsOperand(OS, false); + case scUnknown: + cast<SCEVUnknown>(this)->getValue()->printAsOperand(OS, false); return; - } case scCouldNotCompute: OS << "***COULDNOTCOMPUTE***"; return; @@ -388,6 +383,8 @@ Type *SCEV::getType() const { switch (getSCEVType()) { case scConstant: return cast<SCEVConstant>(this)->getType(); + case scVScale: + return cast<SCEVVScale>(this)->getType(); case scPtrToInt: case scTruncate: case scZeroExtend: @@ -419,6 +416,7 @@ Type *SCEV::getType() const { ArrayRef<const SCEV *> SCEV::operands() const { switch (getSCEVType()) { case scConstant: + case scVScale: case scUnknown: return {}; case scPtrToInt: @@ -501,6 +499,18 @@ ScalarEvolution::getConstant(Type *Ty, uint64_t V, bool isSigned) { return getConstant(ConstantInt::get(ITy, V, isSigned)); } +const SCEV *ScalarEvolution::getVScale(Type *Ty) { + FoldingSetNodeID ID; + ID.AddInteger(scVScale); + ID.AddPointer(Ty); + void *IP = nullptr; + if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) + return S; + SCEV *S = new (SCEVAllocator) SCEVVScale(ID.Intern(SCEVAllocator), Ty); + UniqueSCEVs.InsertNode(S, IP); + return S; +} + SCEVCastExpr::SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op, Type *ty) : SCEV(ID, SCEVTy, computeExpressionSize(op)), Op(op), Ty(ty) {} @@ -560,10 +570,6 @@ void SCEVUnknown::allUsesReplacedWith(Value *New) { setValPtr(New); } -bool SCEVUnknown::isVScale() const { - return match(getValue(), m_VScale()); -} - //===----------------------------------------------------------------------===// // SCEV Utilities //===----------------------------------------------------------------------===// @@ -714,6 +720,12 @@ CompareSCEVComplexity(EquivalenceClasses<const SCEV *> &EqCacheSCEV, return LA.ult(RA) ? -1 : 1; } + case scVScale: { + const auto *LTy = cast<IntegerType>(cast<SCEVVScale>(LHS)->getType()); + const auto *RTy = cast<IntegerType>(cast<SCEVVScale>(RHS)->getType()); + return LTy->getBitWidth() - RTy->getBitWidth(); + } + case scAddRecExpr: { const SCEVAddRecExpr *LA = cast<SCEVAddRecExpr>(LHS); const SCEVAddRecExpr *RA = cast<SCEVAddRecExpr>(RHS); @@ -4015,6 +4027,8 @@ public: RetVal visitConstant(const SCEVConstant *Constant) { return Constant; } + RetVal visitVScale(const SCEVVScale *VScale) { return VScale; } + RetVal visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { return Expr; } RetVal visitTruncateExpr(const SCEVTruncateExpr *Expr) { return Expr; } @@ -4061,6 +4075,7 @@ public: static bool scevUnconditionallyPropagatesPoisonFromOperands(SCEVTypes Kind) { switch (Kind) { case scConstant: + case scVScale: case scTruncate: case scZeroExtend: case scSignExtend: @@ -4104,6 +4119,7 @@ static bool impliesPoison(const SCEV *AssumedPoison, const SCEV *S) { if (!scevUnconditionallyPropagatesPoisonFromOperands(S->getSCEVType())) { switch (S->getSCEVType()) { case scConstant: + case scVScale: case scTruncate: case scZeroExtend: case scSignExtend: @@ -4315,15 +4331,8 @@ const SCEV *ScalarEvolution::getUMinExpr(SmallVectorImpl<const SCEV *> &Ops, const SCEV * ScalarEvolution::getSizeOfExpr(Type *IntTy, TypeSize Size) { const SCEV *Res = getConstant(IntTy, Size.getKnownMinValue()); - if (Size.isScalable()) { - // TODO: Why is there no ConstantExpr::getVScale()? - Type *SrcElemTy = ScalableVectorType::get(Type::getInt8Ty(getContext()), 1); - Constant *NullPtr = Constant::getNullValue(SrcElemTy->getPointerTo()); - Constant *One = ConstantInt::get(IntTy, 1); - Constant *GEP = ConstantExpr::getGetElementPtr(SrcElemTy, NullPtr, One); - Constant *VScale = ConstantExpr::getPtrToInt(GEP, IntTy); - Res = getMulExpr(Res, getUnknown(VScale)); - } + if (Size.isScalable()) + Res = getMulExpr(Res, getVScale(IntTy)); return Res; } @@ -5887,6 +5896,7 @@ static bool IsAvailableOnEntry(const Loop *L, DominatorTree &DT, const SCEV *S, bool follow(const SCEV *S) { switch (S->getSCEVType()) { case scConstant: + case scVScale: case scPtrToInt: case scTruncate: case scZeroExtend: @@ -6274,6 +6284,8 @@ uint32_t ScalarEvolution::GetMinTrailingZerosImpl(const SCEV *S) { switch (S->getSCEVType()) { case scConstant: return cast<SCEVConstant>(S)->getAPInt().countr_zero(); + case scVScale: + return 0; case scTruncate: { const SCEVTruncateExpr *T = cast<SCEVTruncateExpr>(S); return std::min(GetMinTrailingZeros(T->getOperand()), @@ -6504,6 +6516,7 @@ ScalarEvolution::getRangeRefIter(const SCEV *S, break; [[fallthrough]]; case scConstant: + case scVScale: case scTruncate: case scZeroExtend: case scSignExtend: @@ -6607,6 +6620,8 @@ const ConstantRange &ScalarEvolution::getRangeRef( switch (S->getSCEVType()) { case scConstant: llvm_unreachable("Already handled above."); + case scVScale: + return setRange(S, SignHint, std::move(ConservativeResult)); case scTruncate: { const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(S); ConstantRange X = getRangeRef(Trunc->getOperand(), SignHint, Depth + 1); @@ -9711,6 +9726,7 @@ static Constant *BuildConstantFromSCEV(const SCEV *V) { switch (V->getSCEVType()) { case scCouldNotCompute: case scAddRecExpr: + case scVScale: return nullptr; case scConstant: return cast<SCEVConstant>(V)->getValue(); @@ -9794,6 +9810,7 @@ static Constant *BuildConstantFromSCEV(const SCEV *V) { const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) { switch (V->getSCEVType()) { case scConstant: + case scVScale: return V; case scAddRecExpr: { // If this is a loop recurrence for a loop that does not contain L, then we @@ -9892,6 +9909,7 @@ const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) { case scSequentialUMinExpr: return getSequentialMinMaxExpr(V->getSCEVType(), NewOps); case scConstant: + case scVScale: case scAddRecExpr: case scUnknown: case scCouldNotCompute: @@ -13677,6 +13695,7 @@ ScalarEvolution::LoopDisposition ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) { switch (S->getSCEVType()) { case scConstant: + case scVScale: return LoopInvariant; case scAddRecExpr: { const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S); @@ -13775,6 +13794,7 @@ ScalarEvolution::BlockDisposition ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) { switch (S->getSCEVType()) { case scConstant: + case scVScale: return ProperlyDominatesBlock; case scAddRecExpr: { // This uses a "dominates" query instead of "properly dominates" query |