aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Transforms')
-rw-r--r--llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp12
-rw-r--r--llvm/lib/Transforms/IPO/FunctionImport.cpp12
-rw-r--r--llvm/lib/Transforms/IPO/LowerTypeTests.cpp6
-rw-r--r--llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp84
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp54
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp30
-rw-r--r--llvm/lib/Transforms/Instrumentation/AllocToken.cpp148
-rw-r--r--llvm/lib/Transforms/Instrumentation/NumericalStabilitySanitizer.cpp1
-rw-r--r--llvm/lib/Transforms/Utils/BasicBlockUtils.cpp10
-rw-r--r--llvm/lib/Transforms/Utils/PredicateInfo.cpp13
-rw-r--r--llvm/lib/Transforms/Vectorize/LoopVectorize.cpp19
-rw-r--r--llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp117
-rw-r--r--llvm/lib/Transforms/Vectorize/VPlan.cpp35
-rw-r--r--llvm/lib/Transforms/Vectorize/VPlan.h38
-rw-r--r--llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp2
-rw-r--r--llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp136
-rw-r--r--llvm/lib/Transforms/Vectorize/VPlanTransforms.h22
-rw-r--r--llvm/lib/Transforms/Vectorize/VPlanUtils.cpp27
-rw-r--r--llvm/lib/Transforms/Vectorize/VPlanValue.h10
19 files changed, 483 insertions, 293 deletions
diff --git a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp
index bbbac45..7a95df4 100644
--- a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp
+++ b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp
@@ -907,10 +907,20 @@ static bool mergeConsecutivePartStores(ArrayRef<PartStore> Parts,
StoreInst *Store = Builder.CreateAlignedStore(
Val, First.Store->getPointerOperand(), First.Store->getAlign());
+ // Merge various metadata onto the new store.
AAMDNodes AATags = First.Store->getAAMetadata();
- for (const PartStore &Part : drop_begin(Parts))
+ SmallVector<Instruction *> Stores = {First.Store};
+ Stores.reserve(Parts.size());
+ SmallVector<DebugLoc> DbgLocs = {First.Store->getDebugLoc()};
+ DbgLocs.reserve(Parts.size());
+ for (const PartStore &Part : drop_begin(Parts)) {
AATags = AATags.concat(Part.Store->getAAMetadata());
+ Stores.push_back(Part.Store);
+ DbgLocs.push_back(Part.Store->getDebugLoc());
+ }
Store->setAAMetadata(AATags);
+ Store->mergeDIAssignID(Stores);
+ Store->setDebugLoc(DebugLoc::getMergedLocations(DbgLocs));
// Remove the old stores.
for (const PartStore &Part : Parts)
diff --git a/llvm/lib/Transforms/IPO/FunctionImport.cpp b/llvm/lib/Transforms/IPO/FunctionImport.cpp
index 28ee444..a29faab 100644
--- a/llvm/lib/Transforms/IPO/FunctionImport.cpp
+++ b/llvm/lib/Transforms/IPO/FunctionImport.cpp
@@ -1368,13 +1368,13 @@ static void ComputeCrossModuleImportForModuleFromIndexForTest(
FunctionImporter::ImportMapTy &ImportList) {
for (const auto &GlobalList : Index) {
// Ignore entries for undefined references.
- if (GlobalList.second.SummaryList.empty())
+ if (GlobalList.second.getSummaryList().empty())
continue;
auto GUID = GlobalList.first;
- assert(GlobalList.second.SummaryList.size() == 1 &&
+ assert(GlobalList.second.getSummaryList().size() == 1 &&
"Expected individual combined index to have one summary per GUID");
- auto &Summary = GlobalList.second.SummaryList[0];
+ auto &Summary = GlobalList.second.getSummaryList()[0];
// Skip the summaries for the importing module. These are included to
// e.g. record required linkage changes.
if (Summary->modulePath() == ModulePath)
@@ -1423,7 +1423,7 @@ void updateValueInfoForIndirectCalls(ModuleSummaryIndex &Index,
void llvm::updateIndirectCalls(ModuleSummaryIndex &Index) {
for (const auto &Entry : Index) {
- for (const auto &S : Entry.second.SummaryList) {
+ for (const auto &S : Entry.second.getSummaryList()) {
if (auto *FS = dyn_cast<FunctionSummary>(S.get()))
updateValueInfoForIndirectCalls(Index, FS);
}
@@ -1456,7 +1456,7 @@ void llvm::computeDeadSymbolsAndUpdateIndirectCalls(
// Add values flagged in the index as live roots to the worklist.
for (const auto &Entry : Index) {
auto VI = Index.getValueInfo(Entry);
- for (const auto &S : Entry.second.SummaryList) {
+ for (const auto &S : Entry.second.getSummaryList()) {
if (auto *FS = dyn_cast<FunctionSummary>(S.get()))
updateValueInfoForIndirectCalls(Index, FS);
if (S->isLive()) {
@@ -2094,7 +2094,7 @@ static bool doImportingForModuleForTest(
// is only enabled when testing importing via the 'opt' tool, which does
// not do the ThinLink that would normally determine what values to promote.
for (auto &I : *Index) {
- for (auto &S : I.second.SummaryList) {
+ for (auto &S : I.second.getSummaryList()) {
if (GlobalValue::isLocalLinkage(S->linkage()))
S->setLinkage(GlobalValue::ExternalLinkage);
}
diff --git a/llvm/lib/Transforms/IPO/LowerTypeTests.cpp b/llvm/lib/Transforms/IPO/LowerTypeTests.cpp
index be6cba3..aa1346d 100644
--- a/llvm/lib/Transforms/IPO/LowerTypeTests.cpp
+++ b/llvm/lib/Transforms/IPO/LowerTypeTests.cpp
@@ -1271,7 +1271,7 @@ bool LowerTypeTestsModule::hasBranchTargetEnforcement() {
// the module flags.
if (const auto *BTE = mdconst::extract_or_null<ConstantInt>(
M.getModuleFlag("branch-target-enforcement")))
- HasBranchTargetEnforcement = (BTE->getZExtValue() != 0);
+ HasBranchTargetEnforcement = !BTE->isZero();
else
HasBranchTargetEnforcement = 0;
}
@@ -2130,7 +2130,7 @@ bool LowerTypeTestsModule::lower() {
// A set of all functions that are address taken by a live global object.
DenseSet<GlobalValue::GUID> AddressTaken;
for (auto &I : *ExportSummary)
- for (auto &GVS : I.second.SummaryList)
+ for (auto &GVS : I.second.getSummaryList())
if (GVS->isLive())
for (const auto &Ref : GVS->refs()) {
AddressTaken.insert(Ref.getGUID());
@@ -2409,7 +2409,7 @@ bool LowerTypeTestsModule::lower() {
}
for (auto &P : *ExportSummary) {
- for (auto &S : P.second.SummaryList) {
+ for (auto &S : P.second.getSummaryList()) {
if (!ExportSummary->isGlobalValueLive(S.get()))
continue;
if (auto *FS = dyn_cast<FunctionSummary>(S->getBaseObject()))
diff --git a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
index 2d5cb82..a0f7ec6 100644
--- a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
+++ b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
@@ -24,7 +24,8 @@
// returns 0, or a single vtable's function returns 1, replace each virtual
// call with a comparison of the vptr against that vtable's address.
//
-// This pass is intended to be used during the regular and thin LTO pipelines:
+// This pass is intended to be used during the regular/thin and non-LTO
+// pipelines:
//
// During regular LTO, the pass determines the best optimization for each
// virtual call and applies the resolutions directly to virtual calls that are
@@ -48,6 +49,14 @@
// is supported.
// - Import phase: (same as with hybrid case above).
//
+// During Speculative devirtualization mode -not restricted to LTO-:
+// - The pass applies speculative devirtualization without requiring any type of
+// visibility.
+// - Skips other features like virtual constant propagation, uniform return
+// value optimization, unique return value optimization and branch funnels as
+// they need LTO.
+// - This mode is enabled via 'devirtualize-speculatively' flag.
+//
//===----------------------------------------------------------------------===//
#include "llvm/Transforms/IPO/WholeProgramDevirt.h"
@@ -61,7 +70,9 @@
#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/Analysis/BasicAliasAnalysis.h"
#include "llvm/Analysis/BlockFrequencyInfo.h"
+#include "llvm/Analysis/ModuleSummaryAnalysis.h"
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
+#include "llvm/Analysis/ProfileSummaryInfo.h"
#include "llvm/Analysis/TypeMetadataUtils.h"
#include "llvm/Bitcode/BitcodeReader.h"
#include "llvm/Bitcode/BitcodeWriter.h"
@@ -145,6 +156,13 @@ static cl::opt<std::string> ClWriteSummary(
"bitcode, otherwise YAML"),
cl::Hidden);
+// TODO: This option eventually should support any public visibility vtables
+// with/out LTO.
+static cl::opt<bool> ClDevirtualizeSpeculatively(
+ "devirtualize-speculatively",
+ cl::desc("Enable speculative devirtualization optimization"),
+ cl::init(false));
+
static cl::opt<unsigned>
ClThreshold("wholeprogramdevirt-branch-funnel-threshold", cl::Hidden,
cl::init(10),
@@ -892,6 +910,8 @@ void llvm::updatePublicTypeTestCalls(Module &M,
CI->eraseFromParent();
}
} else {
+ // TODO: Don't replace public type tests when speculative devirtualization
+ // gets enabled in LTO mode.
auto *True = ConstantInt::getTrue(M.getContext());
for (Use &U : make_early_inc_range(PublicTypeTestFunc->uses())) {
auto *CI = cast<CallInst>(U.getUser());
@@ -928,7 +948,7 @@ void llvm::updateVCallVisibilityInIndex(
// linker, as we have no information on their eventual use.
if (DynamicExportSymbols.count(P.first))
continue;
- for (auto &S : P.second.SummaryList) {
+ for (auto &S : P.second.getSummaryList()) {
auto *GVar = dyn_cast<GlobalVarSummary>(S.get());
if (!GVar ||
GVar->getVCallVisibility() != GlobalObject::VCallVisibilityPublic)
@@ -1083,10 +1103,10 @@ bool DevirtModule::tryFindVirtualCallTargets(
if (!TM.Bits->GV->isConstant())
return false;
- // We cannot perform whole program devirtualization analysis on a vtable
- // with public LTO visibility.
- if (TM.Bits->GV->getVCallVisibility() ==
- GlobalObject::VCallVisibilityPublic)
+ // Without ClDevirtualizeSpeculatively, we cannot perform whole program
+ // devirtualization analysis on a vtable with public LTO visibility.
+ if (!ClDevirtualizeSpeculatively && TM.Bits->GV->getVCallVisibility() ==
+ GlobalObject::VCallVisibilityPublic)
return false;
Function *Fn = nullptr;
@@ -1105,6 +1125,12 @@ bool DevirtModule::tryFindVirtualCallTargets(
if (Fn->getName() == "__cxa_pure_virtual")
continue;
+ // In most cases empty functions will be overridden by the
+ // implementation of the derived class, so we can skip them.
+ if (ClDevirtualizeSpeculatively && Fn->getReturnType()->isVoidTy() &&
+ Fn->getInstructionCount() <= 1)
+ continue;
+
// We can disregard unreachable functions as possible call targets, as
// unreachable functions shouldn't be called.
if (mustBeUnreachableFunction(Fn, ExportSummary))
@@ -1223,10 +1249,12 @@ void DevirtModule::applySingleImplDevirt(VTableSlotInfo &SlotInfo,
CallTrap->setDebugLoc(CB.getDebugLoc());
}
- // If fallback checking is enabled, add support to compare the virtual
- // function pointer to the devirtualized target. In case of a mismatch,
- // fall back to indirect call.
- if (DevirtCheckMode == WPDCheckMode::Fallback) {
+ // If fallback checking or speculative devirtualization are enabled,
+ // add support to compare the virtual function pointer to the
+ // devirtualized target. In case of a mismatch, fall back to indirect
+ // call.
+ if (DevirtCheckMode == WPDCheckMode::Fallback ||
+ ClDevirtualizeSpeculatively) {
MDNode *Weights = MDBuilder(M.getContext()).createLikelyBranchWeights();
// Version the indirect call site. If the called value is equal to the
// given callee, 'NewInst' will be executed, otherwise the original call
@@ -2057,15 +2085,15 @@ void DevirtModule::scanTypeTestUsers(
Function *TypeTestFunc,
DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap) {
// Find all virtual calls via a virtual table pointer %p under an assumption
- // of the form llvm.assume(llvm.type.test(%p, %md)). This indicates that %p
- // points to a member of the type identifier %md. Group calls by (type ID,
- // offset) pair (effectively the identity of the virtual function) and store
- // to CallSlots.
+ // of the form llvm.assume(llvm.type.test(%p, %md)) or
+ // llvm.assume(llvm.public.type.test(%p, %md)).
+ // This indicates that %p points to a member of the type identifier %md.
+ // Group calls by (type ID, offset) pair (effectively the identity of the
+ // virtual function) and store to CallSlots.
for (Use &U : llvm::make_early_inc_range(TypeTestFunc->uses())) {
auto *CI = dyn_cast<CallInst>(U.getUser());
if (!CI)
continue;
-
// Search for virtual calls based on %p and add them to DevirtCalls.
SmallVector<DevirtCallSite, 1> DevirtCalls;
SmallVector<CallInst *, 1> Assumes;
@@ -2348,6 +2376,12 @@ bool DevirtModule::run() {
(ImportSummary && ImportSummary->partiallySplitLTOUnits()))
return false;
+ Function *PublicTypeTestFunc = nullptr;
+ // If we are in speculative devirtualization mode, we can work on the public
+ // type test intrinsics.
+ if (ClDevirtualizeSpeculatively)
+ PublicTypeTestFunc =
+ Intrinsic::getDeclarationIfExists(&M, Intrinsic::public_type_test);
Function *TypeTestFunc =
Intrinsic::getDeclarationIfExists(&M, Intrinsic::type_test);
Function *TypeCheckedLoadFunc =
@@ -2361,8 +2395,9 @@ bool DevirtModule::run() {
// module, this pass has nothing to do. But if we are exporting, we also need
// to handle any users that appear only in the function summaries.
if (!ExportSummary &&
- (!TypeTestFunc || TypeTestFunc->use_empty() || !AssumeFunc ||
- AssumeFunc->use_empty()) &&
+ (((!PublicTypeTestFunc || PublicTypeTestFunc->use_empty()) &&
+ (!TypeTestFunc || TypeTestFunc->use_empty())) ||
+ !AssumeFunc || AssumeFunc->use_empty()) &&
(!TypeCheckedLoadFunc || TypeCheckedLoadFunc->use_empty()) &&
(!TypeCheckedLoadRelativeFunc ||
TypeCheckedLoadRelativeFunc->use_empty()))
@@ -2373,6 +2408,9 @@ bool DevirtModule::run() {
DenseMap<Metadata *, std::set<TypeMemberInfo>> TypeIdMap;
buildTypeIdentifierMap(Bits, TypeIdMap);
+ if (PublicTypeTestFunc && AssumeFunc)
+ scanTypeTestUsers(PublicTypeTestFunc, TypeIdMap);
+
if (TypeTestFunc && AssumeFunc)
scanTypeTestUsers(TypeTestFunc, TypeIdMap);
@@ -2413,7 +2451,7 @@ bool DevirtModule::run() {
}
for (auto &P : *ExportSummary) {
- for (auto &S : P.second.SummaryList) {
+ for (auto &S : P.second.getSummaryList()) {
auto *FS = dyn_cast<FunctionSummary>(S.get());
if (!FS)
continue;
@@ -2472,8 +2510,12 @@ bool DevirtModule::run() {
.WPDRes[S.first.ByteOffset];
if (tryFindVirtualCallTargets(TargetsForSlot, TypeMemberInfos,
S.first.ByteOffset, ExportSummary)) {
-
- if (!trySingleImplDevirt(ExportSummary, TargetsForSlot, S.second, Res)) {
+ bool SingleImplDevirt =
+ trySingleImplDevirt(ExportSummary, TargetsForSlot, S.second, Res);
+ // Out of speculative devirtualization mode, Try to apply virtual constant
+ // propagation or branch funneling.
+ // TODO: This should eventually be enabled for non-public type tests.
+ if (!SingleImplDevirt && !ClDevirtualizeSpeculatively) {
DidVirtualConstProp |=
tryVirtualConstProp(TargetsForSlot, S.second, Res, S.first);
@@ -2564,7 +2606,7 @@ void DevirtIndex::run() {
// Collect information from summary about which calls to try to devirtualize.
for (auto &P : ExportSummary) {
- for (auto &S : P.second.SummaryList) {
+ for (auto &S : P.second.getSummaryList()) {
auto *FS = dyn_cast<FunctionSummary>(S.get());
if (!FS)
continue;
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index cdc559b..9b9fe26 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -1643,33 +1643,46 @@ Instruction *InstCombinerImpl::visitSExt(SExtInst &Sext) {
/// Return a Constant* for the specified floating-point constant if it fits
/// in the specified FP type without changing its value.
-static bool fitsInFPType(ConstantFP *CFP, const fltSemantics &Sem) {
+static bool fitsInFPType(APFloat F, const fltSemantics &Sem) {
bool losesInfo;
- APFloat F = CFP->getValueAPF();
(void)F.convert(Sem, APFloat::rmNearestTiesToEven, &losesInfo);
return !losesInfo;
}
-static Type *shrinkFPConstant(ConstantFP *CFP, bool PreferBFloat) {
- if (CFP->getType() == Type::getPPC_FP128Ty(CFP->getContext()))
- return nullptr; // No constant folding of this.
+static Type *shrinkFPConstant(LLVMContext &Ctx, const APFloat &F,
+ bool PreferBFloat) {
// See if the value can be truncated to bfloat and then reextended.
- if (PreferBFloat && fitsInFPType(CFP, APFloat::BFloat()))
- return Type::getBFloatTy(CFP->getContext());
+ if (PreferBFloat && fitsInFPType(F, APFloat::BFloat()))
+ return Type::getBFloatTy(Ctx);
// See if the value can be truncated to half and then reextended.
- if (!PreferBFloat && fitsInFPType(CFP, APFloat::IEEEhalf()))
- return Type::getHalfTy(CFP->getContext());
+ if (!PreferBFloat && fitsInFPType(F, APFloat::IEEEhalf()))
+ return Type::getHalfTy(Ctx);
// See if the value can be truncated to float and then reextended.
- if (fitsInFPType(CFP, APFloat::IEEEsingle()))
- return Type::getFloatTy(CFP->getContext());
- if (CFP->getType()->isDoubleTy())
- return nullptr; // Won't shrink.
- if (fitsInFPType(CFP, APFloat::IEEEdouble()))
- return Type::getDoubleTy(CFP->getContext());
+ if (fitsInFPType(F, APFloat::IEEEsingle()))
+ return Type::getFloatTy(Ctx);
+ if (&F.getSemantics() == &APFloat::IEEEdouble())
+ return nullptr; // Won't shrink.
+ // See if the value can be truncated to double and then reextended.
+ if (fitsInFPType(F, APFloat::IEEEdouble()))
+ return Type::getDoubleTy(Ctx);
// Don't try to shrink to various long double types.
return nullptr;
}
+static Type *shrinkFPConstant(ConstantFP *CFP, bool PreferBFloat) {
+ Type *Ty = CFP->getType();
+ if (Ty->getScalarType()->isPPC_FP128Ty())
+ return nullptr; // No constant folding of this.
+
+ Type *ShrinkTy =
+ shrinkFPConstant(CFP->getContext(), CFP->getValueAPF(), PreferBFloat);
+ if (ShrinkTy)
+ if (auto *VecTy = dyn_cast<VectorType>(Ty))
+ ShrinkTy = VectorType::get(ShrinkTy, VecTy);
+
+ return ShrinkTy;
+}
+
// Determine if this is a vector of ConstantFPs and if so, return the minimal
// type we can safely truncate all elements to.
static Type *shrinkFPConstantVector(Value *V, bool PreferBFloat) {
@@ -1720,10 +1733,10 @@ static Type *getMinimumFPType(Value *V, bool PreferBFloat) {
// Try to shrink scalable and fixed splat vectors.
if (auto *FPC = dyn_cast<Constant>(V))
- if (isa<VectorType>(V->getType()))
+ if (auto *VTy = dyn_cast<VectorType>(V->getType()))
if (auto *Splat = dyn_cast_or_null<ConstantFP>(FPC->getSplatValue()))
if (Type *T = shrinkFPConstant(Splat, PreferBFloat))
- return T;
+ return VectorType::get(T, VTy);
// Try to shrink a vector of FP constants. This returns nullptr on scalable
// vectors
@@ -1796,10 +1809,9 @@ Instruction *InstCombinerImpl::visitFPTrunc(FPTruncInst &FPT) {
Type *Ty = FPT.getType();
auto *BO = dyn_cast<BinaryOperator>(FPT.getOperand(0));
if (BO && BO->hasOneUse()) {
- Type *LHSMinType =
- getMinimumFPType(BO->getOperand(0), /*PreferBFloat=*/Ty->isBFloatTy());
- Type *RHSMinType =
- getMinimumFPType(BO->getOperand(1), /*PreferBFloat=*/Ty->isBFloatTy());
+ bool PreferBFloat = Ty->getScalarType()->isBFloatTy();
+ Type *LHSMinType = getMinimumFPType(BO->getOperand(0), PreferBFloat);
+ Type *RHSMinType = getMinimumFPType(BO->getOperand(1), PreferBFloat);
unsigned OpWidth = BO->getType()->getFPMantissaWidth();
unsigned LHSWidth = LHSMinType->getFPMantissaWidth();
unsigned RHSWidth = RHSMinType->getFPMantissaWidth();
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 975498f..5aa8de3 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -3455,27 +3455,45 @@ Instruction *InstCombinerImpl::foldSelectOfBools(SelectInst &SI) {
// select a, false, b -> select !a, b, false
if (match(TrueVal, m_Specific(Zero))) {
Value *NotCond = Builder.CreateNot(CondVal, "not." + CondVal->getName());
- return SelectInst::Create(NotCond, FalseVal, Zero);
+ Instruction *MDFrom = ProfcheckDisableMetadataFixes ? nullptr : &SI;
+ SelectInst *NewSI =
+ SelectInst::Create(NotCond, FalseVal, Zero, "", nullptr, MDFrom);
+ NewSI->swapProfMetadata();
+ return NewSI;
}
// select a, b, true -> select !a, true, b
if (match(FalseVal, m_Specific(One))) {
Value *NotCond = Builder.CreateNot(CondVal, "not." + CondVal->getName());
- return SelectInst::Create(NotCond, One, TrueVal);
+ Instruction *MDFrom = ProfcheckDisableMetadataFixes ? nullptr : &SI;
+ SelectInst *NewSI =
+ SelectInst::Create(NotCond, One, TrueVal, "", nullptr, MDFrom);
+ NewSI->swapProfMetadata();
+ return NewSI;
}
// DeMorgan in select form: !a && !b --> !(a || b)
// select !a, !b, false --> not (select a, true, b)
if (match(&SI, m_LogicalAnd(m_Not(m_Value(A)), m_Not(m_Value(B)))) &&
(CondVal->hasOneUse() || TrueVal->hasOneUse()) &&
- !match(A, m_ConstantExpr()) && !match(B, m_ConstantExpr()))
- return BinaryOperator::CreateNot(Builder.CreateSelect(A, One, B));
+ !match(A, m_ConstantExpr()) && !match(B, m_ConstantExpr())) {
+ Instruction *MDFrom = ProfcheckDisableMetadataFixes ? nullptr : &SI;
+ SelectInst *NewSI =
+ cast<SelectInst>(Builder.CreateSelect(A, One, B, "", MDFrom));
+ NewSI->swapProfMetadata();
+ return BinaryOperator::CreateNot(NewSI);
+ }
// DeMorgan in select form: !a || !b --> !(a && b)
// select !a, true, !b --> not (select a, b, false)
if (match(&SI, m_LogicalOr(m_Not(m_Value(A)), m_Not(m_Value(B)))) &&
(CondVal->hasOneUse() || FalseVal->hasOneUse()) &&
- !match(A, m_ConstantExpr()) && !match(B, m_ConstantExpr()))
- return BinaryOperator::CreateNot(Builder.CreateSelect(A, B, Zero));
+ !match(A, m_ConstantExpr()) && !match(B, m_ConstantExpr())) {
+ Instruction *MDFrom = ProfcheckDisableMetadataFixes ? nullptr : &SI;
+ SelectInst *NewSI =
+ cast<SelectInst>(Builder.CreateSelect(A, B, Zero, "", MDFrom));
+ NewSI->swapProfMetadata();
+ return BinaryOperator::CreateNot(NewSI);
+ }
// select (select a, true, b), true, b -> select a, true, b
if (match(CondVal, m_Select(m_Value(A), m_One(), m_Value(B))) &&
diff --git a/llvm/lib/Transforms/Instrumentation/AllocToken.cpp b/llvm/lib/Transforms/Instrumentation/AllocToken.cpp
index 40720ae..8181e4e 100644
--- a/llvm/lib/Transforms/Instrumentation/AllocToken.cpp
+++ b/llvm/lib/Transforms/Instrumentation/AllocToken.cpp
@@ -31,10 +31,12 @@
#include "llvm/IR/InstIterator.h"
#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Metadata.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/PassManager.h"
#include "llvm/IR/Type.h"
+#include "llvm/Support/AllocToken.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Compiler.h"
@@ -53,47 +55,14 @@
#include <variant>
using namespace llvm;
+using TokenMode = AllocTokenMode;
#define DEBUG_TYPE "alloc-token"
namespace {
-//===--- Constants --------------------------------------------------------===//
-
-enum class TokenMode : unsigned {
- /// Incrementally increasing token ID.
- Increment = 0,
-
- /// Simple mode that returns a statically-assigned random token ID.
- Random = 1,
-
- /// Token ID based on allocated type hash.
- TypeHash = 2,
-
- /// Token ID based on allocated type hash, where the top half ID-space is
- /// reserved for types that contain pointers and the bottom half for types
- /// that do not contain pointers.
- TypeHashPointerSplit = 3,
-};
-
//===--- Command-line options ---------------------------------------------===//
-cl::opt<TokenMode> ClMode(
- "alloc-token-mode", cl::Hidden, cl::desc("Token assignment mode"),
- cl::init(TokenMode::TypeHashPointerSplit),
- cl::values(
- clEnumValN(TokenMode::Increment, "increment",
- "Incrementally increasing token ID"),
- clEnumValN(TokenMode::Random, "random",
- "Statically-assigned random token ID"),
- clEnumValN(TokenMode::TypeHash, "typehash",
- "Token ID based on allocated type hash"),
- clEnumValN(
- TokenMode::TypeHashPointerSplit, "typehashpointersplit",
- "Token ID based on allocated type hash, where the top half "
- "ID-space is reserved for types that contain pointers and the "
- "bottom half for types that do not contain pointers. ")));
-
cl::opt<std::string> ClFuncPrefix("alloc-token-prefix",
cl::desc("The allocation function prefix"),
cl::Hidden, cl::init("__alloc_token_"));
@@ -131,7 +100,7 @@ cl::opt<uint64_t> ClFallbackToken(
//===--- Statistics -------------------------------------------------------===//
-STATISTIC(NumFunctionsInstrumented, "Functions instrumented");
+STATISTIC(NumFunctionsModified, "Functions modified");
STATISTIC(NumAllocationsInstrumented, "Allocations instrumented");
//===----------------------------------------------------------------------===//
@@ -140,9 +109,19 @@ STATISTIC(NumAllocationsInstrumented, "Allocations instrumented");
///
/// Expected format is: !{<type-name>, <contains-pointer>}
MDNode *getAllocTokenMetadata(const CallBase &CB) {
- MDNode *Ret = CB.getMetadata(LLVMContext::MD_alloc_token);
- if (!Ret)
- return nullptr;
+ MDNode *Ret = nullptr;
+ if (auto *II = dyn_cast<IntrinsicInst>(&CB);
+ II && II->getIntrinsicID() == Intrinsic::alloc_token_id) {
+ auto *MDV = cast<MetadataAsValue>(II->getArgOperand(0));
+ Ret = cast<MDNode>(MDV->getMetadata());
+ // If the intrinsic has an empty MDNode, type inference failed.
+ if (Ret->getNumOperands() == 0)
+ return nullptr;
+ } else {
+ Ret = CB.getMetadata(LLVMContext::MD_alloc_token);
+ if (!Ret)
+ return nullptr;
+ }
assert(Ret->getNumOperands() == 2 && "bad !alloc_token");
assert(isa<MDString>(Ret->getOperand(0)));
assert(isa<ConstantAsMetadata>(Ret->getOperand(1)));
@@ -206,22 +185,19 @@ public:
using ModeBase::ModeBase;
uint64_t operator()(const CallBase &CB, OptimizationRemarkEmitter &ORE) {
- const auto [N, H] = getHash(CB, ORE);
- return N ? boundedToken(H) : H;
- }
-protected:
- std::pair<MDNode *, uint64_t> getHash(const CallBase &CB,
- OptimizationRemarkEmitter &ORE) {
if (MDNode *N = getAllocTokenMetadata(CB)) {
MDString *S = cast<MDString>(N->getOperand(0));
- return {N, getStableSipHash(S->getString())};
+ AllocTokenMetadata Metadata{S->getString(), containsPointer(N)};
+ if (auto Token = getAllocToken(TokenMode::TypeHash, Metadata, MaxTokens))
+ return *Token;
}
// Fallback.
remarkNoMetadata(CB, ORE);
- return {nullptr, ClFallbackToken};
+ return ClFallbackToken;
}
+protected:
/// Remark that there was no precise type information.
static void remarkNoMetadata(const CallBase &CB,
OptimizationRemarkEmitter &ORE) {
@@ -242,20 +218,18 @@ public:
using TypeHashMode::TypeHashMode;
uint64_t operator()(const CallBase &CB, OptimizationRemarkEmitter &ORE) {
- if (MaxTokens == 1)
- return 0;
- const uint64_t HalfTokens = MaxTokens / 2;
- const auto [N, H] = getHash(CB, ORE);
- if (!N) {
- // Pick the fallback token (ClFallbackToken), which by default is 0,
- // meaning it'll fall into the pointer-less bucket. Override by setting
- // -alloc-token-fallback if that is the wrong choice.
- return H;
+ if (MDNode *N = getAllocTokenMetadata(CB)) {
+ MDString *S = cast<MDString>(N->getOperand(0));
+ AllocTokenMetadata Metadata{S->getString(), containsPointer(N)};
+ if (auto Token = getAllocToken(TokenMode::TypeHashPointerSplit, Metadata,
+ MaxTokens))
+ return *Token;
}
- uint64_t Hash = H % HalfTokens; // base hash
- if (containsPointer(N))
- Hash += HalfTokens;
- return Hash;
+ // Pick the fallback token (ClFallbackToken), which by default is 0, meaning
+ // it'll fall into the pointer-less bucket. Override by setting
+ // -alloc-token-fallback if that is the wrong choice.
+ remarkNoMetadata(CB, ORE);
+ return ClFallbackToken;
}
};
@@ -275,7 +249,7 @@ public:
: Options(transformOptionsFromCl(std::move(Opts))), Mod(M),
FAM(MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager()),
Mode(IncrementMode(*IntPtrTy, *Options.MaxTokens)) {
- switch (ClMode.getValue()) {
+ switch (Options.Mode) {
case TokenMode::Increment:
break;
case TokenMode::Random:
@@ -315,6 +289,9 @@ private:
FunctionCallee getTokenAllocFunction(const CallBase &CB, uint64_t TokenID,
LibFunc OriginalFunc);
+ /// Lower alloc_token_* intrinsics.
+ void replaceIntrinsicInst(IntrinsicInst *II, OptimizationRemarkEmitter &ORE);
+
/// Return the token ID from metadata in the call.
uint64_t getToken(const CallBase &CB, OptimizationRemarkEmitter &ORE) {
return std::visit([&](auto &&Mode) { return Mode(CB, ORE); }, Mode);
@@ -336,21 +313,32 @@ bool AllocToken::instrumentFunction(Function &F) {
// Do not apply any instrumentation for naked functions.
if (F.hasFnAttribute(Attribute::Naked))
return false;
- if (F.hasFnAttribute(Attribute::DisableSanitizerInstrumentation))
- return false;
// Don't touch available_externally functions, their actual body is elsewhere.
if (F.getLinkage() == GlobalValue::AvailableExternallyLinkage)
return false;
- // Only instrument functions that have the sanitize_alloc_token attribute.
- if (!F.hasFnAttribute(Attribute::SanitizeAllocToken))
- return false;
auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F);
auto &TLI = FAM.getResult<TargetLibraryAnalysis>(F);
SmallVector<std::pair<CallBase *, LibFunc>, 4> AllocCalls;
+ SmallVector<IntrinsicInst *, 4> IntrinsicInsts;
+
+ // Only instrument functions that have the sanitize_alloc_token attribute.
+ const bool InstrumentFunction =
+ F.hasFnAttribute(Attribute::SanitizeAllocToken) &&
+ !F.hasFnAttribute(Attribute::DisableSanitizerInstrumentation);
// Collect all allocation calls to avoid iterator invalidation.
for (Instruction &I : instructions(F)) {
+ // Collect all alloc_token_* intrinsics.
+ if (auto *II = dyn_cast<IntrinsicInst>(&I);
+ II && II->getIntrinsicID() == Intrinsic::alloc_token_id) {
+ IntrinsicInsts.emplace_back(II);
+ continue;
+ }
+
+ if (!InstrumentFunction)
+ continue;
+
auto *CB = dyn_cast<CallBase>(&I);
if (!CB)
continue;
@@ -359,11 +347,21 @@ bool AllocToken::instrumentFunction(Function &F) {
}
bool Modified = false;
- for (auto &[CB, Func] : AllocCalls)
- Modified |= replaceAllocationCall(CB, Func, ORE, TLI);
- if (Modified)
- NumFunctionsInstrumented++;
+ if (!AllocCalls.empty()) {
+ for (auto &[CB, Func] : AllocCalls)
+ Modified |= replaceAllocationCall(CB, Func, ORE, TLI);
+ if (Modified)
+ NumFunctionsModified++;
+ }
+
+ if (!IntrinsicInsts.empty()) {
+ for (auto *II : IntrinsicInsts)
+ replaceIntrinsicInst(II, ORE);
+ Modified = true;
+ NumFunctionsModified++;
+ }
+
return Modified;
}
@@ -381,7 +379,7 @@ AllocToken::shouldInstrumentCall(const CallBase &CB,
if (TLI.getLibFunc(*Callee, Func)) {
if (isInstrumentableLibFunc(Func, CB, TLI))
return Func;
- } else if (Options.Extended && getAllocTokenMetadata(CB)) {
+ } else if (Options.Extended && CB.getMetadata(LLVMContext::MD_alloc_token)) {
return NotLibFunc;
}
@@ -528,6 +526,16 @@ FunctionCallee AllocToken::getTokenAllocFunction(const CallBase &CB,
return TokenAlloc;
}
+void AllocToken::replaceIntrinsicInst(IntrinsicInst *II,
+ OptimizationRemarkEmitter &ORE) {
+ assert(II->getIntrinsicID() == Intrinsic::alloc_token_id);
+
+ uint64_t TokenID = getToken(*II, ORE);
+ Value *V = ConstantInt::get(IntPtrTy, TokenID);
+ II->replaceAllUsesWith(V);
+ II->eraseFromParent();
+}
+
} // namespace
AllocTokenPass::AllocTokenPass(AllocTokenOptions Opts)
diff --git a/llvm/lib/Transforms/Instrumentation/NumericalStabilitySanitizer.cpp b/llvm/lib/Transforms/Instrumentation/NumericalStabilitySanitizer.cpp
index d18c0d0..80e77e09 100644
--- a/llvm/lib/Transforms/Instrumentation/NumericalStabilitySanitizer.cpp
+++ b/llvm/lib/Transforms/Instrumentation/NumericalStabilitySanitizer.cpp
@@ -2020,7 +2020,6 @@ static void moveFastMathFlags(Function &F,
F.removeFnAttr(attr); \
FMF.set##setter(); \
}
- MOVE_FLAG("unsafe-fp-math", Fast)
MOVE_FLAG("no-infs-fp-math", NoInfs)
MOVE_FLAG("no-nans-fp-math", NoNaNs)
MOVE_FLAG("no-signed-zeros-fp-math", NoSignedZeros)
diff --git a/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp b/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp
index 8714741a..9829d4d 100644
--- a/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp
+++ b/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp
@@ -1793,3 +1793,13 @@ bool llvm::hasOnlySimpleTerminator(const Function &F) {
}
return true;
}
+
+Printable llvm::printBasicBlock(const BasicBlock *BB) {
+ return Printable([BB](raw_ostream &OS) {
+ if (!BB) {
+ OS << "<nullptr>";
+ return;
+ }
+ BB->printAsOperand(OS);
+ });
+}
diff --git a/llvm/lib/Transforms/Utils/PredicateInfo.cpp b/llvm/lib/Transforms/Utils/PredicateInfo.cpp
index 978d5a2..371d9e6 100644
--- a/llvm/lib/Transforms/Utils/PredicateInfo.cpp
+++ b/llvm/lib/Transforms/Utils/PredicateInfo.cpp
@@ -260,9 +260,16 @@ bool PredicateInfoBuilder::stackIsInScope(const ValueDFSStack &Stack,
// next to the defs they must go with so that we can know it's time to pop
// the stack when we hit the end of the phi uses for a given def.
const ValueDFS &Top = *Stack.back().V;
- if (Top.LocalNum == LN_Last && Top.PInfo) {
- if (!VDUse.U)
- return false;
+ assert(Top.PInfo && "RenameStack should only contain predicate infos (defs)");
+ if (Top.LocalNum == LN_Last) {
+ if (!VDUse.U) {
+ assert(VDUse.PInfo && "A non-use VDUse should have a predicate info");
+ // We should reserve adjacent LN_Last defs for the same phi use.
+ return VDUse.LocalNum == LN_Last &&
+ // If the two phi defs have the same edge, they must be designated
+ // for the same succ BB.
+ getBlockEdge(Top.PInfo) == getBlockEdge(VDUse.PInfo);
+ }
auto *PHI = dyn_cast<PHINode>(VDUse.U->getUser());
if (!PHI)
return false;
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 1cc9173..d2c100c9 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7231,9 +7231,6 @@ DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan(
return DenseMap<const SCEV *, Value *>();
}
- VPlanTransforms::narrowInterleaveGroups(
- BestVPlan, BestVF,
- TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector));
VPlanTransforms::removeDeadRecipes(BestVPlan);
VPlanTransforms::convertToConcreteRecipes(BestVPlan);
@@ -8202,6 +8199,10 @@ void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF,
if (CM.foldTailWithEVL())
VPlanTransforms::runPass(VPlanTransforms::addExplicitVectorLength,
*Plan, CM.getMaxSafeElements());
+
+ if (auto P = VPlanTransforms::narrowInterleaveGroups(*Plan, TTI))
+ VPlans.push_back(std::move(P));
+
assert(verifyVPlanIsValid(*Plan) && "VPlan is invalid");
VPlans.push_back(std::move(Plan));
}
@@ -9859,6 +9860,8 @@ bool LoopVectorizePass::processLoop(Loop *L) {
// Get user vectorization factor and interleave count.
ElementCount UserVF = Hints.getWidth();
unsigned UserIC = Hints.getInterleave();
+ if (UserIC > 1 && !LVL.isSafeForAnyVectorWidth())
+ UserIC = 1;
// Plan how to best vectorize.
LVP.plan(UserVF, UserIC);
@@ -9923,7 +9926,15 @@ bool LoopVectorizePass::processLoop(Loop *L) {
VectorizeLoop = false;
}
- if (!LVP.hasPlanWithVF(VF.Width) && UserIC > 1) {
+ if (UserIC == 1 && Hints.getInterleave() > 1) {
+ assert(!LVL.isSafeForAnyVectorWidth() &&
+ "UserIC should only be ignored due to unsafe dependencies");
+ LLVM_DEBUG(dbgs() << "LV: Ignoring user-specified interleave count.\n");
+ IntDiagMsg = {"InterleavingUnsafe",
+ "Ignoring user-specified interleave count due to possibly "
+ "unsafe dependencies in the loop."};
+ InterleaveLoop = false;
+ } else if (!LVP.hasPlanWithVF(VF.Width) && UserIC > 1) {
// Tell the user interleaving was avoided up-front, despite being explicitly
// requested.
LLVM_DEBUG(dbgs() << "LV: Ignoring UserIC, because vectorization and "
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 3f18bd7..cdb9e7e 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -5577,62 +5577,79 @@ private:
}
// Decrement the unscheduled counter and insert to ready list if
// ready.
- auto DecrUnschedForInst = [&](Instruction *I, TreeEntry *UserTE,
- unsigned OpIdx) {
- if (!ScheduleCopyableDataMap.empty()) {
- const EdgeInfo EI = {UserTE, OpIdx};
- if (ScheduleCopyableData *CD = getScheduleCopyableData(EI, I)) {
- DecrUnsched(CD, /*IsControl=*/false);
- return;
- }
- }
- auto It = OperandsUses.find(I);
- assert(It != OperandsUses.end() && "Operand not found");
- if (It->second > 0) {
- --It->getSecond();
- assert(TotalOpCount > 0 && "No more operands to decrement");
- --TotalOpCount;
- if (ScheduleData *OpSD = getScheduleData(I))
- DecrUnsched(OpSD, /*IsControl=*/false);
- }
- };
+ auto DecrUnschedForInst =
+ [&](Instruction *I, TreeEntry *UserTE, unsigned OpIdx,
+ SmallDenseSet<std::pair<const ScheduleEntity *, unsigned>>
+ &Checked) {
+ if (!ScheduleCopyableDataMap.empty()) {
+ const EdgeInfo EI = {UserTE, OpIdx};
+ if (ScheduleCopyableData *CD =
+ getScheduleCopyableData(EI, I)) {
+ if (!Checked.insert(std::make_pair(CD, OpIdx)).second)
+ return;
+ DecrUnsched(CD, /*IsControl=*/false);
+ return;
+ }
+ }
+ auto It = OperandsUses.find(I);
+ assert(It != OperandsUses.end() && "Operand not found");
+ if (It->second > 0) {
+ --It->getSecond();
+ assert(TotalOpCount > 0 && "No more operands to decrement");
+ --TotalOpCount;
+ if (ScheduleData *OpSD = getScheduleData(I)) {
+ if (!Checked.insert(std::make_pair(OpSD, OpIdx)).second)
+ return;
+ DecrUnsched(OpSD, /*IsControl=*/false);
+ }
+ }
+ };
for (ScheduleBundle *Bundle : Bundles) {
if (ScheduleCopyableDataMap.empty() && TotalOpCount == 0)
break;
// Need to search for the lane since the tree entry can be
// reordered.
- int Lane = std::distance(Bundle->getTreeEntry()->Scalars.begin(),
- find(Bundle->getTreeEntry()->Scalars, In));
- assert(Lane >= 0 && "Lane not set");
- if (isa<StoreInst>(In) &&
- !Bundle->getTreeEntry()->ReorderIndices.empty())
- Lane = Bundle->getTreeEntry()->ReorderIndices[Lane];
- assert(Lane < static_cast<int>(
- Bundle->getTreeEntry()->Scalars.size()) &&
- "Couldn't find extract lane");
-
- // Since vectorization tree is being built recursively this
- // assertion ensures that the tree entry has all operands set before
- // reaching this code. Couple of exceptions known at the moment are
- // extracts where their second (immediate) operand is not added.
- // Since immediates do not affect scheduler behavior this is
- // considered okay.
- assert(In &&
- (isa<ExtractValueInst, ExtractElementInst, CallBase>(In) ||
- In->getNumOperands() ==
- Bundle->getTreeEntry()->getNumOperands() ||
- Bundle->getTreeEntry()->isCopyableElement(In)) &&
- "Missed TreeEntry operands?");
-
- for (unsigned OpIdx :
- seq<unsigned>(Bundle->getTreeEntry()->getNumOperands()))
- if (auto *I = dyn_cast<Instruction>(
- Bundle->getTreeEntry()->getOperand(OpIdx)[Lane])) {
- LLVM_DEBUG(dbgs() << "SLP: check for readiness (def): " << *I
- << "\n");
- DecrUnschedForInst(I, Bundle->getTreeEntry(), OpIdx);
- }
+ auto *It = find(Bundle->getTreeEntry()->Scalars, In);
+ SmallDenseSet<std::pair<const ScheduleEntity *, unsigned>> Checked;
+ do {
+ int Lane =
+ std::distance(Bundle->getTreeEntry()->Scalars.begin(), It);
+ assert(Lane >= 0 && "Lane not set");
+ if (isa<StoreInst>(In) &&
+ !Bundle->getTreeEntry()->ReorderIndices.empty())
+ Lane = Bundle->getTreeEntry()->ReorderIndices[Lane];
+ assert(Lane < static_cast<int>(
+ Bundle->getTreeEntry()->Scalars.size()) &&
+ "Couldn't find extract lane");
+
+ // Since vectorization tree is being built recursively this
+ // assertion ensures that the tree entry has all operands set
+ // before reaching this code. Couple of exceptions known at the
+ // moment are extracts where their second (immediate) operand is
+ // not added. Since immediates do not affect scheduler behavior
+ // this is considered okay.
+ assert(In &&
+ (isa<ExtractValueInst, ExtractElementInst, CallBase>(In) ||
+ In->getNumOperands() ==
+ Bundle->getTreeEntry()->getNumOperands() ||
+ Bundle->getTreeEntry()->isCopyableElement(In)) &&
+ "Missed TreeEntry operands?");
+
+ for (unsigned OpIdx :
+ seq<unsigned>(Bundle->getTreeEntry()->getNumOperands()))
+ if (auto *I = dyn_cast<Instruction>(
+ Bundle->getTreeEntry()->getOperand(OpIdx)[Lane])) {
+ LLVM_DEBUG(dbgs() << "SLP: check for readiness (def): "
+ << *I << "\n");
+ DecrUnschedForInst(I, Bundle->getTreeEntry(), OpIdx, Checked);
+ }
+ // If parent node is schedulable, it will be handle correctly.
+ if (!Bundle->getTreeEntry()->doesNotNeedToSchedule())
+ break;
+ It = std::find(std::next(It),
+ Bundle->getTreeEntry()->Scalars.end(), In);
+ } while (It != Bundle->getTreeEntry()->Scalars.end());
}
} else {
// If BundleMember is a stand-alone instruction, no operand reordering
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.cpp b/llvm/lib/Transforms/Vectorize/VPlan.cpp
index d167009..c95c887 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlan.cpp
@@ -217,32 +217,6 @@ VPBlockBase *VPBlockBase::getEnclosingBlockWithPredecessors() {
return Parent->getEnclosingBlockWithPredecessors();
}
-bool VPBlockUtils::isHeader(const VPBlockBase *VPB,
- const VPDominatorTree &VPDT) {
- auto *VPBB = dyn_cast<VPBasicBlock>(VPB);
- if (!VPBB)
- return false;
-
- // If VPBB is in a region R, VPBB is a loop header if R is a loop region with
- // VPBB as its entry, i.e., free of predecessors.
- if (auto *R = VPBB->getParent())
- return !R->isReplicator() && !VPBB->hasPredecessors();
-
- // A header dominates its second predecessor (the latch), with the other
- // predecessor being the preheader
- return VPB->getPredecessors().size() == 2 &&
- VPDT.dominates(VPB, VPB->getPredecessors()[1]);
-}
-
-bool VPBlockUtils::isLatch(const VPBlockBase *VPB,
- const VPDominatorTree &VPDT) {
- // A latch has a header as its second successor, with its other successor
- // leaving the loop. A preheader OTOH has a header as its first (and only)
- // successor.
- return VPB->getNumSuccessors() == 2 &&
- VPBlockUtils::isHeader(VPB->getSuccessors()[1], VPDT);
-}
-
VPBasicBlock::iterator VPBasicBlock::getFirstNonPhi() {
iterator It = begin();
while (It != end() && It->isPhi())
@@ -768,8 +742,12 @@ static std::pair<VPBlockBase *, VPBlockBase *> cloneFrom(VPBlockBase *Entry) {
VPRegionBlock *VPRegionBlock::clone() {
const auto &[NewEntry, NewExiting] = cloneFrom(getEntry());
- auto *NewRegion = getPlan()->createVPRegionBlock(NewEntry, NewExiting,
- getName(), isReplicator());
+ VPlan &Plan = *getPlan();
+ VPRegionBlock *NewRegion =
+ isReplicator()
+ ? Plan.createReplicateRegion(NewEntry, NewExiting, getName())
+ : Plan.createLoopRegion(getName(), NewEntry, NewExiting);
+
for (VPBlockBase *Block : vp_depth_first_shallow(NewEntry))
Block->setParent(NewRegion);
return NewRegion;
@@ -1213,6 +1191,7 @@ VPlan *VPlan::duplicate() {
}
Old2NewVPValues[&VectorTripCount] = &NewPlan->VectorTripCount;
Old2NewVPValues[&VF] = &NewPlan->VF;
+ Old2NewVPValues[&UF] = &NewPlan->UF;
Old2NewVPValues[&VFxUF] = &NewPlan->VFxUF;
if (BackedgeTakenCount) {
NewPlan->BackedgeTakenCount = new VPValue();
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index fed04eb..167ba55 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -4152,6 +4152,9 @@ class VPlan {
/// Represents the vectorization factor of the loop.
VPValue VF;
+ /// Represents the symbolic unroll factor of the loop.
+ VPValue UF;
+
/// Represents the loop-invariant VF * UF of the vector loop region.
VPValue VFxUF;
@@ -4305,6 +4308,9 @@ public:
VPValue &getVF() { return VF; };
const VPValue &getVF() const { return VF; };
+ /// Returns the symbolic UF of the vector loop region.
+ VPValue &getSymbolicUF() { return UF; };
+
/// Returns VF * UF of the vector loop region.
VPValue &getVFxUF() { return VFxUF; }
@@ -4314,6 +4320,12 @@ public:
void addVF(ElementCount VF) { VFs.insert(VF); }
+ /// Remove \p VF from the plan.
+ void removeVF(ElementCount VF) {
+ assert(hasVF(VF) && "tried to remove VF not present in plan");
+ VFs.remove(VF);
+ }
+
void setVF(ElementCount VF) {
assert(hasVF(VF) && "Cannot set VF not already in plan");
VFs.clear();
@@ -4438,22 +4450,24 @@ public:
return VPB;
}
- /// Create a new VPRegionBlock with \p Entry, \p Exiting and \p Name. If \p
- /// IsReplicator is true, the region is a replicate region. The returned block
- /// is owned by the VPlan and deleted once the VPlan is destroyed.
- VPRegionBlock *createVPRegionBlock(VPBlockBase *Entry, VPBlockBase *Exiting,
- const std::string &Name = "",
- bool IsReplicator = false) {
- auto *VPB = new VPRegionBlock(Entry, Exiting, Name, IsReplicator);
+ /// Create a new loop region with \p Name and entry and exiting blocks set
+ /// to \p Entry and \p Exiting respectively, if set. The returned block is
+ /// owned by the VPlan and deleted once the VPlan is destroyed.
+ VPRegionBlock *createLoopRegion(const std::string &Name = "",
+ VPBlockBase *Entry = nullptr,
+ VPBlockBase *Exiting = nullptr) {
+ auto *VPB = Entry ? new VPRegionBlock(Entry, Exiting, Name)
+ : new VPRegionBlock(Name);
CreatedBlocks.push_back(VPB);
return VPB;
}
- /// Create a new loop VPRegionBlock with \p Name and entry and exiting blocks set
- /// to nullptr. The returned block is owned by the VPlan and deleted once the
- /// VPlan is destroyed.
- VPRegionBlock *createVPRegionBlock(const std::string &Name = "") {
- auto *VPB = new VPRegionBlock(Name);
+ /// Create a new replicate region with \p Entry, \p Exiting and \p Name. The
+ /// returned block is owned by the VPlan and deleted once the VPlan is
+ /// destroyed.
+ VPRegionBlock *createReplicateRegion(VPBlockBase *Entry, VPBlockBase *Exiting,
+ const std::string &Name = "") {
+ auto *VPB = new VPRegionBlock(Entry, Exiting, Name, true);
CreatedBlocks.push_back(VPB);
return VPB;
}
diff --git a/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp b/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp
index 332791a..65688a3 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp
@@ -406,7 +406,7 @@ static void createLoopRegion(VPlan &Plan, VPBlockBase *HeaderVPB) {
// LatchExitVPB, taking care to preserve the original predecessor & successor
// order of blocks. Set region entry and exiting after both HeaderVPB and
// LatchVPBB have been disconnected from their predecessors/successors.
- auto *R = Plan.createVPRegionBlock();
+ auto *R = Plan.createLoopRegion();
VPBlockUtils::insertOnEdge(LatchVPBB, LatchExitVPB, R);
VPBlockUtils::disconnectBlocks(LatchVPBB, R);
VPBlockUtils::connectBlocks(PreheaderVPBB, R);
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index e060e70..48cf763 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -372,7 +372,7 @@ static VPRegionBlock *createReplicateRegion(VPReplicateRecipe *PredRecipe,
auto *Exiting =
Plan.createVPBasicBlock(Twine(RegionName) + ".continue", PHIRecipe);
VPRegionBlock *Region =
- Plan.createVPRegionBlock(Entry, Exiting, RegionName, true);
+ Plan.createReplicateRegion(Entry, Exiting, RegionName);
// Note: first set Entry as region entry and then connect successors starting
// from it in order, to propagate the "parent" of each VPBasicBlock.
@@ -1478,11 +1478,8 @@ static bool optimizeVectorInductionWidthForTCAndVFUF(VPlan &Plan,
if (!Plan.getVectorLoopRegion())
return false;
- if (!Plan.getTripCount()->isLiveIn())
- return false;
- auto *TC = dyn_cast_if_present<ConstantInt>(
- Plan.getTripCount()->getUnderlyingValue());
- if (!TC || !BestVF.isFixed())
+ const APInt *TC;
+ if (!BestVF.isFixed() || !match(Plan.getTripCount(), m_APInt(TC)))
return false;
// Calculate the minimum power-of-2 bit width that can fit the known TC, VF
@@ -1495,7 +1492,7 @@ static bool optimizeVectorInductionWidthForTCAndVFUF(VPlan &Plan,
return std::max<unsigned>(PowerOf2Ceil(MaxVal.getActiveBits()), 8);
};
unsigned NewBitWidth =
- ComputeBitWidth(TC->getValue(), BestVF.getKnownMinValue() * BestUF);
+ ComputeBitWidth(*TC, BestVF.getKnownMinValue() * BestUF);
LLVMContext &Ctx = Plan.getContext();
auto *NewIVTy = IntegerType::get(Ctx, NewBitWidth);
@@ -2092,8 +2089,8 @@ struct VPCSEDenseMapInfo : public DenseMapInfo<VPSingleDefRecipe *> {
// Recipes in replicate regions implicitly depend on predicate. If either
// recipe is in a replicate region, only consider them equal if both have
// the same parent.
- const VPRegionBlock *RegionL = L->getParent()->getParent();
- const VPRegionBlock *RegionR = R->getParent()->getParent();
+ const VPRegionBlock *RegionL = L->getRegion();
+ const VPRegionBlock *RegionR = R->getRegion();
if (((RegionL && RegionL->isReplicator()) ||
(RegionR && RegionR->isReplicator())) &&
L->getParent() != R->getParent())
@@ -3867,8 +3864,7 @@ void VPlanTransforms::materializePacksAndUnpacks(VPlan &Plan) {
// required lanes implicitly.
// TODO: Remove once replicate regions are unrolled completely.
auto IsCandidateUnpackUser = [Def](VPUser *U) {
- VPRegionBlock *ParentRegion =
- cast<VPRecipeBase>(U)->getParent()->getParent();
+ VPRegionBlock *ParentRegion = cast<VPRecipeBase>(U)->getRegion();
return U->usesScalars(Def) &&
(!ParentRegion || !ParentRegion->isReplicator());
};
@@ -3960,6 +3956,9 @@ void VPlanTransforms::materializeVFAndVFxUF(VPlan &Plan, VPBasicBlock *VectorPH,
// used.
// TODO: Assert that they aren't used.
+ VPValue *UF = Plan.getOrAddLiveIn(ConstantInt::get(TCTy, Plan.getUF()));
+ Plan.getSymbolicUF().replaceAllUsesWith(UF);
+
// If there are no users of the runtime VF, compute VFxUF by constant folding
// the multiplication of VF and UF.
if (VF.getNumUsers() == 0) {
@@ -3979,7 +3978,6 @@ void VPlanTransforms::materializeVFAndVFxUF(VPlan &Plan, VPBasicBlock *VectorPH,
}
VF.replaceAllUsesWith(RuntimeVF);
- VPValue *UF = Plan.getOrAddLiveIn(ConstantInt::get(TCTy, Plan.getUF()));
VPValue *MulByUF = Builder.createNaryOp(Instruction::Mul, {RuntimeVF, UF});
VFxUF.replaceAllUsesWith(MulByUF);
}
@@ -4047,14 +4045,14 @@ static bool canNarrowLoad(VPWidenRecipe *WideMember0, unsigned OpIdx,
return false;
}
-/// Returns true if \p IR is a full interleave group with factor and number of
-/// members both equal to \p VF. The interleave group must also access the full
-/// vector width \p VectorRegWidth.
-static bool isConsecutiveInterleaveGroup(VPInterleaveRecipe *InterleaveR,
- unsigned VF, VPTypeAnalysis &TypeInfo,
- unsigned VectorRegWidth) {
- if (!InterleaveR)
- return false;
+/// Returns VF from \p VFs if \p IR is a full interleave group with factor and
+/// number of members both equal to VF. The interleave group must also access
+/// the full vector width.
+static std::optional<ElementCount> isConsecutiveInterleaveGroup(
+ VPInterleaveRecipe *InterleaveR, ArrayRef<ElementCount> VFs,
+ VPTypeAnalysis &TypeInfo, const TargetTransformInfo &TTI) {
+ if (!InterleaveR || InterleaveR->getMask())
+ return std::nullopt;
Type *GroupElementTy = nullptr;
if (InterleaveR->getStoredValues().empty()) {
@@ -4063,7 +4061,7 @@ static bool isConsecutiveInterleaveGroup(VPInterleaveRecipe *InterleaveR,
[&TypeInfo, GroupElementTy](VPValue *Op) {
return TypeInfo.inferScalarType(Op) == GroupElementTy;
}))
- return false;
+ return std::nullopt;
} else {
GroupElementTy =
TypeInfo.inferScalarType(InterleaveR->getStoredValues()[0]);
@@ -4071,13 +4069,27 @@ static bool isConsecutiveInterleaveGroup(VPInterleaveRecipe *InterleaveR,
[&TypeInfo, GroupElementTy](VPValue *Op) {
return TypeInfo.inferScalarType(Op) == GroupElementTy;
}))
- return false;
+ return std::nullopt;
}
- unsigned GroupSize = GroupElementTy->getScalarSizeInBits() * VF;
- auto IG = InterleaveR->getInterleaveGroup();
- return IG->getFactor() == VF && IG->getNumMembers() == VF &&
- GroupSize == VectorRegWidth;
+ auto GetVectorWidthForVF = [&TTI](ElementCount VF) {
+ TypeSize Size = TTI.getRegisterBitWidth(
+ VF.isFixed() ? TargetTransformInfo::RGK_FixedWidthVector
+ : TargetTransformInfo::RGK_ScalableVector);
+ assert(Size.isScalable() == VF.isScalable() &&
+ "if Size is scalable, VF must to and vice versa");
+ return Size.getKnownMinValue();
+ };
+
+ for (ElementCount VF : VFs) {
+ unsigned MinVal = VF.getKnownMinValue();
+ unsigned GroupSize = GroupElementTy->getScalarSizeInBits() * MinVal;
+ auto IG = InterleaveR->getInterleaveGroup();
+ if (IG->getFactor() == MinVal && IG->getNumMembers() == MinVal &&
+ GroupSize == GetVectorWidthForVF(VF))
+ return {VF};
+ }
+ return std::nullopt;
}
/// Returns true if \p VPValue is a narrow VPValue.
@@ -4088,16 +4100,18 @@ static bool isAlreadyNarrow(VPValue *VPV) {
return RepR && RepR->isSingleScalar();
}
-void VPlanTransforms::narrowInterleaveGroups(VPlan &Plan, ElementCount VF,
- unsigned VectorRegWidth) {
+std::unique_ptr<VPlan>
+VPlanTransforms::narrowInterleaveGroups(VPlan &Plan,
+ const TargetTransformInfo &TTI) {
+ using namespace llvm::VPlanPatternMatch;
VPRegionBlock *VectorLoop = Plan.getVectorLoopRegion();
+
if (!VectorLoop)
- return;
+ return nullptr;
VPTypeAnalysis TypeInfo(Plan);
-
- unsigned VFMinVal = VF.getKnownMinValue();
SmallVector<VPInterleaveRecipe *> StoreGroups;
+ std::optional<ElementCount> VFToOptimize;
for (auto &R : *VectorLoop->getEntryBasicBlock()) {
if (isa<VPCanonicalIVPHIRecipe>(&R) || match(&R, m_BranchOnCount()))
continue;
@@ -4111,30 +4125,33 @@ void VPlanTransforms::narrowInterleaveGroups(VPlan &Plan, ElementCount VF,
// * recipes writing to memory except interleave groups
// Only support plans with a canonical induction phi.
if (R.isPhi())
- return;
+ return nullptr;
auto *InterleaveR = dyn_cast<VPInterleaveRecipe>(&R);
if (R.mayWriteToMemory() && !InterleaveR)
- return;
-
- // Do not narrow interleave groups if there are VectorPointer recipes and
- // the plan was unrolled. The recipe implicitly uses VF from
- // VPTransformState.
- // TODO: Remove restriction once the VF for the VectorPointer offset is
- // modeled explicitly as operand.
- if (isa<VPVectorPointerRecipe>(&R) && Plan.getUF() > 1)
- return;
+ return nullptr;
// All other ops are allowed, but we reject uses that cannot be converted
// when checking all allowed consumers (store interleave groups) below.
if (!InterleaveR)
continue;
- // Bail out on non-consecutive interleave groups.
- if (!isConsecutiveInterleaveGroup(InterleaveR, VFMinVal, TypeInfo,
- VectorRegWidth))
- return;
-
+ // Try to find a single VF, where all interleave groups are consecutive and
+ // saturate the full vector width. If we already have a candidate VF, check
+ // if it is applicable for the current InterleaveR, otherwise look for a
+ // suitable VF across the Plans VFs.
+ //
+ if (VFToOptimize) {
+ if (!isConsecutiveInterleaveGroup(InterleaveR, {*VFToOptimize}, TypeInfo,
+ TTI))
+ return nullptr;
+ } else {
+ if (auto VF = isConsecutiveInterleaveGroup(
+ InterleaveR, to_vector(Plan.vectorFactors()), TypeInfo, TTI))
+ VFToOptimize = *VF;
+ else
+ return nullptr;
+ }
// Skip read interleave groups.
if (InterleaveR->getStoredValues().empty())
continue;
@@ -4168,24 +4185,34 @@ void VPlanTransforms::narrowInterleaveGroups(VPlan &Plan, ElementCount VF,
auto *WideMember0 = dyn_cast_or_null<VPWidenRecipe>(
InterleaveR->getStoredValues()[0]->getDefiningRecipe());
if (!WideMember0)
- return;
+ return nullptr;
for (const auto &[I, V] : enumerate(InterleaveR->getStoredValues())) {
auto *R = dyn_cast_or_null<VPWidenRecipe>(V->getDefiningRecipe());
if (!R || R->getOpcode() != WideMember0->getOpcode() ||
R->getNumOperands() > 2)
- return;
+ return nullptr;
if (any_of(enumerate(R->operands()),
[WideMember0, Idx = I](const auto &P) {
const auto &[OpIdx, OpV] = P;
return !canNarrowLoad(WideMember0, OpIdx, OpV, Idx);
}))
- return;
+ return nullptr;
}
StoreGroups.push_back(InterleaveR);
}
if (StoreGroups.empty())
- return;
+ return nullptr;
+
+ // All interleave groups in Plan can be narrowed for VFToOptimize. Split the
+ // original Plan into 2: a) a new clone which contains all VFs of Plan, except
+ // VFToOptimize, and b) the original Plan with VFToOptimize as single VF.
+ std::unique_ptr<VPlan> NewPlan;
+ if (size(Plan.vectorFactors()) != 1) {
+ NewPlan = std::unique_ptr<VPlan>(Plan.duplicate());
+ Plan.setVF(*VFToOptimize);
+ NewPlan->removeVF(*VFToOptimize);
+ }
// Convert InterleaveGroup \p R to a single VPWidenLoadRecipe.
SmallPtrSet<VPValue *, 4> NarrowedOps;
@@ -4256,9 +4283,8 @@ void VPlanTransforms::narrowInterleaveGroups(VPlan &Plan, ElementCount VF,
auto *Inc = cast<VPInstruction>(CanIV->getBackedgeValue());
VPBuilder PHBuilder(Plan.getVectorPreheader());
- VPValue *UF = Plan.getOrAddLiveIn(
- ConstantInt::get(CanIV->getScalarType(), 1 * Plan.getUF()));
- if (VF.isScalable()) {
+ VPValue *UF = &Plan.getSymbolicUF();
+ if (VFToOptimize->isScalable()) {
VPValue *VScale = PHBuilder.createElementCount(
CanIV->getScalarType(), ElementCount::getScalable(1));
VPValue *VScaleUF = PHBuilder.createNaryOp(Instruction::Mul, {VScale, UF});
@@ -4270,6 +4296,10 @@ void VPlanTransforms::narrowInterleaveGroups(VPlan &Plan, ElementCount VF,
Plan.getOrAddLiveIn(ConstantInt::get(CanIV->getScalarType(), 1)));
}
removeDeadRecipes(Plan);
+ assert(none_of(*VectorLoop->getEntryBasicBlock(),
+ IsaPred<VPVectorPointerRecipe>) &&
+ "All VPVectorPointerRecipes should have been removed");
+ return NewPlan;
}
/// Add branch weight metadata, if the \p Plan's middle block is terminated by a
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h
index b28559b..ca8d956 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h
@@ -341,14 +341,20 @@ struct VPlanTransforms {
static DenseMap<const SCEV *, Value *> expandSCEVs(VPlan &Plan,
ScalarEvolution &SE);
- /// Try to convert a plan with interleave groups with VF elements to a plan
- /// with the interleave groups replaced by wide loads and stores processing VF
- /// elements, if all transformed interleave groups access the full vector
- /// width (checked via \o VectorRegWidth). This effectively is a very simple
- /// form of loop-aware SLP, where we use interleave groups to identify
- /// candidates.
- static void narrowInterleaveGroups(VPlan &Plan, ElementCount VF,
- unsigned VectorRegWidth);
+ /// Try to find a single VF among \p Plan's VFs for which all interleave
+ /// groups (with known minimum VF elements) can be replaced by wide loads and
+ /// stores processing VF elements, if all transformed interleave groups access
+ /// the full vector width (checked via the maximum vector register width). If
+ /// the transformation can be applied, the original \p Plan will be split in
+ /// 2:
+ /// 1. The original Plan with the single VF containing the optimized recipes
+ /// using wide loads instead of interleave groups.
+ /// 2. A new clone which contains all VFs of Plan except the optimized VF.
+ ///
+ /// This effectively is a very simple form of loop-aware SLP, where we use
+ /// interleave groups to identify candidates.
+ static std::unique_ptr<VPlan>
+ narrowInterleaveGroups(VPlan &Plan, const TargetTransformInfo &TTI);
/// Predicate and linearize the control-flow in the only loop region of
/// \p Plan. If \p FoldTail is true, create a mask guarding the loop
diff --git a/llvm/lib/Transforms/Vectorize/VPlanUtils.cpp b/llvm/lib/Transforms/Vectorize/VPlanUtils.cpp
index 10801c0..32e4b88 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanUtils.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanUtils.cpp
@@ -8,6 +8,7 @@
#include "VPlanUtils.h"
#include "VPlanCFG.h"
+#include "VPlanDominatorTree.h"
#include "VPlanPatternMatch.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
@@ -253,3 +254,29 @@ vputils::getRecipesForUncountableExit(VPlan &Plan,
return UncountableCondition;
}
+
+bool VPBlockUtils::isHeader(const VPBlockBase *VPB,
+ const VPDominatorTree &VPDT) {
+ auto *VPBB = dyn_cast<VPBasicBlock>(VPB);
+ if (!VPBB)
+ return false;
+
+ // If VPBB is in a region R, VPBB is a loop header if R is a loop region with
+ // VPBB as its entry, i.e., free of predecessors.
+ if (auto *R = VPBB->getParent())
+ return !R->isReplicator() && !VPBB->hasPredecessors();
+
+ // A header dominates its second predecessor (the latch), with the other
+ // predecessor being the preheader
+ return VPB->getPredecessors().size() == 2 &&
+ VPDT.dominates(VPB, VPB->getPredecessors()[1]);
+}
+
+bool VPBlockUtils::isLatch(const VPBlockBase *VPB,
+ const VPDominatorTree &VPDT) {
+ // A latch has a header as its second successor, with its other successor
+ // leaving the loop. A preheader OTOH has a header as its first (and only)
+ // successor.
+ return VPB->getNumSuccessors() == 2 &&
+ VPBlockUtils::isHeader(VPB->getSuccessors()[1], VPDT);
+}
diff --git a/llvm/lib/Transforms/Vectorize/VPlanValue.h b/llvm/lib/Transforms/Vectorize/VPlanValue.h
index 0678bc90..83e3fca 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanValue.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanValue.h
@@ -41,10 +41,10 @@ class VPRecipeBase;
class VPInterleaveBase;
class VPPhiAccessors;
-// This is the base class of the VPlan Def/Use graph, used for modeling the data
-// flow into, within and out of the VPlan. VPValues can stand for live-ins
-// coming from the input IR and instructions which VPlan will generate if
-// executed.
+/// This is the base class of the VPlan Def/Use graph, used for modeling the
+/// data flow into, within and out of the VPlan. VPValues can stand for live-ins
+/// coming from the input IR and instructions which VPlan will generate if
+/// executed.
class LLVM_ABI_FOR_TEST VPValue {
friend class VPDef;
friend struct VPDoubleValueDef;
@@ -57,7 +57,7 @@ class LLVM_ABI_FOR_TEST VPValue {
SmallVector<VPUser *, 1> Users;
protected:
- // Hold the underlying Value, if any, attached to this VPValue.
+ /// Hold the underlying Value, if any, attached to this VPValue.
Value *UnderlyingVal;
/// Pointer to the VPDef that defines this VPValue. If it is nullptr, the