aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/InstCombine
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Transforms/InstCombine')
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp34
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp245
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp392
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp89
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineInternal.h2
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp8
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp135
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp15
-rw-r--r--llvm/lib/Transforms/InstCombine/InstructionCombining.cpp26
9 files changed, 781 insertions, 165 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 3ddf182..ba5568b 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -3997,6 +3997,27 @@ static Value *foldOrUnsignedUMulOverflowICmp(BinaryOperator &I,
return nullptr;
}
+/// Fold select(X >s 0, 0, -X) | smax(X, 0) --> abs(X)
+/// select(X <s 0, -X, 0) | smax(X, 0) --> abs(X)
+static Value *FoldOrOfSelectSmaxToAbs(BinaryOperator &I,
+ InstCombiner::BuilderTy &Builder) {
+ Value *X;
+ Value *Sel;
+ if (match(&I,
+ m_c_Or(m_Value(Sel), m_OneUse(m_SMax(m_Value(X), m_ZeroInt()))))) {
+ auto NegX = m_Neg(m_Specific(X));
+ if (match(Sel, m_Select(m_SpecificICmp(ICmpInst::ICMP_SGT, m_Specific(X),
+ m_ZeroInt()),
+ m_ZeroInt(), NegX)) ||
+ match(Sel, m_Select(m_SpecificICmp(ICmpInst::ICMP_SLT, m_Specific(X),
+ m_ZeroInt()),
+ NegX, m_ZeroInt())))
+ return Builder.CreateBinaryIntrinsic(Intrinsic::abs, X,
+ Builder.getFalse());
+ }
+ return nullptr;
+}
+
// FIXME: We use commutative matchers (m_c_*) for some, but not all, matches
// here. We should standardize that construct where it is needed or choose some
// other way to ensure that commutated variants of patterns are not missed.
@@ -4545,6 +4566,9 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
if (Value *V = SimplifyAddWithRemainder(I))
return replaceInstUsesWith(I, V);
+ if (Value *Res = FoldOrOfSelectSmaxToAbs(I, Builder))
+ return replaceInstUsesWith(I, Res);
+
return nullptr;
}
@@ -5072,9 +5096,17 @@ Instruction *InstCombinerImpl::foldNot(BinaryOperator &I) {
return &I;
}
+ // not (bitcast (cmp A, B) --> bitcast (!cmp A, B)
+ if (match(NotOp, m_OneUse(m_BitCast(m_Value(X)))) &&
+ match(X, m_OneUse(m_Cmp(Pred, m_Value(), m_Value())))) {
+ cast<CmpInst>(X)->setPredicate(CmpInst::getInversePredicate(Pred));
+ return new BitCastInst(X, Ty);
+ }
+
// Move a 'not' ahead of casts of a bool to enable logic reduction:
// not (bitcast (sext i1 X)) --> bitcast (sext (not i1 X))
- if (match(NotOp, m_OneUse(m_BitCast(m_OneUse(m_SExt(m_Value(X)))))) && X->getType()->isIntOrIntVectorTy(1)) {
+ if (match(NotOp, m_OneUse(m_BitCast(m_OneUse(m_SExt(m_Value(X)))))) &&
+ X->getType()->isIntOrIntVectorTy(1)) {
Type *SextTy = cast<BitCastOperator>(NotOp)->getSrcTy();
Value *NotX = Builder.CreateNot(X);
Value *Sext = Builder.CreateSExt(NotX, SextTy);
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 92fca90..85602a5 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -19,6 +19,7 @@
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Statistic.h"
+#include "llvm/ADT/StringExtras.h"
#include "llvm/Analysis/AliasAnalysis.h"
#include "llvm/Analysis/AssumeBundleQueries.h"
#include "llvm/Analysis/AssumptionCache.h"
@@ -736,42 +737,119 @@ static Instruction *foldCtpop(IntrinsicInst &II, InstCombinerImpl &IC) {
return nullptr;
}
-/// Convert a table lookup to shufflevector if the mask is constant.
-/// This could benefit tbl1 if the mask is { 7,6,5,4,3,2,1,0 }, in
-/// which case we could lower the shufflevector with rev64 instructions
-/// as it's actually a byte reverse.
-static Value *simplifyNeonTbl1(const IntrinsicInst &II,
- InstCombiner::BuilderTy &Builder) {
+/// Convert `tbl`/`tbx` intrinsics to shufflevector if the mask is constant, and
+/// at most two source operands are actually referenced.
+static Instruction *simplifyNeonTbl(IntrinsicInst &II, InstCombiner &IC,
+ bool IsExtension) {
// Bail out if the mask is not a constant.
- auto *C = dyn_cast<Constant>(II.getArgOperand(1));
+ auto *C = dyn_cast<Constant>(II.getArgOperand(II.arg_size() - 1));
if (!C)
return nullptr;
- auto *VecTy = cast<FixedVectorType>(II.getType());
- unsigned NumElts = VecTy->getNumElements();
+ auto *RetTy = cast<FixedVectorType>(II.getType());
+ unsigned NumIndexes = RetTy->getNumElements();
- // Only perform this transformation for <8 x i8> vector types.
- if (!VecTy->getElementType()->isIntegerTy(8) || NumElts != 8)
+ // Only perform this transformation for <8 x i8> and <16 x i8> vector types.
+ if (!RetTy->getElementType()->isIntegerTy(8) ||
+ (NumIndexes != 8 && NumIndexes != 16))
return nullptr;
- int Indexes[8];
+ // For tbx instructions, the first argument is the "fallback" vector, which
+ // has the same length as the mask and return type.
+ unsigned int StartIndex = (unsigned)IsExtension;
+ auto *SourceTy =
+ cast<FixedVectorType>(II.getArgOperand(StartIndex)->getType());
+ // Note that the element count of each source vector does *not* need to be the
+ // same as the element count of the return type and mask! All source vectors
+ // must have the same element count as each other, though.
+ unsigned NumElementsPerSource = SourceTy->getNumElements();
+
+ // There are no tbl/tbx intrinsics for which the destination size exceeds the
+ // source size. However, our definitions of the intrinsics, at least in
+ // IntrinsicsAArch64.td, allow for arbitrary destination vector sizes, so it
+ // *could* technically happen.
+ if (NumIndexes > NumElementsPerSource)
+ return nullptr;
+
+ // The tbl/tbx intrinsics take several source operands followed by a mask
+ // operand.
+ unsigned int NumSourceOperands = II.arg_size() - 1 - (unsigned)IsExtension;
+
+ // Map input operands to shuffle indices. This also helpfully deduplicates the
+ // input arguments, in case the same value is passed as an argument multiple
+ // times.
+ SmallDenseMap<Value *, unsigned, 2> ValueToShuffleSlot;
+ Value *ShuffleOperands[2] = {PoisonValue::get(SourceTy),
+ PoisonValue::get(SourceTy)};
- for (unsigned I = 0; I < NumElts; ++I) {
+ int Indexes[16];
+ for (unsigned I = 0; I < NumIndexes; ++I) {
Constant *COp = C->getAggregateElement(I);
- if (!COp || !isa<ConstantInt>(COp))
+ if (!COp || (!isa<UndefValue>(COp) && !isa<ConstantInt>(COp)))
return nullptr;
- Indexes[I] = cast<ConstantInt>(COp)->getLimitedValue();
+ if (isa<UndefValue>(COp)) {
+ Indexes[I] = -1;
+ continue;
+ }
- // Make sure the mask indices are in range.
- if ((unsigned)Indexes[I] >= NumElts)
+ uint64_t Index = cast<ConstantInt>(COp)->getZExtValue();
+ // The index of the input argument that this index references (0 = first
+ // source argument, etc).
+ unsigned SourceOperandIndex = Index / NumElementsPerSource;
+ // The index of the element at that source operand.
+ unsigned SourceOperandElementIndex = Index % NumElementsPerSource;
+
+ Value *SourceOperand;
+ if (SourceOperandIndex >= NumSourceOperands) {
+ // This index is out of bounds. Map it to index into either the fallback
+ // vector (tbx) or vector of zeroes (tbl).
+ SourceOperandIndex = NumSourceOperands;
+ if (IsExtension) {
+ // For out-of-bounds indices in tbx, choose the `I`th element of the
+ // fallback.
+ SourceOperand = II.getArgOperand(0);
+ SourceOperandElementIndex = I;
+ } else {
+ // Otherwise, choose some element from the dummy vector of zeroes (we'll
+ // always choose the first).
+ SourceOperand = Constant::getNullValue(SourceTy);
+ SourceOperandElementIndex = 0;
+ }
+ } else {
+ SourceOperand = II.getArgOperand(SourceOperandIndex + StartIndex);
+ }
+
+ // The source operand may be the fallback vector, which may not have the
+ // same number of elements as the source vector. In that case, we *could*
+ // choose to extend its length with another shufflevector, but it's simpler
+ // to just bail instead.
+ if (cast<FixedVectorType>(SourceOperand->getType())->getNumElements() !=
+ NumElementsPerSource)
return nullptr;
+
+ // We now know the source operand referenced by this index. Make it a
+ // shufflevector operand, if it isn't already.
+ unsigned NumSlots = ValueToShuffleSlot.size();
+ // This shuffle references more than two sources, and hence cannot be
+ // represented as a shufflevector.
+ if (NumSlots == 2 && !ValueToShuffleSlot.contains(SourceOperand))
+ return nullptr;
+
+ auto [It, Inserted] =
+ ValueToShuffleSlot.try_emplace(SourceOperand, NumSlots);
+ if (Inserted)
+ ShuffleOperands[It->getSecond()] = SourceOperand;
+
+ unsigned RemappedIndex =
+ (It->getSecond() * NumElementsPerSource) + SourceOperandElementIndex;
+ Indexes[I] = RemappedIndex;
}
- auto *V1 = II.getArgOperand(0);
- auto *V2 = Constant::getNullValue(V1->getType());
- return Builder.CreateShuffleVector(V1, V2, ArrayRef(Indexes));
+ Value *Shuf = IC.Builder.CreateShuffleVector(
+ ShuffleOperands[0], ShuffleOperands[1], ArrayRef(Indexes, NumIndexes));
+ return IC.replaceInstUsesWith(II, Shuf);
}
// Returns true iff the 2 intrinsics have the same operands, limiting the
@@ -3076,6 +3154,11 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
}
case Intrinsic::ptrauth_auth:
case Intrinsic::ptrauth_resign: {
+ // We don't support this optimization on intrinsic calls with deactivation
+ // symbols, which are represented using operand bundles.
+ if (II->hasOperandBundles())
+ break;
+
// (sign|resign) + (auth|resign) can be folded by omitting the middle
// sign+auth component if the key and discriminator match.
bool NeedSign = II->getIntrinsicID() == Intrinsic::ptrauth_resign;
@@ -3087,6 +3170,11 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
// whatever we replace this sequence with.
Value *AuthKey = nullptr, *AuthDisc = nullptr, *BasePtr;
if (const auto *CI = dyn_cast<CallBase>(Ptr)) {
+ // We don't support this optimization on intrinsic calls with deactivation
+ // symbols, which are represented using operand bundles.
+ if (CI->hasOperandBundles())
+ break;
+
BasePtr = CI->getArgOperand(0);
if (CI->getIntrinsicID() == Intrinsic::ptrauth_sign) {
if (CI->getArgOperand(1) != Key || CI->getArgOperand(2) != Disc)
@@ -3109,9 +3197,10 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
if (NeedSign && isa<ConstantInt>(II->getArgOperand(4))) {
auto *SignKey = cast<ConstantInt>(II->getArgOperand(3));
auto *SignDisc = cast<ConstantInt>(II->getArgOperand(4));
- auto *SignAddrDisc = ConstantPointerNull::get(Builder.getPtrTy());
+ auto *Null = ConstantPointerNull::get(Builder.getPtrTy());
auto *NewCPA = ConstantPtrAuth::get(CPA->getPointer(), SignKey,
- SignDisc, SignAddrDisc);
+ SignDisc, /*AddrDisc=*/Null,
+ /*DeactivationSymbol=*/Null);
replaceInstUsesWith(
*II, ConstantExpr::getPointerCast(NewCPA, II->getType()));
return eraseInstFromFunction(*II);
@@ -3155,10 +3244,23 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
return CallInst::Create(NewFn, CallArgs);
}
case Intrinsic::arm_neon_vtbl1:
+ case Intrinsic::arm_neon_vtbl2:
+ case Intrinsic::arm_neon_vtbl3:
+ case Intrinsic::arm_neon_vtbl4:
case Intrinsic::aarch64_neon_tbl1:
- if (Value *V = simplifyNeonTbl1(*II, Builder))
- return replaceInstUsesWith(*II, V);
- break;
+ case Intrinsic::aarch64_neon_tbl2:
+ case Intrinsic::aarch64_neon_tbl3:
+ case Intrinsic::aarch64_neon_tbl4:
+ return simplifyNeonTbl(*II, *this, /*IsExtension=*/false);
+ case Intrinsic::arm_neon_vtbx1:
+ case Intrinsic::arm_neon_vtbx2:
+ case Intrinsic::arm_neon_vtbx3:
+ case Intrinsic::arm_neon_vtbx4:
+ case Intrinsic::aarch64_neon_tbx1:
+ case Intrinsic::aarch64_neon_tbx2:
+ case Intrinsic::aarch64_neon_tbx3:
+ case Intrinsic::aarch64_neon_tbx4:
+ return simplifyNeonTbl(*II, *this, /*IsExtension=*/true);
case Intrinsic::arm_neon_vmulls:
case Intrinsic::arm_neon_vmullu:
@@ -3799,7 +3901,9 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
if (VecToReduceCount.isFixed()) {
unsigned VectorSize = VecToReduceCount.getFixedValue();
return BinaryOperator::CreateMul(
- Splat, ConstantInt::get(Splat->getType(), VectorSize));
+ Splat,
+ ConstantInt::get(Splat->getType(), VectorSize, /*IsSigned=*/false,
+ /*ImplicitTrunc=*/true));
}
}
}
@@ -4004,6 +4108,27 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
}
break;
}
+ case Intrinsic::experimental_get_vector_length: {
+ // get.vector.length(Cnt, MaxLanes) --> Cnt when Cnt <= MaxLanes
+ unsigned BitWidth =
+ std::max(II->getArgOperand(0)->getType()->getScalarSizeInBits(),
+ II->getType()->getScalarSizeInBits());
+ ConstantRange Cnt =
+ computeConstantRangeIncludingKnownBits(II->getArgOperand(0), false,
+ SQ.getWithInstruction(II))
+ .zextOrTrunc(BitWidth);
+ ConstantRange MaxLanes = cast<ConstantInt>(II->getArgOperand(1))
+ ->getValue()
+ .zextOrTrunc(Cnt.getBitWidth());
+ if (cast<ConstantInt>(II->getArgOperand(2))->isOne())
+ MaxLanes = MaxLanes.multiply(
+ getVScaleRange(II->getFunction(), Cnt.getBitWidth()));
+
+ if (Cnt.icmp(CmpInst::ICMP_ULE, MaxLanes))
+ return replaceInstUsesWith(
+ *II, Builder.CreateZExtOrTrunc(II->getArgOperand(0), II->getType()));
+ return nullptr;
+ }
default: {
// Handle target specific intrinsics
std::optional<Instruction *> V = targetInstCombineIntrinsic(*II);
@@ -4091,6 +4216,70 @@ Instruction *InstCombinerImpl::visitCallBrInst(CallBrInst &CBI) {
return visitCallBase(CBI);
}
+static Value *optimizeModularFormat(CallInst *CI, IRBuilderBase &B) {
+ if (!CI->hasFnAttr("modular-format"))
+ return nullptr;
+
+ SmallVector<StringRef> Args(
+ llvm::split(CI->getFnAttr("modular-format").getValueAsString(), ','));
+ // TODO: Make use of the first two arguments
+ unsigned FirstArgIdx;
+ [[maybe_unused]] bool Error;
+ Error = Args[2].getAsInteger(10, FirstArgIdx);
+ assert(!Error && "invalid first arg index");
+ --FirstArgIdx;
+ StringRef FnName = Args[3];
+ StringRef ImplName = Args[4];
+ ArrayRef<StringRef> AllAspects = ArrayRef<StringRef>(Args).drop_front(5);
+
+ if (AllAspects.empty())
+ return nullptr;
+
+ SmallVector<StringRef> NeededAspects;
+ for (StringRef Aspect : AllAspects) {
+ if (Aspect == "float") {
+ if (llvm::any_of(
+ llvm::make_range(std::next(CI->arg_begin(), FirstArgIdx),
+ CI->arg_end()),
+ [](Value *V) { return V->getType()->isFloatingPointTy(); }))
+ NeededAspects.push_back("float");
+ } else {
+ // Unknown aspects are always considered to be needed.
+ NeededAspects.push_back(Aspect);
+ }
+ }
+
+ if (NeededAspects.size() == AllAspects.size())
+ return nullptr;
+
+ Module *M = CI->getModule();
+ LLVMContext &Ctx = M->getContext();
+ Function *Callee = CI->getCalledFunction();
+ FunctionCallee ModularFn = M->getOrInsertFunction(
+ FnName, Callee->getFunctionType(),
+ Callee->getAttributes().removeFnAttribute(Ctx, "modular-format"));
+ CallInst *New = cast<CallInst>(CI->clone());
+ New->setCalledFunction(ModularFn);
+ New->removeFnAttr("modular-format");
+ B.Insert(New);
+
+ const auto ReferenceAspect = [&](StringRef Aspect) {
+ SmallString<20> Name = ImplName;
+ Name += '_';
+ Name += Aspect;
+ Function *RelocNoneFn =
+ Intrinsic::getOrInsertDeclaration(M, Intrinsic::reloc_none);
+ B.CreateCall(RelocNoneFn,
+ {MetadataAsValue::get(Ctx, MDString::get(Ctx, Name))});
+ };
+
+ llvm::sort(NeededAspects);
+ for (StringRef Request : NeededAspects)
+ ReferenceAspect(Request);
+
+ return New;
+}
+
Instruction *InstCombinerImpl::tryOptimizeCall(CallInst *CI) {
if (!CI->getCalledFunction()) return nullptr;
@@ -4112,6 +4301,10 @@ Instruction *InstCombinerImpl::tryOptimizeCall(CallInst *CI) {
++NumSimplified;
return CI->use_empty() ? CI : replaceInstUsesWith(*CI, With);
}
+ if (Value *With = optimizeModularFormat(CI, Builder)) {
+ ++NumSimplified;
+ return CI->use_empty() ? CI : replaceInstUsesWith(*CI, With);
+ }
return nullptr;
}
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index 614c6eb..0cd2c09 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -12,14 +12,21 @@
#include "InstCombineInternal.h"
#include "llvm/ADT/APInt.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/STLFunctionalExtras.h"
#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallVector.h"
#include "llvm/Analysis/ConstantFolding.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/DebugInfo.h"
+#include "llvm/IR/Instruction.h"
#include "llvm/IR/PatternMatch.h"
+#include "llvm/IR/Type.h"
#include "llvm/IR/Value.h"
#include "llvm/Support/KnownBits.h"
#include "llvm/Transforms/InstCombine/InstCombiner.h"
+#include <iterator>
#include <optional>
using namespace llvm;
@@ -27,12 +34,19 @@ using namespace PatternMatch;
#define DEBUG_TYPE "instcombine"
-/// Given an expression that CanEvaluateTruncated or CanEvaluateSExtd returns
-/// true for, actually insert the code to evaluate the expression.
-Value *InstCombinerImpl::EvaluateInDifferentType(Value *V, Type *Ty,
- bool isSigned) {
+using EvaluatedMap = SmallDenseMap<Value *, Value *, 8>;
+
+static Value *EvaluateInDifferentTypeImpl(Value *V, Type *Ty, bool isSigned,
+ InstCombinerImpl &IC,
+ EvaluatedMap &Processed) {
+ // Since we cover transformation of instructions with multiple users, we might
+ // come to the same node via multiple paths. We should not create a
+ // replacement for every single one of them though.
+ if (Value *Result = Processed.lookup(V))
+ return Result;
+
if (Constant *C = dyn_cast<Constant>(V))
- return ConstantFoldIntegerCast(C, Ty, isSigned, DL);
+ return ConstantFoldIntegerCast(C, Ty, isSigned, IC.getDataLayout());
// Otherwise, it must be an instruction.
Instruction *I = cast<Instruction>(V);
@@ -50,8 +64,10 @@ Value *InstCombinerImpl::EvaluateInDifferentType(Value *V, Type *Ty,
case Instruction::Shl:
case Instruction::UDiv:
case Instruction::URem: {
- Value *LHS = EvaluateInDifferentType(I->getOperand(0), Ty, isSigned);
- Value *RHS = EvaluateInDifferentType(I->getOperand(1), Ty, isSigned);
+ Value *LHS = EvaluateInDifferentTypeImpl(I->getOperand(0), Ty, isSigned, IC,
+ Processed);
+ Value *RHS = EvaluateInDifferentTypeImpl(I->getOperand(1), Ty, isSigned, IC,
+ Processed);
Res = BinaryOperator::Create((Instruction::BinaryOps)Opc, LHS, RHS);
if (Opc == Instruction::LShr || Opc == Instruction::AShr)
Res->setIsExact(I->isExact());
@@ -72,8 +88,10 @@ Value *InstCombinerImpl::EvaluateInDifferentType(Value *V, Type *Ty,
Opc == Instruction::SExt);
break;
case Instruction::Select: {
- Value *True = EvaluateInDifferentType(I->getOperand(1), Ty, isSigned);
- Value *False = EvaluateInDifferentType(I->getOperand(2), Ty, isSigned);
+ Value *True = EvaluateInDifferentTypeImpl(I->getOperand(1), Ty, isSigned,
+ IC, Processed);
+ Value *False = EvaluateInDifferentTypeImpl(I->getOperand(2), Ty, isSigned,
+ IC, Processed);
Res = SelectInst::Create(I->getOperand(0), True, False);
break;
}
@@ -81,8 +99,8 @@ Value *InstCombinerImpl::EvaluateInDifferentType(Value *V, Type *Ty,
PHINode *OPN = cast<PHINode>(I);
PHINode *NPN = PHINode::Create(Ty, OPN->getNumIncomingValues());
for (unsigned i = 0, e = OPN->getNumIncomingValues(); i != e; ++i) {
- Value *V =
- EvaluateInDifferentType(OPN->getIncomingValue(i), Ty, isSigned);
+ Value *V = EvaluateInDifferentTypeImpl(OPN->getIncomingValue(i), Ty,
+ isSigned, IC, Processed);
NPN->addIncoming(V, OPN->getIncomingBlock(i));
}
Res = NPN;
@@ -90,8 +108,8 @@ Value *InstCombinerImpl::EvaluateInDifferentType(Value *V, Type *Ty,
}
case Instruction::FPToUI:
case Instruction::FPToSI:
- Res = CastInst::Create(
- static_cast<Instruction::CastOps>(Opc), I->getOperand(0), Ty);
+ Res = CastInst::Create(static_cast<Instruction::CastOps>(Opc),
+ I->getOperand(0), Ty);
break;
case Instruction::Call:
if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) {
@@ -111,8 +129,10 @@ Value *InstCombinerImpl::EvaluateInDifferentType(Value *V, Type *Ty,
auto *ScalarTy = cast<VectorType>(Ty)->getElementType();
auto *VTy = cast<VectorType>(I->getOperand(0)->getType());
auto *FixedTy = VectorType::get(ScalarTy, VTy->getElementCount());
- Value *Op0 = EvaluateInDifferentType(I->getOperand(0), FixedTy, isSigned);
- Value *Op1 = EvaluateInDifferentType(I->getOperand(1), FixedTy, isSigned);
+ Value *Op0 = EvaluateInDifferentTypeImpl(I->getOperand(0), FixedTy,
+ isSigned, IC, Processed);
+ Value *Op1 = EvaluateInDifferentTypeImpl(I->getOperand(1), FixedTy,
+ isSigned, IC, Processed);
Res = new ShuffleVectorInst(Op0, Op1,
cast<ShuffleVectorInst>(I)->getShuffleMask());
break;
@@ -123,7 +143,22 @@ Value *InstCombinerImpl::EvaluateInDifferentType(Value *V, Type *Ty,
}
Res->takeName(I);
- return InsertNewInstWith(Res, I->getIterator());
+ Value *Result = IC.InsertNewInstWith(Res, I->getIterator());
+ // There is no need in keeping track of the old value/new value relationship
+ // when we have only one user, we came have here from that user and no-one
+ // else cares.
+ if (!V->hasOneUse())
+ Processed[V] = Result;
+
+ return Result;
+}
+
+/// Given an expression that CanEvaluateTruncated or CanEvaluateSExtd returns
+/// true for, actually insert the code to evaluate the expression.
+Value *InstCombinerImpl::EvaluateInDifferentType(Value *V, Type *Ty,
+ bool isSigned) {
+ EvaluatedMap Processed;
+ return EvaluateInDifferentTypeImpl(V, Ty, isSigned, *this, Processed);
}
Instruction::CastOps
@@ -227,9 +262,174 @@ Instruction *InstCombinerImpl::commonCastTransforms(CastInst &CI) {
return nullptr;
}
+namespace {
+
+/// Helper class for evaluating whether a value can be computed in a different
+/// type without changing its value. Used by cast simplification transforms.
+class TypeEvaluationHelper {
+public:
+ /// Return true if we can evaluate the specified expression tree as type Ty
+ /// instead of its larger type, and arrive with the same value.
+ /// This is used by code that tries to eliminate truncates.
+ [[nodiscard]] static bool canEvaluateTruncated(Value *V, Type *Ty,
+ InstCombinerImpl &IC,
+ Instruction *CxtI);
+
+ /// Determine if the specified value can be computed in the specified wider
+ /// type and produce the same low bits. If not, return false.
+ [[nodiscard]] static bool canEvaluateZExtd(Value *V, Type *Ty,
+ unsigned &BitsToClear,
+ InstCombinerImpl &IC,
+ Instruction *CxtI);
+
+ /// Return true if we can take the specified value and return it as type Ty
+ /// without inserting any new casts and without changing the value of the
+ /// common low bits.
+ [[nodiscard]] static bool canEvaluateSExtd(Value *V, Type *Ty);
+
+private:
+ /// Constants and extensions/truncates from the destination type are always
+ /// free to be evaluated in that type.
+ [[nodiscard]] static bool canAlwaysEvaluateInType(Value *V, Type *Ty);
+
+ /// Check if we traversed all the users of the multi-use values we've seen.
+ [[nodiscard]] bool allPendingVisited() const {
+ return llvm::all_of(Pending,
+ [this](Value *V) { return Visited.contains(V); });
+ }
+
+ /// A generic wrapper for canEvaluate* recursions to inject visitation
+ /// tracking and enforce correct multi-use value evaluations.
+ [[nodiscard]] bool
+ canEvaluate(Value *V, Type *Ty,
+ llvm::function_ref<bool(Value *, Type *Type)> Pred) {
+ if (canAlwaysEvaluateInType(V, Ty))
+ return true;
+
+ auto *I = dyn_cast<Instruction>(V);
+
+ if (I == nullptr)
+ return false;
+
+ // We insert false by default to return false when we encounter user loops.
+ const auto [It, Inserted] = Visited.insert({V, false});
+
+ // There are three possible cases for us having information on this value
+ // in the Visited map:
+ // 1. We properly checked it and concluded that we can evaluate it (true)
+ // 2. We properly checked it and concluded that we can't (false)
+ // 3. We started to check it, but during the recursive traversal we came
+ // back to it.
+ //
+ // For cases 1 and 2, we can safely return the stored result. For case 3, we
+ // can potentially have a situation where we can evaluate recursive user
+ // chains, but that can be quite tricky to do properly and isntead, we
+ // return false.
+ //
+ // In any case, we should return whatever was there in the map to begin
+ // with.
+ if (!Inserted)
+ return It->getSecond();
+
+ // We can easily make a decision about single-user values whether they can
+ // be evaluated in a different type or not, we came from that user. This is
+ // not as simple for multi-user values.
+ //
+ // In general, we have the following case (inverted control-flow, users are
+ // at the top):
+ //
+ // Cast %A
+ // ____|
+ // /
+ // %A = Use %B, %C
+ // ________| |
+ // / |
+ // %B = Use %D |
+ // ________| |
+ // / |
+ // %D = Use %C |
+ // ________|___|
+ // /
+ // %C = ...
+ //
+ // In this case, when we check %A, %B and %D, we are confident that we can
+ // make the decision here and now, since we came from their only users.
+ //
+ // For %C, it is harder. We come there twice, and when we come the first
+ // time, it's hard to tell if we will visit the second user (technically
+ // it's not hard, but we might need a lot of repetitive checks with non-zero
+ // cost).
+ //
+ // In the case above, we are allowed to evaluate %C in different type
+ // because all of it users were part of the traversal.
+ //
+ // In the following case, however, we can't make this conclusion:
+ //
+ // Cast %A
+ // ____|
+ // /
+ // %A = Use %B, %C
+ // ________| |
+ // / |
+ // %B = Use %D |
+ // ________| |
+ // / |
+ // %D = Use %C |
+ // | |
+ // foo(%C) | | <- never traversing foo(%C)
+ // ________|___|
+ // /
+ // %C = ...
+ //
+ // In this case, we still can evaluate %C in a different type, but we'd need
+ // to create a copy of the original %C to be used in foo(%C). Such
+ // duplication might be not profitable.
+ //
+ // For this reason, we collect all users of the mult-user values and mark
+ // them as "pending" and defer this decision to the very end. When we are
+ // done and and ready to have a positive verdict, we should double-check all
+ // of the pending users and ensure that we visited them. allPendingVisited
+ // predicate checks exactly that.
+ if (!I->hasOneUse())
+ llvm::append_range(Pending, I->users());
+
+ const bool Result = Pred(V, Ty);
+ // We have to set result this way and not via It because Pred is recursive
+ // and it is very likely that we grew Visited and invalidated It.
+ Visited[V] = Result;
+ return Result;
+ }
+
+ /// Filter out values that we can not evaluate in the destination type for
+ /// free.
+ [[nodiscard]] bool canNotEvaluateInType(Value *V, Type *Ty);
+
+ [[nodiscard]] bool canEvaluateTruncatedImpl(Value *V, Type *Ty,
+ InstCombinerImpl &IC,
+ Instruction *CxtI);
+ [[nodiscard]] bool canEvaluateTruncatedPred(Value *V, Type *Ty,
+ InstCombinerImpl &IC,
+ Instruction *CxtI);
+ [[nodiscard]] bool canEvaluateZExtdImpl(Value *V, Type *Ty,
+ unsigned &BitsToClear,
+ InstCombinerImpl &IC,
+ Instruction *CxtI);
+ [[nodiscard]] bool canEvaluateSExtdImpl(Value *V, Type *Ty);
+ [[nodiscard]] bool canEvaluateSExtdPred(Value *V, Type *Ty);
+
+ /// A bookkeeping map to memorize an already made decision for a traversed
+ /// value.
+ SmallDenseMap<Value *, bool, 8> Visited;
+
+ /// A list of pending values to check in the end.
+ SmallVector<Value *, 8> Pending;
+};
+
+} // anonymous namespace
+
/// Constants and extensions/truncates from the destination type are always
/// free to be evaluated in that type. This is a helper for canEvaluate*.
-static bool canAlwaysEvaluateInType(Value *V, Type *Ty) {
+bool TypeEvaluationHelper::canAlwaysEvaluateInType(Value *V, Type *Ty) {
if (isa<Constant>(V))
return match(V, m_ImmConstant());
@@ -243,7 +443,7 @@ static bool canAlwaysEvaluateInType(Value *V, Type *Ty) {
/// Filter out values that we can not evaluate in the destination type for free.
/// This is a helper for canEvaluate*.
-static bool canNotEvaluateInType(Value *V, Type *Ty) {
+bool TypeEvaluationHelper::canNotEvaluateInType(Value *V, Type *Ty) {
if (!isa<Instruction>(V))
return true;
// We don't extend or shrink something that has multiple uses -- doing so
@@ -265,13 +465,27 @@ static bool canNotEvaluateInType(Value *V, Type *Ty) {
///
/// This function works on both vectors and scalars.
///
-static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombinerImpl &IC,
- Instruction *CxtI) {
- if (canAlwaysEvaluateInType(V, Ty))
- return true;
- if (canNotEvaluateInType(V, Ty))
- return false;
+bool TypeEvaluationHelper::canEvaluateTruncated(Value *V, Type *Ty,
+ InstCombinerImpl &IC,
+ Instruction *CxtI) {
+ TypeEvaluationHelper TYH;
+ return TYH.canEvaluateTruncatedImpl(V, Ty, IC, CxtI) &&
+ // We need to check whether we visited all users of multi-user values,
+ // and we have to do it at the very end, outside of the recursion.
+ TYH.allPendingVisited();
+}
+bool TypeEvaluationHelper::canEvaluateTruncatedImpl(Value *V, Type *Ty,
+ InstCombinerImpl &IC,
+ Instruction *CxtI) {
+ return canEvaluate(V, Ty, [this, &IC, CxtI](Value *V, Type *Ty) {
+ return canEvaluateTruncatedPred(V, Ty, IC, CxtI);
+ });
+}
+
+bool TypeEvaluationHelper::canEvaluateTruncatedPred(Value *V, Type *Ty,
+ InstCombinerImpl &IC,
+ Instruction *CxtI) {
auto *I = cast<Instruction>(V);
Type *OrigTy = V->getType();
switch (I->getOpcode()) {
@@ -282,8 +496,8 @@ static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombinerImpl &IC,
case Instruction::Or:
case Instruction::Xor:
// These operators can all arbitrarily be extended or truncated.
- return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) &&
- canEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI);
+ return canEvaluateTruncatedImpl(I->getOperand(0), Ty, IC, CxtI) &&
+ canEvaluateTruncatedImpl(I->getOperand(1), Ty, IC, CxtI);
case Instruction::UDiv:
case Instruction::URem: {
@@ -296,8 +510,8 @@ static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombinerImpl &IC,
// based on later context may introduce a trap.
if (IC.MaskedValueIsZero(I->getOperand(0), Mask, I) &&
IC.MaskedValueIsZero(I->getOperand(1), Mask, I)) {
- return canEvaluateTruncated(I->getOperand(0), Ty, IC, I) &&
- canEvaluateTruncated(I->getOperand(1), Ty, IC, I);
+ return canEvaluateTruncatedImpl(I->getOperand(0), Ty, IC, CxtI) &&
+ canEvaluateTruncatedImpl(I->getOperand(1), Ty, IC, CxtI);
}
break;
}
@@ -308,8 +522,8 @@ static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombinerImpl &IC,
KnownBits AmtKnownBits =
llvm::computeKnownBits(I->getOperand(1), IC.getDataLayout());
if (AmtKnownBits.getMaxValue().ult(BitWidth))
- return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) &&
- canEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI);
+ return canEvaluateTruncatedImpl(I->getOperand(0), Ty, IC, CxtI) &&
+ canEvaluateTruncatedImpl(I->getOperand(1), Ty, IC, CxtI);
break;
}
case Instruction::LShr: {
@@ -329,12 +543,12 @@ static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombinerImpl &IC,
if (auto *Trunc = dyn_cast<TruncInst>(V->user_back())) {
auto DemandedBits = Trunc->getType()->getScalarSizeInBits();
if ((MaxShiftAmt + DemandedBits).ule(BitWidth))
- return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) &&
- canEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI);
+ return canEvaluateTruncatedImpl(I->getOperand(0), Ty, IC, CxtI) &&
+ canEvaluateTruncatedImpl(I->getOperand(1), Ty, IC, CxtI);
}
if (IC.MaskedValueIsZero(I->getOperand(0), ShiftedBits, CxtI))
- return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) &&
- canEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI);
+ return canEvaluateTruncatedImpl(I->getOperand(0), Ty, IC, CxtI) &&
+ canEvaluateTruncatedImpl(I->getOperand(1), Ty, IC, CxtI);
}
break;
}
@@ -351,8 +565,8 @@ static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombinerImpl &IC,
unsigned ShiftedBits = OrigBitWidth - BitWidth;
if (AmtKnownBits.getMaxValue().ult(BitWidth) &&
ShiftedBits < IC.ComputeNumSignBits(I->getOperand(0), CxtI))
- return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) &&
- canEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI);
+ return canEvaluateTruncatedImpl(I->getOperand(0), Ty, IC, CxtI) &&
+ canEvaluateTruncatedImpl(I->getOperand(1), Ty, IC, CxtI);
break;
}
case Instruction::Trunc:
@@ -365,18 +579,18 @@ static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombinerImpl &IC,
return true;
case Instruction::Select: {
SelectInst *SI = cast<SelectInst>(I);
- return canEvaluateTruncated(SI->getTrueValue(), Ty, IC, CxtI) &&
- canEvaluateTruncated(SI->getFalseValue(), Ty, IC, CxtI);
+ return canEvaluateTruncatedImpl(SI->getTrueValue(), Ty, IC, CxtI) &&
+ canEvaluateTruncatedImpl(SI->getFalseValue(), Ty, IC, CxtI);
}
case Instruction::PHI: {
// We can change a phi if we can change all operands. Note that we never
- // get into trouble with cyclic PHIs here because we only consider
- // instructions with a single use.
+ // get into trouble with cyclic PHIs here because canEvaluate handles use
+ // chain loops.
PHINode *PN = cast<PHINode>(I);
- for (Value *IncValue : PN->incoming_values())
- if (!canEvaluateTruncated(IncValue, Ty, IC, CxtI))
- return false;
- return true;
+ return llvm::all_of(
+ PN->incoming_values(), [this, Ty, &IC, CxtI](Value *IncValue) {
+ return canEvaluateTruncatedImpl(IncValue, Ty, IC, CxtI);
+ });
}
case Instruction::FPToUI:
case Instruction::FPToSI: {
@@ -385,14 +599,14 @@ static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombinerImpl &IC,
// that did not exist in the original code.
Type *InputTy = I->getOperand(0)->getType()->getScalarType();
const fltSemantics &Semantics = InputTy->getFltSemantics();
- uint32_t MinBitWidth =
- APFloatBase::semanticsIntSizeInBits(Semantics,
- I->getOpcode() == Instruction::FPToSI);
+ uint32_t MinBitWidth = APFloatBase::semanticsIntSizeInBits(
+ Semantics, I->getOpcode() == Instruction::FPToSI);
return Ty->getScalarSizeInBits() >= MinBitWidth;
}
case Instruction::ShuffleVector:
- return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) &&
- canEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI);
+ return canEvaluateTruncatedImpl(I->getOperand(0), Ty, IC, CxtI) &&
+ canEvaluateTruncatedImpl(I->getOperand(1), Ty, IC, CxtI);
+
default:
// TODO: Can handle more cases here.
break;
@@ -767,7 +981,7 @@ Instruction *InstCombinerImpl::visitTrunc(TruncInst &Trunc) {
// expression tree to something weird like i93 unless the source is also
// strange.
if ((DestTy->isVectorTy() || shouldChangeType(SrcTy, DestTy)) &&
- canEvaluateTruncated(Src, DestTy, *this, &Trunc)) {
+ TypeEvaluationHelper::canEvaluateTruncated(Src, DestTy, *this, &Trunc)) {
// If this cast is a truncate, evaluting in a different type always
// eliminates the cast, so it is always a win.
@@ -788,7 +1002,8 @@ Instruction *InstCombinerImpl::visitTrunc(TruncInst &Trunc) {
if (DestWidth * 2 < SrcWidth) {
auto *NewDestTy = DestITy->getExtendedType();
if (shouldChangeType(SrcTy, NewDestTy) &&
- canEvaluateTruncated(Src, NewDestTy, *this, &Trunc)) {
+ TypeEvaluationHelper::canEvaluateTruncated(Src, NewDestTy, *this,
+ &Trunc)) {
LLVM_DEBUG(
dbgs() << "ICE: EvaluateInDifferentType converting expression type"
" to reduce the width of operand of"
@@ -1104,11 +1319,22 @@ Instruction *InstCombinerImpl::transformZExtICmp(ICmpInst *Cmp,
/// clear the top bits anyway, doing this has no extra cost.
///
/// This function works on both vectors and scalars.
-static bool canEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear,
- InstCombinerImpl &IC, Instruction *CxtI) {
+bool TypeEvaluationHelper::canEvaluateZExtd(Value *V, Type *Ty,
+ unsigned &BitsToClear,
+ InstCombinerImpl &IC,
+ Instruction *CxtI) {
+ TypeEvaluationHelper TYH;
+ return TYH.canEvaluateZExtdImpl(V, Ty, BitsToClear, IC, CxtI);
+}
+bool TypeEvaluationHelper::canEvaluateZExtdImpl(Value *V, Type *Ty,
+ unsigned &BitsToClear,
+ InstCombinerImpl &IC,
+ Instruction *CxtI) {
BitsToClear = 0;
if (canAlwaysEvaluateInType(V, Ty))
return true;
+ // We stick to the one-user limit for the ZExt transform due to the fact
+ // that this predicate returns two values: predicate result and BitsToClear.
if (canNotEvaluateInType(V, Ty))
return false;
@@ -1125,8 +1351,8 @@ static bool canEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear,
case Instruction::Add:
case Instruction::Sub:
case Instruction::Mul:
- if (!canEvaluateZExtd(I->getOperand(0), Ty, BitsToClear, IC, CxtI) ||
- !canEvaluateZExtd(I->getOperand(1), Ty, Tmp, IC, CxtI))
+ if (!canEvaluateZExtdImpl(I->getOperand(0), Ty, BitsToClear, IC, CxtI) ||
+ !canEvaluateZExtdImpl(I->getOperand(1), Ty, Tmp, IC, CxtI))
return false;
// These can all be promoted if neither operand has 'bits to clear'.
if (BitsToClear == 0 && Tmp == 0)
@@ -1157,7 +1383,7 @@ static bool canEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear,
// upper bits we can reduce BitsToClear by the shift amount.
uint64_t ShiftAmt;
if (match(I->getOperand(1), m_ConstantInt(ShiftAmt))) {
- if (!canEvaluateZExtd(I->getOperand(0), Ty, BitsToClear, IC, CxtI))
+ if (!canEvaluateZExtdImpl(I->getOperand(0), Ty, BitsToClear, IC, CxtI))
return false;
BitsToClear = ShiftAmt < BitsToClear ? BitsToClear - ShiftAmt : 0;
return true;
@@ -1169,7 +1395,7 @@ static bool canEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear,
// ultimate 'and' to clear out the high zero bits we're clearing out though.
uint64_t ShiftAmt;
if (match(I->getOperand(1), m_ConstantInt(ShiftAmt))) {
- if (!canEvaluateZExtd(I->getOperand(0), Ty, BitsToClear, IC, CxtI))
+ if (!canEvaluateZExtdImpl(I->getOperand(0), Ty, BitsToClear, IC, CxtI))
return false;
BitsToClear += ShiftAmt;
if (BitsToClear > V->getType()->getScalarSizeInBits())
@@ -1180,8 +1406,8 @@ static bool canEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear,
return false;
}
case Instruction::Select:
- if (!canEvaluateZExtd(I->getOperand(1), Ty, Tmp, IC, CxtI) ||
- !canEvaluateZExtd(I->getOperand(2), Ty, BitsToClear, IC, CxtI) ||
+ if (!canEvaluateZExtdImpl(I->getOperand(1), Ty, Tmp, IC, CxtI) ||
+ !canEvaluateZExtdImpl(I->getOperand(2), Ty, BitsToClear, IC, CxtI) ||
// TODO: If important, we could handle the case when the BitsToClear are
// known zero in the disagreeing side.
Tmp != BitsToClear)
@@ -1193,10 +1419,11 @@ static bool canEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear,
// get into trouble with cyclic PHIs here because we only consider
// instructions with a single use.
PHINode *PN = cast<PHINode>(I);
- if (!canEvaluateZExtd(PN->getIncomingValue(0), Ty, BitsToClear, IC, CxtI))
+ if (!canEvaluateZExtdImpl(PN->getIncomingValue(0), Ty, BitsToClear, IC,
+ CxtI))
return false;
for (unsigned i = 1, e = PN->getNumIncomingValues(); i != e; ++i)
- if (!canEvaluateZExtd(PN->getIncomingValue(i), Ty, Tmp, IC, CxtI) ||
+ if (!canEvaluateZExtdImpl(PN->getIncomingValue(i), Ty, Tmp, IC, CxtI) ||
// TODO: If important, we could handle the case when the BitsToClear
// are known zero in the disagreeing input.
Tmp != BitsToClear)
@@ -1237,7 +1464,8 @@ Instruction *InstCombinerImpl::visitZExt(ZExtInst &Zext) {
// Try to extend the entire expression tree to the wide destination type.
unsigned BitsToClear;
if (shouldChangeType(SrcTy, DestTy) &&
- canEvaluateZExtd(Src, DestTy, BitsToClear, *this, &Zext)) {
+ TypeEvaluationHelper::canEvaluateZExtd(Src, DestTy, BitsToClear, *this,
+ &Zext)) {
assert(BitsToClear <= SrcTy->getScalarSizeInBits() &&
"Can't clear more bits than in SrcTy");
@@ -1455,13 +1683,20 @@ Instruction *InstCombinerImpl::transformSExtICmp(ICmpInst *Cmp,
///
/// This function works on both vectors and scalars.
///
-static bool canEvaluateSExtd(Value *V, Type *Ty) {
+bool TypeEvaluationHelper::canEvaluateSExtd(Value *V, Type *Ty) {
+ TypeEvaluationHelper TYH;
+ return TYH.canEvaluateSExtdImpl(V, Ty) && TYH.allPendingVisited();
+}
+
+bool TypeEvaluationHelper::canEvaluateSExtdImpl(Value *V, Type *Ty) {
+ return canEvaluate(V, Ty, [this](Value *V, Type *Ty) {
+ return canEvaluateSExtdPred(V, Ty);
+ });
+}
+
+bool TypeEvaluationHelper::canEvaluateSExtdPred(Value *V, Type *Ty) {
assert(V->getType()->getScalarSizeInBits() < Ty->getScalarSizeInBits() &&
"Can't sign extend type to a smaller type");
- if (canAlwaysEvaluateInType(V, Ty))
- return true;
- if (canNotEvaluateInType(V, Ty))
- return false;
auto *I = cast<Instruction>(V);
switch (I->getOpcode()) {
@@ -1476,23 +1711,24 @@ static bool canEvaluateSExtd(Value *V, Type *Ty) {
case Instruction::Sub:
case Instruction::Mul:
// These operators can all arbitrarily be extended if their inputs can.
- return canEvaluateSExtd(I->getOperand(0), Ty) &&
- canEvaluateSExtd(I->getOperand(1), Ty);
+ return canEvaluateSExtdImpl(I->getOperand(0), Ty) &&
+ canEvaluateSExtdImpl(I->getOperand(1), Ty);
- //case Instruction::Shl: TODO
- //case Instruction::LShr: TODO
+ // case Instruction::Shl: TODO
+ // case Instruction::LShr: TODO
case Instruction::Select:
- return canEvaluateSExtd(I->getOperand(1), Ty) &&
- canEvaluateSExtd(I->getOperand(2), Ty);
+ return canEvaluateSExtdImpl(I->getOperand(1), Ty) &&
+ canEvaluateSExtdImpl(I->getOperand(2), Ty);
case Instruction::PHI: {
// We can change a phi if we can change all operands. Note that we never
- // get into trouble with cyclic PHIs here because we only consider
- // instructions with a single use.
+ // get into trouble with cyclic PHIs here because canEvaluate handles use
+ // chain loops.
PHINode *PN = cast<PHINode>(I);
for (Value *IncValue : PN->incoming_values())
- if (!canEvaluateSExtd(IncValue, Ty)) return false;
+ if (!canEvaluateSExtdImpl(IncValue, Ty))
+ return false;
return true;
}
default:
@@ -1533,7 +1769,7 @@ Instruction *InstCombinerImpl::visitSExt(SExtInst &Sext) {
if (TruncSrc->getType()->getScalarSizeInBits() > DestBitSize)
ShouldExtendExpression = false;
if (ShouldExtendExpression && shouldChangeType(SrcTy, DestTy) &&
- canEvaluateSExtd(Src, DestTy)) {
+ TypeEvaluationHelper::canEvaluateSExtd(Src, DestTy)) {
// Okay, we can transform this! Insert the new expression now.
LLVM_DEBUG(
dbgs() << "ICE: EvaluateInDifferentType converting expression type"
@@ -1548,7 +1784,7 @@ Instruction *InstCombinerImpl::visitSExt(SExtInst &Sext) {
return replaceInstUsesWith(Sext, Res);
// We need to emit a shl + ashr to do the sign extend.
- Value *ShAmt = ConstantInt::get(DestTy, DestBitSize-SrcBitSize);
+ Value *ShAmt = ConstantInt::get(DestTy, DestBitSize - SrcBitSize);
return BinaryOperator::CreateAShr(Builder.CreateShl(Res, ShAmt, "sext"),
ShAmt);
}
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index fba1ccf..abf4381 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -1465,20 +1465,24 @@ Instruction *InstCombinerImpl::foldICmpTruncConstant(ICmpInst &Cmp,
ConstantInt::get(V->getType(), 1));
}
- // TODO: Handle any shifted constant by subtracting trailing zeros.
// TODO: Handle non-equality predicates.
Value *Y;
- if (Cmp.isEquality() && match(X, m_Shl(m_One(), m_Value(Y)))) {
- // (trunc (1 << Y) to iN) == 0 --> Y u>= N
- // (trunc (1 << Y) to iN) != 0 --> Y u< N
+ const APInt *Pow2;
+ if (Cmp.isEquality() && match(X, m_Shl(m_Power2(Pow2), m_Value(Y))) &&
+ DstBits > Pow2->logBase2()) {
+ // (trunc (Pow2 << Y) to iN) == 0 --> Y u>= N - log2(Pow2)
+ // (trunc (Pow2 << Y) to iN) != 0 --> Y u< N - log2(Pow2)
+ // iff N > log2(Pow2)
if (C.isZero()) {
auto NewPred = (Pred == Cmp.ICMP_EQ) ? Cmp.ICMP_UGE : Cmp.ICMP_ULT;
- return new ICmpInst(NewPred, Y, ConstantInt::get(SrcTy, DstBits));
+ return new ICmpInst(NewPred, Y,
+ ConstantInt::get(SrcTy, DstBits - Pow2->logBase2()));
}
- // (trunc (1 << Y) to iN) == 2**C --> Y == C
- // (trunc (1 << Y) to iN) != 2**C --> Y != C
+ // (trunc (Pow2 << Y) to iN) == 2**C --> Y == C - log2(Pow2)
+ // (trunc (Pow2 << Y) to iN) != 2**C --> Y != C - log2(Pow2)
if (C.isPowerOf2())
- return new ICmpInst(Pred, Y, ConstantInt::get(SrcTy, C.logBase2()));
+ return new ICmpInst(
+ Pred, Y, ConstantInt::get(SrcTy, C.logBase2() - Pow2->logBase2()));
}
if (Cmp.isEquality() && (Trunc->hasOneUse() || Trunc->hasNoUnsignedWrap())) {
@@ -2638,16 +2642,6 @@ Instruction *InstCombinerImpl::foldICmpShrConstant(ICmpInst &Cmp,
if (Shr->isExact())
return new ICmpInst(Pred, X, ConstantInt::get(ShrTy, C << ShAmtVal));
- if (C.isZero()) {
- // == 0 is u< 1.
- if (Pred == CmpInst::ICMP_EQ)
- return new ICmpInst(CmpInst::ICMP_ULT, X,
- ConstantInt::get(ShrTy, (C + 1).shl(ShAmtVal)));
- else
- return new ICmpInst(CmpInst::ICMP_UGT, X,
- ConstantInt::get(ShrTy, (C + 1).shl(ShAmtVal) - 1));
- }
-
if (Shr->hasOneUse()) {
// Canonicalize the shift into an 'and':
// icmp eq/ne (shr X, ShAmt), C --> icmp eq/ne (and X, HiMask), (C << ShAmt)
@@ -3138,7 +3132,7 @@ Instruction *InstCombinerImpl::foldICmpAddConstant(ICmpInst &Cmp,
Value *Op0, *Op1;
Instruction *Ext0, *Ext1;
- const CmpInst::Predicate Pred = Cmp.getPredicate();
+ const CmpPredicate Pred = Cmp.getCmpPredicate();
if (match(Add,
m_Add(m_CombineAnd(m_Instruction(Ext0), m_ZExtOrSExt(m_Value(Op0))),
m_CombineAnd(m_Instruction(Ext1),
@@ -3173,22 +3167,29 @@ Instruction *InstCombinerImpl::foldICmpAddConstant(ICmpInst &Cmp,
// If the add does not wrap, we can always adjust the compare by subtracting
// the constants. Equality comparisons are handled elsewhere. SGE/SLE/UGE/ULE
- // are canonicalized to SGT/SLT/UGT/ULT.
- if ((Add->hasNoSignedWrap() &&
- (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SLT)) ||
- (Add->hasNoUnsignedWrap() &&
- (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_ULT))) {
+ // have been canonicalized to SGT/SLT/UGT/ULT.
+ if (Add->hasNoUnsignedWrap() &&
+ (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_ULT)) {
bool Overflow;
- APInt NewC =
- Cmp.isSigned() ? C.ssub_ov(*C2, Overflow) : C.usub_ov(*C2, Overflow);
+ APInt NewC = C.usub_ov(*C2, Overflow);
// If there is overflow, the result must be true or false.
- // TODO: Can we assert there is no overflow because InstSimplify always
- // handles those cases?
if (!Overflow)
// icmp Pred (add nsw X, C2), C --> icmp Pred X, (C - C2)
return new ICmpInst(Pred, X, ConstantInt::get(Ty, NewC));
}
+ CmpInst::Predicate ChosenPred = Pred.getPreferredSignedPredicate();
+
+ if (Add->hasNoSignedWrap() &&
+ (ChosenPred == ICmpInst::ICMP_SGT || ChosenPred == ICmpInst::ICMP_SLT)) {
+ bool Overflow;
+ APInt NewC = C.ssub_ov(*C2, Overflow);
+ if (!Overflow)
+ // icmp samesign ugt/ult (add nsw X, C2), C
+ // -> icmp sgt/slt X, (C - C2)
+ return new ICmpInst(ChosenPred, X, ConstantInt::get(Ty, NewC));
+ }
+
if (ICmpInst::isUnsigned(Pred) && Add->hasNoSignedWrap() &&
C.isNonNegative() && (C - *C2).isNonNegative() &&
computeConstantRange(X, /*ForSigned=*/true).add(*C2).isAllNonNegative())
@@ -5892,6 +5893,12 @@ static void collectOffsetOp(Value *V, SmallVectorImpl<OffsetOp> &Offsets,
Offsets.emplace_back(Instruction::Xor, Inst->getOperand(1));
Offsets.emplace_back(Instruction::Xor, Inst->getOperand(0));
break;
+ case Instruction::Shl:
+ if (Inst->hasNoSignedWrap())
+ Offsets.emplace_back(Instruction::AShr, Inst->getOperand(1));
+ if (Inst->hasNoUnsignedWrap())
+ Offsets.emplace_back(Instruction::LShr, Inst->getOperand(1));
+ break;
case Instruction::Select:
if (AllowRecursion) {
collectOffsetOp(Inst->getOperand(1), Offsets, /*AllowRecursion=*/false);
@@ -5948,9 +5955,31 @@ static Instruction *foldICmpEqualityWithOffset(ICmpInst &I,
collectOffsetOp(Op1, OffsetOps, /*AllowRecursion=*/true);
auto ApplyOffsetImpl = [&](Value *V, unsigned BinOpc, Value *RHS) -> Value * {
+ switch (BinOpc) {
+ // V = shl nsw X, RHS => X = ashr V, RHS
+ case Instruction::AShr: {
+ const APInt *CV, *CRHS;
+ if (!(match(V, m_APInt(CV)) && match(RHS, m_APInt(CRHS)) &&
+ CV->ashr(*CRHS).shl(*CRHS) == *CV) &&
+ !match(V, m_NSWShl(m_Value(), m_Specific(RHS))))
+ return nullptr;
+ break;
+ }
+ // V = shl nuw X, RHS => X = lshr V, RHS
+ case Instruction::LShr: {
+ const APInt *CV, *CRHS;
+ if (!(match(V, m_APInt(CV)) && match(RHS, m_APInt(CRHS)) &&
+ CV->lshr(*CRHS).shl(*CRHS) == *CV) &&
+ !match(V, m_NUWShl(m_Value(), m_Specific(RHS))))
+ return nullptr;
+ break;
+ }
+ default:
+ break;
+ }
+
Value *Simplified = simplifyBinOp(BinOpc, V, RHS, SQ);
- // Avoid infinite loops by checking if RHS is an identity for the BinOp.
- if (!Simplified || Simplified == V)
+ if (!Simplified)
return nullptr;
// Reject constant expressions as they don't simplify things.
if (isa<Constant>(Simplified) && !match(Simplified, m_ImmConstant()))
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index d85e4f7..9bdd8cb 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -479,7 +479,7 @@ private:
const Twine &NameStr = "",
InsertPosition InsertBefore = nullptr) {
auto *Sel = SelectInst::Create(C, S1, S2, NameStr, InsertBefore, nullptr);
- setExplicitlyUnknownBranchWeightsIfProfiled(*Sel, F, DEBUG_TYPE);
+ setExplicitlyUnknownBranchWeightsIfProfiled(*Sel, DEBUG_TYPE, &F);
return Sel;
}
diff --git a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
index 9815644..ba1865a 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
@@ -698,8 +698,7 @@ static bool isSafeAndProfitableToSinkLoad(LoadInst *L) {
Instruction *InstCombinerImpl::foldPHIArgLoadIntoPHI(PHINode &PN) {
LoadInst *FirstLI = cast<LoadInst>(PN.getIncomingValue(0));
- // Can't forward swifterror through a phi.
- if (FirstLI->getOperand(0)->isSwiftError())
+ if (!canReplaceOperandWithVariable(FirstLI, 0))
return nullptr;
// FIXME: This is overconservative; this transform is allowed in some cases
@@ -738,8 +737,7 @@ Instruction *InstCombinerImpl::foldPHIArgLoadIntoPHI(PHINode &PN) {
LI->getPointerAddressSpace() != LoadAddrSpace)
return nullptr;
- // Can't forward swifterror through a phi.
- if (LI->getOperand(0)->isSwiftError())
+ if (!canReplaceOperandWithVariable(LI, 0))
return nullptr;
// We can't sink the load if the loaded value could be modified between
@@ -1007,7 +1005,7 @@ static bool PHIsEqualValue(PHINode *PN, Value *&NonPhiInVal,
return true;
// Don't scan crazily complex things.
- if (ValueEqualPHIs.size() == 16)
+ if (ValueEqualPHIs.size() >= 16)
return false;
// Scan the operands to see if they are either phi nodes or are equal to
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index f5130da..c00551b 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -1027,10 +1027,9 @@ static Value *canonicalizeSaturatedSubtract(const ICmpInst *ICI,
return Result;
}
-static Value *canonicalizeSaturatedAdd(ICmpInst *Cmp, Value *TVal, Value *FVal,
- InstCombiner::BuilderTy &Builder) {
- if (!Cmp->hasOneUse())
- return nullptr;
+static Value *
+canonicalizeSaturatedAddUnsigned(ICmpInst *Cmp, Value *TVal, Value *FVal,
+ InstCombiner::BuilderTy &Builder) {
// Match unsigned saturated add with constant.
Value *Cmp0 = Cmp->getOperand(0);
@@ -1130,6 +1129,95 @@ static Value *canonicalizeSaturatedAdd(ICmpInst *Cmp, Value *TVal, Value *FVal,
return nullptr;
}
+static Value *canonicalizeSaturatedAddSigned(ICmpInst *Cmp, Value *TVal,
+ Value *FVal,
+ InstCombiner::BuilderTy &Builder) {
+ // Match saturated add with constant.
+ Value *Cmp0 = Cmp->getOperand(0);
+ Value *Cmp1 = Cmp->getOperand(1);
+ ICmpInst::Predicate Pred = Cmp->getPredicate();
+ Value *X;
+ const APInt *C;
+
+ // Canonicalize INT_MAX to true value of the select.
+ if (match(FVal, m_MaxSignedValue())) {
+ std::swap(TVal, FVal);
+ Pred = CmpInst::getInversePredicate(Pred);
+ }
+
+ if (!match(TVal, m_MaxSignedValue()))
+ return nullptr;
+
+ // sge maximum signed value is canonicalized to eq maximum signed value and
+ // requires special handling (a == INT_MAX) ? INT_MAX : a + 1 -> sadd.sat(a,
+ // 1)
+ if (Pred == ICmpInst::ICMP_EQ) {
+ if (match(FVal, m_Add(m_Specific(Cmp0), m_One())) && Cmp1 == TVal) {
+ return Builder.CreateBinaryIntrinsic(
+ Intrinsic::sadd_sat, Cmp0, ConstantInt::get(Cmp0->getType(), 1));
+ }
+ return nullptr;
+ }
+
+ // (X > Y) ? INT_MAX : (X + C) --> sadd.sat(X, C)
+ // (X >= Y) ? INT_MAX : (X + C) --> sadd.sat(X, C)
+ // where Y is INT_MAX - C or INT_MAX - C - 1, and C > 0
+ if ((Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE) &&
+ isa<Constant>(Cmp1) &&
+ match(FVal, m_Add(m_Specific(Cmp0), m_StrictlyPositive(C)))) {
+ APInt IntMax =
+ APInt::getSignedMaxValue(Cmp1->getType()->getScalarSizeInBits());
+
+ // For SGE, try to flip to SGT to normalize the comparison constant.
+ if (Pred == ICmpInst::ICMP_SGE) {
+ if (auto Flipped = getFlippedStrictnessPredicateAndConstant(
+ Pred, cast<Constant>(Cmp1))) {
+ Pred = Flipped->first;
+ Cmp1 = Flipped->second;
+ }
+ }
+
+ // Check the pattern: X > INT_MAX - C or X > INT_MAX - C - 1
+ if (Pred == ICmpInst::ICMP_SGT &&
+ (match(Cmp1, m_SpecificIntAllowPoison(IntMax - *C)) ||
+ match(Cmp1, m_SpecificIntAllowPoison(IntMax - *C - 1))))
+ return Builder.CreateBinaryIntrinsic(
+ Intrinsic::sadd_sat, Cmp0, ConstantInt::get(Cmp0->getType(), *C));
+ }
+
+ // Canonicalize predicate to less-than or less-or-equal-than.
+ if (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE) {
+ std::swap(Cmp0, Cmp1);
+ Pred = CmpInst::getSwappedPredicate(Pred);
+ }
+
+ if (Pred != ICmpInst::ICMP_SLT && Pred != ICmpInst::ICMP_SLE)
+ return nullptr;
+
+ if (match(Cmp0, m_NSWSub(m_MaxSignedValue(), m_Value(X))) &&
+ match(FVal, m_c_Add(m_Specific(X), m_Specific(Cmp1)))) {
+ // (INT_MAX - X s< Y) ? INT_MAX : (X + Y) --> sadd.sat(X, Y)
+ // (INT_MAX - X s< Y) ? INT_MAX : (Y + X) --> sadd.sat(X, Y)
+ return Builder.CreateBinaryIntrinsic(Intrinsic::sadd_sat, X, Cmp1);
+ }
+
+ return nullptr;
+}
+
+static Value *canonicalizeSaturatedAdd(ICmpInst *Cmp, Value *TVal, Value *FVal,
+ InstCombiner::BuilderTy &Builder) {
+ if (!Cmp->hasOneUse())
+ return nullptr;
+
+ if (Value *V = canonicalizeSaturatedAddUnsigned(Cmp, TVal, FVal, Builder))
+ return V;
+
+ if (Value *V = canonicalizeSaturatedAddSigned(Cmp, TVal, FVal, Builder))
+ return V;
+
+ return nullptr;
+}
+
/// Try to match patterns with select and subtract as absolute difference.
static Value *foldAbsDiff(ICmpInst *Cmp, Value *TVal, Value *FVal,
InstCombiner::BuilderTy &Builder) {
@@ -2979,14 +3067,10 @@ Instruction *InstCombinerImpl::foldAndOrOfSelectUsingImpliedCond(Value *Op,
"Op must be either i1 or vector of i1.");
if (SI.getCondition()->getType() != Op->getType())
return nullptr;
- if (Value *V = simplifyNestedSelectsUsingImpliedCond(SI, Op, IsAnd, DL)) {
- Instruction *MDFrom = nullptr;
- if (!ProfcheckDisableMetadataFixes)
- MDFrom = &SI;
- return SelectInst::Create(
+ if (Value *V = simplifyNestedSelectsUsingImpliedCond(SI, Op, IsAnd, DL))
+ return createSelectInstWithUnknownProfile(
Op, IsAnd ? V : ConstantInt::getTrue(Op->getType()),
- IsAnd ? ConstantInt::getFalse(Op->getType()) : V, "", nullptr, MDFrom);
- }
+ IsAnd ? ConstantInt::getFalse(Op->getType()) : V);
return nullptr;
}
@@ -3599,6 +3683,21 @@ Instruction *InstCombinerImpl::foldSelectOfBools(SelectInst &SI) {
m_Not(m_Specific(SelCond->getTrueValue())));
if (MayNeedFreeze)
C = Builder.CreateFreeze(C);
+ if (!ProfcheckDisableMetadataFixes) {
+ Value *C2 = nullptr, *A2 = nullptr, *B2 = nullptr;
+ if (match(CondVal, m_LogicalAnd(m_Specific(C), m_Value(A2))) &&
+ SelCond) {
+ return SelectInst::Create(C, A, B, "", nullptr, SelCond);
+ } else if (match(FalseVal,
+ m_LogicalAnd(m_Not(m_Value(C2)), m_Value(B2))) &&
+ SelFVal) {
+ SelectInst *NewSI = SelectInst::Create(C, A, B, "", nullptr, SelFVal);
+ NewSI->swapProfMetadata();
+ return NewSI;
+ } else {
+ return createSelectInstWithUnknownProfile(C, A, B);
+ }
+ }
return SelectInst::Create(C, A, B);
}
@@ -3615,6 +3714,20 @@ Instruction *InstCombinerImpl::foldSelectOfBools(SelectInst &SI) {
m_Not(m_Specific(SelFVal->getTrueValue())));
if (MayNeedFreeze)
C = Builder.CreateFreeze(C);
+ if (!ProfcheckDisableMetadataFixes) {
+ Value *C2 = nullptr, *A2 = nullptr, *B2 = nullptr;
+ if (match(CondVal, m_LogicalAnd(m_Not(m_Value(C2)), m_Value(A2))) &&
+ SelCond) {
+ SelectInst *NewSI = SelectInst::Create(C, B, A, "", nullptr, SelCond);
+ NewSI->swapProfMetadata();
+ return NewSI;
+ } else if (match(FalseVal, m_LogicalAnd(m_Specific(C), m_Value(B2))) &&
+ SelFVal) {
+ return SelectInst::Create(C, B, A, "", nullptr, SelFVal);
+ } else {
+ return createSelectInstWithUnknownProfile(C, B, A);
+ }
+ }
return SelectInst::Create(C, B, A);
}
}
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
index 18a45c6..98e2d9e 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
@@ -140,8 +140,8 @@ Instruction *InstCombinerImpl::scalarizePHI(ExtractElementInst &EI,
Value *Elt = EI.getIndexOperand();
// If the operand is the PHI induction variable:
if (PHIInVal == PHIUser) {
- // Scalarize the binary operation. Its first operand is the
- // scalar PHI, and the second operand is extracted from the other
+ // Scalarize the binary operation. One operand is the
+ // scalar PHI, and the other is extracted from the other
// vector operand.
BinaryOperator *B0 = cast<BinaryOperator>(PHIUser);
unsigned opId = (B0->getOperand(0) == PN) ? 1 : 0;
@@ -149,9 +149,14 @@ Instruction *InstCombinerImpl::scalarizePHI(ExtractElementInst &EI,
ExtractElementInst::Create(B0->getOperand(opId), Elt,
B0->getOperand(opId)->getName() + ".Elt"),
B0->getIterator());
- Value *newPHIUser = InsertNewInstWith(
- BinaryOperator::CreateWithCopiedFlags(B0->getOpcode(),
- scalarPHI, Op, B0), B0->getIterator());
+ // Preserve operand order for binary operation to preserve semantics of
+ // non-commutative operations.
+ Value *FirstOp = (B0->getOperand(0) == PN) ? scalarPHI : Op;
+ Value *SecondOp = (B0->getOperand(0) == PN) ? Op : scalarPHI;
+ Value *newPHIUser =
+ InsertNewInstWith(BinaryOperator::CreateWithCopiedFlags(
+ B0->getOpcode(), FirstOp, SecondOp, B0),
+ B0->getIterator());
scalarPHI->addIncoming(newPHIUser, inBB);
} else {
// Scalarize PHI input:
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index 67f837c..c6de57c 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -1758,6 +1758,9 @@ static Value *simplifyOperationIntoSelectOperand(Instruction &I, SelectInst *SI,
m_Specific(Op), m_Value(V))) &&
isGuaranteedNotToBeUndefOrPoison(V)) {
// Pass
+ } else if (match(Op, m_ZExt(m_Specific(SI->getCondition())))) {
+ V = IsTrueArm ? ConstantInt::get(Op->getType(), 1)
+ : ConstantInt::getNullValue(Op->getType());
} else {
V = Op;
}
@@ -2261,11 +2264,11 @@ Instruction *InstCombinerImpl::foldBinopWithPhiOperands(BinaryOperator &BO) {
}
Instruction *InstCombinerImpl::foldBinOpIntoSelectOrPhi(BinaryOperator &I) {
- if (!isa<Constant>(I.getOperand(1)))
- return nullptr;
+ bool IsOtherParamConst = isa<Constant>(I.getOperand(1));
if (auto *Sel = dyn_cast<SelectInst>(I.getOperand(0))) {
- if (Instruction *NewSel = FoldOpIntoSelect(I, Sel))
+ if (Instruction *NewSel =
+ FoldOpIntoSelect(I, Sel, false, !IsOtherParamConst))
return NewSel;
} else if (auto *PN = dyn_cast<PHINode>(I.getOperand(0))) {
if (Instruction *NewPhi = foldOpIntoPhi(I, PN))
@@ -3370,9 +3373,9 @@ Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) {
DL.getAddressSizeInBits(AS) != DL.getPointerSizeInBits(AS);
bool Changed = false;
GEP.replaceUsesWithIf(Y, [&](Use &U) {
- bool ShouldReplace = isa<PtrToAddrInst>(U.getUser()) ||
- (!HasNonAddressBits &&
- isa<ICmpInst, PtrToIntInst>(U.getUser()));
+ bool ShouldReplace =
+ isa<PtrToAddrInst, ICmpInst>(U.getUser()) ||
+ (!HasNonAddressBits && isa<PtrToIntInst>(U.getUser()));
Changed |= ShouldReplace;
return ShouldReplace;
});
@@ -5624,8 +5627,15 @@ bool InstCombinerImpl::run() {
for (Use &U : I->uses()) {
User *User = U.getUser();
- if (User->isDroppable())
- continue;
+ if (User->isDroppable()) {
+ // Do not sink if there are dereferenceable assumes that would be
+ // removed.
+ auto II = dyn_cast<IntrinsicInst>(User);
+ if (II->getIntrinsicID() != Intrinsic::assume ||
+ !II->getOperandBundle("dereferenceable"))
+ continue;
+ }
+
if (NumUsers > MaxSinkNumUsers)
return std::nullopt;