aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/InstCombine
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Transforms/InstCombine')
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp23
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp60
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp4
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp69
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp4
5 files changed, 105 insertions, 55 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index e1e24a9..dab200d 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -289,12 +289,11 @@ Instruction *InstCombinerImpl::SimplifyAnyMemSet(AnyMemSetInst *MI) {
// * Narrow width by halfs excluding zero/undef lanes
Value *InstCombinerImpl::simplifyMaskedLoad(IntrinsicInst &II) {
Value *LoadPtr = II.getArgOperand(0);
- const Align Alignment =
- cast<ConstantInt>(II.getArgOperand(1))->getAlignValue();
+ const Align Alignment = II.getParamAlign(0).valueOrOne();
// If the mask is all ones or undefs, this is a plain vector load of the 1st
// argument.
- if (maskIsAllOneOrUndef(II.getArgOperand(2))) {
+ if (maskIsAllOneOrUndef(II.getArgOperand(1))) {
LoadInst *L = Builder.CreateAlignedLoad(II.getType(), LoadPtr, Alignment,
"unmaskedload");
L->copyMetadata(II);
@@ -308,7 +307,7 @@ Value *InstCombinerImpl::simplifyMaskedLoad(IntrinsicInst &II) {
LoadInst *LI = Builder.CreateAlignedLoad(II.getType(), LoadPtr, Alignment,
"unmaskedload");
LI->copyMetadata(II);
- return Builder.CreateSelect(II.getArgOperand(2), LI, II.getArgOperand(3));
+ return Builder.CreateSelect(II.getArgOperand(1), LI, II.getArgOperand(2));
}
return nullptr;
@@ -319,8 +318,8 @@ Value *InstCombinerImpl::simplifyMaskedLoad(IntrinsicInst &II) {
// * Narrow width by halfs excluding zero/undef lanes
Instruction *InstCombinerImpl::simplifyMaskedStore(IntrinsicInst &II) {
Value *StorePtr = II.getArgOperand(1);
- Align Alignment = cast<ConstantInt>(II.getArgOperand(2))->getAlignValue();
- auto *ConstMask = dyn_cast<Constant>(II.getArgOperand(3));
+ Align Alignment = II.getParamAlign(1).valueOrOne();
+ auto *ConstMask = dyn_cast<Constant>(II.getArgOperand(2));
if (!ConstMask)
return nullptr;
@@ -356,7 +355,7 @@ Instruction *InstCombinerImpl::simplifyMaskedStore(IntrinsicInst &II) {
// * Narrow width by halfs excluding zero/undef lanes
// * Vector incrementing address -> vector masked load
Instruction *InstCombinerImpl::simplifyMaskedGather(IntrinsicInst &II) {
- auto *ConstMask = dyn_cast<Constant>(II.getArgOperand(2));
+ auto *ConstMask = dyn_cast<Constant>(II.getArgOperand(1));
if (!ConstMask)
return nullptr;
@@ -366,8 +365,7 @@ Instruction *InstCombinerImpl::simplifyMaskedGather(IntrinsicInst &II) {
if (ConstMask->isAllOnesValue())
if (auto *SplatPtr = getSplatValue(II.getArgOperand(0))) {
auto *VecTy = cast<VectorType>(II.getType());
- const Align Alignment =
- cast<ConstantInt>(II.getArgOperand(1))->getAlignValue();
+ const Align Alignment = II.getParamAlign(0).valueOrOne();
LoadInst *L = Builder.CreateAlignedLoad(VecTy->getElementType(), SplatPtr,
Alignment, "load.scalar");
Value *Shuf =
@@ -384,7 +382,7 @@ Instruction *InstCombinerImpl::simplifyMaskedGather(IntrinsicInst &II) {
// * Narrow store width by halfs excluding zero/undef lanes
// * Vector incrementing address -> vector masked store
Instruction *InstCombinerImpl::simplifyMaskedScatter(IntrinsicInst &II) {
- auto *ConstMask = dyn_cast<Constant>(II.getArgOperand(3));
+ auto *ConstMask = dyn_cast<Constant>(II.getArgOperand(2));
if (!ConstMask)
return nullptr;
@@ -397,8 +395,7 @@ Instruction *InstCombinerImpl::simplifyMaskedScatter(IntrinsicInst &II) {
// scatter(splat(value), splat(ptr), non-zero-mask) -> store value, ptr
if (auto *SplatValue = getSplatValue(II.getArgOperand(0))) {
if (maskContainsAllOneOrUndef(ConstMask)) {
- Align Alignment =
- cast<ConstantInt>(II.getArgOperand(2))->getAlignValue();
+ Align Alignment = II.getParamAlign(1).valueOrOne();
StoreInst *S = new StoreInst(SplatValue, SplatPtr, /*IsVolatile=*/false,
Alignment);
S->copyMetadata(II);
@@ -408,7 +405,7 @@ Instruction *InstCombinerImpl::simplifyMaskedScatter(IntrinsicInst &II) {
// scatter(vector, splat(ptr), splat(true)) -> store extract(vector,
// lastlane), ptr
if (ConstMask->isAllOnesValue()) {
- Align Alignment = cast<ConstantInt>(II.getArgOperand(2))->getAlignValue();
+ Align Alignment = II.getParamAlign(1).valueOrOne();
VectorType *WideLoadTy = cast<VectorType>(II.getArgOperand(1)->getType());
ElementCount VF = WideLoadTy->getElementCount();
Value *RunTimeVF = Builder.CreateElementCount(Builder.getInt32Ty(), VF);
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index 4c9b10a..9b9fe26 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -156,9 +156,9 @@ Instruction *InstCombinerImpl::commonCastTransforms(CastInst &CI) {
Value *Src = CI.getOperand(0);
Type *Ty = CI.getType();
- if (auto *SrcC = dyn_cast<Constant>(Src))
- if (Constant *Res = ConstantFoldCastOperand(CI.getOpcode(), SrcC, Ty, DL))
- return replaceInstUsesWith(CI, Res);
+ if (Value *Res =
+ simplifyCastInst(CI.getOpcode(), Src, Ty, SQ.getWithInstruction(&CI)))
+ return replaceInstUsesWith(CI, Res);
// Try to eliminate a cast of a cast.
if (auto *CSrc = dyn_cast<CastInst>(Src)) { // A->B->C cast
@@ -1643,33 +1643,46 @@ Instruction *InstCombinerImpl::visitSExt(SExtInst &Sext) {
/// Return a Constant* for the specified floating-point constant if it fits
/// in the specified FP type without changing its value.
-static bool fitsInFPType(ConstantFP *CFP, const fltSemantics &Sem) {
+static bool fitsInFPType(APFloat F, const fltSemantics &Sem) {
bool losesInfo;
- APFloat F = CFP->getValueAPF();
(void)F.convert(Sem, APFloat::rmNearestTiesToEven, &losesInfo);
return !losesInfo;
}
-static Type *shrinkFPConstant(ConstantFP *CFP, bool PreferBFloat) {
- if (CFP->getType() == Type::getPPC_FP128Ty(CFP->getContext()))
- return nullptr; // No constant folding of this.
+static Type *shrinkFPConstant(LLVMContext &Ctx, const APFloat &F,
+ bool PreferBFloat) {
// See if the value can be truncated to bfloat and then reextended.
- if (PreferBFloat && fitsInFPType(CFP, APFloat::BFloat()))
- return Type::getBFloatTy(CFP->getContext());
+ if (PreferBFloat && fitsInFPType(F, APFloat::BFloat()))
+ return Type::getBFloatTy(Ctx);
// See if the value can be truncated to half and then reextended.
- if (!PreferBFloat && fitsInFPType(CFP, APFloat::IEEEhalf()))
- return Type::getHalfTy(CFP->getContext());
+ if (!PreferBFloat && fitsInFPType(F, APFloat::IEEEhalf()))
+ return Type::getHalfTy(Ctx);
// See if the value can be truncated to float and then reextended.
- if (fitsInFPType(CFP, APFloat::IEEEsingle()))
- return Type::getFloatTy(CFP->getContext());
- if (CFP->getType()->isDoubleTy())
- return nullptr; // Won't shrink.
- if (fitsInFPType(CFP, APFloat::IEEEdouble()))
- return Type::getDoubleTy(CFP->getContext());
+ if (fitsInFPType(F, APFloat::IEEEsingle()))
+ return Type::getFloatTy(Ctx);
+ if (&F.getSemantics() == &APFloat::IEEEdouble())
+ return nullptr; // Won't shrink.
+ // See if the value can be truncated to double and then reextended.
+ if (fitsInFPType(F, APFloat::IEEEdouble()))
+ return Type::getDoubleTy(Ctx);
// Don't try to shrink to various long double types.
return nullptr;
}
+static Type *shrinkFPConstant(ConstantFP *CFP, bool PreferBFloat) {
+ Type *Ty = CFP->getType();
+ if (Ty->getScalarType()->isPPC_FP128Ty())
+ return nullptr; // No constant folding of this.
+
+ Type *ShrinkTy =
+ shrinkFPConstant(CFP->getContext(), CFP->getValueAPF(), PreferBFloat);
+ if (ShrinkTy)
+ if (auto *VecTy = dyn_cast<VectorType>(Ty))
+ ShrinkTy = VectorType::get(ShrinkTy, VecTy);
+
+ return ShrinkTy;
+}
+
// Determine if this is a vector of ConstantFPs and if so, return the minimal
// type we can safely truncate all elements to.
static Type *shrinkFPConstantVector(Value *V, bool PreferBFloat) {
@@ -1720,10 +1733,10 @@ static Type *getMinimumFPType(Value *V, bool PreferBFloat) {
// Try to shrink scalable and fixed splat vectors.
if (auto *FPC = dyn_cast<Constant>(V))
- if (isa<VectorType>(V->getType()))
+ if (auto *VTy = dyn_cast<VectorType>(V->getType()))
if (auto *Splat = dyn_cast_or_null<ConstantFP>(FPC->getSplatValue()))
if (Type *T = shrinkFPConstant(Splat, PreferBFloat))
- return T;
+ return VectorType::get(T, VTy);
// Try to shrink a vector of FP constants. This returns nullptr on scalable
// vectors
@@ -1796,10 +1809,9 @@ Instruction *InstCombinerImpl::visitFPTrunc(FPTruncInst &FPT) {
Type *Ty = FPT.getType();
auto *BO = dyn_cast<BinaryOperator>(FPT.getOperand(0));
if (BO && BO->hasOneUse()) {
- Type *LHSMinType =
- getMinimumFPType(BO->getOperand(0), /*PreferBFloat=*/Ty->isBFloatTy());
- Type *RHSMinType =
- getMinimumFPType(BO->getOperand(1), /*PreferBFloat=*/Ty->isBFloatTy());
+ bool PreferBFloat = Ty->getScalarType()->isBFloatTy();
+ Type *LHSMinType = getMinimumFPType(BO->getOperand(0), PreferBFloat);
+ Type *RHSMinType = getMinimumFPType(BO->getOperand(1), PreferBFloat);
unsigned OpWidth = BO->getType()->getFPMantissaWidth();
unsigned LHSWidth = LHSMinType->getFPMantissaWidth();
unsigned RHSWidth = RHSMinType->getFPMantissaWidth();
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 07ad65c..fba1ccf 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -1481,13 +1481,13 @@ Instruction *InstCombinerImpl::foldICmpTruncConstant(ICmpInst &Cmp,
return new ICmpInst(Pred, Y, ConstantInt::get(SrcTy, C.logBase2()));
}
- if (Cmp.isEquality() && Trunc->hasOneUse()) {
+ if (Cmp.isEquality() && (Trunc->hasOneUse() || Trunc->hasNoUnsignedWrap())) {
// Canonicalize to a mask and wider compare if the wide type is suitable:
// (trunc X to i8) == C --> (X & 0xff) == (zext C)
if (!SrcTy->isVectorTy() && shouldChangeType(DstBits, SrcBits)) {
Constant *Mask =
ConstantInt::get(SrcTy, APInt::getLowBitsSet(SrcBits, DstBits));
- Value *And = Builder.CreateAnd(X, Mask);
+ Value *And = Trunc->hasNoUnsignedWrap() ? X : Builder.CreateAnd(X, Mask);
Constant *WideC = ConstantInt::get(SrcTy, C.zext(SrcBits));
return new ICmpInst(Pred, And, WideC);
}
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 09cb225..975498f 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -3757,6 +3757,10 @@ static Instruction *foldBitCeil(SelectInst &SI, IRBuilderBase &Builder,
// (x < y) ? -1 : zext(x > y)
// (x > y) ? 1 : sext(x != y)
// (x > y) ? 1 : sext(x < y)
+// (x == y) ? 0 : (x > y ? 1 : -1)
+// (x == y) ? 0 : (x < y ? -1 : 1)
+// Special case: x == C ? 0 : (x > C - 1 ? 1 : -1)
+// Special case: x == C ? 0 : (x < C + 1 ? -1 : 1)
// Into ucmp/scmp(x, y), where signedness is determined by the signedness
// of the comparison in the original sequence.
Instruction *InstCombinerImpl::foldSelectToCmp(SelectInst &SI) {
@@ -3849,6 +3853,44 @@ Instruction *InstCombinerImpl::foldSelectToCmp(SelectInst &SI) {
}
}
+ // Special cases with constants: x == C ? 0 : (x > C-1 ? 1 : -1)
+ if (Pred == ICmpInst::ICMP_EQ && match(TV, m_Zero())) {
+ const APInt *C;
+ if (match(RHS, m_APInt(C))) {
+ CmpPredicate InnerPred;
+ Value *InnerRHS;
+ const APInt *InnerTV, *InnerFV;
+ if (match(FV,
+ m_Select(m_ICmp(InnerPred, m_Specific(LHS), m_Value(InnerRHS)),
+ m_APInt(InnerTV), m_APInt(InnerFV)))) {
+
+ // x == C ? 0 : (x > C-1 ? 1 : -1)
+ if (ICmpInst::isGT(InnerPred) && InnerTV->isOne() &&
+ InnerFV->isAllOnes()) {
+ IsSigned = ICmpInst::isSigned(InnerPred);
+ bool CanSubOne = IsSigned ? !C->isMinSignedValue() : !C->isMinValue();
+ if (CanSubOne) {
+ APInt Cminus1 = *C - 1;
+ if (match(InnerRHS, m_SpecificInt(Cminus1)))
+ Replace = true;
+ }
+ }
+
+ // x == C ? 0 : (x < C+1 ? -1 : 1)
+ if (ICmpInst::isLT(InnerPred) && InnerTV->isAllOnes() &&
+ InnerFV->isOne()) {
+ IsSigned = ICmpInst::isSigned(InnerPred);
+ bool CanAddOne = IsSigned ? !C->isMaxSignedValue() : !C->isMaxValue();
+ if (CanAddOne) {
+ APInt Cplus1 = *C + 1;
+ if (match(InnerRHS, m_SpecificInt(Cplus1)))
+ Replace = true;
+ }
+ }
+ }
+ }
+ }
+
Intrinsic::ID IID = IsSigned ? Intrinsic::scmp : Intrinsic::ucmp;
if (Replace)
return replaceInstUsesWith(
@@ -4459,24 +4501,24 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
if (Value *V = foldSelectIntoAddConstant(SI, Builder))
return replaceInstUsesWith(SI, V);
- // select(mask, mload(,,mask,0), 0) -> mload(,,mask,0)
+ // select(mask, mload(ptr,mask,0), 0) -> mload(ptr,mask,0)
// Load inst is intentionally not checked for hasOneUse()
if (match(FalseVal, m_Zero()) &&
- (match(TrueVal, m_MaskedLoad(m_Value(), m_Value(), m_Specific(CondVal),
+ (match(TrueVal, m_MaskedLoad(m_Value(), m_Specific(CondVal),
m_CombineOr(m_Undef(), m_Zero()))) ||
- match(TrueVal, m_MaskedGather(m_Value(), m_Value(), m_Specific(CondVal),
+ match(TrueVal, m_MaskedGather(m_Value(), m_Specific(CondVal),
m_CombineOr(m_Undef(), m_Zero()))))) {
auto *MaskedInst = cast<IntrinsicInst>(TrueVal);
- if (isa<UndefValue>(MaskedInst->getArgOperand(3)))
- MaskedInst->setArgOperand(3, FalseVal /* Zero */);
+ if (isa<UndefValue>(MaskedInst->getArgOperand(2)))
+ MaskedInst->setArgOperand(2, FalseVal /* Zero */);
return replaceInstUsesWith(SI, MaskedInst);
}
Value *Mask;
if (match(TrueVal, m_Zero()) &&
- (match(FalseVal, m_MaskedLoad(m_Value(), m_Value(), m_Value(Mask),
+ (match(FalseVal, m_MaskedLoad(m_Value(), m_Value(Mask),
m_CombineOr(m_Undef(), m_Zero()))) ||
- match(FalseVal, m_MaskedGather(m_Value(), m_Value(), m_Value(Mask),
+ match(FalseVal, m_MaskedGather(m_Value(), m_Value(Mask),
m_CombineOr(m_Undef(), m_Zero())))) &&
(CondVal->getType() == Mask->getType())) {
// We can remove the select by ensuring the load zeros all lanes the
@@ -4489,8 +4531,8 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
if (CanMergeSelectIntoLoad) {
auto *MaskedInst = cast<IntrinsicInst>(FalseVal);
- if (isa<UndefValue>(MaskedInst->getArgOperand(3)))
- MaskedInst->setArgOperand(3, TrueVal /* Zero */);
+ if (isa<UndefValue>(MaskedInst->getArgOperand(2)))
+ MaskedInst->setArgOperand(2, TrueVal /* Zero */);
return replaceInstUsesWith(SI, MaskedInst);
}
}
@@ -4629,14 +4671,13 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
}
Value *MaskedLoadPtr;
- const APInt *MaskedLoadAlignment;
if (match(TrueVal, m_OneUse(m_MaskedLoad(m_Value(MaskedLoadPtr),
- m_APInt(MaskedLoadAlignment),
m_Specific(CondVal), m_Value()))))
return replaceInstUsesWith(
- SI, Builder.CreateMaskedLoad(TrueVal->getType(), MaskedLoadPtr,
- Align(MaskedLoadAlignment->getZExtValue()),
- CondVal, FalseVal));
+ SI, Builder.CreateMaskedLoad(
+ TrueVal->getType(), MaskedLoadPtr,
+ cast<IntrinsicInst>(TrueVal)->getParamAlign(0).valueOrOne(),
+ CondVal, FalseVal));
return nullptr;
}
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
index a330bb7..651e305 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
@@ -1892,7 +1892,7 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V,
// segfaults which didn't exist in the original program.
APInt DemandedPtrs(APInt::getAllOnes(VWidth)),
DemandedPassThrough(DemandedElts);
- if (auto *CMask = dyn_cast<Constant>(II->getOperand(2))) {
+ if (auto *CMask = dyn_cast<Constant>(II->getOperand(1))) {
for (unsigned i = 0; i < VWidth; i++) {
if (Constant *CElt = CMask->getAggregateElement(i)) {
if (CElt->isNullValue())
@@ -1905,7 +1905,7 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V,
if (II->getIntrinsicID() == Intrinsic::masked_gather)
simplifyAndSetOp(II, 0, DemandedPtrs, PoisonElts2);
- simplifyAndSetOp(II, 3, DemandedPassThrough, PoisonElts3);
+ simplifyAndSetOp(II, 2, DemandedPassThrough, PoisonElts3);
// Output elements are undefined if the element from both sources are.
// TODO: can strengthen via mask as well.