diff options
Diffstat (limited to 'llvm/lib/Transforms/InstCombine')
9 files changed, 781 insertions, 165 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index 3ddf182..ba5568b 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -3997,6 +3997,27 @@ static Value *foldOrUnsignedUMulOverflowICmp(BinaryOperator &I, return nullptr; } +/// Fold select(X >s 0, 0, -X) | smax(X, 0) --> abs(X) +/// select(X <s 0, -X, 0) | smax(X, 0) --> abs(X) +static Value *FoldOrOfSelectSmaxToAbs(BinaryOperator &I, + InstCombiner::BuilderTy &Builder) { + Value *X; + Value *Sel; + if (match(&I, + m_c_Or(m_Value(Sel), m_OneUse(m_SMax(m_Value(X), m_ZeroInt()))))) { + auto NegX = m_Neg(m_Specific(X)); + if (match(Sel, m_Select(m_SpecificICmp(ICmpInst::ICMP_SGT, m_Specific(X), + m_ZeroInt()), + m_ZeroInt(), NegX)) || + match(Sel, m_Select(m_SpecificICmp(ICmpInst::ICMP_SLT, m_Specific(X), + m_ZeroInt()), + NegX, m_ZeroInt()))) + return Builder.CreateBinaryIntrinsic(Intrinsic::abs, X, + Builder.getFalse()); + } + return nullptr; +} + // FIXME: We use commutative matchers (m_c_*) for some, but not all, matches // here. We should standardize that construct where it is needed or choose some // other way to ensure that commutated variants of patterns are not missed. @@ -4545,6 +4566,9 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { if (Value *V = SimplifyAddWithRemainder(I)) return replaceInstUsesWith(I, V); + if (Value *Res = FoldOrOfSelectSmaxToAbs(I, Builder)) + return replaceInstUsesWith(I, Res); + return nullptr; } @@ -5072,9 +5096,17 @@ Instruction *InstCombinerImpl::foldNot(BinaryOperator &I) { return &I; } + // not (bitcast (cmp A, B) --> bitcast (!cmp A, B) + if (match(NotOp, m_OneUse(m_BitCast(m_Value(X)))) && + match(X, m_OneUse(m_Cmp(Pred, m_Value(), m_Value())))) { + cast<CmpInst>(X)->setPredicate(CmpInst::getInversePredicate(Pred)); + return new BitCastInst(X, Ty); + } + // Move a 'not' ahead of casts of a bool to enable logic reduction: // not (bitcast (sext i1 X)) --> bitcast (sext (not i1 X)) - if (match(NotOp, m_OneUse(m_BitCast(m_OneUse(m_SExt(m_Value(X)))))) && X->getType()->isIntOrIntVectorTy(1)) { + if (match(NotOp, m_OneUse(m_BitCast(m_OneUse(m_SExt(m_Value(X)))))) && + X->getType()->isIntOrIntVectorTy(1)) { Type *SextTy = cast<BitCastOperator>(NotOp)->getSrcTy(); Value *NotX = Builder.CreateNot(X); Value *Sext = Builder.CreateSExt(NotX, SextTy); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp index 92fca90..85602a5 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -19,6 +19,7 @@ #include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/AssumeBundleQueries.h" #include "llvm/Analysis/AssumptionCache.h" @@ -736,42 +737,119 @@ static Instruction *foldCtpop(IntrinsicInst &II, InstCombinerImpl &IC) { return nullptr; } -/// Convert a table lookup to shufflevector if the mask is constant. -/// This could benefit tbl1 if the mask is { 7,6,5,4,3,2,1,0 }, in -/// which case we could lower the shufflevector with rev64 instructions -/// as it's actually a byte reverse. -static Value *simplifyNeonTbl1(const IntrinsicInst &II, - InstCombiner::BuilderTy &Builder) { +/// Convert `tbl`/`tbx` intrinsics to shufflevector if the mask is constant, and +/// at most two source operands are actually referenced. +static Instruction *simplifyNeonTbl(IntrinsicInst &II, InstCombiner &IC, + bool IsExtension) { // Bail out if the mask is not a constant. - auto *C = dyn_cast<Constant>(II.getArgOperand(1)); + auto *C = dyn_cast<Constant>(II.getArgOperand(II.arg_size() - 1)); if (!C) return nullptr; - auto *VecTy = cast<FixedVectorType>(II.getType()); - unsigned NumElts = VecTy->getNumElements(); + auto *RetTy = cast<FixedVectorType>(II.getType()); + unsigned NumIndexes = RetTy->getNumElements(); - // Only perform this transformation for <8 x i8> vector types. - if (!VecTy->getElementType()->isIntegerTy(8) || NumElts != 8) + // Only perform this transformation for <8 x i8> and <16 x i8> vector types. + if (!RetTy->getElementType()->isIntegerTy(8) || + (NumIndexes != 8 && NumIndexes != 16)) return nullptr; - int Indexes[8]; + // For tbx instructions, the first argument is the "fallback" vector, which + // has the same length as the mask and return type. + unsigned int StartIndex = (unsigned)IsExtension; + auto *SourceTy = + cast<FixedVectorType>(II.getArgOperand(StartIndex)->getType()); + // Note that the element count of each source vector does *not* need to be the + // same as the element count of the return type and mask! All source vectors + // must have the same element count as each other, though. + unsigned NumElementsPerSource = SourceTy->getNumElements(); + + // There are no tbl/tbx intrinsics for which the destination size exceeds the + // source size. However, our definitions of the intrinsics, at least in + // IntrinsicsAArch64.td, allow for arbitrary destination vector sizes, so it + // *could* technically happen. + if (NumIndexes > NumElementsPerSource) + return nullptr; + + // The tbl/tbx intrinsics take several source operands followed by a mask + // operand. + unsigned int NumSourceOperands = II.arg_size() - 1 - (unsigned)IsExtension; + + // Map input operands to shuffle indices. This also helpfully deduplicates the + // input arguments, in case the same value is passed as an argument multiple + // times. + SmallDenseMap<Value *, unsigned, 2> ValueToShuffleSlot; + Value *ShuffleOperands[2] = {PoisonValue::get(SourceTy), + PoisonValue::get(SourceTy)}; - for (unsigned I = 0; I < NumElts; ++I) { + int Indexes[16]; + for (unsigned I = 0; I < NumIndexes; ++I) { Constant *COp = C->getAggregateElement(I); - if (!COp || !isa<ConstantInt>(COp)) + if (!COp || (!isa<UndefValue>(COp) && !isa<ConstantInt>(COp))) return nullptr; - Indexes[I] = cast<ConstantInt>(COp)->getLimitedValue(); + if (isa<UndefValue>(COp)) { + Indexes[I] = -1; + continue; + } - // Make sure the mask indices are in range. - if ((unsigned)Indexes[I] >= NumElts) + uint64_t Index = cast<ConstantInt>(COp)->getZExtValue(); + // The index of the input argument that this index references (0 = first + // source argument, etc). + unsigned SourceOperandIndex = Index / NumElementsPerSource; + // The index of the element at that source operand. + unsigned SourceOperandElementIndex = Index % NumElementsPerSource; + + Value *SourceOperand; + if (SourceOperandIndex >= NumSourceOperands) { + // This index is out of bounds. Map it to index into either the fallback + // vector (tbx) or vector of zeroes (tbl). + SourceOperandIndex = NumSourceOperands; + if (IsExtension) { + // For out-of-bounds indices in tbx, choose the `I`th element of the + // fallback. + SourceOperand = II.getArgOperand(0); + SourceOperandElementIndex = I; + } else { + // Otherwise, choose some element from the dummy vector of zeroes (we'll + // always choose the first). + SourceOperand = Constant::getNullValue(SourceTy); + SourceOperandElementIndex = 0; + } + } else { + SourceOperand = II.getArgOperand(SourceOperandIndex + StartIndex); + } + + // The source operand may be the fallback vector, which may not have the + // same number of elements as the source vector. In that case, we *could* + // choose to extend its length with another shufflevector, but it's simpler + // to just bail instead. + if (cast<FixedVectorType>(SourceOperand->getType())->getNumElements() != + NumElementsPerSource) return nullptr; + + // We now know the source operand referenced by this index. Make it a + // shufflevector operand, if it isn't already. + unsigned NumSlots = ValueToShuffleSlot.size(); + // This shuffle references more than two sources, and hence cannot be + // represented as a shufflevector. + if (NumSlots == 2 && !ValueToShuffleSlot.contains(SourceOperand)) + return nullptr; + + auto [It, Inserted] = + ValueToShuffleSlot.try_emplace(SourceOperand, NumSlots); + if (Inserted) + ShuffleOperands[It->getSecond()] = SourceOperand; + + unsigned RemappedIndex = + (It->getSecond() * NumElementsPerSource) + SourceOperandElementIndex; + Indexes[I] = RemappedIndex; } - auto *V1 = II.getArgOperand(0); - auto *V2 = Constant::getNullValue(V1->getType()); - return Builder.CreateShuffleVector(V1, V2, ArrayRef(Indexes)); + Value *Shuf = IC.Builder.CreateShuffleVector( + ShuffleOperands[0], ShuffleOperands[1], ArrayRef(Indexes, NumIndexes)); + return IC.replaceInstUsesWith(II, Shuf); } // Returns true iff the 2 intrinsics have the same operands, limiting the @@ -3076,6 +3154,11 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { } case Intrinsic::ptrauth_auth: case Intrinsic::ptrauth_resign: { + // We don't support this optimization on intrinsic calls with deactivation + // symbols, which are represented using operand bundles. + if (II->hasOperandBundles()) + break; + // (sign|resign) + (auth|resign) can be folded by omitting the middle // sign+auth component if the key and discriminator match. bool NeedSign = II->getIntrinsicID() == Intrinsic::ptrauth_resign; @@ -3087,6 +3170,11 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { // whatever we replace this sequence with. Value *AuthKey = nullptr, *AuthDisc = nullptr, *BasePtr; if (const auto *CI = dyn_cast<CallBase>(Ptr)) { + // We don't support this optimization on intrinsic calls with deactivation + // symbols, which are represented using operand bundles. + if (CI->hasOperandBundles()) + break; + BasePtr = CI->getArgOperand(0); if (CI->getIntrinsicID() == Intrinsic::ptrauth_sign) { if (CI->getArgOperand(1) != Key || CI->getArgOperand(2) != Disc) @@ -3109,9 +3197,10 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { if (NeedSign && isa<ConstantInt>(II->getArgOperand(4))) { auto *SignKey = cast<ConstantInt>(II->getArgOperand(3)); auto *SignDisc = cast<ConstantInt>(II->getArgOperand(4)); - auto *SignAddrDisc = ConstantPointerNull::get(Builder.getPtrTy()); + auto *Null = ConstantPointerNull::get(Builder.getPtrTy()); auto *NewCPA = ConstantPtrAuth::get(CPA->getPointer(), SignKey, - SignDisc, SignAddrDisc); + SignDisc, /*AddrDisc=*/Null, + /*DeactivationSymbol=*/Null); replaceInstUsesWith( *II, ConstantExpr::getPointerCast(NewCPA, II->getType())); return eraseInstFromFunction(*II); @@ -3155,10 +3244,23 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { return CallInst::Create(NewFn, CallArgs); } case Intrinsic::arm_neon_vtbl1: + case Intrinsic::arm_neon_vtbl2: + case Intrinsic::arm_neon_vtbl3: + case Intrinsic::arm_neon_vtbl4: case Intrinsic::aarch64_neon_tbl1: - if (Value *V = simplifyNeonTbl1(*II, Builder)) - return replaceInstUsesWith(*II, V); - break; + case Intrinsic::aarch64_neon_tbl2: + case Intrinsic::aarch64_neon_tbl3: + case Intrinsic::aarch64_neon_tbl4: + return simplifyNeonTbl(*II, *this, /*IsExtension=*/false); + case Intrinsic::arm_neon_vtbx1: + case Intrinsic::arm_neon_vtbx2: + case Intrinsic::arm_neon_vtbx3: + case Intrinsic::arm_neon_vtbx4: + case Intrinsic::aarch64_neon_tbx1: + case Intrinsic::aarch64_neon_tbx2: + case Intrinsic::aarch64_neon_tbx3: + case Intrinsic::aarch64_neon_tbx4: + return simplifyNeonTbl(*II, *this, /*IsExtension=*/true); case Intrinsic::arm_neon_vmulls: case Intrinsic::arm_neon_vmullu: @@ -3799,7 +3901,9 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { if (VecToReduceCount.isFixed()) { unsigned VectorSize = VecToReduceCount.getFixedValue(); return BinaryOperator::CreateMul( - Splat, ConstantInt::get(Splat->getType(), VectorSize)); + Splat, + ConstantInt::get(Splat->getType(), VectorSize, /*IsSigned=*/false, + /*ImplicitTrunc=*/true)); } } } @@ -4004,6 +4108,27 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { } break; } + case Intrinsic::experimental_get_vector_length: { + // get.vector.length(Cnt, MaxLanes) --> Cnt when Cnt <= MaxLanes + unsigned BitWidth = + std::max(II->getArgOperand(0)->getType()->getScalarSizeInBits(), + II->getType()->getScalarSizeInBits()); + ConstantRange Cnt = + computeConstantRangeIncludingKnownBits(II->getArgOperand(0), false, + SQ.getWithInstruction(II)) + .zextOrTrunc(BitWidth); + ConstantRange MaxLanes = cast<ConstantInt>(II->getArgOperand(1)) + ->getValue() + .zextOrTrunc(Cnt.getBitWidth()); + if (cast<ConstantInt>(II->getArgOperand(2))->isOne()) + MaxLanes = MaxLanes.multiply( + getVScaleRange(II->getFunction(), Cnt.getBitWidth())); + + if (Cnt.icmp(CmpInst::ICMP_ULE, MaxLanes)) + return replaceInstUsesWith( + *II, Builder.CreateZExtOrTrunc(II->getArgOperand(0), II->getType())); + return nullptr; + } default: { // Handle target specific intrinsics std::optional<Instruction *> V = targetInstCombineIntrinsic(*II); @@ -4091,6 +4216,70 @@ Instruction *InstCombinerImpl::visitCallBrInst(CallBrInst &CBI) { return visitCallBase(CBI); } +static Value *optimizeModularFormat(CallInst *CI, IRBuilderBase &B) { + if (!CI->hasFnAttr("modular-format")) + return nullptr; + + SmallVector<StringRef> Args( + llvm::split(CI->getFnAttr("modular-format").getValueAsString(), ',')); + // TODO: Make use of the first two arguments + unsigned FirstArgIdx; + [[maybe_unused]] bool Error; + Error = Args[2].getAsInteger(10, FirstArgIdx); + assert(!Error && "invalid first arg index"); + --FirstArgIdx; + StringRef FnName = Args[3]; + StringRef ImplName = Args[4]; + ArrayRef<StringRef> AllAspects = ArrayRef<StringRef>(Args).drop_front(5); + + if (AllAspects.empty()) + return nullptr; + + SmallVector<StringRef> NeededAspects; + for (StringRef Aspect : AllAspects) { + if (Aspect == "float") { + if (llvm::any_of( + llvm::make_range(std::next(CI->arg_begin(), FirstArgIdx), + CI->arg_end()), + [](Value *V) { return V->getType()->isFloatingPointTy(); })) + NeededAspects.push_back("float"); + } else { + // Unknown aspects are always considered to be needed. + NeededAspects.push_back(Aspect); + } + } + + if (NeededAspects.size() == AllAspects.size()) + return nullptr; + + Module *M = CI->getModule(); + LLVMContext &Ctx = M->getContext(); + Function *Callee = CI->getCalledFunction(); + FunctionCallee ModularFn = M->getOrInsertFunction( + FnName, Callee->getFunctionType(), + Callee->getAttributes().removeFnAttribute(Ctx, "modular-format")); + CallInst *New = cast<CallInst>(CI->clone()); + New->setCalledFunction(ModularFn); + New->removeFnAttr("modular-format"); + B.Insert(New); + + const auto ReferenceAspect = [&](StringRef Aspect) { + SmallString<20> Name = ImplName; + Name += '_'; + Name += Aspect; + Function *RelocNoneFn = + Intrinsic::getOrInsertDeclaration(M, Intrinsic::reloc_none); + B.CreateCall(RelocNoneFn, + {MetadataAsValue::get(Ctx, MDString::get(Ctx, Name))}); + }; + + llvm::sort(NeededAspects); + for (StringRef Request : NeededAspects) + ReferenceAspect(Request); + + return New; +} + Instruction *InstCombinerImpl::tryOptimizeCall(CallInst *CI) { if (!CI->getCalledFunction()) return nullptr; @@ -4112,6 +4301,10 @@ Instruction *InstCombinerImpl::tryOptimizeCall(CallInst *CI) { ++NumSimplified; return CI->use_empty() ? CI : replaceInstUsesWith(*CI, With); } + if (Value *With = optimizeModularFormat(CI, Builder)) { + ++NumSimplified; + return CI->use_empty() ? CI : replaceInstUsesWith(*CI, With); + } return nullptr; } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp index 614c6eb..0cd2c09 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -12,14 +12,21 @@ #include "InstCombineInternal.h" #include "llvm/ADT/APInt.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/STLFunctionalExtras.h" #include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/DebugInfo.h" +#include "llvm/IR/Instruction.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Type.h" #include "llvm/IR/Value.h" #include "llvm/Support/KnownBits.h" #include "llvm/Transforms/InstCombine/InstCombiner.h" +#include <iterator> #include <optional> using namespace llvm; @@ -27,12 +34,19 @@ using namespace PatternMatch; #define DEBUG_TYPE "instcombine" -/// Given an expression that CanEvaluateTruncated or CanEvaluateSExtd returns -/// true for, actually insert the code to evaluate the expression. -Value *InstCombinerImpl::EvaluateInDifferentType(Value *V, Type *Ty, - bool isSigned) { +using EvaluatedMap = SmallDenseMap<Value *, Value *, 8>; + +static Value *EvaluateInDifferentTypeImpl(Value *V, Type *Ty, bool isSigned, + InstCombinerImpl &IC, + EvaluatedMap &Processed) { + // Since we cover transformation of instructions with multiple users, we might + // come to the same node via multiple paths. We should not create a + // replacement for every single one of them though. + if (Value *Result = Processed.lookup(V)) + return Result; + if (Constant *C = dyn_cast<Constant>(V)) - return ConstantFoldIntegerCast(C, Ty, isSigned, DL); + return ConstantFoldIntegerCast(C, Ty, isSigned, IC.getDataLayout()); // Otherwise, it must be an instruction. Instruction *I = cast<Instruction>(V); @@ -50,8 +64,10 @@ Value *InstCombinerImpl::EvaluateInDifferentType(Value *V, Type *Ty, case Instruction::Shl: case Instruction::UDiv: case Instruction::URem: { - Value *LHS = EvaluateInDifferentType(I->getOperand(0), Ty, isSigned); - Value *RHS = EvaluateInDifferentType(I->getOperand(1), Ty, isSigned); + Value *LHS = EvaluateInDifferentTypeImpl(I->getOperand(0), Ty, isSigned, IC, + Processed); + Value *RHS = EvaluateInDifferentTypeImpl(I->getOperand(1), Ty, isSigned, IC, + Processed); Res = BinaryOperator::Create((Instruction::BinaryOps)Opc, LHS, RHS); if (Opc == Instruction::LShr || Opc == Instruction::AShr) Res->setIsExact(I->isExact()); @@ -72,8 +88,10 @@ Value *InstCombinerImpl::EvaluateInDifferentType(Value *V, Type *Ty, Opc == Instruction::SExt); break; case Instruction::Select: { - Value *True = EvaluateInDifferentType(I->getOperand(1), Ty, isSigned); - Value *False = EvaluateInDifferentType(I->getOperand(2), Ty, isSigned); + Value *True = EvaluateInDifferentTypeImpl(I->getOperand(1), Ty, isSigned, + IC, Processed); + Value *False = EvaluateInDifferentTypeImpl(I->getOperand(2), Ty, isSigned, + IC, Processed); Res = SelectInst::Create(I->getOperand(0), True, False); break; } @@ -81,8 +99,8 @@ Value *InstCombinerImpl::EvaluateInDifferentType(Value *V, Type *Ty, PHINode *OPN = cast<PHINode>(I); PHINode *NPN = PHINode::Create(Ty, OPN->getNumIncomingValues()); for (unsigned i = 0, e = OPN->getNumIncomingValues(); i != e; ++i) { - Value *V = - EvaluateInDifferentType(OPN->getIncomingValue(i), Ty, isSigned); + Value *V = EvaluateInDifferentTypeImpl(OPN->getIncomingValue(i), Ty, + isSigned, IC, Processed); NPN->addIncoming(V, OPN->getIncomingBlock(i)); } Res = NPN; @@ -90,8 +108,8 @@ Value *InstCombinerImpl::EvaluateInDifferentType(Value *V, Type *Ty, } case Instruction::FPToUI: case Instruction::FPToSI: - Res = CastInst::Create( - static_cast<Instruction::CastOps>(Opc), I->getOperand(0), Ty); + Res = CastInst::Create(static_cast<Instruction::CastOps>(Opc), + I->getOperand(0), Ty); break; case Instruction::Call: if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) { @@ -111,8 +129,10 @@ Value *InstCombinerImpl::EvaluateInDifferentType(Value *V, Type *Ty, auto *ScalarTy = cast<VectorType>(Ty)->getElementType(); auto *VTy = cast<VectorType>(I->getOperand(0)->getType()); auto *FixedTy = VectorType::get(ScalarTy, VTy->getElementCount()); - Value *Op0 = EvaluateInDifferentType(I->getOperand(0), FixedTy, isSigned); - Value *Op1 = EvaluateInDifferentType(I->getOperand(1), FixedTy, isSigned); + Value *Op0 = EvaluateInDifferentTypeImpl(I->getOperand(0), FixedTy, + isSigned, IC, Processed); + Value *Op1 = EvaluateInDifferentTypeImpl(I->getOperand(1), FixedTy, + isSigned, IC, Processed); Res = new ShuffleVectorInst(Op0, Op1, cast<ShuffleVectorInst>(I)->getShuffleMask()); break; @@ -123,7 +143,22 @@ Value *InstCombinerImpl::EvaluateInDifferentType(Value *V, Type *Ty, } Res->takeName(I); - return InsertNewInstWith(Res, I->getIterator()); + Value *Result = IC.InsertNewInstWith(Res, I->getIterator()); + // There is no need in keeping track of the old value/new value relationship + // when we have only one user, we came have here from that user and no-one + // else cares. + if (!V->hasOneUse()) + Processed[V] = Result; + + return Result; +} + +/// Given an expression that CanEvaluateTruncated or CanEvaluateSExtd returns +/// true for, actually insert the code to evaluate the expression. +Value *InstCombinerImpl::EvaluateInDifferentType(Value *V, Type *Ty, + bool isSigned) { + EvaluatedMap Processed; + return EvaluateInDifferentTypeImpl(V, Ty, isSigned, *this, Processed); } Instruction::CastOps @@ -227,9 +262,174 @@ Instruction *InstCombinerImpl::commonCastTransforms(CastInst &CI) { return nullptr; } +namespace { + +/// Helper class for evaluating whether a value can be computed in a different +/// type without changing its value. Used by cast simplification transforms. +class TypeEvaluationHelper { +public: + /// Return true if we can evaluate the specified expression tree as type Ty + /// instead of its larger type, and arrive with the same value. + /// This is used by code that tries to eliminate truncates. + [[nodiscard]] static bool canEvaluateTruncated(Value *V, Type *Ty, + InstCombinerImpl &IC, + Instruction *CxtI); + + /// Determine if the specified value can be computed in the specified wider + /// type and produce the same low bits. If not, return false. + [[nodiscard]] static bool canEvaluateZExtd(Value *V, Type *Ty, + unsigned &BitsToClear, + InstCombinerImpl &IC, + Instruction *CxtI); + + /// Return true if we can take the specified value and return it as type Ty + /// without inserting any new casts and without changing the value of the + /// common low bits. + [[nodiscard]] static bool canEvaluateSExtd(Value *V, Type *Ty); + +private: + /// Constants and extensions/truncates from the destination type are always + /// free to be evaluated in that type. + [[nodiscard]] static bool canAlwaysEvaluateInType(Value *V, Type *Ty); + + /// Check if we traversed all the users of the multi-use values we've seen. + [[nodiscard]] bool allPendingVisited() const { + return llvm::all_of(Pending, + [this](Value *V) { return Visited.contains(V); }); + } + + /// A generic wrapper for canEvaluate* recursions to inject visitation + /// tracking and enforce correct multi-use value evaluations. + [[nodiscard]] bool + canEvaluate(Value *V, Type *Ty, + llvm::function_ref<bool(Value *, Type *Type)> Pred) { + if (canAlwaysEvaluateInType(V, Ty)) + return true; + + auto *I = dyn_cast<Instruction>(V); + + if (I == nullptr) + return false; + + // We insert false by default to return false when we encounter user loops. + const auto [It, Inserted] = Visited.insert({V, false}); + + // There are three possible cases for us having information on this value + // in the Visited map: + // 1. We properly checked it and concluded that we can evaluate it (true) + // 2. We properly checked it and concluded that we can't (false) + // 3. We started to check it, but during the recursive traversal we came + // back to it. + // + // For cases 1 and 2, we can safely return the stored result. For case 3, we + // can potentially have a situation where we can evaluate recursive user + // chains, but that can be quite tricky to do properly and isntead, we + // return false. + // + // In any case, we should return whatever was there in the map to begin + // with. + if (!Inserted) + return It->getSecond(); + + // We can easily make a decision about single-user values whether they can + // be evaluated in a different type or not, we came from that user. This is + // not as simple for multi-user values. + // + // In general, we have the following case (inverted control-flow, users are + // at the top): + // + // Cast %A + // ____| + // / + // %A = Use %B, %C + // ________| | + // / | + // %B = Use %D | + // ________| | + // / | + // %D = Use %C | + // ________|___| + // / + // %C = ... + // + // In this case, when we check %A, %B and %D, we are confident that we can + // make the decision here and now, since we came from their only users. + // + // For %C, it is harder. We come there twice, and when we come the first + // time, it's hard to tell if we will visit the second user (technically + // it's not hard, but we might need a lot of repetitive checks with non-zero + // cost). + // + // In the case above, we are allowed to evaluate %C in different type + // because all of it users were part of the traversal. + // + // In the following case, however, we can't make this conclusion: + // + // Cast %A + // ____| + // / + // %A = Use %B, %C + // ________| | + // / | + // %B = Use %D | + // ________| | + // / | + // %D = Use %C | + // | | + // foo(%C) | | <- never traversing foo(%C) + // ________|___| + // / + // %C = ... + // + // In this case, we still can evaluate %C in a different type, but we'd need + // to create a copy of the original %C to be used in foo(%C). Such + // duplication might be not profitable. + // + // For this reason, we collect all users of the mult-user values and mark + // them as "pending" and defer this decision to the very end. When we are + // done and and ready to have a positive verdict, we should double-check all + // of the pending users and ensure that we visited them. allPendingVisited + // predicate checks exactly that. + if (!I->hasOneUse()) + llvm::append_range(Pending, I->users()); + + const bool Result = Pred(V, Ty); + // We have to set result this way and not via It because Pred is recursive + // and it is very likely that we grew Visited and invalidated It. + Visited[V] = Result; + return Result; + } + + /// Filter out values that we can not evaluate in the destination type for + /// free. + [[nodiscard]] bool canNotEvaluateInType(Value *V, Type *Ty); + + [[nodiscard]] bool canEvaluateTruncatedImpl(Value *V, Type *Ty, + InstCombinerImpl &IC, + Instruction *CxtI); + [[nodiscard]] bool canEvaluateTruncatedPred(Value *V, Type *Ty, + InstCombinerImpl &IC, + Instruction *CxtI); + [[nodiscard]] bool canEvaluateZExtdImpl(Value *V, Type *Ty, + unsigned &BitsToClear, + InstCombinerImpl &IC, + Instruction *CxtI); + [[nodiscard]] bool canEvaluateSExtdImpl(Value *V, Type *Ty); + [[nodiscard]] bool canEvaluateSExtdPred(Value *V, Type *Ty); + + /// A bookkeeping map to memorize an already made decision for a traversed + /// value. + SmallDenseMap<Value *, bool, 8> Visited; + + /// A list of pending values to check in the end. + SmallVector<Value *, 8> Pending; +}; + +} // anonymous namespace + /// Constants and extensions/truncates from the destination type are always /// free to be evaluated in that type. This is a helper for canEvaluate*. -static bool canAlwaysEvaluateInType(Value *V, Type *Ty) { +bool TypeEvaluationHelper::canAlwaysEvaluateInType(Value *V, Type *Ty) { if (isa<Constant>(V)) return match(V, m_ImmConstant()); @@ -243,7 +443,7 @@ static bool canAlwaysEvaluateInType(Value *V, Type *Ty) { /// Filter out values that we can not evaluate in the destination type for free. /// This is a helper for canEvaluate*. -static bool canNotEvaluateInType(Value *V, Type *Ty) { +bool TypeEvaluationHelper::canNotEvaluateInType(Value *V, Type *Ty) { if (!isa<Instruction>(V)) return true; // We don't extend or shrink something that has multiple uses -- doing so @@ -265,13 +465,27 @@ static bool canNotEvaluateInType(Value *V, Type *Ty) { /// /// This function works on both vectors and scalars. /// -static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombinerImpl &IC, - Instruction *CxtI) { - if (canAlwaysEvaluateInType(V, Ty)) - return true; - if (canNotEvaluateInType(V, Ty)) - return false; +bool TypeEvaluationHelper::canEvaluateTruncated(Value *V, Type *Ty, + InstCombinerImpl &IC, + Instruction *CxtI) { + TypeEvaluationHelper TYH; + return TYH.canEvaluateTruncatedImpl(V, Ty, IC, CxtI) && + // We need to check whether we visited all users of multi-user values, + // and we have to do it at the very end, outside of the recursion. + TYH.allPendingVisited(); +} +bool TypeEvaluationHelper::canEvaluateTruncatedImpl(Value *V, Type *Ty, + InstCombinerImpl &IC, + Instruction *CxtI) { + return canEvaluate(V, Ty, [this, &IC, CxtI](Value *V, Type *Ty) { + return canEvaluateTruncatedPred(V, Ty, IC, CxtI); + }); +} + +bool TypeEvaluationHelper::canEvaluateTruncatedPred(Value *V, Type *Ty, + InstCombinerImpl &IC, + Instruction *CxtI) { auto *I = cast<Instruction>(V); Type *OrigTy = V->getType(); switch (I->getOpcode()) { @@ -282,8 +496,8 @@ static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombinerImpl &IC, case Instruction::Or: case Instruction::Xor: // These operators can all arbitrarily be extended or truncated. - return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) && - canEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI); + return canEvaluateTruncatedImpl(I->getOperand(0), Ty, IC, CxtI) && + canEvaluateTruncatedImpl(I->getOperand(1), Ty, IC, CxtI); case Instruction::UDiv: case Instruction::URem: { @@ -296,8 +510,8 @@ static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombinerImpl &IC, // based on later context may introduce a trap. if (IC.MaskedValueIsZero(I->getOperand(0), Mask, I) && IC.MaskedValueIsZero(I->getOperand(1), Mask, I)) { - return canEvaluateTruncated(I->getOperand(0), Ty, IC, I) && - canEvaluateTruncated(I->getOperand(1), Ty, IC, I); + return canEvaluateTruncatedImpl(I->getOperand(0), Ty, IC, CxtI) && + canEvaluateTruncatedImpl(I->getOperand(1), Ty, IC, CxtI); } break; } @@ -308,8 +522,8 @@ static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombinerImpl &IC, KnownBits AmtKnownBits = llvm::computeKnownBits(I->getOperand(1), IC.getDataLayout()); if (AmtKnownBits.getMaxValue().ult(BitWidth)) - return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) && - canEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI); + return canEvaluateTruncatedImpl(I->getOperand(0), Ty, IC, CxtI) && + canEvaluateTruncatedImpl(I->getOperand(1), Ty, IC, CxtI); break; } case Instruction::LShr: { @@ -329,12 +543,12 @@ static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombinerImpl &IC, if (auto *Trunc = dyn_cast<TruncInst>(V->user_back())) { auto DemandedBits = Trunc->getType()->getScalarSizeInBits(); if ((MaxShiftAmt + DemandedBits).ule(BitWidth)) - return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) && - canEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI); + return canEvaluateTruncatedImpl(I->getOperand(0), Ty, IC, CxtI) && + canEvaluateTruncatedImpl(I->getOperand(1), Ty, IC, CxtI); } if (IC.MaskedValueIsZero(I->getOperand(0), ShiftedBits, CxtI)) - return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) && - canEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI); + return canEvaluateTruncatedImpl(I->getOperand(0), Ty, IC, CxtI) && + canEvaluateTruncatedImpl(I->getOperand(1), Ty, IC, CxtI); } break; } @@ -351,8 +565,8 @@ static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombinerImpl &IC, unsigned ShiftedBits = OrigBitWidth - BitWidth; if (AmtKnownBits.getMaxValue().ult(BitWidth) && ShiftedBits < IC.ComputeNumSignBits(I->getOperand(0), CxtI)) - return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) && - canEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI); + return canEvaluateTruncatedImpl(I->getOperand(0), Ty, IC, CxtI) && + canEvaluateTruncatedImpl(I->getOperand(1), Ty, IC, CxtI); break; } case Instruction::Trunc: @@ -365,18 +579,18 @@ static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombinerImpl &IC, return true; case Instruction::Select: { SelectInst *SI = cast<SelectInst>(I); - return canEvaluateTruncated(SI->getTrueValue(), Ty, IC, CxtI) && - canEvaluateTruncated(SI->getFalseValue(), Ty, IC, CxtI); + return canEvaluateTruncatedImpl(SI->getTrueValue(), Ty, IC, CxtI) && + canEvaluateTruncatedImpl(SI->getFalseValue(), Ty, IC, CxtI); } case Instruction::PHI: { // We can change a phi if we can change all operands. Note that we never - // get into trouble with cyclic PHIs here because we only consider - // instructions with a single use. + // get into trouble with cyclic PHIs here because canEvaluate handles use + // chain loops. PHINode *PN = cast<PHINode>(I); - for (Value *IncValue : PN->incoming_values()) - if (!canEvaluateTruncated(IncValue, Ty, IC, CxtI)) - return false; - return true; + return llvm::all_of( + PN->incoming_values(), [this, Ty, &IC, CxtI](Value *IncValue) { + return canEvaluateTruncatedImpl(IncValue, Ty, IC, CxtI); + }); } case Instruction::FPToUI: case Instruction::FPToSI: { @@ -385,14 +599,14 @@ static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombinerImpl &IC, // that did not exist in the original code. Type *InputTy = I->getOperand(0)->getType()->getScalarType(); const fltSemantics &Semantics = InputTy->getFltSemantics(); - uint32_t MinBitWidth = - APFloatBase::semanticsIntSizeInBits(Semantics, - I->getOpcode() == Instruction::FPToSI); + uint32_t MinBitWidth = APFloatBase::semanticsIntSizeInBits( + Semantics, I->getOpcode() == Instruction::FPToSI); return Ty->getScalarSizeInBits() >= MinBitWidth; } case Instruction::ShuffleVector: - return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) && - canEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI); + return canEvaluateTruncatedImpl(I->getOperand(0), Ty, IC, CxtI) && + canEvaluateTruncatedImpl(I->getOperand(1), Ty, IC, CxtI); + default: // TODO: Can handle more cases here. break; @@ -767,7 +981,7 @@ Instruction *InstCombinerImpl::visitTrunc(TruncInst &Trunc) { // expression tree to something weird like i93 unless the source is also // strange. if ((DestTy->isVectorTy() || shouldChangeType(SrcTy, DestTy)) && - canEvaluateTruncated(Src, DestTy, *this, &Trunc)) { + TypeEvaluationHelper::canEvaluateTruncated(Src, DestTy, *this, &Trunc)) { // If this cast is a truncate, evaluting in a different type always // eliminates the cast, so it is always a win. @@ -788,7 +1002,8 @@ Instruction *InstCombinerImpl::visitTrunc(TruncInst &Trunc) { if (DestWidth * 2 < SrcWidth) { auto *NewDestTy = DestITy->getExtendedType(); if (shouldChangeType(SrcTy, NewDestTy) && - canEvaluateTruncated(Src, NewDestTy, *this, &Trunc)) { + TypeEvaluationHelper::canEvaluateTruncated(Src, NewDestTy, *this, + &Trunc)) { LLVM_DEBUG( dbgs() << "ICE: EvaluateInDifferentType converting expression type" " to reduce the width of operand of" @@ -1104,11 +1319,22 @@ Instruction *InstCombinerImpl::transformZExtICmp(ICmpInst *Cmp, /// clear the top bits anyway, doing this has no extra cost. /// /// This function works on both vectors and scalars. -static bool canEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear, - InstCombinerImpl &IC, Instruction *CxtI) { +bool TypeEvaluationHelper::canEvaluateZExtd(Value *V, Type *Ty, + unsigned &BitsToClear, + InstCombinerImpl &IC, + Instruction *CxtI) { + TypeEvaluationHelper TYH; + return TYH.canEvaluateZExtdImpl(V, Ty, BitsToClear, IC, CxtI); +} +bool TypeEvaluationHelper::canEvaluateZExtdImpl(Value *V, Type *Ty, + unsigned &BitsToClear, + InstCombinerImpl &IC, + Instruction *CxtI) { BitsToClear = 0; if (canAlwaysEvaluateInType(V, Ty)) return true; + // We stick to the one-user limit for the ZExt transform due to the fact + // that this predicate returns two values: predicate result and BitsToClear. if (canNotEvaluateInType(V, Ty)) return false; @@ -1125,8 +1351,8 @@ static bool canEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear, case Instruction::Add: case Instruction::Sub: case Instruction::Mul: - if (!canEvaluateZExtd(I->getOperand(0), Ty, BitsToClear, IC, CxtI) || - !canEvaluateZExtd(I->getOperand(1), Ty, Tmp, IC, CxtI)) + if (!canEvaluateZExtdImpl(I->getOperand(0), Ty, BitsToClear, IC, CxtI) || + !canEvaluateZExtdImpl(I->getOperand(1), Ty, Tmp, IC, CxtI)) return false; // These can all be promoted if neither operand has 'bits to clear'. if (BitsToClear == 0 && Tmp == 0) @@ -1157,7 +1383,7 @@ static bool canEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear, // upper bits we can reduce BitsToClear by the shift amount. uint64_t ShiftAmt; if (match(I->getOperand(1), m_ConstantInt(ShiftAmt))) { - if (!canEvaluateZExtd(I->getOperand(0), Ty, BitsToClear, IC, CxtI)) + if (!canEvaluateZExtdImpl(I->getOperand(0), Ty, BitsToClear, IC, CxtI)) return false; BitsToClear = ShiftAmt < BitsToClear ? BitsToClear - ShiftAmt : 0; return true; @@ -1169,7 +1395,7 @@ static bool canEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear, // ultimate 'and' to clear out the high zero bits we're clearing out though. uint64_t ShiftAmt; if (match(I->getOperand(1), m_ConstantInt(ShiftAmt))) { - if (!canEvaluateZExtd(I->getOperand(0), Ty, BitsToClear, IC, CxtI)) + if (!canEvaluateZExtdImpl(I->getOperand(0), Ty, BitsToClear, IC, CxtI)) return false; BitsToClear += ShiftAmt; if (BitsToClear > V->getType()->getScalarSizeInBits()) @@ -1180,8 +1406,8 @@ static bool canEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear, return false; } case Instruction::Select: - if (!canEvaluateZExtd(I->getOperand(1), Ty, Tmp, IC, CxtI) || - !canEvaluateZExtd(I->getOperand(2), Ty, BitsToClear, IC, CxtI) || + if (!canEvaluateZExtdImpl(I->getOperand(1), Ty, Tmp, IC, CxtI) || + !canEvaluateZExtdImpl(I->getOperand(2), Ty, BitsToClear, IC, CxtI) || // TODO: If important, we could handle the case when the BitsToClear are // known zero in the disagreeing side. Tmp != BitsToClear) @@ -1193,10 +1419,11 @@ static bool canEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear, // get into trouble with cyclic PHIs here because we only consider // instructions with a single use. PHINode *PN = cast<PHINode>(I); - if (!canEvaluateZExtd(PN->getIncomingValue(0), Ty, BitsToClear, IC, CxtI)) + if (!canEvaluateZExtdImpl(PN->getIncomingValue(0), Ty, BitsToClear, IC, + CxtI)) return false; for (unsigned i = 1, e = PN->getNumIncomingValues(); i != e; ++i) - if (!canEvaluateZExtd(PN->getIncomingValue(i), Ty, Tmp, IC, CxtI) || + if (!canEvaluateZExtdImpl(PN->getIncomingValue(i), Ty, Tmp, IC, CxtI) || // TODO: If important, we could handle the case when the BitsToClear // are known zero in the disagreeing input. Tmp != BitsToClear) @@ -1237,7 +1464,8 @@ Instruction *InstCombinerImpl::visitZExt(ZExtInst &Zext) { // Try to extend the entire expression tree to the wide destination type. unsigned BitsToClear; if (shouldChangeType(SrcTy, DestTy) && - canEvaluateZExtd(Src, DestTy, BitsToClear, *this, &Zext)) { + TypeEvaluationHelper::canEvaluateZExtd(Src, DestTy, BitsToClear, *this, + &Zext)) { assert(BitsToClear <= SrcTy->getScalarSizeInBits() && "Can't clear more bits than in SrcTy"); @@ -1455,13 +1683,20 @@ Instruction *InstCombinerImpl::transformSExtICmp(ICmpInst *Cmp, /// /// This function works on both vectors and scalars. /// -static bool canEvaluateSExtd(Value *V, Type *Ty) { +bool TypeEvaluationHelper::canEvaluateSExtd(Value *V, Type *Ty) { + TypeEvaluationHelper TYH; + return TYH.canEvaluateSExtdImpl(V, Ty) && TYH.allPendingVisited(); +} + +bool TypeEvaluationHelper::canEvaluateSExtdImpl(Value *V, Type *Ty) { + return canEvaluate(V, Ty, [this](Value *V, Type *Ty) { + return canEvaluateSExtdPred(V, Ty); + }); +} + +bool TypeEvaluationHelper::canEvaluateSExtdPred(Value *V, Type *Ty) { assert(V->getType()->getScalarSizeInBits() < Ty->getScalarSizeInBits() && "Can't sign extend type to a smaller type"); - if (canAlwaysEvaluateInType(V, Ty)) - return true; - if (canNotEvaluateInType(V, Ty)) - return false; auto *I = cast<Instruction>(V); switch (I->getOpcode()) { @@ -1476,23 +1711,24 @@ static bool canEvaluateSExtd(Value *V, Type *Ty) { case Instruction::Sub: case Instruction::Mul: // These operators can all arbitrarily be extended if their inputs can. - return canEvaluateSExtd(I->getOperand(0), Ty) && - canEvaluateSExtd(I->getOperand(1), Ty); + return canEvaluateSExtdImpl(I->getOperand(0), Ty) && + canEvaluateSExtdImpl(I->getOperand(1), Ty); - //case Instruction::Shl: TODO - //case Instruction::LShr: TODO + // case Instruction::Shl: TODO + // case Instruction::LShr: TODO case Instruction::Select: - return canEvaluateSExtd(I->getOperand(1), Ty) && - canEvaluateSExtd(I->getOperand(2), Ty); + return canEvaluateSExtdImpl(I->getOperand(1), Ty) && + canEvaluateSExtdImpl(I->getOperand(2), Ty); case Instruction::PHI: { // We can change a phi if we can change all operands. Note that we never - // get into trouble with cyclic PHIs here because we only consider - // instructions with a single use. + // get into trouble with cyclic PHIs here because canEvaluate handles use + // chain loops. PHINode *PN = cast<PHINode>(I); for (Value *IncValue : PN->incoming_values()) - if (!canEvaluateSExtd(IncValue, Ty)) return false; + if (!canEvaluateSExtdImpl(IncValue, Ty)) + return false; return true; } default: @@ -1533,7 +1769,7 @@ Instruction *InstCombinerImpl::visitSExt(SExtInst &Sext) { if (TruncSrc->getType()->getScalarSizeInBits() > DestBitSize) ShouldExtendExpression = false; if (ShouldExtendExpression && shouldChangeType(SrcTy, DestTy) && - canEvaluateSExtd(Src, DestTy)) { + TypeEvaluationHelper::canEvaluateSExtd(Src, DestTy)) { // Okay, we can transform this! Insert the new expression now. LLVM_DEBUG( dbgs() << "ICE: EvaluateInDifferentType converting expression type" @@ -1548,7 +1784,7 @@ Instruction *InstCombinerImpl::visitSExt(SExtInst &Sext) { return replaceInstUsesWith(Sext, Res); // We need to emit a shl + ashr to do the sign extend. - Value *ShAmt = ConstantInt::get(DestTy, DestBitSize-SrcBitSize); + Value *ShAmt = ConstantInt::get(DestTy, DestBitSize - SrcBitSize); return BinaryOperator::CreateAShr(Builder.CreateShl(Res, ShAmt, "sext"), ShAmt); } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index fba1ccf..abf4381 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -1465,20 +1465,24 @@ Instruction *InstCombinerImpl::foldICmpTruncConstant(ICmpInst &Cmp, ConstantInt::get(V->getType(), 1)); } - // TODO: Handle any shifted constant by subtracting trailing zeros. // TODO: Handle non-equality predicates. Value *Y; - if (Cmp.isEquality() && match(X, m_Shl(m_One(), m_Value(Y)))) { - // (trunc (1 << Y) to iN) == 0 --> Y u>= N - // (trunc (1 << Y) to iN) != 0 --> Y u< N + const APInt *Pow2; + if (Cmp.isEquality() && match(X, m_Shl(m_Power2(Pow2), m_Value(Y))) && + DstBits > Pow2->logBase2()) { + // (trunc (Pow2 << Y) to iN) == 0 --> Y u>= N - log2(Pow2) + // (trunc (Pow2 << Y) to iN) != 0 --> Y u< N - log2(Pow2) + // iff N > log2(Pow2) if (C.isZero()) { auto NewPred = (Pred == Cmp.ICMP_EQ) ? Cmp.ICMP_UGE : Cmp.ICMP_ULT; - return new ICmpInst(NewPred, Y, ConstantInt::get(SrcTy, DstBits)); + return new ICmpInst(NewPred, Y, + ConstantInt::get(SrcTy, DstBits - Pow2->logBase2())); } - // (trunc (1 << Y) to iN) == 2**C --> Y == C - // (trunc (1 << Y) to iN) != 2**C --> Y != C + // (trunc (Pow2 << Y) to iN) == 2**C --> Y == C - log2(Pow2) + // (trunc (Pow2 << Y) to iN) != 2**C --> Y != C - log2(Pow2) if (C.isPowerOf2()) - return new ICmpInst(Pred, Y, ConstantInt::get(SrcTy, C.logBase2())); + return new ICmpInst( + Pred, Y, ConstantInt::get(SrcTy, C.logBase2() - Pow2->logBase2())); } if (Cmp.isEquality() && (Trunc->hasOneUse() || Trunc->hasNoUnsignedWrap())) { @@ -2638,16 +2642,6 @@ Instruction *InstCombinerImpl::foldICmpShrConstant(ICmpInst &Cmp, if (Shr->isExact()) return new ICmpInst(Pred, X, ConstantInt::get(ShrTy, C << ShAmtVal)); - if (C.isZero()) { - // == 0 is u< 1. - if (Pred == CmpInst::ICMP_EQ) - return new ICmpInst(CmpInst::ICMP_ULT, X, - ConstantInt::get(ShrTy, (C + 1).shl(ShAmtVal))); - else - return new ICmpInst(CmpInst::ICMP_UGT, X, - ConstantInt::get(ShrTy, (C + 1).shl(ShAmtVal) - 1)); - } - if (Shr->hasOneUse()) { // Canonicalize the shift into an 'and': // icmp eq/ne (shr X, ShAmt), C --> icmp eq/ne (and X, HiMask), (C << ShAmt) @@ -3138,7 +3132,7 @@ Instruction *InstCombinerImpl::foldICmpAddConstant(ICmpInst &Cmp, Value *Op0, *Op1; Instruction *Ext0, *Ext1; - const CmpInst::Predicate Pred = Cmp.getPredicate(); + const CmpPredicate Pred = Cmp.getCmpPredicate(); if (match(Add, m_Add(m_CombineAnd(m_Instruction(Ext0), m_ZExtOrSExt(m_Value(Op0))), m_CombineAnd(m_Instruction(Ext1), @@ -3173,22 +3167,29 @@ Instruction *InstCombinerImpl::foldICmpAddConstant(ICmpInst &Cmp, // If the add does not wrap, we can always adjust the compare by subtracting // the constants. Equality comparisons are handled elsewhere. SGE/SLE/UGE/ULE - // are canonicalized to SGT/SLT/UGT/ULT. - if ((Add->hasNoSignedWrap() && - (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SLT)) || - (Add->hasNoUnsignedWrap() && - (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_ULT))) { + // have been canonicalized to SGT/SLT/UGT/ULT. + if (Add->hasNoUnsignedWrap() && + (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_ULT)) { bool Overflow; - APInt NewC = - Cmp.isSigned() ? C.ssub_ov(*C2, Overflow) : C.usub_ov(*C2, Overflow); + APInt NewC = C.usub_ov(*C2, Overflow); // If there is overflow, the result must be true or false. - // TODO: Can we assert there is no overflow because InstSimplify always - // handles those cases? if (!Overflow) // icmp Pred (add nsw X, C2), C --> icmp Pred X, (C - C2) return new ICmpInst(Pred, X, ConstantInt::get(Ty, NewC)); } + CmpInst::Predicate ChosenPred = Pred.getPreferredSignedPredicate(); + + if (Add->hasNoSignedWrap() && + (ChosenPred == ICmpInst::ICMP_SGT || ChosenPred == ICmpInst::ICMP_SLT)) { + bool Overflow; + APInt NewC = C.ssub_ov(*C2, Overflow); + if (!Overflow) + // icmp samesign ugt/ult (add nsw X, C2), C + // -> icmp sgt/slt X, (C - C2) + return new ICmpInst(ChosenPred, X, ConstantInt::get(Ty, NewC)); + } + if (ICmpInst::isUnsigned(Pred) && Add->hasNoSignedWrap() && C.isNonNegative() && (C - *C2).isNonNegative() && computeConstantRange(X, /*ForSigned=*/true).add(*C2).isAllNonNegative()) @@ -5892,6 +5893,12 @@ static void collectOffsetOp(Value *V, SmallVectorImpl<OffsetOp> &Offsets, Offsets.emplace_back(Instruction::Xor, Inst->getOperand(1)); Offsets.emplace_back(Instruction::Xor, Inst->getOperand(0)); break; + case Instruction::Shl: + if (Inst->hasNoSignedWrap()) + Offsets.emplace_back(Instruction::AShr, Inst->getOperand(1)); + if (Inst->hasNoUnsignedWrap()) + Offsets.emplace_back(Instruction::LShr, Inst->getOperand(1)); + break; case Instruction::Select: if (AllowRecursion) { collectOffsetOp(Inst->getOperand(1), Offsets, /*AllowRecursion=*/false); @@ -5948,9 +5955,31 @@ static Instruction *foldICmpEqualityWithOffset(ICmpInst &I, collectOffsetOp(Op1, OffsetOps, /*AllowRecursion=*/true); auto ApplyOffsetImpl = [&](Value *V, unsigned BinOpc, Value *RHS) -> Value * { + switch (BinOpc) { + // V = shl nsw X, RHS => X = ashr V, RHS + case Instruction::AShr: { + const APInt *CV, *CRHS; + if (!(match(V, m_APInt(CV)) && match(RHS, m_APInt(CRHS)) && + CV->ashr(*CRHS).shl(*CRHS) == *CV) && + !match(V, m_NSWShl(m_Value(), m_Specific(RHS)))) + return nullptr; + break; + } + // V = shl nuw X, RHS => X = lshr V, RHS + case Instruction::LShr: { + const APInt *CV, *CRHS; + if (!(match(V, m_APInt(CV)) && match(RHS, m_APInt(CRHS)) && + CV->lshr(*CRHS).shl(*CRHS) == *CV) && + !match(V, m_NUWShl(m_Value(), m_Specific(RHS)))) + return nullptr; + break; + } + default: + break; + } + Value *Simplified = simplifyBinOp(BinOpc, V, RHS, SQ); - // Avoid infinite loops by checking if RHS is an identity for the BinOp. - if (!Simplified || Simplified == V) + if (!Simplified) return nullptr; // Reject constant expressions as they don't simplify things. if (isa<Constant>(Simplified) && !match(Simplified, m_ImmConstant())) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h index d85e4f7..9bdd8cb 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -479,7 +479,7 @@ private: const Twine &NameStr = "", InsertPosition InsertBefore = nullptr) { auto *Sel = SelectInst::Create(C, S1, S2, NameStr, InsertBefore, nullptr); - setExplicitlyUnknownBranchWeightsIfProfiled(*Sel, F, DEBUG_TYPE); + setExplicitlyUnknownBranchWeightsIfProfiled(*Sel, DEBUG_TYPE, &F); return Sel; } diff --git a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp index 9815644..ba1865a 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp @@ -698,8 +698,7 @@ static bool isSafeAndProfitableToSinkLoad(LoadInst *L) { Instruction *InstCombinerImpl::foldPHIArgLoadIntoPHI(PHINode &PN) { LoadInst *FirstLI = cast<LoadInst>(PN.getIncomingValue(0)); - // Can't forward swifterror through a phi. - if (FirstLI->getOperand(0)->isSwiftError()) + if (!canReplaceOperandWithVariable(FirstLI, 0)) return nullptr; // FIXME: This is overconservative; this transform is allowed in some cases @@ -738,8 +737,7 @@ Instruction *InstCombinerImpl::foldPHIArgLoadIntoPHI(PHINode &PN) { LI->getPointerAddressSpace() != LoadAddrSpace) return nullptr; - // Can't forward swifterror through a phi. - if (LI->getOperand(0)->isSwiftError()) + if (!canReplaceOperandWithVariable(LI, 0)) return nullptr; // We can't sink the load if the loaded value could be modified between @@ -1007,7 +1005,7 @@ static bool PHIsEqualValue(PHINode *PN, Value *&NonPhiInVal, return true; // Don't scan crazily complex things. - if (ValueEqualPHIs.size() == 16) + if (ValueEqualPHIs.size() >= 16) return false; // Scan the operands to see if they are either phi nodes or are equal to diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index f5130da..c00551b 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -1027,10 +1027,9 @@ static Value *canonicalizeSaturatedSubtract(const ICmpInst *ICI, return Result; } -static Value *canonicalizeSaturatedAdd(ICmpInst *Cmp, Value *TVal, Value *FVal, - InstCombiner::BuilderTy &Builder) { - if (!Cmp->hasOneUse()) - return nullptr; +static Value * +canonicalizeSaturatedAddUnsigned(ICmpInst *Cmp, Value *TVal, Value *FVal, + InstCombiner::BuilderTy &Builder) { // Match unsigned saturated add with constant. Value *Cmp0 = Cmp->getOperand(0); @@ -1130,6 +1129,95 @@ static Value *canonicalizeSaturatedAdd(ICmpInst *Cmp, Value *TVal, Value *FVal, return nullptr; } +static Value *canonicalizeSaturatedAddSigned(ICmpInst *Cmp, Value *TVal, + Value *FVal, + InstCombiner::BuilderTy &Builder) { + // Match saturated add with constant. + Value *Cmp0 = Cmp->getOperand(0); + Value *Cmp1 = Cmp->getOperand(1); + ICmpInst::Predicate Pred = Cmp->getPredicate(); + Value *X; + const APInt *C; + + // Canonicalize INT_MAX to true value of the select. + if (match(FVal, m_MaxSignedValue())) { + std::swap(TVal, FVal); + Pred = CmpInst::getInversePredicate(Pred); + } + + if (!match(TVal, m_MaxSignedValue())) + return nullptr; + + // sge maximum signed value is canonicalized to eq maximum signed value and + // requires special handling (a == INT_MAX) ? INT_MAX : a + 1 -> sadd.sat(a, + // 1) + if (Pred == ICmpInst::ICMP_EQ) { + if (match(FVal, m_Add(m_Specific(Cmp0), m_One())) && Cmp1 == TVal) { + return Builder.CreateBinaryIntrinsic( + Intrinsic::sadd_sat, Cmp0, ConstantInt::get(Cmp0->getType(), 1)); + } + return nullptr; + } + + // (X > Y) ? INT_MAX : (X + C) --> sadd.sat(X, C) + // (X >= Y) ? INT_MAX : (X + C) --> sadd.sat(X, C) + // where Y is INT_MAX - C or INT_MAX - C - 1, and C > 0 + if ((Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE) && + isa<Constant>(Cmp1) && + match(FVal, m_Add(m_Specific(Cmp0), m_StrictlyPositive(C)))) { + APInt IntMax = + APInt::getSignedMaxValue(Cmp1->getType()->getScalarSizeInBits()); + + // For SGE, try to flip to SGT to normalize the comparison constant. + if (Pred == ICmpInst::ICMP_SGE) { + if (auto Flipped = getFlippedStrictnessPredicateAndConstant( + Pred, cast<Constant>(Cmp1))) { + Pred = Flipped->first; + Cmp1 = Flipped->second; + } + } + + // Check the pattern: X > INT_MAX - C or X > INT_MAX - C - 1 + if (Pred == ICmpInst::ICMP_SGT && + (match(Cmp1, m_SpecificIntAllowPoison(IntMax - *C)) || + match(Cmp1, m_SpecificIntAllowPoison(IntMax - *C - 1)))) + return Builder.CreateBinaryIntrinsic( + Intrinsic::sadd_sat, Cmp0, ConstantInt::get(Cmp0->getType(), *C)); + } + + // Canonicalize predicate to less-than or less-or-equal-than. + if (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE) { + std::swap(Cmp0, Cmp1); + Pred = CmpInst::getSwappedPredicate(Pred); + } + + if (Pred != ICmpInst::ICMP_SLT && Pred != ICmpInst::ICMP_SLE) + return nullptr; + + if (match(Cmp0, m_NSWSub(m_MaxSignedValue(), m_Value(X))) && + match(FVal, m_c_Add(m_Specific(X), m_Specific(Cmp1)))) { + // (INT_MAX - X s< Y) ? INT_MAX : (X + Y) --> sadd.sat(X, Y) + // (INT_MAX - X s< Y) ? INT_MAX : (Y + X) --> sadd.sat(X, Y) + return Builder.CreateBinaryIntrinsic(Intrinsic::sadd_sat, X, Cmp1); + } + + return nullptr; +} + +static Value *canonicalizeSaturatedAdd(ICmpInst *Cmp, Value *TVal, Value *FVal, + InstCombiner::BuilderTy &Builder) { + if (!Cmp->hasOneUse()) + return nullptr; + + if (Value *V = canonicalizeSaturatedAddUnsigned(Cmp, TVal, FVal, Builder)) + return V; + + if (Value *V = canonicalizeSaturatedAddSigned(Cmp, TVal, FVal, Builder)) + return V; + + return nullptr; +} + /// Try to match patterns with select and subtract as absolute difference. static Value *foldAbsDiff(ICmpInst *Cmp, Value *TVal, Value *FVal, InstCombiner::BuilderTy &Builder) { @@ -2979,14 +3067,10 @@ Instruction *InstCombinerImpl::foldAndOrOfSelectUsingImpliedCond(Value *Op, "Op must be either i1 or vector of i1."); if (SI.getCondition()->getType() != Op->getType()) return nullptr; - if (Value *V = simplifyNestedSelectsUsingImpliedCond(SI, Op, IsAnd, DL)) { - Instruction *MDFrom = nullptr; - if (!ProfcheckDisableMetadataFixes) - MDFrom = &SI; - return SelectInst::Create( + if (Value *V = simplifyNestedSelectsUsingImpliedCond(SI, Op, IsAnd, DL)) + return createSelectInstWithUnknownProfile( Op, IsAnd ? V : ConstantInt::getTrue(Op->getType()), - IsAnd ? ConstantInt::getFalse(Op->getType()) : V, "", nullptr, MDFrom); - } + IsAnd ? ConstantInt::getFalse(Op->getType()) : V); return nullptr; } @@ -3599,6 +3683,21 @@ Instruction *InstCombinerImpl::foldSelectOfBools(SelectInst &SI) { m_Not(m_Specific(SelCond->getTrueValue()))); if (MayNeedFreeze) C = Builder.CreateFreeze(C); + if (!ProfcheckDisableMetadataFixes) { + Value *C2 = nullptr, *A2 = nullptr, *B2 = nullptr; + if (match(CondVal, m_LogicalAnd(m_Specific(C), m_Value(A2))) && + SelCond) { + return SelectInst::Create(C, A, B, "", nullptr, SelCond); + } else if (match(FalseVal, + m_LogicalAnd(m_Not(m_Value(C2)), m_Value(B2))) && + SelFVal) { + SelectInst *NewSI = SelectInst::Create(C, A, B, "", nullptr, SelFVal); + NewSI->swapProfMetadata(); + return NewSI; + } else { + return createSelectInstWithUnknownProfile(C, A, B); + } + } return SelectInst::Create(C, A, B); } @@ -3615,6 +3714,20 @@ Instruction *InstCombinerImpl::foldSelectOfBools(SelectInst &SI) { m_Not(m_Specific(SelFVal->getTrueValue()))); if (MayNeedFreeze) C = Builder.CreateFreeze(C); + if (!ProfcheckDisableMetadataFixes) { + Value *C2 = nullptr, *A2 = nullptr, *B2 = nullptr; + if (match(CondVal, m_LogicalAnd(m_Not(m_Value(C2)), m_Value(A2))) && + SelCond) { + SelectInst *NewSI = SelectInst::Create(C, B, A, "", nullptr, SelCond); + NewSI->swapProfMetadata(); + return NewSI; + } else if (match(FalseVal, m_LogicalAnd(m_Specific(C), m_Value(B2))) && + SelFVal) { + return SelectInst::Create(C, B, A, "", nullptr, SelFVal); + } else { + return createSelectInstWithUnknownProfile(C, B, A); + } + } return SelectInst::Create(C, B, A); } } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp index 18a45c6..98e2d9e 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp @@ -140,8 +140,8 @@ Instruction *InstCombinerImpl::scalarizePHI(ExtractElementInst &EI, Value *Elt = EI.getIndexOperand(); // If the operand is the PHI induction variable: if (PHIInVal == PHIUser) { - // Scalarize the binary operation. Its first operand is the - // scalar PHI, and the second operand is extracted from the other + // Scalarize the binary operation. One operand is the + // scalar PHI, and the other is extracted from the other // vector operand. BinaryOperator *B0 = cast<BinaryOperator>(PHIUser); unsigned opId = (B0->getOperand(0) == PN) ? 1 : 0; @@ -149,9 +149,14 @@ Instruction *InstCombinerImpl::scalarizePHI(ExtractElementInst &EI, ExtractElementInst::Create(B0->getOperand(opId), Elt, B0->getOperand(opId)->getName() + ".Elt"), B0->getIterator()); - Value *newPHIUser = InsertNewInstWith( - BinaryOperator::CreateWithCopiedFlags(B0->getOpcode(), - scalarPHI, Op, B0), B0->getIterator()); + // Preserve operand order for binary operation to preserve semantics of + // non-commutative operations. + Value *FirstOp = (B0->getOperand(0) == PN) ? scalarPHI : Op; + Value *SecondOp = (B0->getOperand(0) == PN) ? Op : scalarPHI; + Value *newPHIUser = + InsertNewInstWith(BinaryOperator::CreateWithCopiedFlags( + B0->getOpcode(), FirstOp, SecondOp, B0), + B0->getIterator()); scalarPHI->addIncoming(newPHIUser, inBB); } else { // Scalarize PHI input: diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp index 67f837c..c6de57c 100644 --- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -1758,6 +1758,9 @@ static Value *simplifyOperationIntoSelectOperand(Instruction &I, SelectInst *SI, m_Specific(Op), m_Value(V))) && isGuaranteedNotToBeUndefOrPoison(V)) { // Pass + } else if (match(Op, m_ZExt(m_Specific(SI->getCondition())))) { + V = IsTrueArm ? ConstantInt::get(Op->getType(), 1) + : ConstantInt::getNullValue(Op->getType()); } else { V = Op; } @@ -2261,11 +2264,11 @@ Instruction *InstCombinerImpl::foldBinopWithPhiOperands(BinaryOperator &BO) { } Instruction *InstCombinerImpl::foldBinOpIntoSelectOrPhi(BinaryOperator &I) { - if (!isa<Constant>(I.getOperand(1))) - return nullptr; + bool IsOtherParamConst = isa<Constant>(I.getOperand(1)); if (auto *Sel = dyn_cast<SelectInst>(I.getOperand(0))) { - if (Instruction *NewSel = FoldOpIntoSelect(I, Sel)) + if (Instruction *NewSel = + FoldOpIntoSelect(I, Sel, false, !IsOtherParamConst)) return NewSel; } else if (auto *PN = dyn_cast<PHINode>(I.getOperand(0))) { if (Instruction *NewPhi = foldOpIntoPhi(I, PN)) @@ -3370,9 +3373,9 @@ Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) { DL.getAddressSizeInBits(AS) != DL.getPointerSizeInBits(AS); bool Changed = false; GEP.replaceUsesWithIf(Y, [&](Use &U) { - bool ShouldReplace = isa<PtrToAddrInst>(U.getUser()) || - (!HasNonAddressBits && - isa<ICmpInst, PtrToIntInst>(U.getUser())); + bool ShouldReplace = + isa<PtrToAddrInst, ICmpInst>(U.getUser()) || + (!HasNonAddressBits && isa<PtrToIntInst>(U.getUser())); Changed |= ShouldReplace; return ShouldReplace; }); @@ -5624,8 +5627,15 @@ bool InstCombinerImpl::run() { for (Use &U : I->uses()) { User *User = U.getUser(); - if (User->isDroppable()) - continue; + if (User->isDroppable()) { + // Do not sink if there are dereferenceable assumes that would be + // removed. + auto II = dyn_cast<IntrinsicInst>(User); + if (II->getIntrinsicID() != Intrinsic::assume || + !II->getOperandBundle("dereferenceable")) + continue; + } + if (NumUsers > MaxSinkNumUsers) return std::nullopt; |
