diff options
Diffstat (limited to 'llvm/lib/Transforms/Utils/LoopUtils.cpp')
-rw-r--r-- | llvm/lib/Transforms/Utils/LoopUtils.cpp | 60 |
1 files changed, 57 insertions, 3 deletions
diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp index 78d756a..7896d55 100644 --- a/llvm/lib/Transforms/Utils/LoopUtils.cpp +++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp @@ -908,6 +908,15 @@ CmpInst::Predicate llvm::getMinMaxReductionPredicate(RecurKind RK) { } } +Value *llvm::createSelectCmpOp(IRBuilderBase &Builder, Value *StartVal, + RecurKind RK, Value *Left, Value *Right) { + if (auto VTy = dyn_cast<VectorType>(Left->getType())) + StartVal = Builder.CreateVectorSplat(VTy->getElementCount(), StartVal); + Value *Cmp = + Builder.CreateCmp(CmpInst::ICMP_NE, Left, StartVal, "rdx.select.cmp"); + return Builder.CreateSelect(Cmp, Left, Right, "rdx.select"); +} + Value *llvm::createMinMaxOp(IRBuilderBase &Builder, RecurKind RK, Value *Left, Value *Right) { CmpInst::Predicate Pred = getMinMaxReductionPredicate(RK); @@ -988,6 +997,46 @@ Value *llvm::getShuffleReduction(IRBuilderBase &Builder, Value *Src, return Builder.CreateExtractElement(TmpVec, Builder.getInt32(0)); } +Value *llvm::createSelectCmpTargetReduction(IRBuilderBase &Builder, + const TargetTransformInfo *TTI, + Value *Src, + const RecurrenceDescriptor &Desc, + PHINode *OrigPhi) { + assert(RecurrenceDescriptor::isSelectCmpRecurrenceKind( + Desc.getRecurrenceKind()) && + "Unexpected reduction kind"); + Value *InitVal = Desc.getRecurrenceStartValue(); + Value *NewVal = nullptr; + + // First use the original phi to determine the new value we're trying to + // select from in the loop. + SelectInst *SI = nullptr; + for (auto *U : OrigPhi->users()) { + if ((SI = dyn_cast<SelectInst>(U))) + break; + } + assert(SI && "One user of the original phi should be a select"); + + if (SI->getTrueValue() == OrigPhi) + NewVal = SI->getFalseValue(); + else { + assert(SI->getFalseValue() == OrigPhi && + "At least one input to the select should be the original Phi"); + NewVal = SI->getTrueValue(); + } + + // Create a splat vector with the new value and compare this to the vector + // we want to reduce. + ElementCount EC = cast<VectorType>(Src->getType())->getElementCount(); + Value *Right = Builder.CreateVectorSplat(EC, InitVal); + Value *Cmp = + Builder.CreateCmp(CmpInst::ICMP_NE, Src, Right, "rdx.select.cmp"); + + // If any predicate is true it means that we want to select the new value. + Cmp = Builder.CreateOrReduce(Cmp); + return Builder.CreateSelect(Cmp, NewVal, InitVal, "rdx.select"); +} + Value *llvm::createSimpleTargetReduction(IRBuilderBase &Builder, const TargetTransformInfo *TTI, Value *Src, RecurKind RdxKind, @@ -1028,14 +1077,19 @@ Value *llvm::createSimpleTargetReduction(IRBuilderBase &Builder, Value *llvm::createTargetReduction(IRBuilderBase &B, const TargetTransformInfo *TTI, - const RecurrenceDescriptor &Desc, - Value *Src) { + const RecurrenceDescriptor &Desc, Value *Src, + PHINode *OrigPhi) { // TODO: Support in-order reductions based on the recurrence descriptor. // All ops in the reduction inherit fast-math-flags from the recurrence // descriptor. IRBuilderBase::FastMathFlagGuard FMFGuard(B); B.setFastMathFlags(Desc.getFastMathFlags()); - return createSimpleTargetReduction(B, TTI, Src, Desc.getRecurrenceKind()); + + RecurKind RK = Desc.getRecurrenceKind(); + if (RecurrenceDescriptor::isSelectCmpRecurrenceKind(RK)) + return createSelectCmpTargetReduction(B, TTI, Src, Desc, OrigPhi); + + return createSimpleTargetReduction(B, TTI, Src, RK); } Value *llvm::createOrderedReduction(IRBuilderBase &B, |