aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Transforms')
-rw-r--r--llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp31
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp25
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp35
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineInternal.h3
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp26
-rw-r--r--llvm/lib/Transforms/InstCombine/InstructionCombining.cpp49
-rw-r--r--llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp40
-rw-r--r--llvm/lib/Transforms/Scalar/ConstraintElimination.cpp20
-rw-r--r--llvm/lib/Transforms/Scalar/MergeICmps.cpp34
-rw-r--r--llvm/lib/Transforms/Vectorize/LoopVectorize.cpp7
-rw-r--r--llvm/lib/Transforms/Vectorize/VPlan.cpp1
-rw-r--r--llvm/lib/Transforms/Vectorize/VPlan.h39
-rw-r--r--llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp67
-rw-r--r--llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp362
-rw-r--r--llvm/lib/Transforms/Vectorize/VPlanTransforms.h22
-rw-r--r--llvm/lib/Transforms/Vectorize/VPlanUtils.cpp5
16 files changed, 466 insertions, 300 deletions
diff --git a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
index a0f7ec6..2dd0fde 100644
--- a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
+++ b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
@@ -948,17 +948,17 @@ void llvm::updateVCallVisibilityInIndex(
// linker, as we have no information on their eventual use.
if (DynamicExportSymbols.count(P.first))
continue;
+ // With validation enabled, we want to exclude symbols visible to regular
+ // objects. Local symbols will be in this group due to the current
+ // implementation but those with VCallVisibilityTranslationUnit will have
+ // already been marked in clang so are unaffected.
+ if (VisibleToRegularObjSymbols.count(P.first))
+ continue;
for (auto &S : P.second.getSummaryList()) {
auto *GVar = dyn_cast<GlobalVarSummary>(S.get());
if (!GVar ||
GVar->getVCallVisibility() != GlobalObject::VCallVisibilityPublic)
continue;
- // With validation enabled, we want to exclude symbols visible to regular
- // objects. Local symbols will be in this group due to the current
- // implementation but those with VCallVisibilityTranslationUnit will have
- // already been marked in clang so are unaffected.
- if (VisibleToRegularObjSymbols.count(P.first))
- continue;
GVar->setVCallVisibility(GlobalObject::VCallVisibilityLinkageUnit);
}
}
@@ -1161,14 +1161,10 @@ bool DevirtIndex::tryFindVirtualCallTargets(
// and therefore the same GUID. This can happen if there isn't enough
// distinguishing path when compiling the source file. In that case we
// conservatively return false early.
+ if (P.VTableVI.hasLocal() && P.VTableVI.getSummaryList().size() > 1)
+ return false;
const GlobalVarSummary *VS = nullptr;
- bool LocalFound = false;
for (const auto &S : P.VTableVI.getSummaryList()) {
- if (GlobalValue::isLocalLinkage(S->linkage())) {
- if (LocalFound)
- return false;
- LocalFound = true;
- }
auto *CurVS = cast<GlobalVarSummary>(S->getBaseObject());
if (!CurVS->vTableFuncs().empty() ||
// Previously clang did not attach the necessary type metadata to
@@ -1184,6 +1180,7 @@ bool DevirtIndex::tryFindVirtualCallTargets(
// with public LTO visibility.
if (VS->getVCallVisibility() == GlobalObject::VCallVisibilityPublic)
return false;
+ break;
}
}
// There will be no VS if all copies are available_externally having no
@@ -1411,9 +1408,8 @@ bool DevirtIndex::trySingleImplDevirt(MutableArrayRef<ValueInfo> TargetsForSlot,
// If the summary list contains multiple summaries where at least one is
// a local, give up, as we won't know which (possibly promoted) name to use.
- for (const auto &S : TheFn.getSummaryList())
- if (GlobalValue::isLocalLinkage(S->linkage()) && Size > 1)
- return false;
+ if (TheFn.hasLocal() && Size > 1)
+ return false;
// Collect functions devirtualized at least for one call site for stats.
if (PrintSummaryDevirt || AreStatisticsEnabled())
@@ -2591,6 +2587,11 @@ void DevirtIndex::run() {
if (ExportSummary.typeIdCompatibleVtableMap().empty())
return;
+ // Assert that we haven't made any changes that would affect the hasLocal()
+ // flag on the GUID summary info.
+ assert(!ExportSummary.withInternalizeAndPromote() &&
+ "Expect index-based WPD to run before internalization and promotion");
+
DenseMap<GlobalValue::GUID, std::vector<StringRef>> NameByGUID;
for (const auto &P : ExportSummary.typeIdCompatibleVtableMap()) {
NameByGUID[GlobalValue::getGUIDAssumingExternalLinkage(P.first)].push_back(
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index 73ec451..9bee523 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -2760,21 +2760,34 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
// Optimize pointer differences into the same array into a size. Consider:
// &A[10] - &A[0]: we should compile this to "10".
Value *LHSOp, *RHSOp;
- if (match(Op0, m_PtrToInt(m_Value(LHSOp))) &&
- match(Op1, m_PtrToInt(m_Value(RHSOp))))
+ if (match(Op0, m_PtrToIntOrAddr(m_Value(LHSOp))) &&
+ match(Op1, m_PtrToIntOrAddr(m_Value(RHSOp))))
if (Value *Res = OptimizePointerDifference(LHSOp, RHSOp, I.getType(),
I.hasNoUnsignedWrap()))
return replaceInstUsesWith(I, Res);
// trunc(p)-trunc(q) -> trunc(p-q)
- if (match(Op0, m_Trunc(m_PtrToInt(m_Value(LHSOp)))) &&
- match(Op1, m_Trunc(m_PtrToInt(m_Value(RHSOp)))))
+ if (match(Op0, m_Trunc(m_PtrToIntOrAddr(m_Value(LHSOp)))) &&
+ match(Op1, m_Trunc(m_PtrToIntOrAddr(m_Value(RHSOp)))))
if (Value *Res = OptimizePointerDifference(LHSOp, RHSOp, I.getType(),
/* IsNUW */ false))
return replaceInstUsesWith(I, Res);
- if (match(Op0, m_ZExt(m_PtrToIntSameSize(DL, m_Value(LHSOp)))) &&
- match(Op1, m_ZExtOrSelf(m_PtrToInt(m_Value(RHSOp))))) {
+ auto MatchSubOfZExtOfPtrToIntOrAddr = [&]() {
+ if (match(Op0, m_ZExt(m_PtrToIntSameSize(DL, m_Value(LHSOp)))) &&
+ match(Op1, m_ZExt(m_PtrToIntSameSize(DL, m_Value(RHSOp)))))
+ return true;
+ if (match(Op0, m_ZExt(m_PtrToAddr(m_Value(LHSOp)))) &&
+ match(Op1, m_ZExt(m_PtrToAddr(m_Value(RHSOp)))))
+ return true;
+ // Special case for non-canonical ptrtoint in constant expression,
+ // where the zext has been folded into the ptrtoint.
+ if (match(Op0, m_ZExt(m_PtrToIntSameSize(DL, m_Value(LHSOp)))) &&
+ match(Op1, m_PtrToInt(m_Value(RHSOp))))
+ return true;
+ return false;
+ };
+ if (MatchSubOfZExtOfPtrToIntOrAddr()) {
if (auto *GEP = dyn_cast<GEPOperator>(LHSOp)) {
if (GEP->getPointerOperand() == RHSOp) {
if (GEP->hasNoUnsignedWrap() || GEP->hasNoUnsignedSignedWrap()) {
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index dab200d..8d9933b 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -582,6 +582,18 @@ static Instruction *foldCttzCtlz(IntrinsicInst &II, InstCombinerImpl &IC) {
IC.Builder.CreateBinaryIntrinsic(Intrinsic::ctlz, C, Op1);
return BinaryOperator::CreateSub(ConstCtlz, X);
}
+
+ // ctlz(~x & (x - 1)) -> bitwidth - cttz(x, false)
+ if (Op0->hasOneUse() &&
+ match(Op0,
+ m_c_And(m_Not(m_Value(X)), m_Add(m_Deferred(X), m_AllOnes())))) {
+ Type *Ty = II.getType();
+ unsigned BitWidth = Ty->getScalarSizeInBits();
+ auto *Cttz = IC.Builder.CreateIntrinsic(Intrinsic::cttz, Ty,
+ {X, IC.Builder.getFalse()});
+ auto *Bw = ConstantInt::get(Ty, APInt(BitWidth, BitWidth));
+ return IC.replaceInstUsesWith(II, IC.Builder.CreateSub(Bw, Cttz));
+ }
}
// cttz(Pow2) -> Log2(Pow2)
@@ -4003,18 +4015,29 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
// Try to fold intrinsic into select/phi operands. This is legal if:
// * The intrinsic is speculatable.
- // * The select condition is not a vector, or the intrinsic does not
- // perform cross-lane operations.
- if (isSafeToSpeculativelyExecuteWithVariableReplaced(&CI) &&
- isNotCrossLaneOperation(II))
+ // * The operand is one of the following:
+ // - a phi.
+ // - a select with a scalar condition.
+ // - a select with a vector condition and II is not a cross lane operation.
+ if (isSafeToSpeculativelyExecuteWithVariableReplaced(&CI)) {
for (Value *Op : II->args()) {
- if (auto *Sel = dyn_cast<SelectInst>(Op))
- if (Instruction *R = FoldOpIntoSelect(*II, Sel))
+ if (auto *Sel = dyn_cast<SelectInst>(Op)) {
+ bool IsVectorCond = Sel->getCondition()->getType()->isVectorTy();
+ if (IsVectorCond && !isNotCrossLaneOperation(II))
+ continue;
+ // Don't replace a scalar select with a more expensive vector select if
+ // we can't simplify both arms of the select.
+ bool SimplifyBothArms =
+ !Op->getType()->isVectorTy() && II->getType()->isVectorTy();
+ if (Instruction *R = FoldOpIntoSelect(
+ *II, Sel, /*FoldWithMultiUse=*/false, SimplifyBothArms))
return R;
+ }
if (auto *Phi = dyn_cast<PHINode>(Op))
if (Instruction *R = foldOpIntoPhi(*II, Phi))
return R;
}
+ }
if (Instruction *Shuf = foldShuffledIntrinsicOperands(II))
return Shuf;
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 943c223..ede73f8 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -664,7 +664,8 @@ public:
/// This also works for Cast instructions, which obviously do not have a
/// second operand.
Instruction *FoldOpIntoSelect(Instruction &Op, SelectInst *SI,
- bool FoldWithMultiUse = false);
+ bool FoldWithMultiUse = false,
+ bool SimplifyBothArms = false);
/// This is a convenience wrapper function for the above two functions.
Instruction *foldBinOpIntoSelectOrPhi(BinaryOperator &I);
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 5aa8de3..f5130da 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -4697,5 +4697,31 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
cast<IntrinsicInst>(TrueVal)->getParamAlign(0).valueOrOne(),
CondVal, FalseVal));
+ // Canonicalize sign function ashr pattern: select (icmp slt X, 1), ashr X,
+ // bitwidth-1, 1 -> scmp(X, 0)
+ // Also handles: select (icmp sgt X, 0), 1, ashr X, bitwidth-1 -> scmp(X, 0)
+ unsigned BitWidth = SI.getType()->getScalarSizeInBits();
+ CmpPredicate Pred;
+ Value *CmpLHS, *CmpRHS;
+
+ // Canonicalize sign function ashr patterns:
+ // select (icmp slt X, 1), ashr X, bitwidth-1, 1 -> scmp(X, 0)
+ // select (icmp sgt X, 0), 1, ashr X, bitwidth-1 -> scmp(X, 0)
+ if (match(&SI, m_Select(m_ICmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS)),
+ m_Value(TrueVal), m_Value(FalseVal))) &&
+ ((Pred == ICmpInst::ICMP_SLT && match(CmpRHS, m_One()) &&
+ match(TrueVal,
+ m_AShr(m_Specific(CmpLHS), m_SpecificInt(BitWidth - 1))) &&
+ match(FalseVal, m_One())) ||
+ (Pred == ICmpInst::ICMP_SGT && match(CmpRHS, m_Zero()) &&
+ match(TrueVal, m_One()) &&
+ match(FalseVal,
+ m_AShr(m_Specific(CmpLHS), m_SpecificInt(BitWidth - 1)))))) {
+
+ Function *Scmp = Intrinsic::getOrInsertDeclaration(
+ SI.getModule(), Intrinsic::scmp, {SI.getType(), SI.getType()});
+ return CallInst::Create(Scmp, {CmpLHS, ConstantInt::get(SI.getType(), 0)});
+ }
+
return nullptr;
}
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index 3f11cae..9c8de45 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -1777,7 +1777,8 @@ static Value *foldOperationIntoSelectOperand(Instruction &I, SelectInst *SI,
}
Instruction *InstCombinerImpl::FoldOpIntoSelect(Instruction &Op, SelectInst *SI,
- bool FoldWithMultiUse) {
+ bool FoldWithMultiUse,
+ bool SimplifyBothArms) {
// Don't modify shared select instructions unless set FoldWithMultiUse
if (!SI->hasOneUse() && !FoldWithMultiUse)
return nullptr;
@@ -1821,6 +1822,9 @@ Instruction *InstCombinerImpl::FoldOpIntoSelect(Instruction &Op, SelectInst *SI,
if (!NewTV && !NewFV)
return nullptr;
+ if (SimplifyBothArms && !(NewTV && NewFV))
+ return nullptr;
+
// Create an instruction for the arm that did not fold.
if (!NewTV)
NewTV = foldOperationIntoSelectOperand(Op, SI, TV, *this);
@@ -2323,6 +2327,18 @@ Constant *InstCombinerImpl::unshuffleConstant(ArrayRef<int> ShMask, Constant *C,
return ConstantVector::get(NewVecC);
}
+// Get the result of `Vector Op Splat` (or Splat Op Vector if \p SplatLHS).
+static Constant *constantFoldBinOpWithSplat(unsigned Opcode, Constant *Vector,
+ Constant *Splat, bool SplatLHS,
+ const DataLayout &DL) {
+ ElementCount EC = cast<VectorType>(Vector->getType())->getElementCount();
+ Constant *LHS = ConstantVector::getSplat(EC, Splat);
+ Constant *RHS = Vector;
+ if (!SplatLHS)
+ std::swap(LHS, RHS);
+ return ConstantFoldBinaryOpOperands(Opcode, LHS, RHS, DL);
+}
+
Instruction *InstCombinerImpl::foldVectorBinop(BinaryOperator &Inst) {
if (!isa<VectorType>(Inst.getType()))
return nullptr;
@@ -2334,6 +2350,37 @@ Instruction *InstCombinerImpl::foldVectorBinop(BinaryOperator &Inst) {
assert(cast<VectorType>(RHS->getType())->getElementCount() ==
cast<VectorType>(Inst.getType())->getElementCount());
+ auto foldConstantsThroughSubVectorInsertSplat =
+ [&](Value *MaybeSubVector, Value *MaybeSplat,
+ bool SplatLHS) -> Instruction * {
+ Value *Idx;
+ Constant *Splat, *SubVector, *Dest;
+ if (!match(MaybeSplat, m_ConstantSplat(m_Constant(Splat))) ||
+ !match(MaybeSubVector,
+ m_VectorInsert(m_Constant(Dest), m_Constant(SubVector),
+ m_Value(Idx))))
+ return nullptr;
+ SubVector =
+ constantFoldBinOpWithSplat(Opcode, SubVector, Splat, SplatLHS, DL);
+ Dest = constantFoldBinOpWithSplat(Opcode, Dest, Splat, SplatLHS, DL);
+ if (!SubVector || !Dest)
+ return nullptr;
+ auto *InsertVector =
+ Builder.CreateInsertVector(Dest->getType(), Dest, SubVector, Idx);
+ return replaceInstUsesWith(Inst, InsertVector);
+ };
+
+ // If one operand is a constant splat and the other operand is a
+ // `vector.insert` where both the destination and subvector are constant,
+ // apply the operation to both the destination and subvector, returning a new
+ // constant `vector.insert`. This helps constant folding for scalable vectors.
+ if (Instruction *Folded = foldConstantsThroughSubVectorInsertSplat(
+ /*MaybeSubVector=*/LHS, /*MaybeSplat=*/RHS, /*SplatLHS=*/false))
+ return Folded;
+ if (Instruction *Folded = foldConstantsThroughSubVectorInsertSplat(
+ /*MaybeSubVector=*/RHS, /*MaybeSplat=*/LHS, /*SplatLHS=*/true))
+ return Folded;
+
// If both operands of the binop are vector concatenations, then perform the
// narrow binop on each pair of the source operands followed by concatenation
// of the results.
diff --git a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp
index b6cbecb..10b03bb 100644
--- a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp
+++ b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp
@@ -226,6 +226,7 @@ static const Align kMinOriginAlignment = Align(4);
static const Align kShadowTLSAlignment = Align(8);
// These constants must be kept in sync with the ones in msan.h.
+// TODO: increase size to match SVE/SVE2/SME/SME2 limits
static const unsigned kParamTLSSize = 800;
static const unsigned kRetvalTLSSize = 800;
@@ -1544,6 +1545,22 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
}
}
+ static bool isAArch64SVCount(Type *Ty) {
+ if (TargetExtType *TTy = dyn_cast<TargetExtType>(Ty))
+ return TTy->getName() == "aarch64.svcount";
+ return false;
+ }
+
+ // This is intended to match the "AArch64 Predicate-as-Counter Type" (aka
+ // 'target("aarch64.svcount")', but not e.g., <vscale x 4 x i32>.
+ static bool isScalableNonVectorType(Type *Ty) {
+ if (!isAArch64SVCount(Ty))
+ LLVM_DEBUG(dbgs() << "isScalableNonVectorType: Unexpected type " << *Ty
+ << "\n");
+
+ return Ty->isScalableTy() && !isa<VectorType>(Ty);
+ }
+
void materializeChecks() {
#ifndef NDEBUG
// For assert below.
@@ -1672,6 +1689,12 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
LLVM_DEBUG(dbgs() << "getShadowTy: " << *ST << " ===> " << *Res << "\n");
return Res;
}
+ if (isScalableNonVectorType(OrigTy)) {
+ LLVM_DEBUG(dbgs() << "getShadowTy: Scalable non-vector type: " << *OrigTy
+ << "\n");
+ return OrigTy;
+ }
+
uint32_t TypeSize = DL.getTypeSizeInBits(OrigTy);
return IntegerType::get(*MS.C, TypeSize);
}
@@ -2185,8 +2208,14 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
<< *OrigIns << "\n");
return;
}
-#ifndef NDEBUG
+
Type *ShadowTy = Shadow->getType();
+ if (isScalableNonVectorType(ShadowTy)) {
+ LLVM_DEBUG(dbgs() << "Skipping check of scalable non-vector " << *Shadow
+ << " before " << *OrigIns << "\n");
+ return;
+ }
+#ifndef NDEBUG
assert((isa<IntegerType>(ShadowTy) || isa<VectorType>(ShadowTy) ||
isa<StructType>(ShadowTy) || isa<ArrayType>(ShadowTy)) &&
"Can only insert checks for integer, vector, and aggregate shadow "
@@ -6972,6 +7001,15 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
// an extra "select". This results in much more compact IR.
// Sa = select Sb, poisoned, (select b, Sc, Sd)
Sa1 = getPoisonedShadow(getShadowTy(I.getType()));
+ } else if (isScalableNonVectorType(I.getType())) {
+ // This is intended to handle target("aarch64.svcount"), which can't be
+ // handled in the else branch because of incompatibility with CreateXor
+ // ("The supported LLVM operations on this type are limited to load,
+ // store, phi, select and alloca instructions").
+
+ // TODO: this currently underapproximates. Use Arm SVE EOR in the else
+ // branch as needed instead.
+ Sa1 = getCleanShadow(getShadowTy(I.getType()));
} else {
// Sa = select Sb, [ (c^d) | Sc | Sd ], [ b ? Sc : Sd ]
// If Sb (condition is poisoned), look for bits in c and d that are equal
diff --git a/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp b/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp
index 4acc3f2..d347ced 100644
--- a/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp
+++ b/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp
@@ -614,6 +614,16 @@ static Decomposition decompose(Value *V,
return {V, IsKnownNonNegative};
}
+ if (match(V, m_Add(m_Value(Op0), m_ConstantInt(CI))) && CI->isNegative() &&
+ canUseSExt(CI)) {
+ Preconditions.emplace_back(
+ CmpInst::ICMP_UGE, Op0,
+ ConstantInt::get(Op0->getType(), CI->getSExtValue() * -1));
+ if (auto Decomp = MergeResults(Op0, CI, true))
+ return *Decomp;
+ return {V, IsKnownNonNegative};
+ }
+
if (match(V, m_NSWAdd(m_Value(Op0), m_Value(Op1)))) {
if (!isKnownNonNegative(Op0, DL))
Preconditions.emplace_back(CmpInst::ICMP_SGE, Op0,
@@ -627,16 +637,6 @@ static Decomposition decompose(Value *V,
return {V, IsKnownNonNegative};
}
- if (match(V, m_Add(m_Value(Op0), m_ConstantInt(CI))) && CI->isNegative() &&
- canUseSExt(CI)) {
- Preconditions.emplace_back(
- CmpInst::ICMP_UGE, Op0,
- ConstantInt::get(Op0->getType(), CI->getSExtValue() * -1));
- if (auto Decomp = MergeResults(Op0, CI, true))
- return *Decomp;
- return {V, IsKnownNonNegative};
- }
-
// Decompose or as an add if there are no common bits between the operands.
if (match(V, m_DisjointOr(m_Value(Op0), m_ConstantInt(CI)))) {
if (auto Decomp = MergeResults(Op0, CI, IsSigned))
diff --git a/llvm/lib/Transforms/Scalar/MergeICmps.cpp b/llvm/lib/Transforms/Scalar/MergeICmps.cpp
index a83cbd17a7..f273e9d 100644
--- a/llvm/lib/Transforms/Scalar/MergeICmps.cpp
+++ b/llvm/lib/Transforms/Scalar/MergeICmps.cpp
@@ -64,10 +64,10 @@
using namespace llvm;
-namespace {
-
#define DEBUG_TYPE "mergeicmps"
+namespace {
+
// A BCE atom "Binary Compare Expression Atom" represents an integer load
// that is a constant offset from a base value, e.g. `a` or `o.c` in the example
// at the top.
@@ -128,11 +128,12 @@ private:
unsigned Order = 1;
DenseMap<const Value*, int> BaseToIndex;
};
+} // namespace
// If this value is a load from a constant offset w.r.t. a base address, and
// there are no other users of the load or address, returns the base address and
// the offset.
-BCEAtom visitICmpLoadOperand(Value *const Val, BaseIdentifier &BaseId) {
+static BCEAtom visitICmpLoadOperand(Value *const Val, BaseIdentifier &BaseId) {
auto *const LoadI = dyn_cast<LoadInst>(Val);
if (!LoadI)
return {};
@@ -175,6 +176,7 @@ BCEAtom visitICmpLoadOperand(Value *const Val, BaseIdentifier &BaseId) {
return BCEAtom(GEP, LoadI, BaseId.getBaseId(Base), Offset);
}
+namespace {
// A comparison between two BCE atoms, e.g. `a == o.a` in the example at the
// top.
// Note: the terminology is misleading: the comparison is symmetric, so there
@@ -239,6 +241,7 @@ class BCECmpBlock {
private:
BCECmp Cmp;
};
+} // namespace
bool BCECmpBlock::canSinkBCECmpInst(const Instruction *Inst,
AliasAnalysis &AA) const {
@@ -302,9 +305,9 @@ bool BCECmpBlock::doesOtherWork() const {
// Visit the given comparison. If this is a comparison between two valid
// BCE atoms, returns the comparison.
-std::optional<BCECmp> visitICmp(const ICmpInst *const CmpI,
- const ICmpInst::Predicate ExpectedPredicate,
- BaseIdentifier &BaseId) {
+static std::optional<BCECmp>
+visitICmp(const ICmpInst *const CmpI,
+ const ICmpInst::Predicate ExpectedPredicate, BaseIdentifier &BaseId) {
// The comparison can only be used once:
// - For intermediate blocks, as a branch condition.
// - For the final block, as an incoming value for the Phi.
@@ -332,10 +335,9 @@ std::optional<BCECmp> visitICmp(const ICmpInst *const CmpI,
// Visit the given comparison block. If this is a comparison between two valid
// BCE atoms, returns the comparison.
-std::optional<BCECmpBlock> visitCmpBlock(Value *const Val,
- BasicBlock *const Block,
- const BasicBlock *const PhiBlock,
- BaseIdentifier &BaseId) {
+static std::optional<BCECmpBlock>
+visitCmpBlock(Value *const Val, BasicBlock *const Block,
+ const BasicBlock *const PhiBlock, BaseIdentifier &BaseId) {
if (Block->empty())
return std::nullopt;
auto *const BranchI = dyn_cast<BranchInst>(Block->getTerminator());
@@ -397,6 +399,7 @@ static inline void enqueueBlock(std::vector<BCECmpBlock> &Comparisons,
Comparisons.push_back(std::move(Comparison));
}
+namespace {
// A chain of comparisons.
class BCECmpChain {
public:
@@ -420,6 +423,7 @@ private:
// The original entry block (before sorting);
BasicBlock *EntryBlock_;
};
+} // namespace
static bool areContiguous(const BCECmpBlock &First, const BCECmpBlock &Second) {
return First.Lhs().BaseId == Second.Lhs().BaseId &&
@@ -742,9 +746,8 @@ bool BCECmpChain::simplify(const TargetLibraryInfo &TLI, AliasAnalysis &AA,
return true;
}
-std::vector<BasicBlock *> getOrderedBlocks(PHINode &Phi,
- BasicBlock *const LastBlock,
- int NumBlocks) {
+static std::vector<BasicBlock *>
+getOrderedBlocks(PHINode &Phi, BasicBlock *const LastBlock, int NumBlocks) {
// Walk up from the last block to find other blocks.
std::vector<BasicBlock *> Blocks(NumBlocks);
assert(LastBlock && "invalid last block");
@@ -777,8 +780,8 @@ std::vector<BasicBlock *> getOrderedBlocks(PHINode &Phi,
return Blocks;
}
-bool processPhi(PHINode &Phi, const TargetLibraryInfo &TLI, AliasAnalysis &AA,
- DomTreeUpdater &DTU) {
+static bool processPhi(PHINode &Phi, const TargetLibraryInfo &TLI,
+ AliasAnalysis &AA, DomTreeUpdater &DTU) {
LLVM_DEBUG(dbgs() << "processPhi()\n");
if (Phi.getNumIncomingValues() <= 1) {
LLVM_DEBUG(dbgs() << "skip: only one incoming value in phi\n");
@@ -874,6 +877,7 @@ static bool runImpl(Function &F, const TargetLibraryInfo &TLI,
return MadeChange;
}
+namespace {
class MergeICmpsLegacyPass : public FunctionPass {
public:
static char ID;
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index d2c100c9..3356516 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -7231,6 +7231,9 @@ DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan(
return DenseMap<const SCEV *, Value *>();
}
+ VPlanTransforms::narrowInterleaveGroups(
+ BestVPlan, BestVF,
+ TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector));
VPlanTransforms::removeDeadRecipes(BestVPlan);
VPlanTransforms::convertToConcreteRecipes(BestVPlan);
@@ -8199,10 +8202,6 @@ void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF,
if (CM.foldTailWithEVL())
VPlanTransforms::runPass(VPlanTransforms::addExplicitVectorLength,
*Plan, CM.getMaxSafeElements());
-
- if (auto P = VPlanTransforms::narrowInterleaveGroups(*Plan, TTI))
- VPlans.push_back(std::move(P));
-
assert(verifyVPlanIsValid(*Plan) && "VPlan is invalid");
VPlans.push_back(std::move(Plan));
}
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.cpp b/llvm/lib/Transforms/Vectorize/VPlan.cpp
index c95c887..428a8f4 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlan.cpp
@@ -1191,7 +1191,6 @@ VPlan *VPlan::duplicate() {
}
Old2NewVPValues[&VectorTripCount] = &NewPlan->VectorTripCount;
Old2NewVPValues[&VF] = &NewPlan->VF;
- Old2NewVPValues[&UF] = &NewPlan->UF;
Old2NewVPValues[&VFxUF] = &NewPlan->VFxUF;
if (BackedgeTakenCount) {
NewPlan->BackedgeTakenCount = new VPValue();
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 167ba55..2591df8 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2712,7 +2712,8 @@ public:
static inline bool classof(const VPRecipeBase *R) {
return R->getVPDefID() == VPRecipeBase::VPReductionSC ||
- R->getVPDefID() == VPRecipeBase::VPReductionEVLSC;
+ R->getVPDefID() == VPRecipeBase::VPReductionEVLSC ||
+ R->getVPDefID() == VPRecipeBase::VPPartialReductionSC;
}
static inline bool classof(const VPUser *U) {
@@ -2783,7 +2784,10 @@ public:
Opcode(Opcode), VFScaleFactor(ScaleFactor) {
[[maybe_unused]] auto *AccumulatorRecipe =
getChainOp()->getDefiningRecipe();
- assert((isa<VPReductionPHIRecipe>(AccumulatorRecipe) ||
+ // When cloning as part of a VPExpressionRecipe the chain op could have
+ // replaced by a temporary VPValue, so it doesn't have a defining recipe.
+ assert((!AccumulatorRecipe ||
+ isa<VPReductionPHIRecipe>(AccumulatorRecipe) ||
isa<VPPartialReductionRecipe>(AccumulatorRecipe)) &&
"Unexpected operand order for partial reduction recipe");
}
@@ -3093,6 +3097,11 @@ public:
/// removed before codegen.
void decompose();
+ unsigned getVFScaleFactor() const {
+ auto *PR = dyn_cast<VPPartialReductionRecipe>(ExpressionRecipes.back());
+ return PR ? PR->getVFScaleFactor() : 1;
+ }
+
/// Method for generating code, must not be called as this recipe is abstract.
void execute(VPTransformState &State) override {
llvm_unreachable("recipe must be removed before execute");
@@ -4152,9 +4161,6 @@ class VPlan {
/// Represents the vectorization factor of the loop.
VPValue VF;
- /// Represents the symbolic unroll factor of the loop.
- VPValue UF;
-
/// Represents the loop-invariant VF * UF of the vector loop region.
VPValue VFxUF;
@@ -4166,11 +4172,6 @@ class VPlan {
/// definitions are VPValues that hold a pointer to their underlying IR.
SmallVector<VPValue *, 16> VPLiveIns;
- /// Mapping from SCEVs to the VPValues representing their expansions.
- /// NOTE: This mapping is temporary and will be removed once all users have
- /// been modeled in VPlan directly.
- DenseMap<const SCEV *, VPValue *> SCEVToExpansion;
-
/// Blocks allocated and owned by the VPlan. They will be deleted once the
/// VPlan is destroyed.
SmallVector<VPBlockBase *> CreatedBlocks;
@@ -4308,9 +4309,6 @@ public:
VPValue &getVF() { return VF; };
const VPValue &getVF() const { return VF; };
- /// Returns the symbolic UF of the vector loop region.
- VPValue &getSymbolicUF() { return UF; };
-
/// Returns VF * UF of the vector loop region.
VPValue &getVFxUF() { return VFxUF; }
@@ -4320,12 +4318,6 @@ public:
void addVF(ElementCount VF) { VFs.insert(VF); }
- /// Remove \p VF from the plan.
- void removeVF(ElementCount VF) {
- assert(hasVF(VF) && "tried to remove VF not present in plan");
- VFs.remove(VF);
- }
-
void setVF(ElementCount VF) {
assert(hasVF(VF) && "Cannot set VF not already in plan");
VFs.clear();
@@ -4427,15 +4419,6 @@ public:
LLVM_DUMP_METHOD void dump() const;
#endif
- VPValue *getSCEVExpansion(const SCEV *S) const {
- return SCEVToExpansion.lookup(S);
- }
-
- void addSCEVExpansion(const SCEV *S, VPValue *V) {
- assert(!SCEVToExpansion.contains(S) && "SCEV already expanded");
- SCEVToExpansion[S] = V;
- }
-
/// Clone the current VPlan, update all VPValues of the new VPlan and cloned
/// recipes to refer to the clones, and return it.
VPlan *duplicate();
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 1f1b42b..931a5b7 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -168,6 +168,7 @@ bool VPRecipeBase::mayHaveSideEffects() const {
return cast<VPWidenIntrinsicRecipe>(this)->mayHaveSideEffects();
case VPBlendSC:
case VPReductionEVLSC:
+ case VPPartialReductionSC:
case VPReductionSC:
case VPScalarIVStepsSC:
case VPVectorPointerSC:
@@ -300,14 +301,23 @@ InstructionCost
VPPartialReductionRecipe::computeCost(ElementCount VF,
VPCostContext &Ctx) const {
std::optional<unsigned> Opcode;
- VPValue *Op = getOperand(0);
- VPRecipeBase *OpR = Op->getDefiningRecipe();
-
- // If the partial reduction is predicated, a select will be operand 0
- if (match(getOperand(1), m_Select(m_VPValue(), m_VPValue(Op), m_VPValue()))) {
- OpR = Op->getDefiningRecipe();
+ VPValue *Op = getVecOp();
+ uint64_t MulConst;
+ // If the partial reduction is predicated, a select will be operand 1.
+ // If it isn't predicated and the mul isn't operating on a constant, then it
+ // should have been turned into a VPExpressionRecipe.
+ // FIXME: Replace the entire function with this once all partial reduction
+ // variants are bundled into VPExpressionRecipe.
+ if (!match(Op, m_Select(m_VPValue(), m_VPValue(Op), m_VPValue())) &&
+ !match(Op, m_Mul(m_VPValue(), m_ConstantInt(MulConst)))) {
+ auto *PhiType = Ctx.Types.inferScalarType(getChainOp());
+ auto *InputType = Ctx.Types.inferScalarType(getVecOp());
+ return Ctx.TTI.getPartialReductionCost(getOpcode(), InputType, InputType,
+ PhiType, VF, TTI::PR_None,
+ TTI::PR_None, {}, Ctx.CostKind);
}
+ VPRecipeBase *OpR = Op->getDefiningRecipe();
Type *InputTypeA = nullptr, *InputTypeB = nullptr;
TTI::PartialReductionExtendKind ExtAType = TTI::PR_None,
ExtBType = TTI::PR_None;
@@ -2856,11 +2866,19 @@ InstructionCost VPExpressionRecipe::computeCost(ElementCount VF,
cast<VPReductionRecipe>(ExpressionRecipes.back())->getRecurrenceKind());
switch (ExpressionType) {
case ExpressionTypes::ExtendedReduction: {
- return Ctx.TTI.getExtendedReductionCost(
- Opcode,
- cast<VPWidenCastRecipe>(ExpressionRecipes.front())->getOpcode() ==
- Instruction::ZExt,
- RedTy, SrcVecTy, std::nullopt, Ctx.CostKind);
+ unsigned Opcode = RecurrenceDescriptor::getOpcode(
+ cast<VPReductionRecipe>(ExpressionRecipes[1])->getRecurrenceKind());
+ auto *ExtR = cast<VPWidenCastRecipe>(ExpressionRecipes[0]);
+ return isa<VPPartialReductionRecipe>(ExpressionRecipes.back())
+ ? Ctx.TTI.getPartialReductionCost(
+ Opcode, Ctx.Types.inferScalarType(getOperand(0)), nullptr,
+ RedTy, VF,
+ TargetTransformInfo::getPartialReductionExtendKind(
+ ExtR->getOpcode()),
+ TargetTransformInfo::PR_None, std::nullopt, Ctx.CostKind)
+ : Ctx.TTI.getExtendedReductionCost(
+ Opcode, ExtR->getOpcode() == Instruction::ZExt, RedTy,
+ SrcVecTy, std::nullopt, Ctx.CostKind);
}
case ExpressionTypes::MulAccReduction:
return Ctx.TTI.getMulAccReductionCost(false, Opcode, RedTy, SrcVecTy,
@@ -2871,6 +2889,19 @@ InstructionCost VPExpressionRecipe::computeCost(ElementCount VF,
Opcode = Instruction::Sub;
[[fallthrough]];
case ExpressionTypes::ExtMulAccReduction: {
+ if (isa<VPPartialReductionRecipe>(ExpressionRecipes.back())) {
+ auto *Ext0R = cast<VPWidenCastRecipe>(ExpressionRecipes[0]);
+ auto *Ext1R = cast<VPWidenCastRecipe>(ExpressionRecipes[1]);
+ auto *Mul = cast<VPWidenRecipe>(ExpressionRecipes[2]);
+ return Ctx.TTI.getPartialReductionCost(
+ Opcode, Ctx.Types.inferScalarType(getOperand(0)),
+ Ctx.Types.inferScalarType(getOperand(1)), RedTy, VF,
+ TargetTransformInfo::getPartialReductionExtendKind(
+ Ext0R->getOpcode()),
+ TargetTransformInfo::getPartialReductionExtendKind(
+ Ext1R->getOpcode()),
+ Mul->getOpcode(), Ctx.CostKind);
+ }
return Ctx.TTI.getMulAccReductionCost(
cast<VPWidenCastRecipe>(ExpressionRecipes.front())->getOpcode() ==
Instruction::ZExt,
@@ -2910,12 +2941,13 @@ void VPExpressionRecipe::print(raw_ostream &O, const Twine &Indent,
O << " = ";
auto *Red = cast<VPReductionRecipe>(ExpressionRecipes.back());
unsigned Opcode = RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind());
+ bool IsPartialReduction = isa<VPPartialReductionRecipe>(Red);
switch (ExpressionType) {
case ExpressionTypes::ExtendedReduction: {
getOperand(1)->printAsOperand(O, SlotTracker);
- O << " +";
- O << " reduce." << Instruction::getOpcodeName(Opcode) << " (";
+ O << " + " << (IsPartialReduction ? "partial." : "") << "reduce.";
+ O << Instruction::getOpcodeName(Opcode) << " (";
getOperand(0)->printAsOperand(O, SlotTracker);
Red->printFlags(O);
@@ -2931,8 +2963,8 @@ void VPExpressionRecipe::print(raw_ostream &O, const Twine &Indent,
}
case ExpressionTypes::ExtNegatedMulAccReduction: {
getOperand(getNumOperands() - 1)->printAsOperand(O, SlotTracker);
- O << " + reduce."
- << Instruction::getOpcodeName(
+ O << " + " << (IsPartialReduction ? "partial." : "") << "reduce.";
+ O << Instruction::getOpcodeName(
RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind()))
<< " (sub (0, mul";
auto *Mul = cast<VPWidenRecipe>(ExpressionRecipes[2]);
@@ -2956,9 +2988,8 @@ void VPExpressionRecipe::print(raw_ostream &O, const Twine &Indent,
case ExpressionTypes::MulAccReduction:
case ExpressionTypes::ExtMulAccReduction: {
getOperand(getNumOperands() - 1)->printAsOperand(O, SlotTracker);
- O << " + ";
- O << "reduce."
- << Instruction::getOpcodeName(
+ O << " + " << (IsPartialReduction ? "partial." : "") << "reduce.";
+ O << Instruction::getOpcodeName(
RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind()))
<< " (";
O << "mul";
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 48cf763..84817d7 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -943,12 +943,40 @@ static void recursivelyDeleteDeadRecipes(VPValue *V) {
}
}
+/// Get any instruction opcode or intrinsic ID data embedded in recipe \p R.
+/// Returns an optional pair, where the first element indicates whether it is
+/// an intrinsic ID.
+static std::optional<std::pair<bool, unsigned>>
+getOpcodeOrIntrinsicID(const VPSingleDefRecipe *R) {
+ return TypeSwitch<const VPSingleDefRecipe *,
+ std::optional<std::pair<bool, unsigned>>>(R)
+ .Case<VPInstruction, VPWidenRecipe, VPWidenCastRecipe,
+ VPWidenSelectRecipe, VPWidenGEPRecipe, VPReplicateRecipe>(
+ [](auto *I) { return std::make_pair(false, I->getOpcode()); })
+ .Case<VPWidenIntrinsicRecipe>([](auto *I) {
+ return std::make_pair(true, I->getVectorIntrinsicID());
+ })
+ .Case<VPVectorPointerRecipe, VPPredInstPHIRecipe>([](auto *I) {
+ // For recipes that do not directly map to LLVM IR instructions,
+ // assign opcodes after the last VPInstruction opcode (which is also
+ // after the last IR Instruction opcode), based on the VPDefID.
+ return std::make_pair(false,
+ VPInstruction::OpsEnd + 1 + I->getVPDefID());
+ })
+ .Default([](auto *) { return std::nullopt; });
+}
+
/// Try to fold \p R using InstSimplifyFolder. Will succeed and return a
-/// non-nullptr Value for a handled \p Opcode if corresponding \p Operands are
-/// foldable live-ins.
-static Value *tryToFoldLiveIns(const VPRecipeBase &R, unsigned Opcode,
- ArrayRef<VPValue *> Operands,
- const DataLayout &DL, VPTypeAnalysis &TypeInfo) {
+/// non-nullptr VPValue for a handled opcode or intrinsic ID if corresponding \p
+/// Operands are foldable live-ins.
+static VPValue *tryToFoldLiveIns(VPSingleDefRecipe &R,
+ ArrayRef<VPValue *> Operands,
+ const DataLayout &DL,
+ VPTypeAnalysis &TypeInfo) {
+ auto OpcodeOrIID = getOpcodeOrIntrinsicID(&R);
+ if (!OpcodeOrIID)
+ return nullptr;
+
SmallVector<Value *, 4> Ops;
for (VPValue *Op : Operands) {
if (!Op->isLiveIn() || !Op->getLiveInIRValue())
@@ -956,43 +984,57 @@ static Value *tryToFoldLiveIns(const VPRecipeBase &R, unsigned Opcode,
Ops.push_back(Op->getLiveInIRValue());
}
- InstSimplifyFolder Folder(DL);
- if (Instruction::isBinaryOp(Opcode))
- return Folder.FoldBinOp(static_cast<Instruction::BinaryOps>(Opcode), Ops[0],
+ auto FoldToIRValue = [&]() -> Value * {
+ InstSimplifyFolder Folder(DL);
+ if (OpcodeOrIID->first) {
+ if (R.getNumOperands() != 2)
+ return nullptr;
+ unsigned ID = OpcodeOrIID->second;
+ return Folder.FoldBinaryIntrinsic(ID, Ops[0], Ops[1],
+ TypeInfo.inferScalarType(&R));
+ }
+ unsigned Opcode = OpcodeOrIID->second;
+ if (Instruction::isBinaryOp(Opcode))
+ return Folder.FoldBinOp(static_cast<Instruction::BinaryOps>(Opcode),
+ Ops[0], Ops[1]);
+ if (Instruction::isCast(Opcode))
+ return Folder.FoldCast(static_cast<Instruction::CastOps>(Opcode), Ops[0],
+ TypeInfo.inferScalarType(R.getVPSingleValue()));
+ switch (Opcode) {
+ case VPInstruction::LogicalAnd:
+ return Folder.FoldSelect(Ops[0], Ops[1],
+ ConstantInt::getNullValue(Ops[1]->getType()));
+ case VPInstruction::Not:
+ return Folder.FoldBinOp(Instruction::BinaryOps::Xor, Ops[0],
+ Constant::getAllOnesValue(Ops[0]->getType()));
+ case Instruction::Select:
+ return Folder.FoldSelect(Ops[0], Ops[1], Ops[2]);
+ case Instruction::ICmp:
+ case Instruction::FCmp:
+ return Folder.FoldCmp(cast<VPRecipeWithIRFlags>(R).getPredicate(), Ops[0],
Ops[1]);
- if (Instruction::isCast(Opcode))
- return Folder.FoldCast(static_cast<Instruction::CastOps>(Opcode), Ops[0],
- TypeInfo.inferScalarType(R.getVPSingleValue()));
- switch (Opcode) {
- case VPInstruction::LogicalAnd:
- return Folder.FoldSelect(Ops[0], Ops[1],
- ConstantInt::getNullValue(Ops[1]->getType()));
- case VPInstruction::Not:
- return Folder.FoldBinOp(Instruction::BinaryOps::Xor, Ops[0],
- Constant::getAllOnesValue(Ops[0]->getType()));
- case Instruction::Select:
- return Folder.FoldSelect(Ops[0], Ops[1], Ops[2]);
- case Instruction::ICmp:
- case Instruction::FCmp:
- return Folder.FoldCmp(cast<VPRecipeWithIRFlags>(R).getPredicate(), Ops[0],
- Ops[1]);
- case Instruction::GetElementPtr: {
- auto &RFlags = cast<VPRecipeWithIRFlags>(R);
- auto *GEP = cast<GetElementPtrInst>(RFlags.getUnderlyingInstr());
- return Folder.FoldGEP(GEP->getSourceElementType(), Ops[0], drop_begin(Ops),
- RFlags.getGEPNoWrapFlags());
- }
- case VPInstruction::PtrAdd:
- case VPInstruction::WidePtrAdd:
- return Folder.FoldGEP(IntegerType::getInt8Ty(TypeInfo.getContext()), Ops[0],
- Ops[1],
- cast<VPRecipeWithIRFlags>(R).getGEPNoWrapFlags());
- // An extract of a live-in is an extract of a broadcast, so return the
- // broadcasted element.
- case Instruction::ExtractElement:
- assert(!Ops[0]->getType()->isVectorTy() && "Live-ins should be scalar");
- return Ops[0];
- }
+ case Instruction::GetElementPtr: {
+ auto &RFlags = cast<VPRecipeWithIRFlags>(R);
+ auto *GEP = cast<GetElementPtrInst>(RFlags.getUnderlyingInstr());
+ return Folder.FoldGEP(GEP->getSourceElementType(), Ops[0],
+ drop_begin(Ops), RFlags.getGEPNoWrapFlags());
+ }
+ case VPInstruction::PtrAdd:
+ case VPInstruction::WidePtrAdd:
+ return Folder.FoldGEP(IntegerType::getInt8Ty(TypeInfo.getContext()),
+ Ops[0], Ops[1],
+ cast<VPRecipeWithIRFlags>(R).getGEPNoWrapFlags());
+ // An extract of a live-in is an extract of a broadcast, so return the
+ // broadcasted element.
+ case Instruction::ExtractElement:
+ assert(!Ops[0]->getType()->isVectorTy() && "Live-ins should be scalar");
+ return Ops[0];
+ }
+ return nullptr;
+ };
+
+ if (Value *V = FoldToIRValue())
+ return R.getParent()->getPlan()->getOrAddLiveIn(V);
return nullptr;
}
@@ -1006,19 +1048,10 @@ static void simplifyRecipe(VPRecipeBase &R, VPTypeAnalysis &TypeInfo) {
// Simplification of live-in IR values for SingleDef recipes using
// InstSimplifyFolder.
- if (TypeSwitch<VPRecipeBase *, bool>(&R)
- .Case<VPInstruction, VPWidenRecipe, VPWidenCastRecipe,
- VPReplicateRecipe, VPWidenSelectRecipe>([&](auto *I) {
- const DataLayout &DL =
- Plan->getScalarHeader()->getIRBasicBlock()->getDataLayout();
- Value *V = tryToFoldLiveIns(*I, I->getOpcode(), I->operands(), DL,
- TypeInfo);
- if (V)
- I->replaceAllUsesWith(Plan->getOrAddLiveIn(V));
- return V;
- })
- .Default([](auto *) { return false; }))
- return;
+ const DataLayout &DL =
+ Plan->getScalarHeader()->getIRBasicBlock()->getDataLayout();
+ if (VPValue *V = tryToFoldLiveIns(*Def, Def->operands(), DL, TypeInfo))
+ return Def->replaceAllUsesWith(V);
// Fold PredPHI LiveIn -> LiveIn.
if (auto *PredPHI = dyn_cast<VPPredInstPHIRecipe>(&R)) {
@@ -1996,29 +2029,6 @@ struct VPCSEDenseMapInfo : public DenseMapInfo<VPSingleDefRecipe *> {
return Def == getEmptyKey() || Def == getTombstoneKey();
}
- /// Get any instruction opcode or intrinsic ID data embedded in recipe \p R.
- /// Returns an optional pair, where the first element indicates whether it is
- /// an intrinsic ID.
- static std::optional<std::pair<bool, unsigned>>
- getOpcodeOrIntrinsicID(const VPSingleDefRecipe *R) {
- return TypeSwitch<const VPSingleDefRecipe *,
- std::optional<std::pair<bool, unsigned>>>(R)
- .Case<VPInstruction, VPWidenRecipe, VPWidenCastRecipe,
- VPWidenSelectRecipe, VPWidenGEPRecipe, VPReplicateRecipe>(
- [](auto *I) { return std::make_pair(false, I->getOpcode()); })
- .Case<VPWidenIntrinsicRecipe>([](auto *I) {
- return std::make_pair(true, I->getVectorIntrinsicID());
- })
- .Case<VPVectorPointerRecipe, VPPredInstPHIRecipe>([](auto *I) {
- // For recipes that do not directly map to LLVM IR instructions,
- // assign opcodes after the last VPInstruction opcode (which is also
- // after the last IR Instruction opcode), based on the VPDefID.
- return std::make_pair(false,
- VPInstruction::OpsEnd + 1 + I->getVPDefID());
- })
- .Default([](auto *) { return std::nullopt; });
- }
-
/// If recipe \p R will lower to a GEP with a non-i8 source element type,
/// return that source element type.
static Type *getGEPSourceElementType(const VPSingleDefRecipe *R) {
@@ -3519,18 +3529,31 @@ tryToMatchAndCreateExtendedReduction(VPReductionRecipe *Red, VPCostContext &Ctx,
VPValue *VecOp = Red->getVecOp();
// Clamp the range if using extended-reduction is profitable.
- auto IsExtendedRedValidAndClampRange = [&](unsigned Opcode, bool isZExt,
- Type *SrcTy) -> bool {
+ auto IsExtendedRedValidAndClampRange =
+ [&](unsigned Opcode, Instruction::CastOps ExtOpc, Type *SrcTy) -> bool {
return LoopVectorizationPlanner::getDecisionAndClampRange(
[&](ElementCount VF) {
auto *SrcVecTy = cast<VectorType>(toVectorTy(SrcTy, VF));
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
- InstructionCost ExtRedCost = Ctx.TTI.getExtendedReductionCost(
- Opcode, isZExt, RedTy, SrcVecTy, Red->getFastMathFlags(),
- CostKind);
+
+ InstructionCost ExtRedCost;
InstructionCost ExtCost =
cast<VPWidenCastRecipe>(VecOp)->computeCost(VF, Ctx);
InstructionCost RedCost = Red->computeCost(VF, Ctx);
+
+ if (isa<VPPartialReductionRecipe>(Red)) {
+ TargetTransformInfo::PartialReductionExtendKind ExtKind =
+ TargetTransformInfo::getPartialReductionExtendKind(ExtOpc);
+ // FIXME: Move partial reduction creation, costing and clamping
+ // here from LoopVectorize.cpp.
+ ExtRedCost = Ctx.TTI.getPartialReductionCost(
+ Opcode, SrcTy, nullptr, RedTy, VF, ExtKind,
+ llvm::TargetTransformInfo::PR_None, std::nullopt, Ctx.CostKind);
+ } else {
+ ExtRedCost = Ctx.TTI.getExtendedReductionCost(
+ Opcode, ExtOpc == Instruction::CastOps::ZExt, RedTy, SrcVecTy,
+ Red->getFastMathFlags(), CostKind);
+ }
return ExtRedCost.isValid() && ExtRedCost < ExtCost + RedCost;
},
Range);
@@ -3541,8 +3564,7 @@ tryToMatchAndCreateExtendedReduction(VPReductionRecipe *Red, VPCostContext &Ctx,
if (match(VecOp, m_ZExtOrSExt(m_VPValue(A))) &&
IsExtendedRedValidAndClampRange(
RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind()),
- cast<VPWidenCastRecipe>(VecOp)->getOpcode() ==
- Instruction::CastOps::ZExt,
+ cast<VPWidenCastRecipe>(VecOp)->getOpcode(),
Ctx.Types.inferScalarType(A)))
return new VPExpressionRecipe(cast<VPWidenCastRecipe>(VecOp), Red);
@@ -3560,6 +3582,8 @@ tryToMatchAndCreateExtendedReduction(VPReductionRecipe *Red, VPCostContext &Ctx,
static VPExpressionRecipe *
tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
VPCostContext &Ctx, VFRange &Range) {
+ bool IsPartialReduction = isa<VPPartialReductionRecipe>(Red);
+
unsigned Opcode = RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind());
if (Opcode != Instruction::Add && Opcode != Instruction::Sub)
return nullptr;
@@ -3568,16 +3592,41 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
// Clamp the range if using multiply-accumulate-reduction is profitable.
auto IsMulAccValidAndClampRange =
- [&](bool isZExt, VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0,
- VPWidenCastRecipe *Ext1, VPWidenCastRecipe *OuterExt) -> bool {
+ [&](VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0, VPWidenCastRecipe *Ext1,
+ VPWidenCastRecipe *OuterExt) -> bool {
return LoopVectorizationPlanner::getDecisionAndClampRange(
[&](ElementCount VF) {
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
Type *SrcTy =
Ext0 ? Ctx.Types.inferScalarType(Ext0->getOperand(0)) : RedTy;
- auto *SrcVecTy = cast<VectorType>(toVectorTy(SrcTy, VF));
- InstructionCost MulAccCost = Ctx.TTI.getMulAccReductionCost(
- isZExt, Opcode, RedTy, SrcVecTy, CostKind);
+ InstructionCost MulAccCost;
+
+ if (IsPartialReduction) {
+ Type *SrcTy2 =
+ Ext1 ? Ctx.Types.inferScalarType(Ext1->getOperand(0)) : nullptr;
+ // FIXME: Move partial reduction creation, costing and clamping
+ // here from LoopVectorize.cpp.
+ MulAccCost = Ctx.TTI.getPartialReductionCost(
+ Opcode, SrcTy, SrcTy2, RedTy, VF,
+ Ext0 ? TargetTransformInfo::getPartialReductionExtendKind(
+ Ext0->getOpcode())
+ : TargetTransformInfo::PR_None,
+ Ext1 ? TargetTransformInfo::getPartialReductionExtendKind(
+ Ext1->getOpcode())
+ : TargetTransformInfo::PR_None,
+ Mul->getOpcode(), CostKind);
+ } else {
+ // Only partial reductions support mixed extends at the moment.
+ if (Ext0 && Ext1 && Ext0->getOpcode() != Ext1->getOpcode())
+ return false;
+
+ bool IsZExt =
+ !Ext0 || Ext0->getOpcode() == Instruction::CastOps::ZExt;
+ auto *SrcVecTy = cast<VectorType>(toVectorTy(SrcTy, VF));
+ MulAccCost = Ctx.TTI.getMulAccReductionCost(IsZExt, Opcode, RedTy,
+ SrcVecTy, CostKind);
+ }
+
InstructionCost MulCost = Mul->computeCost(VF, Ctx);
InstructionCost RedCost = Red->computeCost(VF, Ctx);
InstructionCost ExtCost = 0;
@@ -3611,14 +3660,10 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe());
auto *Mul = cast<VPWidenRecipe>(VecOp->getDefiningRecipe());
- // Match reduce.add(mul(ext, ext)).
- if (RecipeA && RecipeB &&
- (RecipeA->getOpcode() == RecipeB->getOpcode() || A == B) &&
- match(RecipeA, m_ZExtOrSExt(m_VPValue())) &&
+ // Match reduce.add/sub(mul(ext, ext)).
+ if (RecipeA && RecipeB && match(RecipeA, m_ZExtOrSExt(m_VPValue())) &&
match(RecipeB, m_ZExtOrSExt(m_VPValue())) &&
- IsMulAccValidAndClampRange(RecipeA->getOpcode() ==
- Instruction::CastOps::ZExt,
- Mul, RecipeA, RecipeB, nullptr)) {
+ IsMulAccValidAndClampRange(Mul, RecipeA, RecipeB, nullptr)) {
if (Sub)
return new VPExpressionRecipe(RecipeA, RecipeB, Mul,
cast<VPWidenRecipe>(Sub), Red);
@@ -3626,8 +3671,7 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
}
// Match reduce.add(mul).
// TODO: Add an expression type for this variant with a negated mul
- if (!Sub &&
- IsMulAccValidAndClampRange(true, Mul, nullptr, nullptr, nullptr))
+ if (!Sub && IsMulAccValidAndClampRange(Mul, nullptr, nullptr, nullptr))
return new VPExpressionRecipe(Mul, Red);
}
// TODO: Add an expression type for negated versions of other expression
@@ -3647,9 +3691,7 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
cast<VPWidenCastRecipe>(Mul->getOperand(1)->getDefiningRecipe());
if ((Ext->getOpcode() == Ext0->getOpcode() || Ext0 == Ext1) &&
Ext0->getOpcode() == Ext1->getOpcode() &&
- IsMulAccValidAndClampRange(Ext0->getOpcode() ==
- Instruction::CastOps::ZExt,
- Mul, Ext0, Ext1, Ext)) {
+ IsMulAccValidAndClampRange(Mul, Ext0, Ext1, Ext) && Mul->hasOneUse()) {
auto *NewExt0 = new VPWidenCastRecipe(
Ext0->getOpcode(), Ext0->getOperand(0), Ext->getResultType(), *Ext0,
*Ext0, Ext0->getDebugLoc());
@@ -3956,9 +3998,6 @@ void VPlanTransforms::materializeVFAndVFxUF(VPlan &Plan, VPBasicBlock *VectorPH,
// used.
// TODO: Assert that they aren't used.
- VPValue *UF = Plan.getOrAddLiveIn(ConstantInt::get(TCTy, Plan.getUF()));
- Plan.getSymbolicUF().replaceAllUsesWith(UF);
-
// If there are no users of the runtime VF, compute VFxUF by constant folding
// the multiplication of VF and UF.
if (VF.getNumUsers() == 0) {
@@ -3978,6 +4017,7 @@ void VPlanTransforms::materializeVFAndVFxUF(VPlan &Plan, VPBasicBlock *VectorPH,
}
VF.replaceAllUsesWith(RuntimeVF);
+ VPValue *UF = Plan.getOrAddLiveIn(ConstantInt::get(TCTy, Plan.getUF()));
VPValue *MulByUF = Builder.createNaryOp(Instruction::Mul, {RuntimeVF, UF});
VFxUF.replaceAllUsesWith(MulByUF);
}
@@ -4045,14 +4085,14 @@ static bool canNarrowLoad(VPWidenRecipe *WideMember0, unsigned OpIdx,
return false;
}
-/// Returns VF from \p VFs if \p IR is a full interleave group with factor and
-/// number of members both equal to VF. The interleave group must also access
-/// the full vector width.
-static std::optional<ElementCount> isConsecutiveInterleaveGroup(
- VPInterleaveRecipe *InterleaveR, ArrayRef<ElementCount> VFs,
- VPTypeAnalysis &TypeInfo, const TargetTransformInfo &TTI) {
+/// Returns true if \p IR is a full interleave group with factor and number of
+/// members both equal to \p VF. The interleave group must also access the full
+/// vector width \p VectorRegWidth.
+static bool isConsecutiveInterleaveGroup(VPInterleaveRecipe *InterleaveR,
+ unsigned VF, VPTypeAnalysis &TypeInfo,
+ unsigned VectorRegWidth) {
if (!InterleaveR || InterleaveR->getMask())
- return std::nullopt;
+ return false;
Type *GroupElementTy = nullptr;
if (InterleaveR->getStoredValues().empty()) {
@@ -4061,7 +4101,7 @@ static std::optional<ElementCount> isConsecutiveInterleaveGroup(
[&TypeInfo, GroupElementTy](VPValue *Op) {
return TypeInfo.inferScalarType(Op) == GroupElementTy;
}))
- return std::nullopt;
+ return false;
} else {
GroupElementTy =
TypeInfo.inferScalarType(InterleaveR->getStoredValues()[0]);
@@ -4069,27 +4109,13 @@ static std::optional<ElementCount> isConsecutiveInterleaveGroup(
[&TypeInfo, GroupElementTy](VPValue *Op) {
return TypeInfo.inferScalarType(Op) == GroupElementTy;
}))
- return std::nullopt;
+ return false;
}
- auto GetVectorWidthForVF = [&TTI](ElementCount VF) {
- TypeSize Size = TTI.getRegisterBitWidth(
- VF.isFixed() ? TargetTransformInfo::RGK_FixedWidthVector
- : TargetTransformInfo::RGK_ScalableVector);
- assert(Size.isScalable() == VF.isScalable() &&
- "if Size is scalable, VF must to and vice versa");
- return Size.getKnownMinValue();
- };
-
- for (ElementCount VF : VFs) {
- unsigned MinVal = VF.getKnownMinValue();
- unsigned GroupSize = GroupElementTy->getScalarSizeInBits() * MinVal;
- auto IG = InterleaveR->getInterleaveGroup();
- if (IG->getFactor() == MinVal && IG->getNumMembers() == MinVal &&
- GroupSize == GetVectorWidthForVF(VF))
- return {VF};
- }
- return std::nullopt;
+ unsigned GroupSize = GroupElementTy->getScalarSizeInBits() * VF;
+ auto IG = InterleaveR->getInterleaveGroup();
+ return IG->getFactor() == VF && IG->getNumMembers() == VF &&
+ GroupSize == VectorRegWidth;
}
/// Returns true if \p VPValue is a narrow VPValue.
@@ -4100,18 +4126,16 @@ static bool isAlreadyNarrow(VPValue *VPV) {
return RepR && RepR->isSingleScalar();
}
-std::unique_ptr<VPlan>
-VPlanTransforms::narrowInterleaveGroups(VPlan &Plan,
- const TargetTransformInfo &TTI) {
- using namespace llvm::VPlanPatternMatch;
+void VPlanTransforms::narrowInterleaveGroups(VPlan &Plan, ElementCount VF,
+ unsigned VectorRegWidth) {
VPRegionBlock *VectorLoop = Plan.getVectorLoopRegion();
-
- if (!VectorLoop)
- return nullptr;
+ if (!VectorLoop || VectorLoop->getEntry()->getNumSuccessors() != 0)
+ return;
VPTypeAnalysis TypeInfo(Plan);
+
+ unsigned VFMinVal = VF.getKnownMinValue();
SmallVector<VPInterleaveRecipe *> StoreGroups;
- std::optional<ElementCount> VFToOptimize;
for (auto &R : *VectorLoop->getEntryBasicBlock()) {
if (isa<VPCanonicalIVPHIRecipe>(&R) || match(&R, m_BranchOnCount()))
continue;
@@ -4125,33 +4149,30 @@ VPlanTransforms::narrowInterleaveGroups(VPlan &Plan,
// * recipes writing to memory except interleave groups
// Only support plans with a canonical induction phi.
if (R.isPhi())
- return nullptr;
+ return;
auto *InterleaveR = dyn_cast<VPInterleaveRecipe>(&R);
if (R.mayWriteToMemory() && !InterleaveR)
- return nullptr;
+ return;
+
+ // Do not narrow interleave groups if there are VectorPointer recipes and
+ // the plan was unrolled. The recipe implicitly uses VF from
+ // VPTransformState.
+ // TODO: Remove restriction once the VF for the VectorPointer offset is
+ // modeled explicitly as operand.
+ if (isa<VPVectorPointerRecipe>(&R) && Plan.getUF() > 1)
+ return;
// All other ops are allowed, but we reject uses that cannot be converted
// when checking all allowed consumers (store interleave groups) below.
if (!InterleaveR)
continue;
- // Try to find a single VF, where all interleave groups are consecutive and
- // saturate the full vector width. If we already have a candidate VF, check
- // if it is applicable for the current InterleaveR, otherwise look for a
- // suitable VF across the Plans VFs.
- //
- if (VFToOptimize) {
- if (!isConsecutiveInterleaveGroup(InterleaveR, {*VFToOptimize}, TypeInfo,
- TTI))
- return nullptr;
- } else {
- if (auto VF = isConsecutiveInterleaveGroup(
- InterleaveR, to_vector(Plan.vectorFactors()), TypeInfo, TTI))
- VFToOptimize = *VF;
- else
- return nullptr;
- }
+ // Bail out on non-consecutive interleave groups.
+ if (!isConsecutiveInterleaveGroup(InterleaveR, VFMinVal, TypeInfo,
+ VectorRegWidth))
+ return;
+
// Skip read interleave groups.
if (InterleaveR->getStoredValues().empty())
continue;
@@ -4185,34 +4206,24 @@ VPlanTransforms::narrowInterleaveGroups(VPlan &Plan,
auto *WideMember0 = dyn_cast_or_null<VPWidenRecipe>(
InterleaveR->getStoredValues()[0]->getDefiningRecipe());
if (!WideMember0)
- return nullptr;
+ return;
for (const auto &[I, V] : enumerate(InterleaveR->getStoredValues())) {
auto *R = dyn_cast_or_null<VPWidenRecipe>(V->getDefiningRecipe());
if (!R || R->getOpcode() != WideMember0->getOpcode() ||
R->getNumOperands() > 2)
- return nullptr;
+ return;
if (any_of(enumerate(R->operands()),
[WideMember0, Idx = I](const auto &P) {
const auto &[OpIdx, OpV] = P;
return !canNarrowLoad(WideMember0, OpIdx, OpV, Idx);
}))
- return nullptr;
+ return;
}
StoreGroups.push_back(InterleaveR);
}
if (StoreGroups.empty())
- return nullptr;
-
- // All interleave groups in Plan can be narrowed for VFToOptimize. Split the
- // original Plan into 2: a) a new clone which contains all VFs of Plan, except
- // VFToOptimize, and b) the original Plan with VFToOptimize as single VF.
- std::unique_ptr<VPlan> NewPlan;
- if (size(Plan.vectorFactors()) != 1) {
- NewPlan = std::unique_ptr<VPlan>(Plan.duplicate());
- Plan.setVF(*VFToOptimize);
- NewPlan->removeVF(*VFToOptimize);
- }
+ return;
// Convert InterleaveGroup \p R to a single VPWidenLoadRecipe.
SmallPtrSet<VPValue *, 4> NarrowedOps;
@@ -4283,8 +4294,9 @@ VPlanTransforms::narrowInterleaveGroups(VPlan &Plan,
auto *Inc = cast<VPInstruction>(CanIV->getBackedgeValue());
VPBuilder PHBuilder(Plan.getVectorPreheader());
- VPValue *UF = &Plan.getSymbolicUF();
- if (VFToOptimize->isScalable()) {
+ VPValue *UF = Plan.getOrAddLiveIn(
+ ConstantInt::get(CanIV->getScalarType(), 1 * Plan.getUF()));
+ if (VF.isScalable()) {
VPValue *VScale = PHBuilder.createElementCount(
CanIV->getScalarType(), ElementCount::getScalable(1));
VPValue *VScaleUF = PHBuilder.createNaryOp(Instruction::Mul, {VScale, UF});
@@ -4296,10 +4308,6 @@ VPlanTransforms::narrowInterleaveGroups(VPlan &Plan,
Plan.getOrAddLiveIn(ConstantInt::get(CanIV->getScalarType(), 1)));
}
removeDeadRecipes(Plan);
- assert(none_of(*VectorLoop->getEntryBasicBlock(),
- IsaPred<VPVectorPointerRecipe>) &&
- "All VPVectorPointerRecipes should have been removed");
- return NewPlan;
}
/// Add branch weight metadata, if the \p Plan's middle block is terminated by a
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h
index ca8d956..b28559b 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h
@@ -341,20 +341,14 @@ struct VPlanTransforms {
static DenseMap<const SCEV *, Value *> expandSCEVs(VPlan &Plan,
ScalarEvolution &SE);
- /// Try to find a single VF among \p Plan's VFs for which all interleave
- /// groups (with known minimum VF elements) can be replaced by wide loads and
- /// stores processing VF elements, if all transformed interleave groups access
- /// the full vector width (checked via the maximum vector register width). If
- /// the transformation can be applied, the original \p Plan will be split in
- /// 2:
- /// 1. The original Plan with the single VF containing the optimized recipes
- /// using wide loads instead of interleave groups.
- /// 2. A new clone which contains all VFs of Plan except the optimized VF.
- ///
- /// This effectively is a very simple form of loop-aware SLP, where we use
- /// interleave groups to identify candidates.
- static std::unique_ptr<VPlan>
- narrowInterleaveGroups(VPlan &Plan, const TargetTransformInfo &TTI);
+ /// Try to convert a plan with interleave groups with VF elements to a plan
+ /// with the interleave groups replaced by wide loads and stores processing VF
+ /// elements, if all transformed interleave groups access the full vector
+ /// width (checked via \o VectorRegWidth). This effectively is a very simple
+ /// form of loop-aware SLP, where we use interleave groups to identify
+ /// candidates.
+ static void narrowInterleaveGroups(VPlan &Plan, ElementCount VF,
+ unsigned VectorRegWidth);
/// Predicate and linearize the control-flow in the only loop region of
/// \p Plan. If \p FoldTail is true, create a mask guarding the loop
diff --git a/llvm/lib/Transforms/Vectorize/VPlanUtils.cpp b/llvm/lib/Transforms/Vectorize/VPlanUtils.cpp
index 32e4b88..fe66f13 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanUtils.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanUtils.cpp
@@ -32,8 +32,6 @@ bool vputils::onlyScalarValuesUsed(const VPValue *Def) {
}
VPValue *vputils::getOrCreateVPValueForSCEVExpr(VPlan &Plan, const SCEV *Expr) {
- if (auto *Expanded = Plan.getSCEVExpansion(Expr))
- return Expanded;
VPValue *Expanded = nullptr;
if (auto *E = dyn_cast<SCEVConstant>(Expr))
Expanded = Plan.getOrAddLiveIn(E->getValue());
@@ -50,7 +48,6 @@ VPValue *vputils::getOrCreateVPValueForSCEVExpr(VPlan &Plan, const SCEV *Expr) {
Plan.getEntry()->appendRecipe(Expanded->getDefiningRecipe());
}
}
- Plan.addSCEVExpansion(Expr, Expanded);
return Expanded;
}
@@ -151,6 +148,8 @@ unsigned vputils::getVFScaleFactor(VPRecipeBase *R) {
return RR->getVFScaleFactor();
if (auto *RR = dyn_cast<VPPartialReductionRecipe>(R))
return RR->getVFScaleFactor();
+ if (auto *ER = dyn_cast<VPExpressionRecipe>(R))
+ return ER->getVFScaleFactor();
assert(
(!isa<VPInstruction>(R) || cast<VPInstruction>(R)->getOpcode() !=
VPInstruction::ReductionStartVector) &&