aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib')
-rw-r--r--llvm/lib/Transforms/Utils/SCCPSolver.cpp70
1 files changed, 48 insertions, 22 deletions
diff --git a/llvm/lib/Transforms/Utils/SCCPSolver.cpp b/llvm/lib/Transforms/Utils/SCCPSolver.cpp
index 4535f86..1a2e422 100644
--- a/llvm/lib/Transforms/Utils/SCCPSolver.cpp
+++ b/llvm/lib/Transforms/Utils/SCCPSolver.cpp
@@ -20,6 +20,7 @@
#include "llvm/Analysis/ValueLatticeUtils.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/IR/InstVisitor.h"
+#include "llvm/IR/PatternMatch.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
@@ -30,6 +31,7 @@
#include <vector>
using namespace llvm;
+using namespace PatternMatch;
#define DEBUG_TYPE "sccp"
@@ -83,20 +85,28 @@ bool SCCPSolver::tryToReplaceWithConstant(Value *V) {
return true;
}
+/// Helper for getting ranges from \p Solver. Instructions inserted during
+/// simplification are unavailable in the solver, so we return a full range for
+/// them.
+static ConstantRange getRange(Value *Op, SCCPSolver &Solver,
+ const SmallPtrSetImpl<Value *> &InsertedValues) {
+ if (auto *Const = dyn_cast<Constant>(Op))
+ return Const->toConstantRange();
+ if (InsertedValues.contains(Op)) {
+ unsigned Bitwidth = Op->getType()->getScalarSizeInBits();
+ return ConstantRange::getFull(Bitwidth);
+ }
+ return Solver.getLatticeValueFor(Op).asConstantRange(Op->getType(),
+ /*UndefAllowed=*/false);
+}
+
/// Try to use \p Inst's value range from \p Solver to infer the NUW flag.
static bool refineInstruction(SCCPSolver &Solver,
const SmallPtrSetImpl<Value *> &InsertedValues,
Instruction &Inst) {
bool Changed = false;
auto GetRange = [&Solver, &InsertedValues](Value *Op) {
- if (auto *Const = dyn_cast<Constant>(Op))
- return Const->toConstantRange();
- if (InsertedValues.contains(Op)) {
- unsigned Bitwidth = Op->getType()->getScalarSizeInBits();
- return ConstantRange::getFull(Bitwidth);
- }
- return Solver.getLatticeValueFor(Op).asConstantRange(
- Op->getType(), /*UndefAllowed=*/false);
+ return getRange(Op, Solver, InsertedValues);
};
if (isa<OverflowingBinaryOperator>(Inst)) {
@@ -167,16 +177,8 @@ static bool replaceSignedInst(SCCPSolver &Solver,
SmallPtrSetImpl<Value *> &InsertedValues,
Instruction &Inst) {
// Determine if a signed value is known to be >= 0.
- auto isNonNegative = [&Solver](Value *V) {
- // If this value was constant-folded, it may not have a solver entry.
- // Handle integers. Otherwise, return false.
- if (auto *C = dyn_cast<Constant>(V)) {
- auto *CInt = dyn_cast<ConstantInt>(C);
- return CInt && !CInt->isNegative();
- }
- const ValueLatticeElement &IV = Solver.getLatticeValueFor(V);
- return IV.isConstantRange(/*UndefAllowed=*/false) &&
- IV.getConstantRange().isAllNonNegative();
+ auto isNonNegative = [&Solver, &InsertedValues](Value *V) {
+ return getRange(V, Solver, InsertedValues).isAllNonNegative();
};
Instruction *NewInst = nullptr;
@@ -185,7 +187,7 @@ static bool replaceSignedInst(SCCPSolver &Solver,
case Instruction::SExt: {
// If the source value is not negative, this is a zext/uitofp.
Value *Op0 = Inst.getOperand(0);
- if (InsertedValues.count(Op0) || !isNonNegative(Op0))
+ if (!isNonNegative(Op0))
return false;
NewInst = CastInst::Create(Inst.getOpcode() == Instruction::SExt
? Instruction::ZExt
@@ -197,7 +199,7 @@ static bool replaceSignedInst(SCCPSolver &Solver,
case Instruction::AShr: {
// If the shifted value is not negative, this is a logical shift right.
Value *Op0 = Inst.getOperand(0);
- if (InsertedValues.count(Op0) || !isNonNegative(Op0))
+ if (!isNonNegative(Op0))
return false;
NewInst = BinaryOperator::CreateLShr(Op0, Inst.getOperand(1), "", Inst.getIterator());
NewInst->setIsExact(Inst.isExact());
@@ -207,8 +209,7 @@ static bool replaceSignedInst(SCCPSolver &Solver,
case Instruction::SRem: {
// If both operands are not negative, this is the same as udiv/urem.
Value *Op0 = Inst.getOperand(0), *Op1 = Inst.getOperand(1);
- if (InsertedValues.count(Op0) || InsertedValues.count(Op1) ||
- !isNonNegative(Op0) || !isNonNegative(Op1))
+ if (!isNonNegative(Op0) || !isNonNegative(Op1))
return false;
auto NewOpcode = Inst.getOpcode() == Instruction::SDiv ? Instruction::UDiv
: Instruction::URem;
@@ -232,6 +233,26 @@ static bool replaceSignedInst(SCCPSolver &Solver,
return true;
}
+/// Try to use \p Inst's value range from \p Solver to simplify it.
+static Value *simplifyInstruction(SCCPSolver &Solver,
+ SmallPtrSetImpl<Value *> &InsertedValues,
+ Instruction &Inst) {
+ auto GetRange = [&Solver, &InsertedValues](Value *Op) {
+ return getRange(Op, Solver, InsertedValues);
+ };
+
+ Value *X;
+ const APInt *RHSC;
+ // Remove masking operations.
+ if (match(&Inst, m_And(m_Value(X), m_LowBitMask(RHSC)))) {
+ ConstantRange LRange = GetRange(Inst.getOperand(0));
+ if (LRange.getUnsignedMax().ule(*RHSC))
+ return X;
+ }
+
+ return nullptr;
+}
+
bool SCCPSolver::simplifyInstsInBlock(BasicBlock &BB,
SmallPtrSetImpl<Value *> &InsertedValues,
Statistic &InstRemovedStat,
@@ -251,6 +272,11 @@ bool SCCPSolver::simplifyInstsInBlock(BasicBlock &BB,
++InstReplacedStat;
} else if (refineInstruction(*this, InsertedValues, Inst)) {
MadeChanges = true;
+ } else if (auto *V = simplifyInstruction(*this, InsertedValues, Inst)) {
+ Inst.replaceAllUsesWith(V);
+ Inst.eraseFromParent();
+ ++InstRemovedStat;
+ MadeChanges = true;
}
}
return MadeChanges;