diff options
Diffstat (limited to 'llvm/lib')
-rw-r--r-- | llvm/lib/Transforms/Utils/SCCPSolver.cpp | 70 |
1 files changed, 48 insertions, 22 deletions
diff --git a/llvm/lib/Transforms/Utils/SCCPSolver.cpp b/llvm/lib/Transforms/Utils/SCCPSolver.cpp index 4535f86..1a2e422 100644 --- a/llvm/lib/Transforms/Utils/SCCPSolver.cpp +++ b/llvm/lib/Transforms/Utils/SCCPSolver.cpp @@ -20,6 +20,7 @@ #include "llvm/Analysis/ValueLatticeUtils.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/InstVisitor.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" @@ -30,6 +31,7 @@ #include <vector> using namespace llvm; +using namespace PatternMatch; #define DEBUG_TYPE "sccp" @@ -83,20 +85,28 @@ bool SCCPSolver::tryToReplaceWithConstant(Value *V) { return true; } +/// Helper for getting ranges from \p Solver. Instructions inserted during +/// simplification are unavailable in the solver, so we return a full range for +/// them. +static ConstantRange getRange(Value *Op, SCCPSolver &Solver, + const SmallPtrSetImpl<Value *> &InsertedValues) { + if (auto *Const = dyn_cast<Constant>(Op)) + return Const->toConstantRange(); + if (InsertedValues.contains(Op)) { + unsigned Bitwidth = Op->getType()->getScalarSizeInBits(); + return ConstantRange::getFull(Bitwidth); + } + return Solver.getLatticeValueFor(Op).asConstantRange(Op->getType(), + /*UndefAllowed=*/false); +} + /// Try to use \p Inst's value range from \p Solver to infer the NUW flag. static bool refineInstruction(SCCPSolver &Solver, const SmallPtrSetImpl<Value *> &InsertedValues, Instruction &Inst) { bool Changed = false; auto GetRange = [&Solver, &InsertedValues](Value *Op) { - if (auto *Const = dyn_cast<Constant>(Op)) - return Const->toConstantRange(); - if (InsertedValues.contains(Op)) { - unsigned Bitwidth = Op->getType()->getScalarSizeInBits(); - return ConstantRange::getFull(Bitwidth); - } - return Solver.getLatticeValueFor(Op).asConstantRange( - Op->getType(), /*UndefAllowed=*/false); + return getRange(Op, Solver, InsertedValues); }; if (isa<OverflowingBinaryOperator>(Inst)) { @@ -167,16 +177,8 @@ static bool replaceSignedInst(SCCPSolver &Solver, SmallPtrSetImpl<Value *> &InsertedValues, Instruction &Inst) { // Determine if a signed value is known to be >= 0. - auto isNonNegative = [&Solver](Value *V) { - // If this value was constant-folded, it may not have a solver entry. - // Handle integers. Otherwise, return false. - if (auto *C = dyn_cast<Constant>(V)) { - auto *CInt = dyn_cast<ConstantInt>(C); - return CInt && !CInt->isNegative(); - } - const ValueLatticeElement &IV = Solver.getLatticeValueFor(V); - return IV.isConstantRange(/*UndefAllowed=*/false) && - IV.getConstantRange().isAllNonNegative(); + auto isNonNegative = [&Solver, &InsertedValues](Value *V) { + return getRange(V, Solver, InsertedValues).isAllNonNegative(); }; Instruction *NewInst = nullptr; @@ -185,7 +187,7 @@ static bool replaceSignedInst(SCCPSolver &Solver, case Instruction::SExt: { // If the source value is not negative, this is a zext/uitofp. Value *Op0 = Inst.getOperand(0); - if (InsertedValues.count(Op0) || !isNonNegative(Op0)) + if (!isNonNegative(Op0)) return false; NewInst = CastInst::Create(Inst.getOpcode() == Instruction::SExt ? Instruction::ZExt @@ -197,7 +199,7 @@ static bool replaceSignedInst(SCCPSolver &Solver, case Instruction::AShr: { // If the shifted value is not negative, this is a logical shift right. Value *Op0 = Inst.getOperand(0); - if (InsertedValues.count(Op0) || !isNonNegative(Op0)) + if (!isNonNegative(Op0)) return false; NewInst = BinaryOperator::CreateLShr(Op0, Inst.getOperand(1), "", Inst.getIterator()); NewInst->setIsExact(Inst.isExact()); @@ -207,8 +209,7 @@ static bool replaceSignedInst(SCCPSolver &Solver, case Instruction::SRem: { // If both operands are not negative, this is the same as udiv/urem. Value *Op0 = Inst.getOperand(0), *Op1 = Inst.getOperand(1); - if (InsertedValues.count(Op0) || InsertedValues.count(Op1) || - !isNonNegative(Op0) || !isNonNegative(Op1)) + if (!isNonNegative(Op0) || !isNonNegative(Op1)) return false; auto NewOpcode = Inst.getOpcode() == Instruction::SDiv ? Instruction::UDiv : Instruction::URem; @@ -232,6 +233,26 @@ static bool replaceSignedInst(SCCPSolver &Solver, return true; } +/// Try to use \p Inst's value range from \p Solver to simplify it. +static Value *simplifyInstruction(SCCPSolver &Solver, + SmallPtrSetImpl<Value *> &InsertedValues, + Instruction &Inst) { + auto GetRange = [&Solver, &InsertedValues](Value *Op) { + return getRange(Op, Solver, InsertedValues); + }; + + Value *X; + const APInt *RHSC; + // Remove masking operations. + if (match(&Inst, m_And(m_Value(X), m_LowBitMask(RHSC)))) { + ConstantRange LRange = GetRange(Inst.getOperand(0)); + if (LRange.getUnsignedMax().ule(*RHSC)) + return X; + } + + return nullptr; +} + bool SCCPSolver::simplifyInstsInBlock(BasicBlock &BB, SmallPtrSetImpl<Value *> &InsertedValues, Statistic &InstRemovedStat, @@ -251,6 +272,11 @@ bool SCCPSolver::simplifyInstsInBlock(BasicBlock &BB, ++InstReplacedStat; } else if (refineInstruction(*this, InsertedValues, Inst)) { MadeChanges = true; + } else if (auto *V = simplifyInstruction(*this, InsertedValues, Inst)) { + Inst.replaceAllUsesWith(V); + Inst.eraseFromParent(); + ++InstRemovedStat; + MadeChanges = true; } } return MadeChanges; |