aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/Utils/LoopUtils.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Transforms/Utils/LoopUtils.cpp')
-rw-r--r--llvm/lib/Transforms/Utils/LoopUtils.cpp60
1 files changed, 57 insertions, 3 deletions
diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp
index 78d756a..7896d55 100644
--- a/llvm/lib/Transforms/Utils/LoopUtils.cpp
+++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp
@@ -908,6 +908,15 @@ CmpInst::Predicate llvm::getMinMaxReductionPredicate(RecurKind RK) {
}
}
+Value *llvm::createSelectCmpOp(IRBuilderBase &Builder, Value *StartVal,
+ RecurKind RK, Value *Left, Value *Right) {
+ if (auto VTy = dyn_cast<VectorType>(Left->getType()))
+ StartVal = Builder.CreateVectorSplat(VTy->getElementCount(), StartVal);
+ Value *Cmp =
+ Builder.CreateCmp(CmpInst::ICMP_NE, Left, StartVal, "rdx.select.cmp");
+ return Builder.CreateSelect(Cmp, Left, Right, "rdx.select");
+}
+
Value *llvm::createMinMaxOp(IRBuilderBase &Builder, RecurKind RK, Value *Left,
Value *Right) {
CmpInst::Predicate Pred = getMinMaxReductionPredicate(RK);
@@ -988,6 +997,46 @@ Value *llvm::getShuffleReduction(IRBuilderBase &Builder, Value *Src,
return Builder.CreateExtractElement(TmpVec, Builder.getInt32(0));
}
+Value *llvm::createSelectCmpTargetReduction(IRBuilderBase &Builder,
+ const TargetTransformInfo *TTI,
+ Value *Src,
+ const RecurrenceDescriptor &Desc,
+ PHINode *OrigPhi) {
+ assert(RecurrenceDescriptor::isSelectCmpRecurrenceKind(
+ Desc.getRecurrenceKind()) &&
+ "Unexpected reduction kind");
+ Value *InitVal = Desc.getRecurrenceStartValue();
+ Value *NewVal = nullptr;
+
+ // First use the original phi to determine the new value we're trying to
+ // select from in the loop.
+ SelectInst *SI = nullptr;
+ for (auto *U : OrigPhi->users()) {
+ if ((SI = dyn_cast<SelectInst>(U)))
+ break;
+ }
+ assert(SI && "One user of the original phi should be a select");
+
+ if (SI->getTrueValue() == OrigPhi)
+ NewVal = SI->getFalseValue();
+ else {
+ assert(SI->getFalseValue() == OrigPhi &&
+ "At least one input to the select should be the original Phi");
+ NewVal = SI->getTrueValue();
+ }
+
+ // Create a splat vector with the new value and compare this to the vector
+ // we want to reduce.
+ ElementCount EC = cast<VectorType>(Src->getType())->getElementCount();
+ Value *Right = Builder.CreateVectorSplat(EC, InitVal);
+ Value *Cmp =
+ Builder.CreateCmp(CmpInst::ICMP_NE, Src, Right, "rdx.select.cmp");
+
+ // If any predicate is true it means that we want to select the new value.
+ Cmp = Builder.CreateOrReduce(Cmp);
+ return Builder.CreateSelect(Cmp, NewVal, InitVal, "rdx.select");
+}
+
Value *llvm::createSimpleTargetReduction(IRBuilderBase &Builder,
const TargetTransformInfo *TTI,
Value *Src, RecurKind RdxKind,
@@ -1028,14 +1077,19 @@ Value *llvm::createSimpleTargetReduction(IRBuilderBase &Builder,
Value *llvm::createTargetReduction(IRBuilderBase &B,
const TargetTransformInfo *TTI,
- const RecurrenceDescriptor &Desc,
- Value *Src) {
+ const RecurrenceDescriptor &Desc, Value *Src,
+ PHINode *OrigPhi) {
// TODO: Support in-order reductions based on the recurrence descriptor.
// All ops in the reduction inherit fast-math-flags from the recurrence
// descriptor.
IRBuilderBase::FastMathFlagGuard FMFGuard(B);
B.setFastMathFlags(Desc.getFastMathFlags());
- return createSimpleTargetReduction(B, TTI, Src, Desc.getRecurrenceKind());
+
+ RecurKind RK = Desc.getRecurrenceKind();
+ if (RecurrenceDescriptor::isSelectCmpRecurrenceKind(RK))
+ return createSelectCmpTargetReduction(B, TTI, Src, Desc, OrigPhi);
+
+ return createSimpleTargetReduction(B, TTI, Src, RK);
}
Value *llvm::createOrderedReduction(IRBuilderBase &B,