diff options
Diffstat (limited to 'llvm/lib/Transforms/Utils/LoopUtils.cpp')
-rw-r--r-- | llvm/lib/Transforms/Utils/LoopUtils.cpp | 59 |
1 files changed, 53 insertions, 6 deletions
diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp index 5591294..9a4289e 100644 --- a/llvm/lib/Transforms/Utils/LoopUtils.cpp +++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp @@ -1207,14 +1207,62 @@ Value *llvm::createAnyOfReduction(IRBuilderBase &Builder, Value *Src, return Builder.CreateSelect(AnyOf, NewVal, InitVal, "rdx.select"); } +Value *llvm::getReductionIdentity(Intrinsic::ID RdxID, Type *Ty, + FastMathFlags Flags) { + bool Negative = false; + switch (RdxID) { + default: + llvm_unreachable("Expecting a reduction intrinsic"); + case Intrinsic::vector_reduce_add: + case Intrinsic::vector_reduce_mul: + case Intrinsic::vector_reduce_or: + case Intrinsic::vector_reduce_xor: + case Intrinsic::vector_reduce_and: + case Intrinsic::vector_reduce_fadd: + case Intrinsic::vector_reduce_fmul: { + unsigned Opc = getArithmeticReductionInstruction(RdxID); + return ConstantExpr::getBinOpIdentity(Opc, Ty, false, + Flags.noSignedZeros()); + } + case Intrinsic::vector_reduce_umax: + case Intrinsic::vector_reduce_umin: + case Intrinsic::vector_reduce_smin: + case Intrinsic::vector_reduce_smax: { + Intrinsic::ID ScalarID = getMinMaxReductionIntrinsicOp(RdxID); + return ConstantExpr::getIntrinsicIdentity(ScalarID, Ty); + } + case Intrinsic::vector_reduce_fmax: + case Intrinsic::vector_reduce_fmaximum: + Negative = true; + [[fallthrough]]; + case Intrinsic::vector_reduce_fmin: + case Intrinsic::vector_reduce_fminimum: { + bool PropagatesNaN = RdxID == Intrinsic::vector_reduce_fminimum || + RdxID == Intrinsic::vector_reduce_fmaximum; + const fltSemantics &Semantics = Ty->getFltSemantics(); + return (!Flags.noNaNs() && !PropagatesNaN) + ? ConstantFP::getQNaN(Ty, Negative) + : !Flags.noInfs() + ? ConstantFP::getInfinity(Ty, Negative) + : ConstantFP::get(Ty, APFloat::getLargest(Semantics, Negative)); + } + } +} + +Value *llvm::getRecurrenceIdentity(RecurKind K, Type *Tp, FastMathFlags FMF) { + assert((!(K == RecurKind::FMin || K == RecurKind::FMax) || + (FMF.noNaNs() && FMF.noSignedZeros())) && + "nnan, nsz is expected to be set for FP min/max reduction."); + Intrinsic::ID RdxID = getReductionIntrinsicID(K); + return getReductionIdentity(RdxID, Tp, FMF); +} + Value *llvm::createSimpleReduction(IRBuilderBase &Builder, Value *Src, RecurKind RdxKind) { auto *SrcVecEltTy = cast<VectorType>(Src->getType())->getElementType(); auto getIdentity = [&]() { - Intrinsic::ID ID = getReductionIntrinsicID(RdxKind); - unsigned Opc = getArithmeticReductionInstruction(ID); - bool NSZ = Builder.getFastMathFlags().noSignedZeros(); - return ConstantExpr::getBinOpIdentity(Opc, SrcVecEltTy, false, NSZ); + return getRecurrenceIdentity(RdxKind, SrcVecEltTy, + Builder.getFastMathFlags()); }; switch (RdxKind) { case RecurKind::Add: @@ -1249,8 +1297,7 @@ Value *llvm::createSimpleReduction(VectorBuilder &VBuilder, Value *Src, Intrinsic::ID Id = getReductionIntrinsicID(Kind); auto *SrcTy = cast<VectorType>(Src->getType()); Type *SrcEltTy = SrcTy->getElementType(); - Value *Iden = - Desc.getRecurrenceIdentity(Kind, SrcEltTy, Desc.getFastMathFlags()); + Value *Iden = getRecurrenceIdentity(Kind, SrcEltTy, Desc.getFastMathFlags()); Value *Ops[] = {Iden, Src}; return VBuilder.createSimpleReduction(Id, SrcTy, Ops); } |