diff options
Diffstat (limited to 'llvm/lib/Transforms')
26 files changed, 795 insertions, 435 deletions
diff --git a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp index bbbac45..7a95df4 100644 --- a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp +++ b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp @@ -907,10 +907,20 @@ static bool mergeConsecutivePartStores(ArrayRef<PartStore> Parts, StoreInst *Store = Builder.CreateAlignedStore( Val, First.Store->getPointerOperand(), First.Store->getAlign()); + // Merge various metadata onto the new store. AAMDNodes AATags = First.Store->getAAMetadata(); - for (const PartStore &Part : drop_begin(Parts)) + SmallVector<Instruction *> Stores = {First.Store}; + Stores.reserve(Parts.size()); + SmallVector<DebugLoc> DbgLocs = {First.Store->getDebugLoc()}; + DbgLocs.reserve(Parts.size()); + for (const PartStore &Part : drop_begin(Parts)) { AATags = AATags.concat(Part.Store->getAAMetadata()); + Stores.push_back(Part.Store); + DbgLocs.push_back(Part.Store->getDebugLoc()); + } Store->setAAMetadata(AATags); + Store->mergeDIAssignID(Stores); + Store->setDebugLoc(DebugLoc::getMergedLocations(DbgLocs)); // Remove the old stores. for (const PartStore &Part : Parts) diff --git a/llvm/lib/Transforms/IPO/FunctionImport.cpp b/llvm/lib/Transforms/IPO/FunctionImport.cpp index 28ee444..a29faab 100644 --- a/llvm/lib/Transforms/IPO/FunctionImport.cpp +++ b/llvm/lib/Transforms/IPO/FunctionImport.cpp @@ -1368,13 +1368,13 @@ static void ComputeCrossModuleImportForModuleFromIndexForTest( FunctionImporter::ImportMapTy &ImportList) { for (const auto &GlobalList : Index) { // Ignore entries for undefined references. - if (GlobalList.second.SummaryList.empty()) + if (GlobalList.second.getSummaryList().empty()) continue; auto GUID = GlobalList.first; - assert(GlobalList.second.SummaryList.size() == 1 && + assert(GlobalList.second.getSummaryList().size() == 1 && "Expected individual combined index to have one summary per GUID"); - auto &Summary = GlobalList.second.SummaryList[0]; + auto &Summary = GlobalList.second.getSummaryList()[0]; // Skip the summaries for the importing module. These are included to // e.g. record required linkage changes. if (Summary->modulePath() == ModulePath) @@ -1423,7 +1423,7 @@ void updateValueInfoForIndirectCalls(ModuleSummaryIndex &Index, void llvm::updateIndirectCalls(ModuleSummaryIndex &Index) { for (const auto &Entry : Index) { - for (const auto &S : Entry.second.SummaryList) { + for (const auto &S : Entry.second.getSummaryList()) { if (auto *FS = dyn_cast<FunctionSummary>(S.get())) updateValueInfoForIndirectCalls(Index, FS); } @@ -1456,7 +1456,7 @@ void llvm::computeDeadSymbolsAndUpdateIndirectCalls( // Add values flagged in the index as live roots to the worklist. for (const auto &Entry : Index) { auto VI = Index.getValueInfo(Entry); - for (const auto &S : Entry.second.SummaryList) { + for (const auto &S : Entry.second.getSummaryList()) { if (auto *FS = dyn_cast<FunctionSummary>(S.get())) updateValueInfoForIndirectCalls(Index, FS); if (S->isLive()) { @@ -2094,7 +2094,7 @@ static bool doImportingForModuleForTest( // is only enabled when testing importing via the 'opt' tool, which does // not do the ThinLink that would normally determine what values to promote. for (auto &I : *Index) { - for (auto &S : I.second.SummaryList) { + for (auto &S : I.second.getSummaryList()) { if (GlobalValue::isLocalLinkage(S->linkage())) S->setLinkage(GlobalValue::ExternalLinkage); } diff --git a/llvm/lib/Transforms/IPO/LowerTypeTests.cpp b/llvm/lib/Transforms/IPO/LowerTypeTests.cpp index be6cba3..aa1346d 100644 --- a/llvm/lib/Transforms/IPO/LowerTypeTests.cpp +++ b/llvm/lib/Transforms/IPO/LowerTypeTests.cpp @@ -1271,7 +1271,7 @@ bool LowerTypeTestsModule::hasBranchTargetEnforcement() { // the module flags. if (const auto *BTE = mdconst::extract_or_null<ConstantInt>( M.getModuleFlag("branch-target-enforcement"))) - HasBranchTargetEnforcement = (BTE->getZExtValue() != 0); + HasBranchTargetEnforcement = !BTE->isZero(); else HasBranchTargetEnforcement = 0; } @@ -2130,7 +2130,7 @@ bool LowerTypeTestsModule::lower() { // A set of all functions that are address taken by a live global object. DenseSet<GlobalValue::GUID> AddressTaken; for (auto &I : *ExportSummary) - for (auto &GVS : I.second.SummaryList) + for (auto &GVS : I.second.getSummaryList()) if (GVS->isLive()) for (const auto &Ref : GVS->refs()) { AddressTaken.insert(Ref.getGUID()); @@ -2409,7 +2409,7 @@ bool LowerTypeTestsModule::lower() { } for (auto &P : *ExportSummary) { - for (auto &S : P.second.SummaryList) { + for (auto &S : P.second.getSummaryList()) { if (!ExportSummary->isGlobalValueLive(S.get())) continue; if (auto *FS = dyn_cast<FunctionSummary>(S->getBaseObject())) diff --git a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp index 2d5cb82..2dd0fde 100644 --- a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp +++ b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp @@ -24,7 +24,8 @@ // returns 0, or a single vtable's function returns 1, replace each virtual // call with a comparison of the vptr against that vtable's address. // -// This pass is intended to be used during the regular and thin LTO pipelines: +// This pass is intended to be used during the regular/thin and non-LTO +// pipelines: // // During regular LTO, the pass determines the best optimization for each // virtual call and applies the resolutions directly to virtual calls that are @@ -48,6 +49,14 @@ // is supported. // - Import phase: (same as with hybrid case above). // +// During Speculative devirtualization mode -not restricted to LTO-: +// - The pass applies speculative devirtualization without requiring any type of +// visibility. +// - Skips other features like virtual constant propagation, uniform return +// value optimization, unique return value optimization and branch funnels as +// they need LTO. +// - This mode is enabled via 'devirtualize-speculatively' flag. +// //===----------------------------------------------------------------------===// #include "llvm/Transforms/IPO/WholeProgramDevirt.h" @@ -61,7 +70,9 @@ #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/BasicAliasAnalysis.h" #include "llvm/Analysis/BlockFrequencyInfo.h" +#include "llvm/Analysis/ModuleSummaryAnalysis.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/Analysis/ProfileSummaryInfo.h" #include "llvm/Analysis/TypeMetadataUtils.h" #include "llvm/Bitcode/BitcodeReader.h" #include "llvm/Bitcode/BitcodeWriter.h" @@ -145,6 +156,13 @@ static cl::opt<std::string> ClWriteSummary( "bitcode, otherwise YAML"), cl::Hidden); +// TODO: This option eventually should support any public visibility vtables +// with/out LTO. +static cl::opt<bool> ClDevirtualizeSpeculatively( + "devirtualize-speculatively", + cl::desc("Enable speculative devirtualization optimization"), + cl::init(false)); + static cl::opt<unsigned> ClThreshold("wholeprogramdevirt-branch-funnel-threshold", cl::Hidden, cl::init(10), @@ -892,6 +910,8 @@ void llvm::updatePublicTypeTestCalls(Module &M, CI->eraseFromParent(); } } else { + // TODO: Don't replace public type tests when speculative devirtualization + // gets enabled in LTO mode. auto *True = ConstantInt::getTrue(M.getContext()); for (Use &U : make_early_inc_range(PublicTypeTestFunc->uses())) { auto *CI = cast<CallInst>(U.getUser()); @@ -928,17 +948,17 @@ void llvm::updateVCallVisibilityInIndex( // linker, as we have no information on their eventual use. if (DynamicExportSymbols.count(P.first)) continue; - for (auto &S : P.second.SummaryList) { + // With validation enabled, we want to exclude symbols visible to regular + // objects. Local symbols will be in this group due to the current + // implementation but those with VCallVisibilityTranslationUnit will have + // already been marked in clang so are unaffected. + if (VisibleToRegularObjSymbols.count(P.first)) + continue; + for (auto &S : P.second.getSummaryList()) { auto *GVar = dyn_cast<GlobalVarSummary>(S.get()); if (!GVar || GVar->getVCallVisibility() != GlobalObject::VCallVisibilityPublic) continue; - // With validation enabled, we want to exclude symbols visible to regular - // objects. Local symbols will be in this group due to the current - // implementation but those with VCallVisibilityTranslationUnit will have - // already been marked in clang so are unaffected. - if (VisibleToRegularObjSymbols.count(P.first)) - continue; GVar->setVCallVisibility(GlobalObject::VCallVisibilityLinkageUnit); } } @@ -1083,10 +1103,10 @@ bool DevirtModule::tryFindVirtualCallTargets( if (!TM.Bits->GV->isConstant()) return false; - // We cannot perform whole program devirtualization analysis on a vtable - // with public LTO visibility. - if (TM.Bits->GV->getVCallVisibility() == - GlobalObject::VCallVisibilityPublic) + // Without ClDevirtualizeSpeculatively, we cannot perform whole program + // devirtualization analysis on a vtable with public LTO visibility. + if (!ClDevirtualizeSpeculatively && TM.Bits->GV->getVCallVisibility() == + GlobalObject::VCallVisibilityPublic) return false; Function *Fn = nullptr; @@ -1105,6 +1125,12 @@ bool DevirtModule::tryFindVirtualCallTargets( if (Fn->getName() == "__cxa_pure_virtual") continue; + // In most cases empty functions will be overridden by the + // implementation of the derived class, so we can skip them. + if (ClDevirtualizeSpeculatively && Fn->getReturnType()->isVoidTy() && + Fn->getInstructionCount() <= 1) + continue; + // We can disregard unreachable functions as possible call targets, as // unreachable functions shouldn't be called. if (mustBeUnreachableFunction(Fn, ExportSummary)) @@ -1135,14 +1161,10 @@ bool DevirtIndex::tryFindVirtualCallTargets( // and therefore the same GUID. This can happen if there isn't enough // distinguishing path when compiling the source file. In that case we // conservatively return false early. + if (P.VTableVI.hasLocal() && P.VTableVI.getSummaryList().size() > 1) + return false; const GlobalVarSummary *VS = nullptr; - bool LocalFound = false; for (const auto &S : P.VTableVI.getSummaryList()) { - if (GlobalValue::isLocalLinkage(S->linkage())) { - if (LocalFound) - return false; - LocalFound = true; - } auto *CurVS = cast<GlobalVarSummary>(S->getBaseObject()); if (!CurVS->vTableFuncs().empty() || // Previously clang did not attach the necessary type metadata to @@ -1158,6 +1180,7 @@ bool DevirtIndex::tryFindVirtualCallTargets( // with public LTO visibility. if (VS->getVCallVisibility() == GlobalObject::VCallVisibilityPublic) return false; + break; } } // There will be no VS if all copies are available_externally having no @@ -1223,10 +1246,12 @@ void DevirtModule::applySingleImplDevirt(VTableSlotInfo &SlotInfo, CallTrap->setDebugLoc(CB.getDebugLoc()); } - // If fallback checking is enabled, add support to compare the virtual - // function pointer to the devirtualized target. In case of a mismatch, - // fall back to indirect call. - if (DevirtCheckMode == WPDCheckMode::Fallback) { + // If fallback checking or speculative devirtualization are enabled, + // add support to compare the virtual function pointer to the + // devirtualized target. In case of a mismatch, fall back to indirect + // call. + if (DevirtCheckMode == WPDCheckMode::Fallback || + ClDevirtualizeSpeculatively) { MDNode *Weights = MDBuilder(M.getContext()).createLikelyBranchWeights(); // Version the indirect call site. If the called value is equal to the // given callee, 'NewInst' will be executed, otherwise the original call @@ -1383,9 +1408,8 @@ bool DevirtIndex::trySingleImplDevirt(MutableArrayRef<ValueInfo> TargetsForSlot, // If the summary list contains multiple summaries where at least one is // a local, give up, as we won't know which (possibly promoted) name to use. - for (const auto &S : TheFn.getSummaryList()) - if (GlobalValue::isLocalLinkage(S->linkage()) && Size > 1) - return false; + if (TheFn.hasLocal() && Size > 1) + return false; // Collect functions devirtualized at least for one call site for stats. if (PrintSummaryDevirt || AreStatisticsEnabled()) @@ -2057,15 +2081,15 @@ void DevirtModule::scanTypeTestUsers( Function *TypeTestFunc, DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap) { // Find all virtual calls via a virtual table pointer %p under an assumption - // of the form llvm.assume(llvm.type.test(%p, %md)). This indicates that %p - // points to a member of the type identifier %md. Group calls by (type ID, - // offset) pair (effectively the identity of the virtual function) and store - // to CallSlots. + // of the form llvm.assume(llvm.type.test(%p, %md)) or + // llvm.assume(llvm.public.type.test(%p, %md)). + // This indicates that %p points to a member of the type identifier %md. + // Group calls by (type ID, offset) pair (effectively the identity of the + // virtual function) and store to CallSlots. for (Use &U : llvm::make_early_inc_range(TypeTestFunc->uses())) { auto *CI = dyn_cast<CallInst>(U.getUser()); if (!CI) continue; - // Search for virtual calls based on %p and add them to DevirtCalls. SmallVector<DevirtCallSite, 1> DevirtCalls; SmallVector<CallInst *, 1> Assumes; @@ -2348,6 +2372,12 @@ bool DevirtModule::run() { (ImportSummary && ImportSummary->partiallySplitLTOUnits())) return false; + Function *PublicTypeTestFunc = nullptr; + // If we are in speculative devirtualization mode, we can work on the public + // type test intrinsics. + if (ClDevirtualizeSpeculatively) + PublicTypeTestFunc = + Intrinsic::getDeclarationIfExists(&M, Intrinsic::public_type_test); Function *TypeTestFunc = Intrinsic::getDeclarationIfExists(&M, Intrinsic::type_test); Function *TypeCheckedLoadFunc = @@ -2361,8 +2391,9 @@ bool DevirtModule::run() { // module, this pass has nothing to do. But if we are exporting, we also need // to handle any users that appear only in the function summaries. if (!ExportSummary && - (!TypeTestFunc || TypeTestFunc->use_empty() || !AssumeFunc || - AssumeFunc->use_empty()) && + (((!PublicTypeTestFunc || PublicTypeTestFunc->use_empty()) && + (!TypeTestFunc || TypeTestFunc->use_empty())) || + !AssumeFunc || AssumeFunc->use_empty()) && (!TypeCheckedLoadFunc || TypeCheckedLoadFunc->use_empty()) && (!TypeCheckedLoadRelativeFunc || TypeCheckedLoadRelativeFunc->use_empty())) @@ -2373,6 +2404,9 @@ bool DevirtModule::run() { DenseMap<Metadata *, std::set<TypeMemberInfo>> TypeIdMap; buildTypeIdentifierMap(Bits, TypeIdMap); + if (PublicTypeTestFunc && AssumeFunc) + scanTypeTestUsers(PublicTypeTestFunc, TypeIdMap); + if (TypeTestFunc && AssumeFunc) scanTypeTestUsers(TypeTestFunc, TypeIdMap); @@ -2413,7 +2447,7 @@ bool DevirtModule::run() { } for (auto &P : *ExportSummary) { - for (auto &S : P.second.SummaryList) { + for (auto &S : P.second.getSummaryList()) { auto *FS = dyn_cast<FunctionSummary>(S.get()); if (!FS) continue; @@ -2472,8 +2506,12 @@ bool DevirtModule::run() { .WPDRes[S.first.ByteOffset]; if (tryFindVirtualCallTargets(TargetsForSlot, TypeMemberInfos, S.first.ByteOffset, ExportSummary)) { - - if (!trySingleImplDevirt(ExportSummary, TargetsForSlot, S.second, Res)) { + bool SingleImplDevirt = + trySingleImplDevirt(ExportSummary, TargetsForSlot, S.second, Res); + // Out of speculative devirtualization mode, Try to apply virtual constant + // propagation or branch funneling. + // TODO: This should eventually be enabled for non-public type tests. + if (!SingleImplDevirt && !ClDevirtualizeSpeculatively) { DidVirtualConstProp |= tryVirtualConstProp(TargetsForSlot, S.second, Res, S.first); @@ -2549,6 +2587,11 @@ void DevirtIndex::run() { if (ExportSummary.typeIdCompatibleVtableMap().empty()) return; + // Assert that we haven't made any changes that would affect the hasLocal() + // flag on the GUID summary info. + assert(!ExportSummary.withInternalizeAndPromote() && + "Expect index-based WPD to run before internalization and promotion"); + DenseMap<GlobalValue::GUID, std::vector<StringRef>> NameByGUID; for (const auto &P : ExportSummary.typeIdCompatibleVtableMap()) { NameByGUID[GlobalValue::getGUIDAssumingExternalLinkage(P.first)].push_back( @@ -2564,7 +2607,7 @@ void DevirtIndex::run() { // Collect information from summary about which calls to try to devirtualize. for (auto &P : ExportSummary) { - for (auto &S : P.second.SummaryList) { + for (auto &S : P.second.getSummaryList()) { auto *FS = dyn_cast<FunctionSummary>(S.get()); if (!FS) continue; diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp index 73ec451..9bee523 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -2760,21 +2760,34 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) { // Optimize pointer differences into the same array into a size. Consider: // &A[10] - &A[0]: we should compile this to "10". Value *LHSOp, *RHSOp; - if (match(Op0, m_PtrToInt(m_Value(LHSOp))) && - match(Op1, m_PtrToInt(m_Value(RHSOp)))) + if (match(Op0, m_PtrToIntOrAddr(m_Value(LHSOp))) && + match(Op1, m_PtrToIntOrAddr(m_Value(RHSOp)))) if (Value *Res = OptimizePointerDifference(LHSOp, RHSOp, I.getType(), I.hasNoUnsignedWrap())) return replaceInstUsesWith(I, Res); // trunc(p)-trunc(q) -> trunc(p-q) - if (match(Op0, m_Trunc(m_PtrToInt(m_Value(LHSOp)))) && - match(Op1, m_Trunc(m_PtrToInt(m_Value(RHSOp))))) + if (match(Op0, m_Trunc(m_PtrToIntOrAddr(m_Value(LHSOp)))) && + match(Op1, m_Trunc(m_PtrToIntOrAddr(m_Value(RHSOp))))) if (Value *Res = OptimizePointerDifference(LHSOp, RHSOp, I.getType(), /* IsNUW */ false)) return replaceInstUsesWith(I, Res); - if (match(Op0, m_ZExt(m_PtrToIntSameSize(DL, m_Value(LHSOp)))) && - match(Op1, m_ZExtOrSelf(m_PtrToInt(m_Value(RHSOp))))) { + auto MatchSubOfZExtOfPtrToIntOrAddr = [&]() { + if (match(Op0, m_ZExt(m_PtrToIntSameSize(DL, m_Value(LHSOp)))) && + match(Op1, m_ZExt(m_PtrToIntSameSize(DL, m_Value(RHSOp))))) + return true; + if (match(Op0, m_ZExt(m_PtrToAddr(m_Value(LHSOp)))) && + match(Op1, m_ZExt(m_PtrToAddr(m_Value(RHSOp))))) + return true; + // Special case for non-canonical ptrtoint in constant expression, + // where the zext has been folded into the ptrtoint. + if (match(Op0, m_ZExt(m_PtrToIntSameSize(DL, m_Value(LHSOp)))) && + match(Op1, m_PtrToInt(m_Value(RHSOp)))) + return true; + return false; + }; + if (MatchSubOfZExtOfPtrToIntOrAddr()) { if (auto *GEP = dyn_cast<GEPOperator>(LHSOp)) { if (GEP->getPointerOperand() == RHSOp) { if (GEP->hasNoUnsignedWrap() || GEP->hasNoUnsignedSignedWrap()) { diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp index dab200d..8d9933b 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -582,6 +582,18 @@ static Instruction *foldCttzCtlz(IntrinsicInst &II, InstCombinerImpl &IC) { IC.Builder.CreateBinaryIntrinsic(Intrinsic::ctlz, C, Op1); return BinaryOperator::CreateSub(ConstCtlz, X); } + + // ctlz(~x & (x - 1)) -> bitwidth - cttz(x, false) + if (Op0->hasOneUse() && + match(Op0, + m_c_And(m_Not(m_Value(X)), m_Add(m_Deferred(X), m_AllOnes())))) { + Type *Ty = II.getType(); + unsigned BitWidth = Ty->getScalarSizeInBits(); + auto *Cttz = IC.Builder.CreateIntrinsic(Intrinsic::cttz, Ty, + {X, IC.Builder.getFalse()}); + auto *Bw = ConstantInt::get(Ty, APInt(BitWidth, BitWidth)); + return IC.replaceInstUsesWith(II, IC.Builder.CreateSub(Bw, Cttz)); + } } // cttz(Pow2) -> Log2(Pow2) @@ -4003,18 +4015,29 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { // Try to fold intrinsic into select/phi operands. This is legal if: // * The intrinsic is speculatable. - // * The select condition is not a vector, or the intrinsic does not - // perform cross-lane operations. - if (isSafeToSpeculativelyExecuteWithVariableReplaced(&CI) && - isNotCrossLaneOperation(II)) + // * The operand is one of the following: + // - a phi. + // - a select with a scalar condition. + // - a select with a vector condition and II is not a cross lane operation. + if (isSafeToSpeculativelyExecuteWithVariableReplaced(&CI)) { for (Value *Op : II->args()) { - if (auto *Sel = dyn_cast<SelectInst>(Op)) - if (Instruction *R = FoldOpIntoSelect(*II, Sel)) + if (auto *Sel = dyn_cast<SelectInst>(Op)) { + bool IsVectorCond = Sel->getCondition()->getType()->isVectorTy(); + if (IsVectorCond && !isNotCrossLaneOperation(II)) + continue; + // Don't replace a scalar select with a more expensive vector select if + // we can't simplify both arms of the select. + bool SimplifyBothArms = + !Op->getType()->isVectorTy() && II->getType()->isVectorTy(); + if (Instruction *R = FoldOpIntoSelect( + *II, Sel, /*FoldWithMultiUse=*/false, SimplifyBothArms)) return R; + } if (auto *Phi = dyn_cast<PHINode>(Op)) if (Instruction *R = foldOpIntoPhi(*II, Phi)) return R; } + } if (Instruction *Shuf = foldShuffledIntrinsicOperands(II)) return Shuf; diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp index cdc559b..9b9fe26 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -1643,33 +1643,46 @@ Instruction *InstCombinerImpl::visitSExt(SExtInst &Sext) { /// Return a Constant* for the specified floating-point constant if it fits /// in the specified FP type without changing its value. -static bool fitsInFPType(ConstantFP *CFP, const fltSemantics &Sem) { +static bool fitsInFPType(APFloat F, const fltSemantics &Sem) { bool losesInfo; - APFloat F = CFP->getValueAPF(); (void)F.convert(Sem, APFloat::rmNearestTiesToEven, &losesInfo); return !losesInfo; } -static Type *shrinkFPConstant(ConstantFP *CFP, bool PreferBFloat) { - if (CFP->getType() == Type::getPPC_FP128Ty(CFP->getContext())) - return nullptr; // No constant folding of this. +static Type *shrinkFPConstant(LLVMContext &Ctx, const APFloat &F, + bool PreferBFloat) { // See if the value can be truncated to bfloat and then reextended. - if (PreferBFloat && fitsInFPType(CFP, APFloat::BFloat())) - return Type::getBFloatTy(CFP->getContext()); + if (PreferBFloat && fitsInFPType(F, APFloat::BFloat())) + return Type::getBFloatTy(Ctx); // See if the value can be truncated to half and then reextended. - if (!PreferBFloat && fitsInFPType(CFP, APFloat::IEEEhalf())) - return Type::getHalfTy(CFP->getContext()); + if (!PreferBFloat && fitsInFPType(F, APFloat::IEEEhalf())) + return Type::getHalfTy(Ctx); // See if the value can be truncated to float and then reextended. - if (fitsInFPType(CFP, APFloat::IEEEsingle())) - return Type::getFloatTy(CFP->getContext()); - if (CFP->getType()->isDoubleTy()) - return nullptr; // Won't shrink. - if (fitsInFPType(CFP, APFloat::IEEEdouble())) - return Type::getDoubleTy(CFP->getContext()); + if (fitsInFPType(F, APFloat::IEEEsingle())) + return Type::getFloatTy(Ctx); + if (&F.getSemantics() == &APFloat::IEEEdouble()) + return nullptr; // Won't shrink. + // See if the value can be truncated to double and then reextended. + if (fitsInFPType(F, APFloat::IEEEdouble())) + return Type::getDoubleTy(Ctx); // Don't try to shrink to various long double types. return nullptr; } +static Type *shrinkFPConstant(ConstantFP *CFP, bool PreferBFloat) { + Type *Ty = CFP->getType(); + if (Ty->getScalarType()->isPPC_FP128Ty()) + return nullptr; // No constant folding of this. + + Type *ShrinkTy = + shrinkFPConstant(CFP->getContext(), CFP->getValueAPF(), PreferBFloat); + if (ShrinkTy) + if (auto *VecTy = dyn_cast<VectorType>(Ty)) + ShrinkTy = VectorType::get(ShrinkTy, VecTy); + + return ShrinkTy; +} + // Determine if this is a vector of ConstantFPs and if so, return the minimal // type we can safely truncate all elements to. static Type *shrinkFPConstantVector(Value *V, bool PreferBFloat) { @@ -1720,10 +1733,10 @@ static Type *getMinimumFPType(Value *V, bool PreferBFloat) { // Try to shrink scalable and fixed splat vectors. if (auto *FPC = dyn_cast<Constant>(V)) - if (isa<VectorType>(V->getType())) + if (auto *VTy = dyn_cast<VectorType>(V->getType())) if (auto *Splat = dyn_cast_or_null<ConstantFP>(FPC->getSplatValue())) if (Type *T = shrinkFPConstant(Splat, PreferBFloat)) - return T; + return VectorType::get(T, VTy); // Try to shrink a vector of FP constants. This returns nullptr on scalable // vectors @@ -1796,10 +1809,9 @@ Instruction *InstCombinerImpl::visitFPTrunc(FPTruncInst &FPT) { Type *Ty = FPT.getType(); auto *BO = dyn_cast<BinaryOperator>(FPT.getOperand(0)); if (BO && BO->hasOneUse()) { - Type *LHSMinType = - getMinimumFPType(BO->getOperand(0), /*PreferBFloat=*/Ty->isBFloatTy()); - Type *RHSMinType = - getMinimumFPType(BO->getOperand(1), /*PreferBFloat=*/Ty->isBFloatTy()); + bool PreferBFloat = Ty->getScalarType()->isBFloatTy(); + Type *LHSMinType = getMinimumFPType(BO->getOperand(0), PreferBFloat); + Type *RHSMinType = getMinimumFPType(BO->getOperand(1), PreferBFloat); unsigned OpWidth = BO->getType()->getFPMantissaWidth(); unsigned LHSWidth = LHSMinType->getFPMantissaWidth(); unsigned RHSWidth = RHSMinType->getFPMantissaWidth(); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h index 943c223..ede73f8 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -664,7 +664,8 @@ public: /// This also works for Cast instructions, which obviously do not have a /// second operand. Instruction *FoldOpIntoSelect(Instruction &Op, SelectInst *SI, - bool FoldWithMultiUse = false); + bool FoldWithMultiUse = false, + bool SimplifyBothArms = false); /// This is a convenience wrapper function for the above two functions. Instruction *foldBinOpIntoSelectOrPhi(BinaryOperator &I); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index 975498f..f5130da 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -3455,27 +3455,45 @@ Instruction *InstCombinerImpl::foldSelectOfBools(SelectInst &SI) { // select a, false, b -> select !a, b, false if (match(TrueVal, m_Specific(Zero))) { Value *NotCond = Builder.CreateNot(CondVal, "not." + CondVal->getName()); - return SelectInst::Create(NotCond, FalseVal, Zero); + Instruction *MDFrom = ProfcheckDisableMetadataFixes ? nullptr : &SI; + SelectInst *NewSI = + SelectInst::Create(NotCond, FalseVal, Zero, "", nullptr, MDFrom); + NewSI->swapProfMetadata(); + return NewSI; } // select a, b, true -> select !a, true, b if (match(FalseVal, m_Specific(One))) { Value *NotCond = Builder.CreateNot(CondVal, "not." + CondVal->getName()); - return SelectInst::Create(NotCond, One, TrueVal); + Instruction *MDFrom = ProfcheckDisableMetadataFixes ? nullptr : &SI; + SelectInst *NewSI = + SelectInst::Create(NotCond, One, TrueVal, "", nullptr, MDFrom); + NewSI->swapProfMetadata(); + return NewSI; } // DeMorgan in select form: !a && !b --> !(a || b) // select !a, !b, false --> not (select a, true, b) if (match(&SI, m_LogicalAnd(m_Not(m_Value(A)), m_Not(m_Value(B)))) && (CondVal->hasOneUse() || TrueVal->hasOneUse()) && - !match(A, m_ConstantExpr()) && !match(B, m_ConstantExpr())) - return BinaryOperator::CreateNot(Builder.CreateSelect(A, One, B)); + !match(A, m_ConstantExpr()) && !match(B, m_ConstantExpr())) { + Instruction *MDFrom = ProfcheckDisableMetadataFixes ? nullptr : &SI; + SelectInst *NewSI = + cast<SelectInst>(Builder.CreateSelect(A, One, B, "", MDFrom)); + NewSI->swapProfMetadata(); + return BinaryOperator::CreateNot(NewSI); + } // DeMorgan in select form: !a || !b --> !(a && b) // select !a, true, !b --> not (select a, b, false) if (match(&SI, m_LogicalOr(m_Not(m_Value(A)), m_Not(m_Value(B)))) && (CondVal->hasOneUse() || FalseVal->hasOneUse()) && - !match(A, m_ConstantExpr()) && !match(B, m_ConstantExpr())) - return BinaryOperator::CreateNot(Builder.CreateSelect(A, B, Zero)); + !match(A, m_ConstantExpr()) && !match(B, m_ConstantExpr())) { + Instruction *MDFrom = ProfcheckDisableMetadataFixes ? nullptr : &SI; + SelectInst *NewSI = + cast<SelectInst>(Builder.CreateSelect(A, B, Zero, "", MDFrom)); + NewSI->swapProfMetadata(); + return BinaryOperator::CreateNot(NewSI); + } // select (select a, true, b), true, b -> select a, true, b if (match(CondVal, m_Select(m_Value(A), m_One(), m_Value(B))) && @@ -4679,5 +4697,31 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { cast<IntrinsicInst>(TrueVal)->getParamAlign(0).valueOrOne(), CondVal, FalseVal)); + // Canonicalize sign function ashr pattern: select (icmp slt X, 1), ashr X, + // bitwidth-1, 1 -> scmp(X, 0) + // Also handles: select (icmp sgt X, 0), 1, ashr X, bitwidth-1 -> scmp(X, 0) + unsigned BitWidth = SI.getType()->getScalarSizeInBits(); + CmpPredicate Pred; + Value *CmpLHS, *CmpRHS; + + // Canonicalize sign function ashr patterns: + // select (icmp slt X, 1), ashr X, bitwidth-1, 1 -> scmp(X, 0) + // select (icmp sgt X, 0), 1, ashr X, bitwidth-1 -> scmp(X, 0) + if (match(&SI, m_Select(m_ICmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS)), + m_Value(TrueVal), m_Value(FalseVal))) && + ((Pred == ICmpInst::ICMP_SLT && match(CmpRHS, m_One()) && + match(TrueVal, + m_AShr(m_Specific(CmpLHS), m_SpecificInt(BitWidth - 1))) && + match(FalseVal, m_One())) || + (Pred == ICmpInst::ICMP_SGT && match(CmpRHS, m_Zero()) && + match(TrueVal, m_One()) && + match(FalseVal, + m_AShr(m_Specific(CmpLHS), m_SpecificInt(BitWidth - 1)))))) { + + Function *Scmp = Intrinsic::getOrInsertDeclaration( + SI.getModule(), Intrinsic::scmp, {SI.getType(), SI.getType()}); + return CallInst::Create(Scmp, {CmpLHS, ConstantInt::get(SI.getType(), 0)}); + } + return nullptr; } diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp index 3f11cae..9c8de45 100644 --- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -1777,7 +1777,8 @@ static Value *foldOperationIntoSelectOperand(Instruction &I, SelectInst *SI, } Instruction *InstCombinerImpl::FoldOpIntoSelect(Instruction &Op, SelectInst *SI, - bool FoldWithMultiUse) { + bool FoldWithMultiUse, + bool SimplifyBothArms) { // Don't modify shared select instructions unless set FoldWithMultiUse if (!SI->hasOneUse() && !FoldWithMultiUse) return nullptr; @@ -1821,6 +1822,9 @@ Instruction *InstCombinerImpl::FoldOpIntoSelect(Instruction &Op, SelectInst *SI, if (!NewTV && !NewFV) return nullptr; + if (SimplifyBothArms && !(NewTV && NewFV)) + return nullptr; + // Create an instruction for the arm that did not fold. if (!NewTV) NewTV = foldOperationIntoSelectOperand(Op, SI, TV, *this); @@ -2323,6 +2327,18 @@ Constant *InstCombinerImpl::unshuffleConstant(ArrayRef<int> ShMask, Constant *C, return ConstantVector::get(NewVecC); } +// Get the result of `Vector Op Splat` (or Splat Op Vector if \p SplatLHS). +static Constant *constantFoldBinOpWithSplat(unsigned Opcode, Constant *Vector, + Constant *Splat, bool SplatLHS, + const DataLayout &DL) { + ElementCount EC = cast<VectorType>(Vector->getType())->getElementCount(); + Constant *LHS = ConstantVector::getSplat(EC, Splat); + Constant *RHS = Vector; + if (!SplatLHS) + std::swap(LHS, RHS); + return ConstantFoldBinaryOpOperands(Opcode, LHS, RHS, DL); +} + Instruction *InstCombinerImpl::foldVectorBinop(BinaryOperator &Inst) { if (!isa<VectorType>(Inst.getType())) return nullptr; @@ -2334,6 +2350,37 @@ Instruction *InstCombinerImpl::foldVectorBinop(BinaryOperator &Inst) { assert(cast<VectorType>(RHS->getType())->getElementCount() == cast<VectorType>(Inst.getType())->getElementCount()); + auto foldConstantsThroughSubVectorInsertSplat = + [&](Value *MaybeSubVector, Value *MaybeSplat, + bool SplatLHS) -> Instruction * { + Value *Idx; + Constant *Splat, *SubVector, *Dest; + if (!match(MaybeSplat, m_ConstantSplat(m_Constant(Splat))) || + !match(MaybeSubVector, + m_VectorInsert(m_Constant(Dest), m_Constant(SubVector), + m_Value(Idx)))) + return nullptr; + SubVector = + constantFoldBinOpWithSplat(Opcode, SubVector, Splat, SplatLHS, DL); + Dest = constantFoldBinOpWithSplat(Opcode, Dest, Splat, SplatLHS, DL); + if (!SubVector || !Dest) + return nullptr; + auto *InsertVector = + Builder.CreateInsertVector(Dest->getType(), Dest, SubVector, Idx); + return replaceInstUsesWith(Inst, InsertVector); + }; + + // If one operand is a constant splat and the other operand is a + // `vector.insert` where both the destination and subvector are constant, + // apply the operation to both the destination and subvector, returning a new + // constant `vector.insert`. This helps constant folding for scalable vectors. + if (Instruction *Folded = foldConstantsThroughSubVectorInsertSplat( + /*MaybeSubVector=*/LHS, /*MaybeSplat=*/RHS, /*SplatLHS=*/false)) + return Folded; + if (Instruction *Folded = foldConstantsThroughSubVectorInsertSplat( + /*MaybeSubVector=*/RHS, /*MaybeSplat=*/LHS, /*SplatLHS=*/true)) + return Folded; + // If both operands of the binop are vector concatenations, then perform the // narrow binop on each pair of the source operands followed by concatenation // of the results. diff --git a/llvm/lib/Transforms/Instrumentation/AllocToken.cpp b/llvm/lib/Transforms/Instrumentation/AllocToken.cpp index 40720ae..8181e4e 100644 --- a/llvm/lib/Transforms/Instrumentation/AllocToken.cpp +++ b/llvm/lib/Transforms/Instrumentation/AllocToken.cpp @@ -31,10 +31,12 @@ #include "llvm/IR/InstIterator.h" #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" #include "llvm/IR/PassManager.h" #include "llvm/IR/Type.h" +#include "llvm/Support/AllocToken.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Compiler.h" @@ -53,47 +55,14 @@ #include <variant> using namespace llvm; +using TokenMode = AllocTokenMode; #define DEBUG_TYPE "alloc-token" namespace { -//===--- Constants --------------------------------------------------------===// - -enum class TokenMode : unsigned { - /// Incrementally increasing token ID. - Increment = 0, - - /// Simple mode that returns a statically-assigned random token ID. - Random = 1, - - /// Token ID based on allocated type hash. - TypeHash = 2, - - /// Token ID based on allocated type hash, where the top half ID-space is - /// reserved for types that contain pointers and the bottom half for types - /// that do not contain pointers. - TypeHashPointerSplit = 3, -}; - //===--- Command-line options ---------------------------------------------===// -cl::opt<TokenMode> ClMode( - "alloc-token-mode", cl::Hidden, cl::desc("Token assignment mode"), - cl::init(TokenMode::TypeHashPointerSplit), - cl::values( - clEnumValN(TokenMode::Increment, "increment", - "Incrementally increasing token ID"), - clEnumValN(TokenMode::Random, "random", - "Statically-assigned random token ID"), - clEnumValN(TokenMode::TypeHash, "typehash", - "Token ID based on allocated type hash"), - clEnumValN( - TokenMode::TypeHashPointerSplit, "typehashpointersplit", - "Token ID based on allocated type hash, where the top half " - "ID-space is reserved for types that contain pointers and the " - "bottom half for types that do not contain pointers. "))); - cl::opt<std::string> ClFuncPrefix("alloc-token-prefix", cl::desc("The allocation function prefix"), cl::Hidden, cl::init("__alloc_token_")); @@ -131,7 +100,7 @@ cl::opt<uint64_t> ClFallbackToken( //===--- Statistics -------------------------------------------------------===// -STATISTIC(NumFunctionsInstrumented, "Functions instrumented"); +STATISTIC(NumFunctionsModified, "Functions modified"); STATISTIC(NumAllocationsInstrumented, "Allocations instrumented"); //===----------------------------------------------------------------------===// @@ -140,9 +109,19 @@ STATISTIC(NumAllocationsInstrumented, "Allocations instrumented"); /// /// Expected format is: !{<type-name>, <contains-pointer>} MDNode *getAllocTokenMetadata(const CallBase &CB) { - MDNode *Ret = CB.getMetadata(LLVMContext::MD_alloc_token); - if (!Ret) - return nullptr; + MDNode *Ret = nullptr; + if (auto *II = dyn_cast<IntrinsicInst>(&CB); + II && II->getIntrinsicID() == Intrinsic::alloc_token_id) { + auto *MDV = cast<MetadataAsValue>(II->getArgOperand(0)); + Ret = cast<MDNode>(MDV->getMetadata()); + // If the intrinsic has an empty MDNode, type inference failed. + if (Ret->getNumOperands() == 0) + return nullptr; + } else { + Ret = CB.getMetadata(LLVMContext::MD_alloc_token); + if (!Ret) + return nullptr; + } assert(Ret->getNumOperands() == 2 && "bad !alloc_token"); assert(isa<MDString>(Ret->getOperand(0))); assert(isa<ConstantAsMetadata>(Ret->getOperand(1))); @@ -206,22 +185,19 @@ public: using ModeBase::ModeBase; uint64_t operator()(const CallBase &CB, OptimizationRemarkEmitter &ORE) { - const auto [N, H] = getHash(CB, ORE); - return N ? boundedToken(H) : H; - } -protected: - std::pair<MDNode *, uint64_t> getHash(const CallBase &CB, - OptimizationRemarkEmitter &ORE) { if (MDNode *N = getAllocTokenMetadata(CB)) { MDString *S = cast<MDString>(N->getOperand(0)); - return {N, getStableSipHash(S->getString())}; + AllocTokenMetadata Metadata{S->getString(), containsPointer(N)}; + if (auto Token = getAllocToken(TokenMode::TypeHash, Metadata, MaxTokens)) + return *Token; } // Fallback. remarkNoMetadata(CB, ORE); - return {nullptr, ClFallbackToken}; + return ClFallbackToken; } +protected: /// Remark that there was no precise type information. static void remarkNoMetadata(const CallBase &CB, OptimizationRemarkEmitter &ORE) { @@ -242,20 +218,18 @@ public: using TypeHashMode::TypeHashMode; uint64_t operator()(const CallBase &CB, OptimizationRemarkEmitter &ORE) { - if (MaxTokens == 1) - return 0; - const uint64_t HalfTokens = MaxTokens / 2; - const auto [N, H] = getHash(CB, ORE); - if (!N) { - // Pick the fallback token (ClFallbackToken), which by default is 0, - // meaning it'll fall into the pointer-less bucket. Override by setting - // -alloc-token-fallback if that is the wrong choice. - return H; + if (MDNode *N = getAllocTokenMetadata(CB)) { + MDString *S = cast<MDString>(N->getOperand(0)); + AllocTokenMetadata Metadata{S->getString(), containsPointer(N)}; + if (auto Token = getAllocToken(TokenMode::TypeHashPointerSplit, Metadata, + MaxTokens)) + return *Token; } - uint64_t Hash = H % HalfTokens; // base hash - if (containsPointer(N)) - Hash += HalfTokens; - return Hash; + // Pick the fallback token (ClFallbackToken), which by default is 0, meaning + // it'll fall into the pointer-less bucket. Override by setting + // -alloc-token-fallback if that is the wrong choice. + remarkNoMetadata(CB, ORE); + return ClFallbackToken; } }; @@ -275,7 +249,7 @@ public: : Options(transformOptionsFromCl(std::move(Opts))), Mod(M), FAM(MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager()), Mode(IncrementMode(*IntPtrTy, *Options.MaxTokens)) { - switch (ClMode.getValue()) { + switch (Options.Mode) { case TokenMode::Increment: break; case TokenMode::Random: @@ -315,6 +289,9 @@ private: FunctionCallee getTokenAllocFunction(const CallBase &CB, uint64_t TokenID, LibFunc OriginalFunc); + /// Lower alloc_token_* intrinsics. + void replaceIntrinsicInst(IntrinsicInst *II, OptimizationRemarkEmitter &ORE); + /// Return the token ID from metadata in the call. uint64_t getToken(const CallBase &CB, OptimizationRemarkEmitter &ORE) { return std::visit([&](auto &&Mode) { return Mode(CB, ORE); }, Mode); @@ -336,21 +313,32 @@ bool AllocToken::instrumentFunction(Function &F) { // Do not apply any instrumentation for naked functions. if (F.hasFnAttribute(Attribute::Naked)) return false; - if (F.hasFnAttribute(Attribute::DisableSanitizerInstrumentation)) - return false; // Don't touch available_externally functions, their actual body is elsewhere. if (F.getLinkage() == GlobalValue::AvailableExternallyLinkage) return false; - // Only instrument functions that have the sanitize_alloc_token attribute. - if (!F.hasFnAttribute(Attribute::SanitizeAllocToken)) - return false; auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F); auto &TLI = FAM.getResult<TargetLibraryAnalysis>(F); SmallVector<std::pair<CallBase *, LibFunc>, 4> AllocCalls; + SmallVector<IntrinsicInst *, 4> IntrinsicInsts; + + // Only instrument functions that have the sanitize_alloc_token attribute. + const bool InstrumentFunction = + F.hasFnAttribute(Attribute::SanitizeAllocToken) && + !F.hasFnAttribute(Attribute::DisableSanitizerInstrumentation); // Collect all allocation calls to avoid iterator invalidation. for (Instruction &I : instructions(F)) { + // Collect all alloc_token_* intrinsics. + if (auto *II = dyn_cast<IntrinsicInst>(&I); + II && II->getIntrinsicID() == Intrinsic::alloc_token_id) { + IntrinsicInsts.emplace_back(II); + continue; + } + + if (!InstrumentFunction) + continue; + auto *CB = dyn_cast<CallBase>(&I); if (!CB) continue; @@ -359,11 +347,21 @@ bool AllocToken::instrumentFunction(Function &F) { } bool Modified = false; - for (auto &[CB, Func] : AllocCalls) - Modified |= replaceAllocationCall(CB, Func, ORE, TLI); - if (Modified) - NumFunctionsInstrumented++; + if (!AllocCalls.empty()) { + for (auto &[CB, Func] : AllocCalls) + Modified |= replaceAllocationCall(CB, Func, ORE, TLI); + if (Modified) + NumFunctionsModified++; + } + + if (!IntrinsicInsts.empty()) { + for (auto *II : IntrinsicInsts) + replaceIntrinsicInst(II, ORE); + Modified = true; + NumFunctionsModified++; + } + return Modified; } @@ -381,7 +379,7 @@ AllocToken::shouldInstrumentCall(const CallBase &CB, if (TLI.getLibFunc(*Callee, Func)) { if (isInstrumentableLibFunc(Func, CB, TLI)) return Func; - } else if (Options.Extended && getAllocTokenMetadata(CB)) { + } else if (Options.Extended && CB.getMetadata(LLVMContext::MD_alloc_token)) { return NotLibFunc; } @@ -528,6 +526,16 @@ FunctionCallee AllocToken::getTokenAllocFunction(const CallBase &CB, return TokenAlloc; } +void AllocToken::replaceIntrinsicInst(IntrinsicInst *II, + OptimizationRemarkEmitter &ORE) { + assert(II->getIntrinsicID() == Intrinsic::alloc_token_id); + + uint64_t TokenID = getToken(*II, ORE); + Value *V = ConstantInt::get(IntPtrTy, TokenID); + II->replaceAllUsesWith(V); + II->eraseFromParent(); +} + } // namespace AllocTokenPass::AllocTokenPass(AllocTokenOptions Opts) diff --git a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp index b6cbecb..10b03bb 100644 --- a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp +++ b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp @@ -226,6 +226,7 @@ static const Align kMinOriginAlignment = Align(4); static const Align kShadowTLSAlignment = Align(8); // These constants must be kept in sync with the ones in msan.h. +// TODO: increase size to match SVE/SVE2/SME/SME2 limits static const unsigned kParamTLSSize = 800; static const unsigned kRetvalTLSSize = 800; @@ -1544,6 +1545,22 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { } } + static bool isAArch64SVCount(Type *Ty) { + if (TargetExtType *TTy = dyn_cast<TargetExtType>(Ty)) + return TTy->getName() == "aarch64.svcount"; + return false; + } + + // This is intended to match the "AArch64 Predicate-as-Counter Type" (aka + // 'target("aarch64.svcount")', but not e.g., <vscale x 4 x i32>. + static bool isScalableNonVectorType(Type *Ty) { + if (!isAArch64SVCount(Ty)) + LLVM_DEBUG(dbgs() << "isScalableNonVectorType: Unexpected type " << *Ty + << "\n"); + + return Ty->isScalableTy() && !isa<VectorType>(Ty); + } + void materializeChecks() { #ifndef NDEBUG // For assert below. @@ -1672,6 +1689,12 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { LLVM_DEBUG(dbgs() << "getShadowTy: " << *ST << " ===> " << *Res << "\n"); return Res; } + if (isScalableNonVectorType(OrigTy)) { + LLVM_DEBUG(dbgs() << "getShadowTy: Scalable non-vector type: " << *OrigTy + << "\n"); + return OrigTy; + } + uint32_t TypeSize = DL.getTypeSizeInBits(OrigTy); return IntegerType::get(*MS.C, TypeSize); } @@ -2185,8 +2208,14 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { << *OrigIns << "\n"); return; } -#ifndef NDEBUG + Type *ShadowTy = Shadow->getType(); + if (isScalableNonVectorType(ShadowTy)) { + LLVM_DEBUG(dbgs() << "Skipping check of scalable non-vector " << *Shadow + << " before " << *OrigIns << "\n"); + return; + } +#ifndef NDEBUG assert((isa<IntegerType>(ShadowTy) || isa<VectorType>(ShadowTy) || isa<StructType>(ShadowTy) || isa<ArrayType>(ShadowTy)) && "Can only insert checks for integer, vector, and aggregate shadow " @@ -6972,6 +7001,15 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { // an extra "select". This results in much more compact IR. // Sa = select Sb, poisoned, (select b, Sc, Sd) Sa1 = getPoisonedShadow(getShadowTy(I.getType())); + } else if (isScalableNonVectorType(I.getType())) { + // This is intended to handle target("aarch64.svcount"), which can't be + // handled in the else branch because of incompatibility with CreateXor + // ("The supported LLVM operations on this type are limited to load, + // store, phi, select and alloca instructions"). + + // TODO: this currently underapproximates. Use Arm SVE EOR in the else + // branch as needed instead. + Sa1 = getCleanShadow(getShadowTy(I.getType())); } else { // Sa = select Sb, [ (c^d) | Sc | Sd ], [ b ? Sc : Sd ] // If Sb (condition is poisoned), look for bits in c and d that are equal diff --git a/llvm/lib/Transforms/Instrumentation/NumericalStabilitySanitizer.cpp b/llvm/lib/Transforms/Instrumentation/NumericalStabilitySanitizer.cpp index d18c0d0..80e77e09 100644 --- a/llvm/lib/Transforms/Instrumentation/NumericalStabilitySanitizer.cpp +++ b/llvm/lib/Transforms/Instrumentation/NumericalStabilitySanitizer.cpp @@ -2020,7 +2020,6 @@ static void moveFastMathFlags(Function &F, F.removeFnAttr(attr); \ FMF.set##setter(); \ } - MOVE_FLAG("unsafe-fp-math", Fast) MOVE_FLAG("no-infs-fp-math", NoInfs) MOVE_FLAG("no-nans-fp-math", NoNaNs) MOVE_FLAG("no-signed-zeros-fp-math", NoSignedZeros) diff --git a/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp b/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp index 4acc3f2..d347ced 100644 --- a/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp +++ b/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp @@ -614,6 +614,16 @@ static Decomposition decompose(Value *V, return {V, IsKnownNonNegative}; } + if (match(V, m_Add(m_Value(Op0), m_ConstantInt(CI))) && CI->isNegative() && + canUseSExt(CI)) { + Preconditions.emplace_back( + CmpInst::ICMP_UGE, Op0, + ConstantInt::get(Op0->getType(), CI->getSExtValue() * -1)); + if (auto Decomp = MergeResults(Op0, CI, true)) + return *Decomp; + return {V, IsKnownNonNegative}; + } + if (match(V, m_NSWAdd(m_Value(Op0), m_Value(Op1)))) { if (!isKnownNonNegative(Op0, DL)) Preconditions.emplace_back(CmpInst::ICMP_SGE, Op0, @@ -627,16 +637,6 @@ static Decomposition decompose(Value *V, return {V, IsKnownNonNegative}; } - if (match(V, m_Add(m_Value(Op0), m_ConstantInt(CI))) && CI->isNegative() && - canUseSExt(CI)) { - Preconditions.emplace_back( - CmpInst::ICMP_UGE, Op0, - ConstantInt::get(Op0->getType(), CI->getSExtValue() * -1)); - if (auto Decomp = MergeResults(Op0, CI, true)) - return *Decomp; - return {V, IsKnownNonNegative}; - } - // Decompose or as an add if there are no common bits between the operands. if (match(V, m_DisjointOr(m_Value(Op0), m_ConstantInt(CI)))) { if (auto Decomp = MergeResults(Op0, CI, IsSigned)) diff --git a/llvm/lib/Transforms/Scalar/MergeICmps.cpp b/llvm/lib/Transforms/Scalar/MergeICmps.cpp index a83cbd17a7..f273e9d 100644 --- a/llvm/lib/Transforms/Scalar/MergeICmps.cpp +++ b/llvm/lib/Transforms/Scalar/MergeICmps.cpp @@ -64,10 +64,10 @@ using namespace llvm; -namespace { - #define DEBUG_TYPE "mergeicmps" +namespace { + // A BCE atom "Binary Compare Expression Atom" represents an integer load // that is a constant offset from a base value, e.g. `a` or `o.c` in the example // at the top. @@ -128,11 +128,12 @@ private: unsigned Order = 1; DenseMap<const Value*, int> BaseToIndex; }; +} // namespace // If this value is a load from a constant offset w.r.t. a base address, and // there are no other users of the load or address, returns the base address and // the offset. -BCEAtom visitICmpLoadOperand(Value *const Val, BaseIdentifier &BaseId) { +static BCEAtom visitICmpLoadOperand(Value *const Val, BaseIdentifier &BaseId) { auto *const LoadI = dyn_cast<LoadInst>(Val); if (!LoadI) return {}; @@ -175,6 +176,7 @@ BCEAtom visitICmpLoadOperand(Value *const Val, BaseIdentifier &BaseId) { return BCEAtom(GEP, LoadI, BaseId.getBaseId(Base), Offset); } +namespace { // A comparison between two BCE atoms, e.g. `a == o.a` in the example at the // top. // Note: the terminology is misleading: the comparison is symmetric, so there @@ -239,6 +241,7 @@ class BCECmpBlock { private: BCECmp Cmp; }; +} // namespace bool BCECmpBlock::canSinkBCECmpInst(const Instruction *Inst, AliasAnalysis &AA) const { @@ -302,9 +305,9 @@ bool BCECmpBlock::doesOtherWork() const { // Visit the given comparison. If this is a comparison between two valid // BCE atoms, returns the comparison. -std::optional<BCECmp> visitICmp(const ICmpInst *const CmpI, - const ICmpInst::Predicate ExpectedPredicate, - BaseIdentifier &BaseId) { +static std::optional<BCECmp> +visitICmp(const ICmpInst *const CmpI, + const ICmpInst::Predicate ExpectedPredicate, BaseIdentifier &BaseId) { // The comparison can only be used once: // - For intermediate blocks, as a branch condition. // - For the final block, as an incoming value for the Phi. @@ -332,10 +335,9 @@ std::optional<BCECmp> visitICmp(const ICmpInst *const CmpI, // Visit the given comparison block. If this is a comparison between two valid // BCE atoms, returns the comparison. -std::optional<BCECmpBlock> visitCmpBlock(Value *const Val, - BasicBlock *const Block, - const BasicBlock *const PhiBlock, - BaseIdentifier &BaseId) { +static std::optional<BCECmpBlock> +visitCmpBlock(Value *const Val, BasicBlock *const Block, + const BasicBlock *const PhiBlock, BaseIdentifier &BaseId) { if (Block->empty()) return std::nullopt; auto *const BranchI = dyn_cast<BranchInst>(Block->getTerminator()); @@ -397,6 +399,7 @@ static inline void enqueueBlock(std::vector<BCECmpBlock> &Comparisons, Comparisons.push_back(std::move(Comparison)); } +namespace { // A chain of comparisons. class BCECmpChain { public: @@ -420,6 +423,7 @@ private: // The original entry block (before sorting); BasicBlock *EntryBlock_; }; +} // namespace static bool areContiguous(const BCECmpBlock &First, const BCECmpBlock &Second) { return First.Lhs().BaseId == Second.Lhs().BaseId && @@ -742,9 +746,8 @@ bool BCECmpChain::simplify(const TargetLibraryInfo &TLI, AliasAnalysis &AA, return true; } -std::vector<BasicBlock *> getOrderedBlocks(PHINode &Phi, - BasicBlock *const LastBlock, - int NumBlocks) { +static std::vector<BasicBlock *> +getOrderedBlocks(PHINode &Phi, BasicBlock *const LastBlock, int NumBlocks) { // Walk up from the last block to find other blocks. std::vector<BasicBlock *> Blocks(NumBlocks); assert(LastBlock && "invalid last block"); @@ -777,8 +780,8 @@ std::vector<BasicBlock *> getOrderedBlocks(PHINode &Phi, return Blocks; } -bool processPhi(PHINode &Phi, const TargetLibraryInfo &TLI, AliasAnalysis &AA, - DomTreeUpdater &DTU) { +static bool processPhi(PHINode &Phi, const TargetLibraryInfo &TLI, + AliasAnalysis &AA, DomTreeUpdater &DTU) { LLVM_DEBUG(dbgs() << "processPhi()\n"); if (Phi.getNumIncomingValues() <= 1) { LLVM_DEBUG(dbgs() << "skip: only one incoming value in phi\n"); @@ -874,6 +877,7 @@ static bool runImpl(Function &F, const TargetLibraryInfo &TLI, return MadeChange; } +namespace { class MergeICmpsLegacyPass : public FunctionPass { public: static char ID; diff --git a/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp b/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp index 8714741a..9829d4d 100644 --- a/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp +++ b/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp @@ -1793,3 +1793,13 @@ bool llvm::hasOnlySimpleTerminator(const Function &F) { } return true; } + +Printable llvm::printBasicBlock(const BasicBlock *BB) { + return Printable([BB](raw_ostream &OS) { + if (!BB) { + OS << "<nullptr>"; + return; + } + BB->printAsOperand(OS); + }); +} diff --git a/llvm/lib/Transforms/Utils/PredicateInfo.cpp b/llvm/lib/Transforms/Utils/PredicateInfo.cpp index 978d5a2..371d9e6 100644 --- a/llvm/lib/Transforms/Utils/PredicateInfo.cpp +++ b/llvm/lib/Transforms/Utils/PredicateInfo.cpp @@ -260,9 +260,16 @@ bool PredicateInfoBuilder::stackIsInScope(const ValueDFSStack &Stack, // next to the defs they must go with so that we can know it's time to pop // the stack when we hit the end of the phi uses for a given def. const ValueDFS &Top = *Stack.back().V; - if (Top.LocalNum == LN_Last && Top.PInfo) { - if (!VDUse.U) - return false; + assert(Top.PInfo && "RenameStack should only contain predicate infos (defs)"); + if (Top.LocalNum == LN_Last) { + if (!VDUse.U) { + assert(VDUse.PInfo && "A non-use VDUse should have a predicate info"); + // We should reserve adjacent LN_Last defs for the same phi use. + return VDUse.LocalNum == LN_Last && + // If the two phi defs have the same edge, they must be designated + // for the same succ BB. + getBlockEdge(Top.PInfo) == getBlockEdge(VDUse.PInfo); + } auto *PHI = dyn_cast<PHINode>(VDUse.U->getUser()); if (!PHI) return false; diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp index febdc54..3356516 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -1011,6 +1011,10 @@ public: /// \returns True if instruction \p I can be truncated to a smaller bitwidth /// for vectorization factor \p VF. bool canTruncateToMinimalBitwidth(Instruction *I, ElementCount VF) const { + // Truncs must truncate at most to their destination type. + if (isa_and_nonnull<TruncInst>(I) && MinBWs.contains(I) && + I->getType()->getScalarSizeInBits() < MinBWs.lookup(I)) + return false; return VF.isVector() && MinBWs.contains(I) && !isProfitableToScalarize(I, VF) && !isScalarAfterVectorization(I, VF); @@ -9855,6 +9859,8 @@ bool LoopVectorizePass::processLoop(Loop *L) { // Get user vectorization factor and interleave count. ElementCount UserVF = Hints.getWidth(); unsigned UserIC = Hints.getInterleave(); + if (UserIC > 1 && !LVL.isSafeForAnyVectorWidth()) + UserIC = 1; // Plan how to best vectorize. LVP.plan(UserVF, UserIC); @@ -9919,7 +9925,15 @@ bool LoopVectorizePass::processLoop(Loop *L) { VectorizeLoop = false; } - if (!LVP.hasPlanWithVF(VF.Width) && UserIC > 1) { + if (UserIC == 1 && Hints.getInterleave() > 1) { + assert(!LVL.isSafeForAnyVectorWidth() && + "UserIC should only be ignored due to unsafe dependencies"); + LLVM_DEBUG(dbgs() << "LV: Ignoring user-specified interleave count.\n"); + IntDiagMsg = {"InterleavingUnsafe", + "Ignoring user-specified interleave count due to possibly " + "unsafe dependencies in the loop."}; + InterleaveLoop = false; + } else if (!LVP.hasPlanWithVF(VF.Width) && UserIC > 1) { // Tell the user interleaving was avoided up-front, despite being explicitly // requested. LLVM_DEBUG(dbgs() << "LV: Ignoring UserIC, because vectorization and " diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index 3f18bd7..cdb9e7e 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -5577,62 +5577,79 @@ private: } // Decrement the unscheduled counter and insert to ready list if // ready. - auto DecrUnschedForInst = [&](Instruction *I, TreeEntry *UserTE, - unsigned OpIdx) { - if (!ScheduleCopyableDataMap.empty()) { - const EdgeInfo EI = {UserTE, OpIdx}; - if (ScheduleCopyableData *CD = getScheduleCopyableData(EI, I)) { - DecrUnsched(CD, /*IsControl=*/false); - return; - } - } - auto It = OperandsUses.find(I); - assert(It != OperandsUses.end() && "Operand not found"); - if (It->second > 0) { - --It->getSecond(); - assert(TotalOpCount > 0 && "No more operands to decrement"); - --TotalOpCount; - if (ScheduleData *OpSD = getScheduleData(I)) - DecrUnsched(OpSD, /*IsControl=*/false); - } - }; + auto DecrUnschedForInst = + [&](Instruction *I, TreeEntry *UserTE, unsigned OpIdx, + SmallDenseSet<std::pair<const ScheduleEntity *, unsigned>> + &Checked) { + if (!ScheduleCopyableDataMap.empty()) { + const EdgeInfo EI = {UserTE, OpIdx}; + if (ScheduleCopyableData *CD = + getScheduleCopyableData(EI, I)) { + if (!Checked.insert(std::make_pair(CD, OpIdx)).second) + return; + DecrUnsched(CD, /*IsControl=*/false); + return; + } + } + auto It = OperandsUses.find(I); + assert(It != OperandsUses.end() && "Operand not found"); + if (It->second > 0) { + --It->getSecond(); + assert(TotalOpCount > 0 && "No more operands to decrement"); + --TotalOpCount; + if (ScheduleData *OpSD = getScheduleData(I)) { + if (!Checked.insert(std::make_pair(OpSD, OpIdx)).second) + return; + DecrUnsched(OpSD, /*IsControl=*/false); + } + } + }; for (ScheduleBundle *Bundle : Bundles) { if (ScheduleCopyableDataMap.empty() && TotalOpCount == 0) break; // Need to search for the lane since the tree entry can be // reordered. - int Lane = std::distance(Bundle->getTreeEntry()->Scalars.begin(), - find(Bundle->getTreeEntry()->Scalars, In)); - assert(Lane >= 0 && "Lane not set"); - if (isa<StoreInst>(In) && - !Bundle->getTreeEntry()->ReorderIndices.empty()) - Lane = Bundle->getTreeEntry()->ReorderIndices[Lane]; - assert(Lane < static_cast<int>( - Bundle->getTreeEntry()->Scalars.size()) && - "Couldn't find extract lane"); - - // Since vectorization tree is being built recursively this - // assertion ensures that the tree entry has all operands set before - // reaching this code. Couple of exceptions known at the moment are - // extracts where their second (immediate) operand is not added. - // Since immediates do not affect scheduler behavior this is - // considered okay. - assert(In && - (isa<ExtractValueInst, ExtractElementInst, CallBase>(In) || - In->getNumOperands() == - Bundle->getTreeEntry()->getNumOperands() || - Bundle->getTreeEntry()->isCopyableElement(In)) && - "Missed TreeEntry operands?"); - - for (unsigned OpIdx : - seq<unsigned>(Bundle->getTreeEntry()->getNumOperands())) - if (auto *I = dyn_cast<Instruction>( - Bundle->getTreeEntry()->getOperand(OpIdx)[Lane])) { - LLVM_DEBUG(dbgs() << "SLP: check for readiness (def): " << *I - << "\n"); - DecrUnschedForInst(I, Bundle->getTreeEntry(), OpIdx); - } + auto *It = find(Bundle->getTreeEntry()->Scalars, In); + SmallDenseSet<std::pair<const ScheduleEntity *, unsigned>> Checked; + do { + int Lane = + std::distance(Bundle->getTreeEntry()->Scalars.begin(), It); + assert(Lane >= 0 && "Lane not set"); + if (isa<StoreInst>(In) && + !Bundle->getTreeEntry()->ReorderIndices.empty()) + Lane = Bundle->getTreeEntry()->ReorderIndices[Lane]; + assert(Lane < static_cast<int>( + Bundle->getTreeEntry()->Scalars.size()) && + "Couldn't find extract lane"); + + // Since vectorization tree is being built recursively this + // assertion ensures that the tree entry has all operands set + // before reaching this code. Couple of exceptions known at the + // moment are extracts where their second (immediate) operand is + // not added. Since immediates do not affect scheduler behavior + // this is considered okay. + assert(In && + (isa<ExtractValueInst, ExtractElementInst, CallBase>(In) || + In->getNumOperands() == + Bundle->getTreeEntry()->getNumOperands() || + Bundle->getTreeEntry()->isCopyableElement(In)) && + "Missed TreeEntry operands?"); + + for (unsigned OpIdx : + seq<unsigned>(Bundle->getTreeEntry()->getNumOperands())) + if (auto *I = dyn_cast<Instruction>( + Bundle->getTreeEntry()->getOperand(OpIdx)[Lane])) { + LLVM_DEBUG(dbgs() << "SLP: check for readiness (def): " + << *I << "\n"); + DecrUnschedForInst(I, Bundle->getTreeEntry(), OpIdx, Checked); + } + // If parent node is schedulable, it will be handle correctly. + if (!Bundle->getTreeEntry()->doesNotNeedToSchedule()) + break; + It = std::find(std::next(It), + Bundle->getTreeEntry()->Scalars.end(), In); + } while (It != Bundle->getTreeEntry()->Scalars.end()); } } else { // If BundleMember is a stand-alone instruction, no operand reordering diff --git a/llvm/lib/Transforms/Vectorize/VPlan.cpp b/llvm/lib/Transforms/Vectorize/VPlan.cpp index d167009..428a8f4 100644 --- a/llvm/lib/Transforms/Vectorize/VPlan.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlan.cpp @@ -217,32 +217,6 @@ VPBlockBase *VPBlockBase::getEnclosingBlockWithPredecessors() { return Parent->getEnclosingBlockWithPredecessors(); } -bool VPBlockUtils::isHeader(const VPBlockBase *VPB, - const VPDominatorTree &VPDT) { - auto *VPBB = dyn_cast<VPBasicBlock>(VPB); - if (!VPBB) - return false; - - // If VPBB is in a region R, VPBB is a loop header if R is a loop region with - // VPBB as its entry, i.e., free of predecessors. - if (auto *R = VPBB->getParent()) - return !R->isReplicator() && !VPBB->hasPredecessors(); - - // A header dominates its second predecessor (the latch), with the other - // predecessor being the preheader - return VPB->getPredecessors().size() == 2 && - VPDT.dominates(VPB, VPB->getPredecessors()[1]); -} - -bool VPBlockUtils::isLatch(const VPBlockBase *VPB, - const VPDominatorTree &VPDT) { - // A latch has a header as its second successor, with its other successor - // leaving the loop. A preheader OTOH has a header as its first (and only) - // successor. - return VPB->getNumSuccessors() == 2 && - VPBlockUtils::isHeader(VPB->getSuccessors()[1], VPDT); -} - VPBasicBlock::iterator VPBasicBlock::getFirstNonPhi() { iterator It = begin(); while (It != end() && It->isPhi()) @@ -768,8 +742,12 @@ static std::pair<VPBlockBase *, VPBlockBase *> cloneFrom(VPBlockBase *Entry) { VPRegionBlock *VPRegionBlock::clone() { const auto &[NewEntry, NewExiting] = cloneFrom(getEntry()); - auto *NewRegion = getPlan()->createVPRegionBlock(NewEntry, NewExiting, - getName(), isReplicator()); + VPlan &Plan = *getPlan(); + VPRegionBlock *NewRegion = + isReplicator() + ? Plan.createReplicateRegion(NewEntry, NewExiting, getName()) + : Plan.createLoopRegion(getName(), NewEntry, NewExiting); + for (VPBlockBase *Block : vp_depth_first_shallow(NewEntry)) Block->setParent(NewRegion); return NewRegion; diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h index fed04eb..2591df8 100644 --- a/llvm/lib/Transforms/Vectorize/VPlan.h +++ b/llvm/lib/Transforms/Vectorize/VPlan.h @@ -2712,7 +2712,8 @@ public: static inline bool classof(const VPRecipeBase *R) { return R->getVPDefID() == VPRecipeBase::VPReductionSC || - R->getVPDefID() == VPRecipeBase::VPReductionEVLSC; + R->getVPDefID() == VPRecipeBase::VPReductionEVLSC || + R->getVPDefID() == VPRecipeBase::VPPartialReductionSC; } static inline bool classof(const VPUser *U) { @@ -2783,7 +2784,10 @@ public: Opcode(Opcode), VFScaleFactor(ScaleFactor) { [[maybe_unused]] auto *AccumulatorRecipe = getChainOp()->getDefiningRecipe(); - assert((isa<VPReductionPHIRecipe>(AccumulatorRecipe) || + // When cloning as part of a VPExpressionRecipe the chain op could have + // replaced by a temporary VPValue, so it doesn't have a defining recipe. + assert((!AccumulatorRecipe || + isa<VPReductionPHIRecipe>(AccumulatorRecipe) || isa<VPPartialReductionRecipe>(AccumulatorRecipe)) && "Unexpected operand order for partial reduction recipe"); } @@ -3093,6 +3097,11 @@ public: /// removed before codegen. void decompose(); + unsigned getVFScaleFactor() const { + auto *PR = dyn_cast<VPPartialReductionRecipe>(ExpressionRecipes.back()); + return PR ? PR->getVFScaleFactor() : 1; + } + /// Method for generating code, must not be called as this recipe is abstract. void execute(VPTransformState &State) override { llvm_unreachable("recipe must be removed before execute"); @@ -4163,11 +4172,6 @@ class VPlan { /// definitions are VPValues that hold a pointer to their underlying IR. SmallVector<VPValue *, 16> VPLiveIns; - /// Mapping from SCEVs to the VPValues representing their expansions. - /// NOTE: This mapping is temporary and will be removed once all users have - /// been modeled in VPlan directly. - DenseMap<const SCEV *, VPValue *> SCEVToExpansion; - /// Blocks allocated and owned by the VPlan. They will be deleted once the /// VPlan is destroyed. SmallVector<VPBlockBase *> CreatedBlocks; @@ -4415,15 +4419,6 @@ public: LLVM_DUMP_METHOD void dump() const; #endif - VPValue *getSCEVExpansion(const SCEV *S) const { - return SCEVToExpansion.lookup(S); - } - - void addSCEVExpansion(const SCEV *S, VPValue *V) { - assert(!SCEVToExpansion.contains(S) && "SCEV already expanded"); - SCEVToExpansion[S] = V; - } - /// Clone the current VPlan, update all VPValues of the new VPlan and cloned /// recipes to refer to the clones, and return it. VPlan *duplicate(); @@ -4438,22 +4433,24 @@ public: return VPB; } - /// Create a new VPRegionBlock with \p Entry, \p Exiting and \p Name. If \p - /// IsReplicator is true, the region is a replicate region. The returned block - /// is owned by the VPlan and deleted once the VPlan is destroyed. - VPRegionBlock *createVPRegionBlock(VPBlockBase *Entry, VPBlockBase *Exiting, - const std::string &Name = "", - bool IsReplicator = false) { - auto *VPB = new VPRegionBlock(Entry, Exiting, Name, IsReplicator); + /// Create a new loop region with \p Name and entry and exiting blocks set + /// to \p Entry and \p Exiting respectively, if set. The returned block is + /// owned by the VPlan and deleted once the VPlan is destroyed. + VPRegionBlock *createLoopRegion(const std::string &Name = "", + VPBlockBase *Entry = nullptr, + VPBlockBase *Exiting = nullptr) { + auto *VPB = Entry ? new VPRegionBlock(Entry, Exiting, Name) + : new VPRegionBlock(Name); CreatedBlocks.push_back(VPB); return VPB; } - /// Create a new loop VPRegionBlock with \p Name and entry and exiting blocks set - /// to nullptr. The returned block is owned by the VPlan and deleted once the - /// VPlan is destroyed. - VPRegionBlock *createVPRegionBlock(const std::string &Name = "") { - auto *VPB = new VPRegionBlock(Name); + /// Create a new replicate region with \p Entry, \p Exiting and \p Name. The + /// returned block is owned by the VPlan and deleted once the VPlan is + /// destroyed. + VPRegionBlock *createReplicateRegion(VPBlockBase *Entry, VPBlockBase *Exiting, + const std::string &Name = "") { + auto *VPB = new VPRegionBlock(Entry, Exiting, Name, true); CreatedBlocks.push_back(VPB); return VPB; } diff --git a/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp b/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp index 332791a..65688a3 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp @@ -406,7 +406,7 @@ static void createLoopRegion(VPlan &Plan, VPBlockBase *HeaderVPB) { // LatchExitVPB, taking care to preserve the original predecessor & successor // order of blocks. Set region entry and exiting after both HeaderVPB and // LatchVPBB have been disconnected from their predecessors/successors. - auto *R = Plan.createVPRegionBlock(); + auto *R = Plan.createLoopRegion(); VPBlockUtils::insertOnEdge(LatchVPBB, LatchExitVPB, R); VPBlockUtils::disconnectBlocks(LatchVPBB, R); VPBlockUtils::connectBlocks(PreheaderVPBB, R); diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp index 1f1b42b..931a5b7 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp @@ -168,6 +168,7 @@ bool VPRecipeBase::mayHaveSideEffects() const { return cast<VPWidenIntrinsicRecipe>(this)->mayHaveSideEffects(); case VPBlendSC: case VPReductionEVLSC: + case VPPartialReductionSC: case VPReductionSC: case VPScalarIVStepsSC: case VPVectorPointerSC: @@ -300,14 +301,23 @@ InstructionCost VPPartialReductionRecipe::computeCost(ElementCount VF, VPCostContext &Ctx) const { std::optional<unsigned> Opcode; - VPValue *Op = getOperand(0); - VPRecipeBase *OpR = Op->getDefiningRecipe(); - - // If the partial reduction is predicated, a select will be operand 0 - if (match(getOperand(1), m_Select(m_VPValue(), m_VPValue(Op), m_VPValue()))) { - OpR = Op->getDefiningRecipe(); + VPValue *Op = getVecOp(); + uint64_t MulConst; + // If the partial reduction is predicated, a select will be operand 1. + // If it isn't predicated and the mul isn't operating on a constant, then it + // should have been turned into a VPExpressionRecipe. + // FIXME: Replace the entire function with this once all partial reduction + // variants are bundled into VPExpressionRecipe. + if (!match(Op, m_Select(m_VPValue(), m_VPValue(Op), m_VPValue())) && + !match(Op, m_Mul(m_VPValue(), m_ConstantInt(MulConst)))) { + auto *PhiType = Ctx.Types.inferScalarType(getChainOp()); + auto *InputType = Ctx.Types.inferScalarType(getVecOp()); + return Ctx.TTI.getPartialReductionCost(getOpcode(), InputType, InputType, + PhiType, VF, TTI::PR_None, + TTI::PR_None, {}, Ctx.CostKind); } + VPRecipeBase *OpR = Op->getDefiningRecipe(); Type *InputTypeA = nullptr, *InputTypeB = nullptr; TTI::PartialReductionExtendKind ExtAType = TTI::PR_None, ExtBType = TTI::PR_None; @@ -2856,11 +2866,19 @@ InstructionCost VPExpressionRecipe::computeCost(ElementCount VF, cast<VPReductionRecipe>(ExpressionRecipes.back())->getRecurrenceKind()); switch (ExpressionType) { case ExpressionTypes::ExtendedReduction: { - return Ctx.TTI.getExtendedReductionCost( - Opcode, - cast<VPWidenCastRecipe>(ExpressionRecipes.front())->getOpcode() == - Instruction::ZExt, - RedTy, SrcVecTy, std::nullopt, Ctx.CostKind); + unsigned Opcode = RecurrenceDescriptor::getOpcode( + cast<VPReductionRecipe>(ExpressionRecipes[1])->getRecurrenceKind()); + auto *ExtR = cast<VPWidenCastRecipe>(ExpressionRecipes[0]); + return isa<VPPartialReductionRecipe>(ExpressionRecipes.back()) + ? Ctx.TTI.getPartialReductionCost( + Opcode, Ctx.Types.inferScalarType(getOperand(0)), nullptr, + RedTy, VF, + TargetTransformInfo::getPartialReductionExtendKind( + ExtR->getOpcode()), + TargetTransformInfo::PR_None, std::nullopt, Ctx.CostKind) + : Ctx.TTI.getExtendedReductionCost( + Opcode, ExtR->getOpcode() == Instruction::ZExt, RedTy, + SrcVecTy, std::nullopt, Ctx.CostKind); } case ExpressionTypes::MulAccReduction: return Ctx.TTI.getMulAccReductionCost(false, Opcode, RedTy, SrcVecTy, @@ -2871,6 +2889,19 @@ InstructionCost VPExpressionRecipe::computeCost(ElementCount VF, Opcode = Instruction::Sub; [[fallthrough]]; case ExpressionTypes::ExtMulAccReduction: { + if (isa<VPPartialReductionRecipe>(ExpressionRecipes.back())) { + auto *Ext0R = cast<VPWidenCastRecipe>(ExpressionRecipes[0]); + auto *Ext1R = cast<VPWidenCastRecipe>(ExpressionRecipes[1]); + auto *Mul = cast<VPWidenRecipe>(ExpressionRecipes[2]); + return Ctx.TTI.getPartialReductionCost( + Opcode, Ctx.Types.inferScalarType(getOperand(0)), + Ctx.Types.inferScalarType(getOperand(1)), RedTy, VF, + TargetTransformInfo::getPartialReductionExtendKind( + Ext0R->getOpcode()), + TargetTransformInfo::getPartialReductionExtendKind( + Ext1R->getOpcode()), + Mul->getOpcode(), Ctx.CostKind); + } return Ctx.TTI.getMulAccReductionCost( cast<VPWidenCastRecipe>(ExpressionRecipes.front())->getOpcode() == Instruction::ZExt, @@ -2910,12 +2941,13 @@ void VPExpressionRecipe::print(raw_ostream &O, const Twine &Indent, O << " = "; auto *Red = cast<VPReductionRecipe>(ExpressionRecipes.back()); unsigned Opcode = RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind()); + bool IsPartialReduction = isa<VPPartialReductionRecipe>(Red); switch (ExpressionType) { case ExpressionTypes::ExtendedReduction: { getOperand(1)->printAsOperand(O, SlotTracker); - O << " +"; - O << " reduce." << Instruction::getOpcodeName(Opcode) << " ("; + O << " + " << (IsPartialReduction ? "partial." : "") << "reduce."; + O << Instruction::getOpcodeName(Opcode) << " ("; getOperand(0)->printAsOperand(O, SlotTracker); Red->printFlags(O); @@ -2931,8 +2963,8 @@ void VPExpressionRecipe::print(raw_ostream &O, const Twine &Indent, } case ExpressionTypes::ExtNegatedMulAccReduction: { getOperand(getNumOperands() - 1)->printAsOperand(O, SlotTracker); - O << " + reduce." - << Instruction::getOpcodeName( + O << " + " << (IsPartialReduction ? "partial." : "") << "reduce."; + O << Instruction::getOpcodeName( RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind())) << " (sub (0, mul"; auto *Mul = cast<VPWidenRecipe>(ExpressionRecipes[2]); @@ -2956,9 +2988,8 @@ void VPExpressionRecipe::print(raw_ostream &O, const Twine &Indent, case ExpressionTypes::MulAccReduction: case ExpressionTypes::ExtMulAccReduction: { getOperand(getNumOperands() - 1)->printAsOperand(O, SlotTracker); - O << " + "; - O << "reduce." - << Instruction::getOpcodeName( + O << " + " << (IsPartialReduction ? "partial." : "") << "reduce."; + O << Instruction::getOpcodeName( RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind())) << " ("; O << "mul"; diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp index e060e70..84817d7 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp @@ -372,7 +372,7 @@ static VPRegionBlock *createReplicateRegion(VPReplicateRecipe *PredRecipe, auto *Exiting = Plan.createVPBasicBlock(Twine(RegionName) + ".continue", PHIRecipe); VPRegionBlock *Region = - Plan.createVPRegionBlock(Entry, Exiting, RegionName, true); + Plan.createReplicateRegion(Entry, Exiting, RegionName); // Note: first set Entry as region entry and then connect successors starting // from it in order, to propagate the "parent" of each VPBasicBlock. @@ -943,12 +943,40 @@ static void recursivelyDeleteDeadRecipes(VPValue *V) { } } +/// Get any instruction opcode or intrinsic ID data embedded in recipe \p R. +/// Returns an optional pair, where the first element indicates whether it is +/// an intrinsic ID. +static std::optional<std::pair<bool, unsigned>> +getOpcodeOrIntrinsicID(const VPSingleDefRecipe *R) { + return TypeSwitch<const VPSingleDefRecipe *, + std::optional<std::pair<bool, unsigned>>>(R) + .Case<VPInstruction, VPWidenRecipe, VPWidenCastRecipe, + VPWidenSelectRecipe, VPWidenGEPRecipe, VPReplicateRecipe>( + [](auto *I) { return std::make_pair(false, I->getOpcode()); }) + .Case<VPWidenIntrinsicRecipe>([](auto *I) { + return std::make_pair(true, I->getVectorIntrinsicID()); + }) + .Case<VPVectorPointerRecipe, VPPredInstPHIRecipe>([](auto *I) { + // For recipes that do not directly map to LLVM IR instructions, + // assign opcodes after the last VPInstruction opcode (which is also + // after the last IR Instruction opcode), based on the VPDefID. + return std::make_pair(false, + VPInstruction::OpsEnd + 1 + I->getVPDefID()); + }) + .Default([](auto *) { return std::nullopt; }); +} + /// Try to fold \p R using InstSimplifyFolder. Will succeed and return a -/// non-nullptr Value for a handled \p Opcode if corresponding \p Operands are -/// foldable live-ins. -static Value *tryToFoldLiveIns(const VPRecipeBase &R, unsigned Opcode, - ArrayRef<VPValue *> Operands, - const DataLayout &DL, VPTypeAnalysis &TypeInfo) { +/// non-nullptr VPValue for a handled opcode or intrinsic ID if corresponding \p +/// Operands are foldable live-ins. +static VPValue *tryToFoldLiveIns(VPSingleDefRecipe &R, + ArrayRef<VPValue *> Operands, + const DataLayout &DL, + VPTypeAnalysis &TypeInfo) { + auto OpcodeOrIID = getOpcodeOrIntrinsicID(&R); + if (!OpcodeOrIID) + return nullptr; + SmallVector<Value *, 4> Ops; for (VPValue *Op : Operands) { if (!Op->isLiveIn() || !Op->getLiveInIRValue()) @@ -956,43 +984,57 @@ static Value *tryToFoldLiveIns(const VPRecipeBase &R, unsigned Opcode, Ops.push_back(Op->getLiveInIRValue()); } - InstSimplifyFolder Folder(DL); - if (Instruction::isBinaryOp(Opcode)) - return Folder.FoldBinOp(static_cast<Instruction::BinaryOps>(Opcode), Ops[0], + auto FoldToIRValue = [&]() -> Value * { + InstSimplifyFolder Folder(DL); + if (OpcodeOrIID->first) { + if (R.getNumOperands() != 2) + return nullptr; + unsigned ID = OpcodeOrIID->second; + return Folder.FoldBinaryIntrinsic(ID, Ops[0], Ops[1], + TypeInfo.inferScalarType(&R)); + } + unsigned Opcode = OpcodeOrIID->second; + if (Instruction::isBinaryOp(Opcode)) + return Folder.FoldBinOp(static_cast<Instruction::BinaryOps>(Opcode), + Ops[0], Ops[1]); + if (Instruction::isCast(Opcode)) + return Folder.FoldCast(static_cast<Instruction::CastOps>(Opcode), Ops[0], + TypeInfo.inferScalarType(R.getVPSingleValue())); + switch (Opcode) { + case VPInstruction::LogicalAnd: + return Folder.FoldSelect(Ops[0], Ops[1], + ConstantInt::getNullValue(Ops[1]->getType())); + case VPInstruction::Not: + return Folder.FoldBinOp(Instruction::BinaryOps::Xor, Ops[0], + Constant::getAllOnesValue(Ops[0]->getType())); + case Instruction::Select: + return Folder.FoldSelect(Ops[0], Ops[1], Ops[2]); + case Instruction::ICmp: + case Instruction::FCmp: + return Folder.FoldCmp(cast<VPRecipeWithIRFlags>(R).getPredicate(), Ops[0], Ops[1]); - if (Instruction::isCast(Opcode)) - return Folder.FoldCast(static_cast<Instruction::CastOps>(Opcode), Ops[0], - TypeInfo.inferScalarType(R.getVPSingleValue())); - switch (Opcode) { - case VPInstruction::LogicalAnd: - return Folder.FoldSelect(Ops[0], Ops[1], - ConstantInt::getNullValue(Ops[1]->getType())); - case VPInstruction::Not: - return Folder.FoldBinOp(Instruction::BinaryOps::Xor, Ops[0], - Constant::getAllOnesValue(Ops[0]->getType())); - case Instruction::Select: - return Folder.FoldSelect(Ops[0], Ops[1], Ops[2]); - case Instruction::ICmp: - case Instruction::FCmp: - return Folder.FoldCmp(cast<VPRecipeWithIRFlags>(R).getPredicate(), Ops[0], - Ops[1]); - case Instruction::GetElementPtr: { - auto &RFlags = cast<VPRecipeWithIRFlags>(R); - auto *GEP = cast<GetElementPtrInst>(RFlags.getUnderlyingInstr()); - return Folder.FoldGEP(GEP->getSourceElementType(), Ops[0], drop_begin(Ops), - RFlags.getGEPNoWrapFlags()); - } - case VPInstruction::PtrAdd: - case VPInstruction::WidePtrAdd: - return Folder.FoldGEP(IntegerType::getInt8Ty(TypeInfo.getContext()), Ops[0], - Ops[1], - cast<VPRecipeWithIRFlags>(R).getGEPNoWrapFlags()); - // An extract of a live-in is an extract of a broadcast, so return the - // broadcasted element. - case Instruction::ExtractElement: - assert(!Ops[0]->getType()->isVectorTy() && "Live-ins should be scalar"); - return Ops[0]; - } + case Instruction::GetElementPtr: { + auto &RFlags = cast<VPRecipeWithIRFlags>(R); + auto *GEP = cast<GetElementPtrInst>(RFlags.getUnderlyingInstr()); + return Folder.FoldGEP(GEP->getSourceElementType(), Ops[0], + drop_begin(Ops), RFlags.getGEPNoWrapFlags()); + } + case VPInstruction::PtrAdd: + case VPInstruction::WidePtrAdd: + return Folder.FoldGEP(IntegerType::getInt8Ty(TypeInfo.getContext()), + Ops[0], Ops[1], + cast<VPRecipeWithIRFlags>(R).getGEPNoWrapFlags()); + // An extract of a live-in is an extract of a broadcast, so return the + // broadcasted element. + case Instruction::ExtractElement: + assert(!Ops[0]->getType()->isVectorTy() && "Live-ins should be scalar"); + return Ops[0]; + } + return nullptr; + }; + + if (Value *V = FoldToIRValue()) + return R.getParent()->getPlan()->getOrAddLiveIn(V); return nullptr; } @@ -1006,19 +1048,10 @@ static void simplifyRecipe(VPRecipeBase &R, VPTypeAnalysis &TypeInfo) { // Simplification of live-in IR values for SingleDef recipes using // InstSimplifyFolder. - if (TypeSwitch<VPRecipeBase *, bool>(&R) - .Case<VPInstruction, VPWidenRecipe, VPWidenCastRecipe, - VPReplicateRecipe, VPWidenSelectRecipe>([&](auto *I) { - const DataLayout &DL = - Plan->getScalarHeader()->getIRBasicBlock()->getDataLayout(); - Value *V = tryToFoldLiveIns(*I, I->getOpcode(), I->operands(), DL, - TypeInfo); - if (V) - I->replaceAllUsesWith(Plan->getOrAddLiveIn(V)); - return V; - }) - .Default([](auto *) { return false; })) - return; + const DataLayout &DL = + Plan->getScalarHeader()->getIRBasicBlock()->getDataLayout(); + if (VPValue *V = tryToFoldLiveIns(*Def, Def->operands(), DL, TypeInfo)) + return Def->replaceAllUsesWith(V); // Fold PredPHI LiveIn -> LiveIn. if (auto *PredPHI = dyn_cast<VPPredInstPHIRecipe>(&R)) { @@ -1478,11 +1511,8 @@ static bool optimizeVectorInductionWidthForTCAndVFUF(VPlan &Plan, if (!Plan.getVectorLoopRegion()) return false; - if (!Plan.getTripCount()->isLiveIn()) - return false; - auto *TC = dyn_cast_if_present<ConstantInt>( - Plan.getTripCount()->getUnderlyingValue()); - if (!TC || !BestVF.isFixed()) + const APInt *TC; + if (!BestVF.isFixed() || !match(Plan.getTripCount(), m_APInt(TC))) return false; // Calculate the minimum power-of-2 bit width that can fit the known TC, VF @@ -1495,7 +1525,7 @@ static bool optimizeVectorInductionWidthForTCAndVFUF(VPlan &Plan, return std::max<unsigned>(PowerOf2Ceil(MaxVal.getActiveBits()), 8); }; unsigned NewBitWidth = - ComputeBitWidth(TC->getValue(), BestVF.getKnownMinValue() * BestUF); + ComputeBitWidth(*TC, BestVF.getKnownMinValue() * BestUF); LLVMContext &Ctx = Plan.getContext(); auto *NewIVTy = IntegerType::get(Ctx, NewBitWidth); @@ -1999,29 +2029,6 @@ struct VPCSEDenseMapInfo : public DenseMapInfo<VPSingleDefRecipe *> { return Def == getEmptyKey() || Def == getTombstoneKey(); } - /// Get any instruction opcode or intrinsic ID data embedded in recipe \p R. - /// Returns an optional pair, where the first element indicates whether it is - /// an intrinsic ID. - static std::optional<std::pair<bool, unsigned>> - getOpcodeOrIntrinsicID(const VPSingleDefRecipe *R) { - return TypeSwitch<const VPSingleDefRecipe *, - std::optional<std::pair<bool, unsigned>>>(R) - .Case<VPInstruction, VPWidenRecipe, VPWidenCastRecipe, - VPWidenSelectRecipe, VPWidenGEPRecipe, VPReplicateRecipe>( - [](auto *I) { return std::make_pair(false, I->getOpcode()); }) - .Case<VPWidenIntrinsicRecipe>([](auto *I) { - return std::make_pair(true, I->getVectorIntrinsicID()); - }) - .Case<VPVectorPointerRecipe, VPPredInstPHIRecipe>([](auto *I) { - // For recipes that do not directly map to LLVM IR instructions, - // assign opcodes after the last VPInstruction opcode (which is also - // after the last IR Instruction opcode), based on the VPDefID. - return std::make_pair(false, - VPInstruction::OpsEnd + 1 + I->getVPDefID()); - }) - .Default([](auto *) { return std::nullopt; }); - } - /// If recipe \p R will lower to a GEP with a non-i8 source element type, /// return that source element type. static Type *getGEPSourceElementType(const VPSingleDefRecipe *R) { @@ -2092,8 +2099,8 @@ struct VPCSEDenseMapInfo : public DenseMapInfo<VPSingleDefRecipe *> { // Recipes in replicate regions implicitly depend on predicate. If either // recipe is in a replicate region, only consider them equal if both have // the same parent. - const VPRegionBlock *RegionL = L->getParent()->getParent(); - const VPRegionBlock *RegionR = R->getParent()->getParent(); + const VPRegionBlock *RegionL = L->getRegion(); + const VPRegionBlock *RegionR = R->getRegion(); if (((RegionL && RegionL->isReplicator()) || (RegionR && RegionR->isReplicator())) && L->getParent() != R->getParent()) @@ -3522,18 +3529,31 @@ tryToMatchAndCreateExtendedReduction(VPReductionRecipe *Red, VPCostContext &Ctx, VPValue *VecOp = Red->getVecOp(); // Clamp the range if using extended-reduction is profitable. - auto IsExtendedRedValidAndClampRange = [&](unsigned Opcode, bool isZExt, - Type *SrcTy) -> bool { + auto IsExtendedRedValidAndClampRange = + [&](unsigned Opcode, Instruction::CastOps ExtOpc, Type *SrcTy) -> bool { return LoopVectorizationPlanner::getDecisionAndClampRange( [&](ElementCount VF) { auto *SrcVecTy = cast<VectorType>(toVectorTy(SrcTy, VF)); TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; - InstructionCost ExtRedCost = Ctx.TTI.getExtendedReductionCost( - Opcode, isZExt, RedTy, SrcVecTy, Red->getFastMathFlags(), - CostKind); + + InstructionCost ExtRedCost; InstructionCost ExtCost = cast<VPWidenCastRecipe>(VecOp)->computeCost(VF, Ctx); InstructionCost RedCost = Red->computeCost(VF, Ctx); + + if (isa<VPPartialReductionRecipe>(Red)) { + TargetTransformInfo::PartialReductionExtendKind ExtKind = + TargetTransformInfo::getPartialReductionExtendKind(ExtOpc); + // FIXME: Move partial reduction creation, costing and clamping + // here from LoopVectorize.cpp. + ExtRedCost = Ctx.TTI.getPartialReductionCost( + Opcode, SrcTy, nullptr, RedTy, VF, ExtKind, + llvm::TargetTransformInfo::PR_None, std::nullopt, Ctx.CostKind); + } else { + ExtRedCost = Ctx.TTI.getExtendedReductionCost( + Opcode, ExtOpc == Instruction::CastOps::ZExt, RedTy, SrcVecTy, + Red->getFastMathFlags(), CostKind); + } return ExtRedCost.isValid() && ExtRedCost < ExtCost + RedCost; }, Range); @@ -3544,8 +3564,7 @@ tryToMatchAndCreateExtendedReduction(VPReductionRecipe *Red, VPCostContext &Ctx, if (match(VecOp, m_ZExtOrSExt(m_VPValue(A))) && IsExtendedRedValidAndClampRange( RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind()), - cast<VPWidenCastRecipe>(VecOp)->getOpcode() == - Instruction::CastOps::ZExt, + cast<VPWidenCastRecipe>(VecOp)->getOpcode(), Ctx.Types.inferScalarType(A))) return new VPExpressionRecipe(cast<VPWidenCastRecipe>(VecOp), Red); @@ -3563,6 +3582,8 @@ tryToMatchAndCreateExtendedReduction(VPReductionRecipe *Red, VPCostContext &Ctx, static VPExpressionRecipe * tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red, VPCostContext &Ctx, VFRange &Range) { + bool IsPartialReduction = isa<VPPartialReductionRecipe>(Red); + unsigned Opcode = RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind()); if (Opcode != Instruction::Add && Opcode != Instruction::Sub) return nullptr; @@ -3571,16 +3592,41 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red, // Clamp the range if using multiply-accumulate-reduction is profitable. auto IsMulAccValidAndClampRange = - [&](bool isZExt, VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0, - VPWidenCastRecipe *Ext1, VPWidenCastRecipe *OuterExt) -> bool { + [&](VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0, VPWidenCastRecipe *Ext1, + VPWidenCastRecipe *OuterExt) -> bool { return LoopVectorizationPlanner::getDecisionAndClampRange( [&](ElementCount VF) { TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; Type *SrcTy = Ext0 ? Ctx.Types.inferScalarType(Ext0->getOperand(0)) : RedTy; - auto *SrcVecTy = cast<VectorType>(toVectorTy(SrcTy, VF)); - InstructionCost MulAccCost = Ctx.TTI.getMulAccReductionCost( - isZExt, Opcode, RedTy, SrcVecTy, CostKind); + InstructionCost MulAccCost; + + if (IsPartialReduction) { + Type *SrcTy2 = + Ext1 ? Ctx.Types.inferScalarType(Ext1->getOperand(0)) : nullptr; + // FIXME: Move partial reduction creation, costing and clamping + // here from LoopVectorize.cpp. + MulAccCost = Ctx.TTI.getPartialReductionCost( + Opcode, SrcTy, SrcTy2, RedTy, VF, + Ext0 ? TargetTransformInfo::getPartialReductionExtendKind( + Ext0->getOpcode()) + : TargetTransformInfo::PR_None, + Ext1 ? TargetTransformInfo::getPartialReductionExtendKind( + Ext1->getOpcode()) + : TargetTransformInfo::PR_None, + Mul->getOpcode(), CostKind); + } else { + // Only partial reductions support mixed extends at the moment. + if (Ext0 && Ext1 && Ext0->getOpcode() != Ext1->getOpcode()) + return false; + + bool IsZExt = + !Ext0 || Ext0->getOpcode() == Instruction::CastOps::ZExt; + auto *SrcVecTy = cast<VectorType>(toVectorTy(SrcTy, VF)); + MulAccCost = Ctx.TTI.getMulAccReductionCost(IsZExt, Opcode, RedTy, + SrcVecTy, CostKind); + } + InstructionCost MulCost = Mul->computeCost(VF, Ctx); InstructionCost RedCost = Red->computeCost(VF, Ctx); InstructionCost ExtCost = 0; @@ -3614,14 +3660,10 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red, dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe()); auto *Mul = cast<VPWidenRecipe>(VecOp->getDefiningRecipe()); - // Match reduce.add(mul(ext, ext)). - if (RecipeA && RecipeB && - (RecipeA->getOpcode() == RecipeB->getOpcode() || A == B) && - match(RecipeA, m_ZExtOrSExt(m_VPValue())) && + // Match reduce.add/sub(mul(ext, ext)). + if (RecipeA && RecipeB && match(RecipeA, m_ZExtOrSExt(m_VPValue())) && match(RecipeB, m_ZExtOrSExt(m_VPValue())) && - IsMulAccValidAndClampRange(RecipeA->getOpcode() == - Instruction::CastOps::ZExt, - Mul, RecipeA, RecipeB, nullptr)) { + IsMulAccValidAndClampRange(Mul, RecipeA, RecipeB, nullptr)) { if (Sub) return new VPExpressionRecipe(RecipeA, RecipeB, Mul, cast<VPWidenRecipe>(Sub), Red); @@ -3629,8 +3671,7 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red, } // Match reduce.add(mul). // TODO: Add an expression type for this variant with a negated mul - if (!Sub && - IsMulAccValidAndClampRange(true, Mul, nullptr, nullptr, nullptr)) + if (!Sub && IsMulAccValidAndClampRange(Mul, nullptr, nullptr, nullptr)) return new VPExpressionRecipe(Mul, Red); } // TODO: Add an expression type for negated versions of other expression @@ -3650,9 +3691,7 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red, cast<VPWidenCastRecipe>(Mul->getOperand(1)->getDefiningRecipe()); if ((Ext->getOpcode() == Ext0->getOpcode() || Ext0 == Ext1) && Ext0->getOpcode() == Ext1->getOpcode() && - IsMulAccValidAndClampRange(Ext0->getOpcode() == - Instruction::CastOps::ZExt, - Mul, Ext0, Ext1, Ext)) { + IsMulAccValidAndClampRange(Mul, Ext0, Ext1, Ext) && Mul->hasOneUse()) { auto *NewExt0 = new VPWidenCastRecipe( Ext0->getOpcode(), Ext0->getOperand(0), Ext->getResultType(), *Ext0, *Ext0, Ext0->getDebugLoc()); @@ -3867,8 +3906,7 @@ void VPlanTransforms::materializePacksAndUnpacks(VPlan &Plan) { // required lanes implicitly. // TODO: Remove once replicate regions are unrolled completely. auto IsCandidateUnpackUser = [Def](VPUser *U) { - VPRegionBlock *ParentRegion = - cast<VPRecipeBase>(U)->getParent()->getParent(); + VPRegionBlock *ParentRegion = cast<VPRecipeBase>(U)->getRegion(); return U->usesScalars(Def) && (!ParentRegion || !ParentRegion->isReplicator()); }; @@ -4053,7 +4091,7 @@ static bool canNarrowLoad(VPWidenRecipe *WideMember0, unsigned OpIdx, static bool isConsecutiveInterleaveGroup(VPInterleaveRecipe *InterleaveR, unsigned VF, VPTypeAnalysis &TypeInfo, unsigned VectorRegWidth) { - if (!InterleaveR) + if (!InterleaveR || InterleaveR->getMask()) return false; Type *GroupElementTy = nullptr; @@ -4091,7 +4129,7 @@ static bool isAlreadyNarrow(VPValue *VPV) { void VPlanTransforms::narrowInterleaveGroups(VPlan &Plan, ElementCount VF, unsigned VectorRegWidth) { VPRegionBlock *VectorLoop = Plan.getVectorLoopRegion(); - if (!VectorLoop) + if (!VectorLoop || VectorLoop->getEntry()->getNumSuccessors() != 0) return; VPTypeAnalysis TypeInfo(Plan); diff --git a/llvm/lib/Transforms/Vectorize/VPlanUtils.cpp b/llvm/lib/Transforms/Vectorize/VPlanUtils.cpp index 10801c0..fe66f13 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanUtils.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanUtils.cpp @@ -8,6 +8,7 @@ #include "VPlanUtils.h" #include "VPlanCFG.h" +#include "VPlanDominatorTree.h" #include "VPlanPatternMatch.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" @@ -31,8 +32,6 @@ bool vputils::onlyScalarValuesUsed(const VPValue *Def) { } VPValue *vputils::getOrCreateVPValueForSCEVExpr(VPlan &Plan, const SCEV *Expr) { - if (auto *Expanded = Plan.getSCEVExpansion(Expr)) - return Expanded; VPValue *Expanded = nullptr; if (auto *E = dyn_cast<SCEVConstant>(Expr)) Expanded = Plan.getOrAddLiveIn(E->getValue()); @@ -49,7 +48,6 @@ VPValue *vputils::getOrCreateVPValueForSCEVExpr(VPlan &Plan, const SCEV *Expr) { Plan.getEntry()->appendRecipe(Expanded->getDefiningRecipe()); } } - Plan.addSCEVExpansion(Expr, Expanded); return Expanded; } @@ -150,6 +148,8 @@ unsigned vputils::getVFScaleFactor(VPRecipeBase *R) { return RR->getVFScaleFactor(); if (auto *RR = dyn_cast<VPPartialReductionRecipe>(R)) return RR->getVFScaleFactor(); + if (auto *ER = dyn_cast<VPExpressionRecipe>(R)) + return ER->getVFScaleFactor(); assert( (!isa<VPInstruction>(R) || cast<VPInstruction>(R)->getOpcode() != VPInstruction::ReductionStartVector) && @@ -253,3 +253,29 @@ vputils::getRecipesForUncountableExit(VPlan &Plan, return UncountableCondition; } + +bool VPBlockUtils::isHeader(const VPBlockBase *VPB, + const VPDominatorTree &VPDT) { + auto *VPBB = dyn_cast<VPBasicBlock>(VPB); + if (!VPBB) + return false; + + // If VPBB is in a region R, VPBB is a loop header if R is a loop region with + // VPBB as its entry, i.e., free of predecessors. + if (auto *R = VPBB->getParent()) + return !R->isReplicator() && !VPBB->hasPredecessors(); + + // A header dominates its second predecessor (the latch), with the other + // predecessor being the preheader + return VPB->getPredecessors().size() == 2 && + VPDT.dominates(VPB, VPB->getPredecessors()[1]); +} + +bool VPBlockUtils::isLatch(const VPBlockBase *VPB, + const VPDominatorTree &VPDT) { + // A latch has a header as its second successor, with its other successor + // leaving the loop. A preheader OTOH has a header as its first (and only) + // successor. + return VPB->getNumSuccessors() == 2 && + VPBlockUtils::isHeader(VPB->getSuccessors()[1], VPDT); +} diff --git a/llvm/lib/Transforms/Vectorize/VPlanValue.h b/llvm/lib/Transforms/Vectorize/VPlanValue.h index 0678bc90..83e3fca 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanValue.h +++ b/llvm/lib/Transforms/Vectorize/VPlanValue.h @@ -41,10 +41,10 @@ class VPRecipeBase; class VPInterleaveBase; class VPPhiAccessors; -// This is the base class of the VPlan Def/Use graph, used for modeling the data -// flow into, within and out of the VPlan. VPValues can stand for live-ins -// coming from the input IR and instructions which VPlan will generate if -// executed. +/// This is the base class of the VPlan Def/Use graph, used for modeling the +/// data flow into, within and out of the VPlan. VPValues can stand for live-ins +/// coming from the input IR and instructions which VPlan will generate if +/// executed. class LLVM_ABI_FOR_TEST VPValue { friend class VPDef; friend struct VPDoubleValueDef; @@ -57,7 +57,7 @@ class LLVM_ABI_FOR_TEST VPValue { SmallVector<VPUser *, 1> Users; protected: - // Hold the underlying Value, if any, attached to this VPValue. + /// Hold the underlying Value, if any, attached to this VPValue. Value *UnderlyingVal; /// Pointer to the VPDef that defines this VPValue. If it is nullptr, the |
