diff options
Diffstat (limited to 'llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp')
-rw-r--r-- | llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp | 92 |
1 files changed, 4 insertions, 88 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index 6e57faa..b13d3ed 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -14,6 +14,7 @@ #include "InstCombineInternal.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/PatternMatch.h" using namespace llvm; using namespace PatternMatch; @@ -60,91 +61,6 @@ static Value *generateMinMaxSelectPattern(InstCombiner::BuilderTy *Builder, return Builder->CreateSelect(Builder->CreateICmp(Pred, A, B), A, B); } -/// MatchSelectPattern - Pattern match integer [SU]MIN, [SU]MAX, and ABS idioms, -/// returning the kind and providing the out parameter results if we -/// successfully match. -static SelectPatternFlavor -MatchSelectPattern(Value *V, Value *&LHS, Value *&RHS) { - SelectInst *SI = dyn_cast<SelectInst>(V); - if (!SI) return SPF_UNKNOWN; - - ICmpInst *ICI = dyn_cast<ICmpInst>(SI->getCondition()); - if (!ICI) return SPF_UNKNOWN; - - ICmpInst::Predicate Pred = ICI->getPredicate(); - Value *CmpLHS = ICI->getOperand(0); - Value *CmpRHS = ICI->getOperand(1); - Value *TrueVal = SI->getTrueValue(); - Value *FalseVal = SI->getFalseValue(); - - LHS = CmpLHS; - RHS = CmpRHS; - - // (icmp X, Y) ? X : Y - if (TrueVal == CmpLHS && FalseVal == CmpRHS) { - switch (Pred) { - default: return SPF_UNKNOWN; // Equality. - case ICmpInst::ICMP_UGT: - case ICmpInst::ICMP_UGE: return SPF_UMAX; - case ICmpInst::ICMP_SGT: - case ICmpInst::ICMP_SGE: return SPF_SMAX; - case ICmpInst::ICMP_ULT: - case ICmpInst::ICMP_ULE: return SPF_UMIN; - case ICmpInst::ICMP_SLT: - case ICmpInst::ICMP_SLE: return SPF_SMIN; - } - } - - // (icmp X, Y) ? Y : X - if (TrueVal == CmpRHS && FalseVal == CmpLHS) { - switch (Pred) { - default: return SPF_UNKNOWN; // Equality. - case ICmpInst::ICMP_UGT: - case ICmpInst::ICMP_UGE: return SPF_UMIN; - case ICmpInst::ICMP_SGT: - case ICmpInst::ICMP_SGE: return SPF_SMIN; - case ICmpInst::ICMP_ULT: - case ICmpInst::ICMP_ULE: return SPF_UMAX; - case ICmpInst::ICMP_SLT: - case ICmpInst::ICMP_SLE: return SPF_SMAX; - } - } - - if (ConstantInt *C1 = dyn_cast<ConstantInt>(CmpRHS)) { - if ((CmpLHS == TrueVal && match(FalseVal, m_Neg(m_Specific(CmpLHS)))) || - (CmpLHS == FalseVal && match(TrueVal, m_Neg(m_Specific(CmpLHS))))) { - - // ABS(X) ==> (X >s 0) ? X : -X and (X >s -1) ? X : -X - // NABS(X) ==> (X >s 0) ? -X : X and (X >s -1) ? -X : X - if (Pred == ICmpInst::ICMP_SGT && (C1->isZero() || C1->isMinusOne())) { - return (CmpLHS == TrueVal) ? SPF_ABS : SPF_NABS; - } - - // ABS(X) ==> (X <s 0) ? -X : X and (X <s 1) ? -X : X - // NABS(X) ==> (X <s 0) ? X : -X and (X <s 1) ? X : -X - if (Pred == ICmpInst::ICMP_SLT && (C1->isZero() || C1->isOne())) { - return (CmpLHS == FalseVal) ? SPF_ABS : SPF_NABS; - } - } - - // Y >s C ? ~Y : ~C == ~Y <s ~C ? ~Y : ~C = SMIN(~Y, ~C) - if (const auto *C2 = dyn_cast<ConstantInt>(FalseVal)) { - if (C1->getType() == C2->getType() && ~C1->getValue() == C2->getValue() && - (match(TrueVal, m_Not(m_Specific(CmpLHS))) || - match(CmpLHS, m_Not(m_Specific(TrueVal))))) { - LHS = TrueVal; - RHS = FalseVal; - return SPF_SMIN; - } - } - } - - // TODO: (X > 4) ? X : 5 --> (X >= 5) ? X : 5 --> MAX(X, 5) - - return SPF_UNKNOWN; -} - - /// GetSelectFoldableOperands - We want to turn code that looks like this: /// %C = or %A, %B /// %D = select %cond, %C, %A @@ -1243,18 +1159,18 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { return FoldI; Value *LHS, *RHS, *LHS2, *RHS2; - SelectPatternFlavor SPF = MatchSelectPattern(&SI, LHS, RHS); + SelectPatternFlavor SPF = matchSelectPattern(&SI, LHS, RHS); // MAX(MAX(a, b), a) -> MAX(a, b) // MIN(MIN(a, b), a) -> MIN(a, b) // MAX(MIN(a, b), a) -> a // MIN(MAX(a, b), a) -> a if (SPF) { - if (SelectPatternFlavor SPF2 = MatchSelectPattern(LHS, LHS2, RHS2)) + if (SelectPatternFlavor SPF2 = matchSelectPattern(LHS, LHS2, RHS2)) if (Instruction *R = FoldSPFofSPF(cast<Instruction>(LHS),SPF2,LHS2,RHS2, SI, SPF, RHS)) return R; - if (SelectPatternFlavor SPF2 = MatchSelectPattern(RHS, LHS2, RHS2)) + if (SelectPatternFlavor SPF2 = matchSelectPattern(RHS, LHS2, RHS2)) if (Instruction *R = FoldSPFofSPF(cast<Instruction>(RHS),SPF2,LHS2,RHS2, SI, SPF, LHS)) return R; |