aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/Utils/LoopUtils.cpp
diff options
context:
space:
mode:
authorDavid Sherwood <david.sherwood@arm.com>2021-08-04 08:10:51 +0100
committerDavid Sherwood <david.sherwood@arm.com>2021-10-11 09:41:38 +0100
commit26b7d9d62275e782da190d1717849c49588a4b0c (patch)
tree958c101302490a696099d74569aa41c0f636d41f /llvm/lib/Transforms/Utils/LoopUtils.cpp
parentcd1bd95d8707371da0e4f75cd01669c427466931 (diff)
downloadllvm-26b7d9d62275e782da190d1717849c49588a4b0c.zip
llvm-26b7d9d62275e782da190d1717849c49588a4b0c.tar.gz
llvm-26b7d9d62275e782da190d1717849c49588a4b0c.tar.bz2
[LoopVectorize] Permit vectorisation of more select(cmp(), X, Y) reduction patterns
This patch adds further support for vectorisation of loops that involve selecting an integer value based on a previous comparison. Consider the following C++ loop: int r = a; for (int i = 0; i < n; i++) { if (src[i] > 3) { r = b; } src[i] += 2; } We should be able to vectorise this loop because all we are doing is selecting between two states - 'a' and 'b' - both of which are loop invariant. This just involves building a vector of values that contain either 'a' or 'b', where the final reduced value will be 'b' if any lane contains 'b'. The IR generated by clang typically looks like this: %phi = phi i32 [ %a, %entry ], [ %phi.update, %for.body ] ... %pred = icmp ugt i32 %val, i32 3 %phi.update = select i1 %pred, i32 %b, i32 %phi We already detect min/max patterns, which also involve a select + cmp. However, with the min/max patterns we are selecting loaded values (and hence loop variant) in the loop. In addition we only support certain cmp predicates. This patch adds a new pattern matching function (isSelectCmpPattern) and new RecurKind enums - SelectICmp & SelectFCmp. We only support selecting values that are integer and loop invariant, however we can support any kind of compare - integer or float. Tests have been added here: Transforms/LoopVectorize/AArch64/sve-select-cmp.ll Transforms/LoopVectorize/select-cmp-predicated.ll Transforms/LoopVectorize/select-cmp.ll Differential Revision: https://reviews.llvm.org/D108136
Diffstat (limited to 'llvm/lib/Transforms/Utils/LoopUtils.cpp')
-rw-r--r--llvm/lib/Transforms/Utils/LoopUtils.cpp60
1 files changed, 57 insertions, 3 deletions
diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp
index 78d756a..7896d55 100644
--- a/llvm/lib/Transforms/Utils/LoopUtils.cpp
+++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp
@@ -908,6 +908,15 @@ CmpInst::Predicate llvm::getMinMaxReductionPredicate(RecurKind RK) {
}
}
+Value *llvm::createSelectCmpOp(IRBuilderBase &Builder, Value *StartVal,
+ RecurKind RK, Value *Left, Value *Right) {
+ if (auto VTy = dyn_cast<VectorType>(Left->getType()))
+ StartVal = Builder.CreateVectorSplat(VTy->getElementCount(), StartVal);
+ Value *Cmp =
+ Builder.CreateCmp(CmpInst::ICMP_NE, Left, StartVal, "rdx.select.cmp");
+ return Builder.CreateSelect(Cmp, Left, Right, "rdx.select");
+}
+
Value *llvm::createMinMaxOp(IRBuilderBase &Builder, RecurKind RK, Value *Left,
Value *Right) {
CmpInst::Predicate Pred = getMinMaxReductionPredicate(RK);
@@ -988,6 +997,46 @@ Value *llvm::getShuffleReduction(IRBuilderBase &Builder, Value *Src,
return Builder.CreateExtractElement(TmpVec, Builder.getInt32(0));
}
+Value *llvm::createSelectCmpTargetReduction(IRBuilderBase &Builder,
+ const TargetTransformInfo *TTI,
+ Value *Src,
+ const RecurrenceDescriptor &Desc,
+ PHINode *OrigPhi) {
+ assert(RecurrenceDescriptor::isSelectCmpRecurrenceKind(
+ Desc.getRecurrenceKind()) &&
+ "Unexpected reduction kind");
+ Value *InitVal = Desc.getRecurrenceStartValue();
+ Value *NewVal = nullptr;
+
+ // First use the original phi to determine the new value we're trying to
+ // select from in the loop.
+ SelectInst *SI = nullptr;
+ for (auto *U : OrigPhi->users()) {
+ if ((SI = dyn_cast<SelectInst>(U)))
+ break;
+ }
+ assert(SI && "One user of the original phi should be a select");
+
+ if (SI->getTrueValue() == OrigPhi)
+ NewVal = SI->getFalseValue();
+ else {
+ assert(SI->getFalseValue() == OrigPhi &&
+ "At least one input to the select should be the original Phi");
+ NewVal = SI->getTrueValue();
+ }
+
+ // Create a splat vector with the new value and compare this to the vector
+ // we want to reduce.
+ ElementCount EC = cast<VectorType>(Src->getType())->getElementCount();
+ Value *Right = Builder.CreateVectorSplat(EC, InitVal);
+ Value *Cmp =
+ Builder.CreateCmp(CmpInst::ICMP_NE, Src, Right, "rdx.select.cmp");
+
+ // If any predicate is true it means that we want to select the new value.
+ Cmp = Builder.CreateOrReduce(Cmp);
+ return Builder.CreateSelect(Cmp, NewVal, InitVal, "rdx.select");
+}
+
Value *llvm::createSimpleTargetReduction(IRBuilderBase &Builder,
const TargetTransformInfo *TTI,
Value *Src, RecurKind RdxKind,
@@ -1028,14 +1077,19 @@ Value *llvm::createSimpleTargetReduction(IRBuilderBase &Builder,
Value *llvm::createTargetReduction(IRBuilderBase &B,
const TargetTransformInfo *TTI,
- const RecurrenceDescriptor &Desc,
- Value *Src) {
+ const RecurrenceDescriptor &Desc, Value *Src,
+ PHINode *OrigPhi) {
// TODO: Support in-order reductions based on the recurrence descriptor.
// All ops in the reduction inherit fast-math-flags from the recurrence
// descriptor.
IRBuilderBase::FastMathFlagGuard FMFGuard(B);
B.setFastMathFlags(Desc.getFastMathFlags());
- return createSimpleTargetReduction(B, TTI, Src, Desc.getRecurrenceKind());
+
+ RecurKind RK = Desc.getRecurrenceKind();
+ if (RecurrenceDescriptor::isSelectCmpRecurrenceKind(RK))
+ return createSelectCmpTargetReduction(B, TTI, Src, Desc, OrigPhi);
+
+ return createSimpleTargetReduction(B, TTI, Src, RK);
}
Value *llvm::createOrderedReduction(IRBuilderBase &B,