diff options
Diffstat (limited to 'llvm/lib/Transforms/InstCombine')
5 files changed, 105 insertions, 55 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp index e1e24a9..dab200d 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -289,12 +289,11 @@ Instruction *InstCombinerImpl::SimplifyAnyMemSet(AnyMemSetInst *MI) { // * Narrow width by halfs excluding zero/undef lanes Value *InstCombinerImpl::simplifyMaskedLoad(IntrinsicInst &II) { Value *LoadPtr = II.getArgOperand(0); - const Align Alignment = - cast<ConstantInt>(II.getArgOperand(1))->getAlignValue(); + const Align Alignment = II.getParamAlign(0).valueOrOne(); // If the mask is all ones or undefs, this is a plain vector load of the 1st // argument. - if (maskIsAllOneOrUndef(II.getArgOperand(2))) { + if (maskIsAllOneOrUndef(II.getArgOperand(1))) { LoadInst *L = Builder.CreateAlignedLoad(II.getType(), LoadPtr, Alignment, "unmaskedload"); L->copyMetadata(II); @@ -308,7 +307,7 @@ Value *InstCombinerImpl::simplifyMaskedLoad(IntrinsicInst &II) { LoadInst *LI = Builder.CreateAlignedLoad(II.getType(), LoadPtr, Alignment, "unmaskedload"); LI->copyMetadata(II); - return Builder.CreateSelect(II.getArgOperand(2), LI, II.getArgOperand(3)); + return Builder.CreateSelect(II.getArgOperand(1), LI, II.getArgOperand(2)); } return nullptr; @@ -319,8 +318,8 @@ Value *InstCombinerImpl::simplifyMaskedLoad(IntrinsicInst &II) { // * Narrow width by halfs excluding zero/undef lanes Instruction *InstCombinerImpl::simplifyMaskedStore(IntrinsicInst &II) { Value *StorePtr = II.getArgOperand(1); - Align Alignment = cast<ConstantInt>(II.getArgOperand(2))->getAlignValue(); - auto *ConstMask = dyn_cast<Constant>(II.getArgOperand(3)); + Align Alignment = II.getParamAlign(1).valueOrOne(); + auto *ConstMask = dyn_cast<Constant>(II.getArgOperand(2)); if (!ConstMask) return nullptr; @@ -356,7 +355,7 @@ Instruction *InstCombinerImpl::simplifyMaskedStore(IntrinsicInst &II) { // * Narrow width by halfs excluding zero/undef lanes // * Vector incrementing address -> vector masked load Instruction *InstCombinerImpl::simplifyMaskedGather(IntrinsicInst &II) { - auto *ConstMask = dyn_cast<Constant>(II.getArgOperand(2)); + auto *ConstMask = dyn_cast<Constant>(II.getArgOperand(1)); if (!ConstMask) return nullptr; @@ -366,8 +365,7 @@ Instruction *InstCombinerImpl::simplifyMaskedGather(IntrinsicInst &II) { if (ConstMask->isAllOnesValue()) if (auto *SplatPtr = getSplatValue(II.getArgOperand(0))) { auto *VecTy = cast<VectorType>(II.getType()); - const Align Alignment = - cast<ConstantInt>(II.getArgOperand(1))->getAlignValue(); + const Align Alignment = II.getParamAlign(0).valueOrOne(); LoadInst *L = Builder.CreateAlignedLoad(VecTy->getElementType(), SplatPtr, Alignment, "load.scalar"); Value *Shuf = @@ -384,7 +382,7 @@ Instruction *InstCombinerImpl::simplifyMaskedGather(IntrinsicInst &II) { // * Narrow store width by halfs excluding zero/undef lanes // * Vector incrementing address -> vector masked store Instruction *InstCombinerImpl::simplifyMaskedScatter(IntrinsicInst &II) { - auto *ConstMask = dyn_cast<Constant>(II.getArgOperand(3)); + auto *ConstMask = dyn_cast<Constant>(II.getArgOperand(2)); if (!ConstMask) return nullptr; @@ -397,8 +395,7 @@ Instruction *InstCombinerImpl::simplifyMaskedScatter(IntrinsicInst &II) { // scatter(splat(value), splat(ptr), non-zero-mask) -> store value, ptr if (auto *SplatValue = getSplatValue(II.getArgOperand(0))) { if (maskContainsAllOneOrUndef(ConstMask)) { - Align Alignment = - cast<ConstantInt>(II.getArgOperand(2))->getAlignValue(); + Align Alignment = II.getParamAlign(1).valueOrOne(); StoreInst *S = new StoreInst(SplatValue, SplatPtr, /*IsVolatile=*/false, Alignment); S->copyMetadata(II); @@ -408,7 +405,7 @@ Instruction *InstCombinerImpl::simplifyMaskedScatter(IntrinsicInst &II) { // scatter(vector, splat(ptr), splat(true)) -> store extract(vector, // lastlane), ptr if (ConstMask->isAllOnesValue()) { - Align Alignment = cast<ConstantInt>(II.getArgOperand(2))->getAlignValue(); + Align Alignment = II.getParamAlign(1).valueOrOne(); VectorType *WideLoadTy = cast<VectorType>(II.getArgOperand(1)->getType()); ElementCount VF = WideLoadTy->getElementCount(); Value *RunTimeVF = Builder.CreateElementCount(Builder.getInt32Ty(), VF); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp index 4c9b10a..9b9fe26 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -156,9 +156,9 @@ Instruction *InstCombinerImpl::commonCastTransforms(CastInst &CI) { Value *Src = CI.getOperand(0); Type *Ty = CI.getType(); - if (auto *SrcC = dyn_cast<Constant>(Src)) - if (Constant *Res = ConstantFoldCastOperand(CI.getOpcode(), SrcC, Ty, DL)) - return replaceInstUsesWith(CI, Res); + if (Value *Res = + simplifyCastInst(CI.getOpcode(), Src, Ty, SQ.getWithInstruction(&CI))) + return replaceInstUsesWith(CI, Res); // Try to eliminate a cast of a cast. if (auto *CSrc = dyn_cast<CastInst>(Src)) { // A->B->C cast @@ -1643,33 +1643,46 @@ Instruction *InstCombinerImpl::visitSExt(SExtInst &Sext) { /// Return a Constant* for the specified floating-point constant if it fits /// in the specified FP type without changing its value. -static bool fitsInFPType(ConstantFP *CFP, const fltSemantics &Sem) { +static bool fitsInFPType(APFloat F, const fltSemantics &Sem) { bool losesInfo; - APFloat F = CFP->getValueAPF(); (void)F.convert(Sem, APFloat::rmNearestTiesToEven, &losesInfo); return !losesInfo; } -static Type *shrinkFPConstant(ConstantFP *CFP, bool PreferBFloat) { - if (CFP->getType() == Type::getPPC_FP128Ty(CFP->getContext())) - return nullptr; // No constant folding of this. +static Type *shrinkFPConstant(LLVMContext &Ctx, const APFloat &F, + bool PreferBFloat) { // See if the value can be truncated to bfloat and then reextended. - if (PreferBFloat && fitsInFPType(CFP, APFloat::BFloat())) - return Type::getBFloatTy(CFP->getContext()); + if (PreferBFloat && fitsInFPType(F, APFloat::BFloat())) + return Type::getBFloatTy(Ctx); // See if the value can be truncated to half and then reextended. - if (!PreferBFloat && fitsInFPType(CFP, APFloat::IEEEhalf())) - return Type::getHalfTy(CFP->getContext()); + if (!PreferBFloat && fitsInFPType(F, APFloat::IEEEhalf())) + return Type::getHalfTy(Ctx); // See if the value can be truncated to float and then reextended. - if (fitsInFPType(CFP, APFloat::IEEEsingle())) - return Type::getFloatTy(CFP->getContext()); - if (CFP->getType()->isDoubleTy()) - return nullptr; // Won't shrink. - if (fitsInFPType(CFP, APFloat::IEEEdouble())) - return Type::getDoubleTy(CFP->getContext()); + if (fitsInFPType(F, APFloat::IEEEsingle())) + return Type::getFloatTy(Ctx); + if (&F.getSemantics() == &APFloat::IEEEdouble()) + return nullptr; // Won't shrink. + // See if the value can be truncated to double and then reextended. + if (fitsInFPType(F, APFloat::IEEEdouble())) + return Type::getDoubleTy(Ctx); // Don't try to shrink to various long double types. return nullptr; } +static Type *shrinkFPConstant(ConstantFP *CFP, bool PreferBFloat) { + Type *Ty = CFP->getType(); + if (Ty->getScalarType()->isPPC_FP128Ty()) + return nullptr; // No constant folding of this. + + Type *ShrinkTy = + shrinkFPConstant(CFP->getContext(), CFP->getValueAPF(), PreferBFloat); + if (ShrinkTy) + if (auto *VecTy = dyn_cast<VectorType>(Ty)) + ShrinkTy = VectorType::get(ShrinkTy, VecTy); + + return ShrinkTy; +} + // Determine if this is a vector of ConstantFPs and if so, return the minimal // type we can safely truncate all elements to. static Type *shrinkFPConstantVector(Value *V, bool PreferBFloat) { @@ -1720,10 +1733,10 @@ static Type *getMinimumFPType(Value *V, bool PreferBFloat) { // Try to shrink scalable and fixed splat vectors. if (auto *FPC = dyn_cast<Constant>(V)) - if (isa<VectorType>(V->getType())) + if (auto *VTy = dyn_cast<VectorType>(V->getType())) if (auto *Splat = dyn_cast_or_null<ConstantFP>(FPC->getSplatValue())) if (Type *T = shrinkFPConstant(Splat, PreferBFloat)) - return T; + return VectorType::get(T, VTy); // Try to shrink a vector of FP constants. This returns nullptr on scalable // vectors @@ -1796,10 +1809,9 @@ Instruction *InstCombinerImpl::visitFPTrunc(FPTruncInst &FPT) { Type *Ty = FPT.getType(); auto *BO = dyn_cast<BinaryOperator>(FPT.getOperand(0)); if (BO && BO->hasOneUse()) { - Type *LHSMinType = - getMinimumFPType(BO->getOperand(0), /*PreferBFloat=*/Ty->isBFloatTy()); - Type *RHSMinType = - getMinimumFPType(BO->getOperand(1), /*PreferBFloat=*/Ty->isBFloatTy()); + bool PreferBFloat = Ty->getScalarType()->isBFloatTy(); + Type *LHSMinType = getMinimumFPType(BO->getOperand(0), PreferBFloat); + Type *RHSMinType = getMinimumFPType(BO->getOperand(1), PreferBFloat); unsigned OpWidth = BO->getType()->getFPMantissaWidth(); unsigned LHSWidth = LHSMinType->getFPMantissaWidth(); unsigned RHSWidth = RHSMinType->getFPMantissaWidth(); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index 07ad65c..fba1ccf 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -1481,13 +1481,13 @@ Instruction *InstCombinerImpl::foldICmpTruncConstant(ICmpInst &Cmp, return new ICmpInst(Pred, Y, ConstantInt::get(SrcTy, C.logBase2())); } - if (Cmp.isEquality() && Trunc->hasOneUse()) { + if (Cmp.isEquality() && (Trunc->hasOneUse() || Trunc->hasNoUnsignedWrap())) { // Canonicalize to a mask and wider compare if the wide type is suitable: // (trunc X to i8) == C --> (X & 0xff) == (zext C) if (!SrcTy->isVectorTy() && shouldChangeType(DstBits, SrcBits)) { Constant *Mask = ConstantInt::get(SrcTy, APInt::getLowBitsSet(SrcBits, DstBits)); - Value *And = Builder.CreateAnd(X, Mask); + Value *And = Trunc->hasNoUnsignedWrap() ? X : Builder.CreateAnd(X, Mask); Constant *WideC = ConstantInt::get(SrcTy, C.zext(SrcBits)); return new ICmpInst(Pred, And, WideC); } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index 09cb225..975498f 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -3757,6 +3757,10 @@ static Instruction *foldBitCeil(SelectInst &SI, IRBuilderBase &Builder, // (x < y) ? -1 : zext(x > y) // (x > y) ? 1 : sext(x != y) // (x > y) ? 1 : sext(x < y) +// (x == y) ? 0 : (x > y ? 1 : -1) +// (x == y) ? 0 : (x < y ? -1 : 1) +// Special case: x == C ? 0 : (x > C - 1 ? 1 : -1) +// Special case: x == C ? 0 : (x < C + 1 ? -1 : 1) // Into ucmp/scmp(x, y), where signedness is determined by the signedness // of the comparison in the original sequence. Instruction *InstCombinerImpl::foldSelectToCmp(SelectInst &SI) { @@ -3849,6 +3853,44 @@ Instruction *InstCombinerImpl::foldSelectToCmp(SelectInst &SI) { } } + // Special cases with constants: x == C ? 0 : (x > C-1 ? 1 : -1) + if (Pred == ICmpInst::ICMP_EQ && match(TV, m_Zero())) { + const APInt *C; + if (match(RHS, m_APInt(C))) { + CmpPredicate InnerPred; + Value *InnerRHS; + const APInt *InnerTV, *InnerFV; + if (match(FV, + m_Select(m_ICmp(InnerPred, m_Specific(LHS), m_Value(InnerRHS)), + m_APInt(InnerTV), m_APInt(InnerFV)))) { + + // x == C ? 0 : (x > C-1 ? 1 : -1) + if (ICmpInst::isGT(InnerPred) && InnerTV->isOne() && + InnerFV->isAllOnes()) { + IsSigned = ICmpInst::isSigned(InnerPred); + bool CanSubOne = IsSigned ? !C->isMinSignedValue() : !C->isMinValue(); + if (CanSubOne) { + APInt Cminus1 = *C - 1; + if (match(InnerRHS, m_SpecificInt(Cminus1))) + Replace = true; + } + } + + // x == C ? 0 : (x < C+1 ? -1 : 1) + if (ICmpInst::isLT(InnerPred) && InnerTV->isAllOnes() && + InnerFV->isOne()) { + IsSigned = ICmpInst::isSigned(InnerPred); + bool CanAddOne = IsSigned ? !C->isMaxSignedValue() : !C->isMaxValue(); + if (CanAddOne) { + APInt Cplus1 = *C + 1; + if (match(InnerRHS, m_SpecificInt(Cplus1))) + Replace = true; + } + } + } + } + } + Intrinsic::ID IID = IsSigned ? Intrinsic::scmp : Intrinsic::ucmp; if (Replace) return replaceInstUsesWith( @@ -4459,24 +4501,24 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { if (Value *V = foldSelectIntoAddConstant(SI, Builder)) return replaceInstUsesWith(SI, V); - // select(mask, mload(,,mask,0), 0) -> mload(,,mask,0) + // select(mask, mload(ptr,mask,0), 0) -> mload(ptr,mask,0) // Load inst is intentionally not checked for hasOneUse() if (match(FalseVal, m_Zero()) && - (match(TrueVal, m_MaskedLoad(m_Value(), m_Value(), m_Specific(CondVal), + (match(TrueVal, m_MaskedLoad(m_Value(), m_Specific(CondVal), m_CombineOr(m_Undef(), m_Zero()))) || - match(TrueVal, m_MaskedGather(m_Value(), m_Value(), m_Specific(CondVal), + match(TrueVal, m_MaskedGather(m_Value(), m_Specific(CondVal), m_CombineOr(m_Undef(), m_Zero()))))) { auto *MaskedInst = cast<IntrinsicInst>(TrueVal); - if (isa<UndefValue>(MaskedInst->getArgOperand(3))) - MaskedInst->setArgOperand(3, FalseVal /* Zero */); + if (isa<UndefValue>(MaskedInst->getArgOperand(2))) + MaskedInst->setArgOperand(2, FalseVal /* Zero */); return replaceInstUsesWith(SI, MaskedInst); } Value *Mask; if (match(TrueVal, m_Zero()) && - (match(FalseVal, m_MaskedLoad(m_Value(), m_Value(), m_Value(Mask), + (match(FalseVal, m_MaskedLoad(m_Value(), m_Value(Mask), m_CombineOr(m_Undef(), m_Zero()))) || - match(FalseVal, m_MaskedGather(m_Value(), m_Value(), m_Value(Mask), + match(FalseVal, m_MaskedGather(m_Value(), m_Value(Mask), m_CombineOr(m_Undef(), m_Zero())))) && (CondVal->getType() == Mask->getType())) { // We can remove the select by ensuring the load zeros all lanes the @@ -4489,8 +4531,8 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { if (CanMergeSelectIntoLoad) { auto *MaskedInst = cast<IntrinsicInst>(FalseVal); - if (isa<UndefValue>(MaskedInst->getArgOperand(3))) - MaskedInst->setArgOperand(3, TrueVal /* Zero */); + if (isa<UndefValue>(MaskedInst->getArgOperand(2))) + MaskedInst->setArgOperand(2, TrueVal /* Zero */); return replaceInstUsesWith(SI, MaskedInst); } } @@ -4629,14 +4671,13 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { } Value *MaskedLoadPtr; - const APInt *MaskedLoadAlignment; if (match(TrueVal, m_OneUse(m_MaskedLoad(m_Value(MaskedLoadPtr), - m_APInt(MaskedLoadAlignment), m_Specific(CondVal), m_Value())))) return replaceInstUsesWith( - SI, Builder.CreateMaskedLoad(TrueVal->getType(), MaskedLoadPtr, - Align(MaskedLoadAlignment->getZExtValue()), - CondVal, FalseVal)); + SI, Builder.CreateMaskedLoad( + TrueVal->getType(), MaskedLoadPtr, + cast<IntrinsicInst>(TrueVal)->getParamAlign(0).valueOrOne(), + CondVal, FalseVal)); return nullptr; } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp index a330bb7..651e305 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp @@ -1892,7 +1892,7 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V, // segfaults which didn't exist in the original program. APInt DemandedPtrs(APInt::getAllOnes(VWidth)), DemandedPassThrough(DemandedElts); - if (auto *CMask = dyn_cast<Constant>(II->getOperand(2))) { + if (auto *CMask = dyn_cast<Constant>(II->getOperand(1))) { for (unsigned i = 0; i < VWidth; i++) { if (Constant *CElt = CMask->getAggregateElement(i)) { if (CElt->isNullValue()) @@ -1905,7 +1905,7 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V, if (II->getIntrinsicID() == Intrinsic::masked_gather) simplifyAndSetOp(II, 0, DemandedPtrs, PoisonElts2); - simplifyAndSetOp(II, 3, DemandedPassThrough, PoisonElts3); + simplifyAndSetOp(II, 2, DemandedPassThrough, PoisonElts3); // Output elements are undefined if the element from both sources are. // TODO: can strengthen via mask as well. |