aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Analysis/ScalarEvolution.cpp
diff options
context:
space:
mode:
authorPaul Walker <paul.walker@arm.com>2023-03-02 12:23:52 +0000
committerPaul Walker <paul.walker@arm.com>2023-03-02 13:14:07 +0000
commit62d11b2ccaeec8abc12e07aa009c922d72fd6377 (patch)
tree2b0cbf2ae047496f0bde854381c8dd3dd9080171 /llvm/lib/Analysis/ScalarEvolution.cpp
parentc396073a0de6bc156514c34c0ffbdd227178c11b (diff)
downloadllvm-62d11b2ccaeec8abc12e07aa009c922d72fd6377.zip
llvm-62d11b2ccaeec8abc12e07aa009c922d72fd6377.tar.gz
llvm-62d11b2ccaeec8abc12e07aa009c922d72fd6377.tar.bz2
Revert "Revert "[SCEV] Add SCEVType to represent `vscale`.""
Relanding after fixing Polly related build error. This reverts commit 7b26dcae9eaf8cdcba7fef032fd83d060dffd4b4.
Diffstat (limited to 'llvm/lib/Analysis/ScalarEvolution.cpp')
-rw-r--r--llvm/lib/Analysis/ScalarEvolution.cpp66
1 files changed, 43 insertions, 23 deletions
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index f997b19..b074294 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -271,6 +271,9 @@ void SCEV::print(raw_ostream &OS) const {
case scConstant:
cast<SCEVConstant>(this)->getValue()->printAsOperand(OS, false);
return;
+ case scVScale:
+ OS << "vscale";
+ return;
case scPtrToInt: {
const SCEVPtrToIntExpr *PtrToInt = cast<SCEVPtrToIntExpr>(this);
const SCEV *Op = PtrToInt->getOperand();
@@ -366,17 +369,9 @@ void SCEV::print(raw_ostream &OS) const {
OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")";
return;
}
- case scUnknown: {
- const SCEVUnknown *U = cast<SCEVUnknown>(this);
- if (U->isVScale()) {
- OS << "vscale";
- return;
- }
-
- // Otherwise just print it normally.
- U->getValue()->printAsOperand(OS, false);
+ case scUnknown:
+ cast<SCEVUnknown>(this)->getValue()->printAsOperand(OS, false);
return;
- }
case scCouldNotCompute:
OS << "***COULDNOTCOMPUTE***";
return;
@@ -388,6 +383,8 @@ Type *SCEV::getType() const {
switch (getSCEVType()) {
case scConstant:
return cast<SCEVConstant>(this)->getType();
+ case scVScale:
+ return cast<SCEVVScale>(this)->getType();
case scPtrToInt:
case scTruncate:
case scZeroExtend:
@@ -419,6 +416,7 @@ Type *SCEV::getType() const {
ArrayRef<const SCEV *> SCEV::operands() const {
switch (getSCEVType()) {
case scConstant:
+ case scVScale:
case scUnknown:
return {};
case scPtrToInt:
@@ -501,6 +499,18 @@ ScalarEvolution::getConstant(Type *Ty, uint64_t V, bool isSigned) {
return getConstant(ConstantInt::get(ITy, V, isSigned));
}
+const SCEV *ScalarEvolution::getVScale(Type *Ty) {
+ FoldingSetNodeID ID;
+ ID.AddInteger(scVScale);
+ ID.AddPointer(Ty);
+ void *IP = nullptr;
+ if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
+ return S;
+ SCEV *S = new (SCEVAllocator) SCEVVScale(ID.Intern(SCEVAllocator), Ty);
+ UniqueSCEVs.InsertNode(S, IP);
+ return S;
+}
+
SCEVCastExpr::SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy,
const SCEV *op, Type *ty)
: SCEV(ID, SCEVTy, computeExpressionSize(op)), Op(op), Ty(ty) {}
@@ -560,10 +570,6 @@ void SCEVUnknown::allUsesReplacedWith(Value *New) {
setValPtr(New);
}
-bool SCEVUnknown::isVScale() const {
- return match(getValue(), m_VScale());
-}
-
//===----------------------------------------------------------------------===//
// SCEV Utilities
//===----------------------------------------------------------------------===//
@@ -714,6 +720,12 @@ CompareSCEVComplexity(EquivalenceClasses<const SCEV *> &EqCacheSCEV,
return LA.ult(RA) ? -1 : 1;
}
+ case scVScale: {
+ const auto *LTy = cast<IntegerType>(cast<SCEVVScale>(LHS)->getType());
+ const auto *RTy = cast<IntegerType>(cast<SCEVVScale>(RHS)->getType());
+ return LTy->getBitWidth() - RTy->getBitWidth();
+ }
+
case scAddRecExpr: {
const SCEVAddRecExpr *LA = cast<SCEVAddRecExpr>(LHS);
const SCEVAddRecExpr *RA = cast<SCEVAddRecExpr>(RHS);
@@ -4015,6 +4027,8 @@ public:
RetVal visitConstant(const SCEVConstant *Constant) { return Constant; }
+ RetVal visitVScale(const SCEVVScale *VScale) { return VScale; }
+
RetVal visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { return Expr; }
RetVal visitTruncateExpr(const SCEVTruncateExpr *Expr) { return Expr; }
@@ -4061,6 +4075,7 @@ public:
static bool scevUnconditionallyPropagatesPoisonFromOperands(SCEVTypes Kind) {
switch (Kind) {
case scConstant:
+ case scVScale:
case scTruncate:
case scZeroExtend:
case scSignExtend:
@@ -4104,6 +4119,7 @@ static bool impliesPoison(const SCEV *AssumedPoison, const SCEV *S) {
if (!scevUnconditionallyPropagatesPoisonFromOperands(S->getSCEVType())) {
switch (S->getSCEVType()) {
case scConstant:
+ case scVScale:
case scTruncate:
case scZeroExtend:
case scSignExtend:
@@ -4315,15 +4331,8 @@ const SCEV *ScalarEvolution::getUMinExpr(SmallVectorImpl<const SCEV *> &Ops,
const SCEV *
ScalarEvolution::getSizeOfExpr(Type *IntTy, TypeSize Size) {
const SCEV *Res = getConstant(IntTy, Size.getKnownMinValue());
- if (Size.isScalable()) {
- // TODO: Why is there no ConstantExpr::getVScale()?
- Type *SrcElemTy = ScalableVectorType::get(Type::getInt8Ty(getContext()), 1);
- Constant *NullPtr = Constant::getNullValue(SrcElemTy->getPointerTo());
- Constant *One = ConstantInt::get(IntTy, 1);
- Constant *GEP = ConstantExpr::getGetElementPtr(SrcElemTy, NullPtr, One);
- Constant *VScale = ConstantExpr::getPtrToInt(GEP, IntTy);
- Res = getMulExpr(Res, getUnknown(VScale));
- }
+ if (Size.isScalable())
+ Res = getMulExpr(Res, getVScale(IntTy));
return Res;
}
@@ -5887,6 +5896,7 @@ static bool IsAvailableOnEntry(const Loop *L, DominatorTree &DT, const SCEV *S,
bool follow(const SCEV *S) {
switch (S->getSCEVType()) {
case scConstant:
+ case scVScale:
case scPtrToInt:
case scTruncate:
case scZeroExtend:
@@ -6274,6 +6284,8 @@ uint32_t ScalarEvolution::GetMinTrailingZerosImpl(const SCEV *S) {
switch (S->getSCEVType()) {
case scConstant:
return cast<SCEVConstant>(S)->getAPInt().countr_zero();
+ case scVScale:
+ return 0;
case scTruncate: {
const SCEVTruncateExpr *T = cast<SCEVTruncateExpr>(S);
return std::min(GetMinTrailingZeros(T->getOperand()),
@@ -6504,6 +6516,7 @@ ScalarEvolution::getRangeRefIter(const SCEV *S,
break;
[[fallthrough]];
case scConstant:
+ case scVScale:
case scTruncate:
case scZeroExtend:
case scSignExtend:
@@ -6607,6 +6620,8 @@ const ConstantRange &ScalarEvolution::getRangeRef(
switch (S->getSCEVType()) {
case scConstant:
llvm_unreachable("Already handled above.");
+ case scVScale:
+ return setRange(S, SignHint, std::move(ConservativeResult));
case scTruncate: {
const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(S);
ConstantRange X = getRangeRef(Trunc->getOperand(), SignHint, Depth + 1);
@@ -9711,6 +9726,7 @@ static Constant *BuildConstantFromSCEV(const SCEV *V) {
switch (V->getSCEVType()) {
case scCouldNotCompute:
case scAddRecExpr:
+ case scVScale:
return nullptr;
case scConstant:
return cast<SCEVConstant>(V)->getValue();
@@ -9794,6 +9810,7 @@ static Constant *BuildConstantFromSCEV(const SCEV *V) {
const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
switch (V->getSCEVType()) {
case scConstant:
+ case scVScale:
return V;
case scAddRecExpr: {
// If this is a loop recurrence for a loop that does not contain L, then we
@@ -9892,6 +9909,7 @@ const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
case scSequentialUMinExpr:
return getSequentialMinMaxExpr(V->getSCEVType(), NewOps);
case scConstant:
+ case scVScale:
case scAddRecExpr:
case scUnknown:
case scCouldNotCompute:
@@ -13677,6 +13695,7 @@ ScalarEvolution::LoopDisposition
ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) {
switch (S->getSCEVType()) {
case scConstant:
+ case scVScale:
return LoopInvariant;
case scAddRecExpr: {
const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
@@ -13775,6 +13794,7 @@ ScalarEvolution::BlockDisposition
ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) {
switch (S->getSCEVType()) {
case scConstant:
+ case scVScale:
return ProperlyDominatesBlock;
case scAddRecExpr: {
// This uses a "dominates" query instead of "properly dominates" query