aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Analysis
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Analysis')
-rw-r--r--llvm/lib/Analysis/IR2Vec.cpp19
-rw-r--r--llvm/lib/Analysis/ModuleDebugInfoPrinter.cpp6
-rw-r--r--llvm/lib/Analysis/ScalarEvolution.cpp56
3 files changed, 53 insertions, 28 deletions
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index 6885351..1794a60 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -239,10 +239,21 @@ void FlowAwareEmbedder::computeEmbeddings(const BasicBlock &BB) const {
// If the operand is defined elsewhere, we use its embedding
if (const auto *DefInst = dyn_cast<Instruction>(Op)) {
auto DefIt = InstVecMap.find(DefInst);
- assert(DefIt != InstVecMap.end() &&
- "Instruction should have been processed before its operands");
- ArgEmb += DefIt->second;
- continue;
+ // Fixme (#159171): Ideally we should never miss an instruction
+ // embedding here.
+ // But when we have cyclic dependencies (e.g., phi
+ // nodes), we might miss the embedding. In such cases, we fall back to
+ // using the vocabulary embedding. This can be fixed by iterating to a
+ // fixed-point, or by using a simple solver for the set of simultaneous
+ // equations.
+ // Another case when we might miss an instruction embedding is when
+ // the operand instruction is in a different basic block that has not
+ // been processed yet. This can be fixed by processing the basic blocks
+ // in a topological order.
+ if (DefIt != InstVecMap.end())
+ ArgEmb += DefIt->second;
+ else
+ ArgEmb += Vocab[*Op];
}
// If the operand is not defined by an instruction, we use the vocabulary
else {
diff --git a/llvm/lib/Analysis/ModuleDebugInfoPrinter.cpp b/llvm/lib/Analysis/ModuleDebugInfoPrinter.cpp
index 0fbf082..f31d625 100644
--- a/llvm/lib/Analysis/ModuleDebugInfoPrinter.cpp
+++ b/llvm/lib/Analysis/ModuleDebugInfoPrinter.cpp
@@ -43,11 +43,13 @@ static void printModuleDebugInfo(raw_ostream &O, const Module *M,
// filenames), so just print a few useful things.
for (DICompileUnit *CU : Finder.compile_units()) {
O << "Compile unit: ";
- auto Lang = dwarf::LanguageString(CU->getSourceLanguage());
+ auto Lang =
+ dwarf::LanguageString(CU->getSourceLanguage().getUnversionedName());
if (!Lang.empty())
O << Lang;
else
- O << "unknown-language(" << CU->getSourceLanguage() << ")";
+ O << "unknown-language(" << CU->getSourceLanguage().getUnversionedName()
+ << ")";
printFile(O, CU->getFilename(), CU->getDirectory());
O << '\n';
}
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 63e1b14..6f6776c 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -6351,19 +6351,20 @@ const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
return getGEPExpr(GEP, IndexExprs);
}
-APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) {
+APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S,
+ const Instruction *CtxI) {
uint64_t BitWidth = getTypeSizeInBits(S->getType());
auto GetShiftedByZeros = [BitWidth](uint32_t TrailingZeros) {
return TrailingZeros >= BitWidth
? APInt::getZero(BitWidth)
: APInt::getOneBitSet(BitWidth, TrailingZeros);
};
- auto GetGCDMultiple = [this](const SCEVNAryExpr *N) {
+ auto GetGCDMultiple = [this, CtxI](const SCEVNAryExpr *N) {
// The result is GCD of all operands results.
- APInt Res = getConstantMultiple(N->getOperand(0));
+ APInt Res = getConstantMultiple(N->getOperand(0), CtxI);
for (unsigned I = 1, E = N->getNumOperands(); I < E && Res != 1; ++I)
Res = APIntOps::GreatestCommonDivisor(
- Res, getConstantMultiple(N->getOperand(I)));
+ Res, getConstantMultiple(N->getOperand(I), CtxI));
return Res;
};
@@ -6371,33 +6372,33 @@ APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) {
case scConstant:
return cast<SCEVConstant>(S)->getAPInt();
case scPtrToInt:
- return getConstantMultiple(cast<SCEVPtrToIntExpr>(S)->getOperand());
+ return getConstantMultiple(cast<SCEVPtrToIntExpr>(S)->getOperand(), CtxI);
case scUDivExpr:
case scVScale:
return APInt(BitWidth, 1);
case scTruncate: {
// Only multiples that are a power of 2 will hold after truncation.
const SCEVTruncateExpr *T = cast<SCEVTruncateExpr>(S);
- uint32_t TZ = getMinTrailingZeros(T->getOperand());
+ uint32_t TZ = getMinTrailingZeros(T->getOperand(), CtxI);
return GetShiftedByZeros(TZ);
}
case scZeroExtend: {
const SCEVZeroExtendExpr *Z = cast<SCEVZeroExtendExpr>(S);
- return getConstantMultiple(Z->getOperand()).zext(BitWidth);
+ return getConstantMultiple(Z->getOperand(), CtxI).zext(BitWidth);
}
case scSignExtend: {
// Only multiples that are a power of 2 will hold after sext.
const SCEVSignExtendExpr *E = cast<SCEVSignExtendExpr>(S);
- uint32_t TZ = getMinTrailingZeros(E->getOperand());
+ uint32_t TZ = getMinTrailingZeros(E->getOperand(), CtxI);
return GetShiftedByZeros(TZ);
}
case scMulExpr: {
const SCEVMulExpr *M = cast<SCEVMulExpr>(S);
if (M->hasNoUnsignedWrap()) {
// The result is the product of all operand results.
- APInt Res = getConstantMultiple(M->getOperand(0));
+ APInt Res = getConstantMultiple(M->getOperand(0), CtxI);
for (const SCEV *Operand : M->operands().drop_front())
- Res = Res * getConstantMultiple(Operand);
+ Res = Res * getConstantMultiple(Operand, CtxI);
return Res;
}
@@ -6405,7 +6406,7 @@ APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) {
// sum of trailing zeros for all its operands.
uint32_t TZ = 0;
for (const SCEV *Operand : M->operands())
- TZ += getMinTrailingZeros(Operand);
+ TZ += getMinTrailingZeros(Operand, CtxI);
return GetShiftedByZeros(TZ);
}
case scAddExpr:
@@ -6414,9 +6415,9 @@ APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) {
if (N->hasNoUnsignedWrap())
return GetGCDMultiple(N);
// Find the trailing bits, which is the minimum of its operands.
- uint32_t TZ = getMinTrailingZeros(N->getOperand(0));
+ uint32_t TZ = getMinTrailingZeros(N->getOperand(0), CtxI);
for (const SCEV *Operand : N->operands().drop_front())
- TZ = std::min(TZ, getMinTrailingZeros(Operand));
+ TZ = std::min(TZ, getMinTrailingZeros(Operand, CtxI));
return GetShiftedByZeros(TZ);
}
case scUMaxExpr:
@@ -6429,7 +6430,7 @@ APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) {
// ask ValueTracking for known bits
const SCEVUnknown *U = cast<SCEVUnknown>(S);
unsigned Known =
- computeKnownBits(U->getValue(), getDataLayout(), &AC, nullptr, &DT)
+ computeKnownBits(U->getValue(), getDataLayout(), &AC, CtxI, &DT)
.countMinTrailingZeros();
return GetShiftedByZeros(Known);
}
@@ -6439,12 +6440,18 @@ APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) {
llvm_unreachable("Unknown SCEV kind!");
}
-APInt ScalarEvolution::getConstantMultiple(const SCEV *S) {
+APInt ScalarEvolution::getConstantMultiple(const SCEV *S,
+ const Instruction *CtxI) {
+ // Skip looking up and updating the cache if there is a context instruction,
+ // as the result will only be valid in the specified context.
+ if (CtxI)
+ return getConstantMultipleImpl(S, CtxI);
+
auto I = ConstantMultipleCache.find(S);
if (I != ConstantMultipleCache.end())
return I->second;
- APInt Result = getConstantMultipleImpl(S);
+ APInt Result = getConstantMultipleImpl(S, CtxI);
auto InsertPair = ConstantMultipleCache.insert({S, Result});
assert(InsertPair.second && "Should insert a new key");
return InsertPair.first->second;
@@ -6455,8 +6462,9 @@ APInt ScalarEvolution::getNonZeroConstantMultiple(const SCEV *S) {
return Multiple == 0 ? APInt(Multiple.getBitWidth(), 1) : Multiple;
}
-uint32_t ScalarEvolution::getMinTrailingZeros(const SCEV *S) {
- return std::min(getConstantMultiple(S).countTrailingZeros(),
+uint32_t ScalarEvolution::getMinTrailingZeros(const SCEV *S,
+ const Instruction *CtxI) {
+ return std::min(getConstantMultiple(S, CtxI).countTrailingZeros(),
(unsigned)getTypeSizeInBits(S->getType()));
}
@@ -10243,8 +10251,7 @@ const SCEV *ScalarEvolution::stripInjectiveFunctions(const SCEV *S) const {
static const SCEV *
SolveLinEquationWithOverflow(const APInt &A, const SCEV *B,
SmallVectorImpl<const SCEVPredicate *> *Predicates,
-
- ScalarEvolution &SE) {
+ ScalarEvolution &SE, const Loop *L) {
uint32_t BW = A.getBitWidth();
assert(BW == SE.getTypeSizeInBits(B->getType()));
assert(A != 0 && "A must be non-zero.");
@@ -10260,7 +10267,12 @@ SolveLinEquationWithOverflow(const APInt &A, const SCEV *B,
//
// B is divisible by D if and only if the multiplicity of prime factor 2 for B
// is not less than multiplicity of this prime factor for D.
- if (SE.getMinTrailingZeros(B) < Mult2) {
+ unsigned MinTZ = SE.getMinTrailingZeros(B);
+ // Try again with the terminator of the loop predecessor for context-specific
+ // result, if MinTZ s too small.
+ if (MinTZ < Mult2 && L->getLoopPredecessor())
+ MinTZ = SE.getMinTrailingZeros(B, L->getLoopPredecessor()->getTerminator());
+ if (MinTZ < Mult2) {
// Check if we can prove there's no remainder using URem.
const SCEV *URem =
SE.getURemExpr(B, SE.getConstant(APInt::getOneBitSet(BW, Mult2)));
@@ -10708,7 +10720,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
return getCouldNotCompute();
const SCEV *E = SolveLinEquationWithOverflow(
StepC->getAPInt(), getNegativeSCEV(Start),
- AllowPredicates ? &Predicates : nullptr, *this);
+ AllowPredicates ? &Predicates : nullptr, *this, L);
const SCEV *M = E;
if (E != getCouldNotCompute()) {