diff options
Diffstat (limited to 'llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp')
-rw-r--r-- | llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp | 69 |
1 files changed, 55 insertions, 14 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index 09cb225..975498f 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -3757,6 +3757,10 @@ static Instruction *foldBitCeil(SelectInst &SI, IRBuilderBase &Builder, // (x < y) ? -1 : zext(x > y) // (x > y) ? 1 : sext(x != y) // (x > y) ? 1 : sext(x < y) +// (x == y) ? 0 : (x > y ? 1 : -1) +// (x == y) ? 0 : (x < y ? -1 : 1) +// Special case: x == C ? 0 : (x > C - 1 ? 1 : -1) +// Special case: x == C ? 0 : (x < C + 1 ? -1 : 1) // Into ucmp/scmp(x, y), where signedness is determined by the signedness // of the comparison in the original sequence. Instruction *InstCombinerImpl::foldSelectToCmp(SelectInst &SI) { @@ -3849,6 +3853,44 @@ Instruction *InstCombinerImpl::foldSelectToCmp(SelectInst &SI) { } } + // Special cases with constants: x == C ? 0 : (x > C-1 ? 1 : -1) + if (Pred == ICmpInst::ICMP_EQ && match(TV, m_Zero())) { + const APInt *C; + if (match(RHS, m_APInt(C))) { + CmpPredicate InnerPred; + Value *InnerRHS; + const APInt *InnerTV, *InnerFV; + if (match(FV, + m_Select(m_ICmp(InnerPred, m_Specific(LHS), m_Value(InnerRHS)), + m_APInt(InnerTV), m_APInt(InnerFV)))) { + + // x == C ? 0 : (x > C-1 ? 1 : -1) + if (ICmpInst::isGT(InnerPred) && InnerTV->isOne() && + InnerFV->isAllOnes()) { + IsSigned = ICmpInst::isSigned(InnerPred); + bool CanSubOne = IsSigned ? !C->isMinSignedValue() : !C->isMinValue(); + if (CanSubOne) { + APInt Cminus1 = *C - 1; + if (match(InnerRHS, m_SpecificInt(Cminus1))) + Replace = true; + } + } + + // x == C ? 0 : (x < C+1 ? -1 : 1) + if (ICmpInst::isLT(InnerPred) && InnerTV->isAllOnes() && + InnerFV->isOne()) { + IsSigned = ICmpInst::isSigned(InnerPred); + bool CanAddOne = IsSigned ? !C->isMaxSignedValue() : !C->isMaxValue(); + if (CanAddOne) { + APInt Cplus1 = *C + 1; + if (match(InnerRHS, m_SpecificInt(Cplus1))) + Replace = true; + } + } + } + } + } + Intrinsic::ID IID = IsSigned ? Intrinsic::scmp : Intrinsic::ucmp; if (Replace) return replaceInstUsesWith( @@ -4459,24 +4501,24 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { if (Value *V = foldSelectIntoAddConstant(SI, Builder)) return replaceInstUsesWith(SI, V); - // select(mask, mload(,,mask,0), 0) -> mload(,,mask,0) + // select(mask, mload(ptr,mask,0), 0) -> mload(ptr,mask,0) // Load inst is intentionally not checked for hasOneUse() if (match(FalseVal, m_Zero()) && - (match(TrueVal, m_MaskedLoad(m_Value(), m_Value(), m_Specific(CondVal), + (match(TrueVal, m_MaskedLoad(m_Value(), m_Specific(CondVal), m_CombineOr(m_Undef(), m_Zero()))) || - match(TrueVal, m_MaskedGather(m_Value(), m_Value(), m_Specific(CondVal), + match(TrueVal, m_MaskedGather(m_Value(), m_Specific(CondVal), m_CombineOr(m_Undef(), m_Zero()))))) { auto *MaskedInst = cast<IntrinsicInst>(TrueVal); - if (isa<UndefValue>(MaskedInst->getArgOperand(3))) - MaskedInst->setArgOperand(3, FalseVal /* Zero */); + if (isa<UndefValue>(MaskedInst->getArgOperand(2))) + MaskedInst->setArgOperand(2, FalseVal /* Zero */); return replaceInstUsesWith(SI, MaskedInst); } Value *Mask; if (match(TrueVal, m_Zero()) && - (match(FalseVal, m_MaskedLoad(m_Value(), m_Value(), m_Value(Mask), + (match(FalseVal, m_MaskedLoad(m_Value(), m_Value(Mask), m_CombineOr(m_Undef(), m_Zero()))) || - match(FalseVal, m_MaskedGather(m_Value(), m_Value(), m_Value(Mask), + match(FalseVal, m_MaskedGather(m_Value(), m_Value(Mask), m_CombineOr(m_Undef(), m_Zero())))) && (CondVal->getType() == Mask->getType())) { // We can remove the select by ensuring the load zeros all lanes the @@ -4489,8 +4531,8 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { if (CanMergeSelectIntoLoad) { auto *MaskedInst = cast<IntrinsicInst>(FalseVal); - if (isa<UndefValue>(MaskedInst->getArgOperand(3))) - MaskedInst->setArgOperand(3, TrueVal /* Zero */); + if (isa<UndefValue>(MaskedInst->getArgOperand(2))) + MaskedInst->setArgOperand(2, TrueVal /* Zero */); return replaceInstUsesWith(SI, MaskedInst); } } @@ -4629,14 +4671,13 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { } Value *MaskedLoadPtr; - const APInt *MaskedLoadAlignment; if (match(TrueVal, m_OneUse(m_MaskedLoad(m_Value(MaskedLoadPtr), - m_APInt(MaskedLoadAlignment), m_Specific(CondVal), m_Value())))) return replaceInstUsesWith( - SI, Builder.CreateMaskedLoad(TrueVal->getType(), MaskedLoadPtr, - Align(MaskedLoadAlignment->getZExtValue()), - CondVal, FalseVal)); + SI, Builder.CreateMaskedLoad( + TrueVal->getType(), MaskedLoadPtr, + cast<IntrinsicInst>(TrueVal)->getParamAlign(0).valueOrOne(), + CondVal, FalseVal)); return nullptr; } |