diff options
Diffstat (limited to 'llvm/lib/Transforms/InstCombine')
3 files changed, 67 insertions, 29 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp index e1e24a9..dab200d 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -289,12 +289,11 @@ Instruction *InstCombinerImpl::SimplifyAnyMemSet(AnyMemSetInst *MI) { // * Narrow width by halfs excluding zero/undef lanes Value *InstCombinerImpl::simplifyMaskedLoad(IntrinsicInst &II) { Value *LoadPtr = II.getArgOperand(0); - const Align Alignment = - cast<ConstantInt>(II.getArgOperand(1))->getAlignValue(); + const Align Alignment = II.getParamAlign(0).valueOrOne(); // If the mask is all ones or undefs, this is a plain vector load of the 1st // argument. - if (maskIsAllOneOrUndef(II.getArgOperand(2))) { + if (maskIsAllOneOrUndef(II.getArgOperand(1))) { LoadInst *L = Builder.CreateAlignedLoad(II.getType(), LoadPtr, Alignment, "unmaskedload"); L->copyMetadata(II); @@ -308,7 +307,7 @@ Value *InstCombinerImpl::simplifyMaskedLoad(IntrinsicInst &II) { LoadInst *LI = Builder.CreateAlignedLoad(II.getType(), LoadPtr, Alignment, "unmaskedload"); LI->copyMetadata(II); - return Builder.CreateSelect(II.getArgOperand(2), LI, II.getArgOperand(3)); + return Builder.CreateSelect(II.getArgOperand(1), LI, II.getArgOperand(2)); } return nullptr; @@ -319,8 +318,8 @@ Value *InstCombinerImpl::simplifyMaskedLoad(IntrinsicInst &II) { // * Narrow width by halfs excluding zero/undef lanes Instruction *InstCombinerImpl::simplifyMaskedStore(IntrinsicInst &II) { Value *StorePtr = II.getArgOperand(1); - Align Alignment = cast<ConstantInt>(II.getArgOperand(2))->getAlignValue(); - auto *ConstMask = dyn_cast<Constant>(II.getArgOperand(3)); + Align Alignment = II.getParamAlign(1).valueOrOne(); + auto *ConstMask = dyn_cast<Constant>(II.getArgOperand(2)); if (!ConstMask) return nullptr; @@ -356,7 +355,7 @@ Instruction *InstCombinerImpl::simplifyMaskedStore(IntrinsicInst &II) { // * Narrow width by halfs excluding zero/undef lanes // * Vector incrementing address -> vector masked load Instruction *InstCombinerImpl::simplifyMaskedGather(IntrinsicInst &II) { - auto *ConstMask = dyn_cast<Constant>(II.getArgOperand(2)); + auto *ConstMask = dyn_cast<Constant>(II.getArgOperand(1)); if (!ConstMask) return nullptr; @@ -366,8 +365,7 @@ Instruction *InstCombinerImpl::simplifyMaskedGather(IntrinsicInst &II) { if (ConstMask->isAllOnesValue()) if (auto *SplatPtr = getSplatValue(II.getArgOperand(0))) { auto *VecTy = cast<VectorType>(II.getType()); - const Align Alignment = - cast<ConstantInt>(II.getArgOperand(1))->getAlignValue(); + const Align Alignment = II.getParamAlign(0).valueOrOne(); LoadInst *L = Builder.CreateAlignedLoad(VecTy->getElementType(), SplatPtr, Alignment, "load.scalar"); Value *Shuf = @@ -384,7 +382,7 @@ Instruction *InstCombinerImpl::simplifyMaskedGather(IntrinsicInst &II) { // * Narrow store width by halfs excluding zero/undef lanes // * Vector incrementing address -> vector masked store Instruction *InstCombinerImpl::simplifyMaskedScatter(IntrinsicInst &II) { - auto *ConstMask = dyn_cast<Constant>(II.getArgOperand(3)); + auto *ConstMask = dyn_cast<Constant>(II.getArgOperand(2)); if (!ConstMask) return nullptr; @@ -397,8 +395,7 @@ Instruction *InstCombinerImpl::simplifyMaskedScatter(IntrinsicInst &II) { // scatter(splat(value), splat(ptr), non-zero-mask) -> store value, ptr if (auto *SplatValue = getSplatValue(II.getArgOperand(0))) { if (maskContainsAllOneOrUndef(ConstMask)) { - Align Alignment = - cast<ConstantInt>(II.getArgOperand(2))->getAlignValue(); + Align Alignment = II.getParamAlign(1).valueOrOne(); StoreInst *S = new StoreInst(SplatValue, SplatPtr, /*IsVolatile=*/false, Alignment); S->copyMetadata(II); @@ -408,7 +405,7 @@ Instruction *InstCombinerImpl::simplifyMaskedScatter(IntrinsicInst &II) { // scatter(vector, splat(ptr), splat(true)) -> store extract(vector, // lastlane), ptr if (ConstMask->isAllOnesValue()) { - Align Alignment = cast<ConstantInt>(II.getArgOperand(2))->getAlignValue(); + Align Alignment = II.getParamAlign(1).valueOrOne(); VectorType *WideLoadTy = cast<VectorType>(II.getArgOperand(1)->getType()); ElementCount VF = WideLoadTy->getElementCount(); Value *RunTimeVF = Builder.CreateElementCount(Builder.getInt32Ty(), VF); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index 09cb225..975498f 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -3757,6 +3757,10 @@ static Instruction *foldBitCeil(SelectInst &SI, IRBuilderBase &Builder, // (x < y) ? -1 : zext(x > y) // (x > y) ? 1 : sext(x != y) // (x > y) ? 1 : sext(x < y) +// (x == y) ? 0 : (x > y ? 1 : -1) +// (x == y) ? 0 : (x < y ? -1 : 1) +// Special case: x == C ? 0 : (x > C - 1 ? 1 : -1) +// Special case: x == C ? 0 : (x < C + 1 ? -1 : 1) // Into ucmp/scmp(x, y), where signedness is determined by the signedness // of the comparison in the original sequence. Instruction *InstCombinerImpl::foldSelectToCmp(SelectInst &SI) { @@ -3849,6 +3853,44 @@ Instruction *InstCombinerImpl::foldSelectToCmp(SelectInst &SI) { } } + // Special cases with constants: x == C ? 0 : (x > C-1 ? 1 : -1) + if (Pred == ICmpInst::ICMP_EQ && match(TV, m_Zero())) { + const APInt *C; + if (match(RHS, m_APInt(C))) { + CmpPredicate InnerPred; + Value *InnerRHS; + const APInt *InnerTV, *InnerFV; + if (match(FV, + m_Select(m_ICmp(InnerPred, m_Specific(LHS), m_Value(InnerRHS)), + m_APInt(InnerTV), m_APInt(InnerFV)))) { + + // x == C ? 0 : (x > C-1 ? 1 : -1) + if (ICmpInst::isGT(InnerPred) && InnerTV->isOne() && + InnerFV->isAllOnes()) { + IsSigned = ICmpInst::isSigned(InnerPred); + bool CanSubOne = IsSigned ? !C->isMinSignedValue() : !C->isMinValue(); + if (CanSubOne) { + APInt Cminus1 = *C - 1; + if (match(InnerRHS, m_SpecificInt(Cminus1))) + Replace = true; + } + } + + // x == C ? 0 : (x < C+1 ? -1 : 1) + if (ICmpInst::isLT(InnerPred) && InnerTV->isAllOnes() && + InnerFV->isOne()) { + IsSigned = ICmpInst::isSigned(InnerPred); + bool CanAddOne = IsSigned ? !C->isMaxSignedValue() : !C->isMaxValue(); + if (CanAddOne) { + APInt Cplus1 = *C + 1; + if (match(InnerRHS, m_SpecificInt(Cplus1))) + Replace = true; + } + } + } + } + } + Intrinsic::ID IID = IsSigned ? Intrinsic::scmp : Intrinsic::ucmp; if (Replace) return replaceInstUsesWith( @@ -4459,24 +4501,24 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { if (Value *V = foldSelectIntoAddConstant(SI, Builder)) return replaceInstUsesWith(SI, V); - // select(mask, mload(,,mask,0), 0) -> mload(,,mask,0) + // select(mask, mload(ptr,mask,0), 0) -> mload(ptr,mask,0) // Load inst is intentionally not checked for hasOneUse() if (match(FalseVal, m_Zero()) && - (match(TrueVal, m_MaskedLoad(m_Value(), m_Value(), m_Specific(CondVal), + (match(TrueVal, m_MaskedLoad(m_Value(), m_Specific(CondVal), m_CombineOr(m_Undef(), m_Zero()))) || - match(TrueVal, m_MaskedGather(m_Value(), m_Value(), m_Specific(CondVal), + match(TrueVal, m_MaskedGather(m_Value(), m_Specific(CondVal), m_CombineOr(m_Undef(), m_Zero()))))) { auto *MaskedInst = cast<IntrinsicInst>(TrueVal); - if (isa<UndefValue>(MaskedInst->getArgOperand(3))) - MaskedInst->setArgOperand(3, FalseVal /* Zero */); + if (isa<UndefValue>(MaskedInst->getArgOperand(2))) + MaskedInst->setArgOperand(2, FalseVal /* Zero */); return replaceInstUsesWith(SI, MaskedInst); } Value *Mask; if (match(TrueVal, m_Zero()) && - (match(FalseVal, m_MaskedLoad(m_Value(), m_Value(), m_Value(Mask), + (match(FalseVal, m_MaskedLoad(m_Value(), m_Value(Mask), m_CombineOr(m_Undef(), m_Zero()))) || - match(FalseVal, m_MaskedGather(m_Value(), m_Value(), m_Value(Mask), + match(FalseVal, m_MaskedGather(m_Value(), m_Value(Mask), m_CombineOr(m_Undef(), m_Zero())))) && (CondVal->getType() == Mask->getType())) { // We can remove the select by ensuring the load zeros all lanes the @@ -4489,8 +4531,8 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { if (CanMergeSelectIntoLoad) { auto *MaskedInst = cast<IntrinsicInst>(FalseVal); - if (isa<UndefValue>(MaskedInst->getArgOperand(3))) - MaskedInst->setArgOperand(3, TrueVal /* Zero */); + if (isa<UndefValue>(MaskedInst->getArgOperand(2))) + MaskedInst->setArgOperand(2, TrueVal /* Zero */); return replaceInstUsesWith(SI, MaskedInst); } } @@ -4629,14 +4671,13 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { } Value *MaskedLoadPtr; - const APInt *MaskedLoadAlignment; if (match(TrueVal, m_OneUse(m_MaskedLoad(m_Value(MaskedLoadPtr), - m_APInt(MaskedLoadAlignment), m_Specific(CondVal), m_Value())))) return replaceInstUsesWith( - SI, Builder.CreateMaskedLoad(TrueVal->getType(), MaskedLoadPtr, - Align(MaskedLoadAlignment->getZExtValue()), - CondVal, FalseVal)); + SI, Builder.CreateMaskedLoad( + TrueVal->getType(), MaskedLoadPtr, + cast<IntrinsicInst>(TrueVal)->getParamAlign(0).valueOrOne(), + CondVal, FalseVal)); return nullptr; } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp index a330bb7..651e305 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp @@ -1892,7 +1892,7 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V, // segfaults which didn't exist in the original program. APInt DemandedPtrs(APInt::getAllOnes(VWidth)), DemandedPassThrough(DemandedElts); - if (auto *CMask = dyn_cast<Constant>(II->getOperand(2))) { + if (auto *CMask = dyn_cast<Constant>(II->getOperand(1))) { for (unsigned i = 0; i < VWidth; i++) { if (Constant *CElt = CMask->getAggregateElement(i)) { if (CElt->isNullValue()) @@ -1905,7 +1905,7 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V, if (II->getIntrinsicID() == Intrinsic::masked_gather) simplifyAndSetOp(II, 0, DemandedPtrs, PoisonElts2); - simplifyAndSetOp(II, 3, DemandedPassThrough, PoisonElts3); + simplifyAndSetOp(II, 2, DemandedPassThrough, PoisonElts3); // Output elements are undefined if the element from both sources are. // TODO: can strengthen via mask as well. |