diff options
author | Philip Reames <preames@rivosinc.com> | 2024-08-30 17:13:51 -0700 |
---|---|---|
committer | Philip Reames <listmail@philipreames.com> | 2024-09-04 08:23:21 -0700 |
commit | 3d9abfc9f841b13825e3d03cfba272f5eeab9a3b (patch) | |
tree | e8b7f6074af959d0f2591be744f39d63bbccf89f /llvm/lib/Transforms/Utils/LoopUtils.cpp | |
parent | c1a8283fcc735b1567c49bb6cd485f9e71a12cc4 (diff) | |
download | llvm-3d9abfc9f841b13825e3d03cfba272f5eeab9a3b.zip llvm-3d9abfc9f841b13825e3d03cfba272f5eeab9a3b.tar.gz llvm-3d9abfc9f841b13825e3d03cfba272f5eeab9a3b.tar.bz2 |
Consolidate all IR logic for getting the identity value of a reduction [nfc]
This change merges the three different places (at the IR layer) for
finding the identity value of a reduction into a single copy. This
depends on several prior commits which fix ommissions and bugs in
the distinct copies, but this patch itself should be fully
non-functional.
As the new comments and naming try to make clear, the identity value
is a property of the @llvm.vector.reduce.* intrinsic, not of e.g.
the recurrence descriptor. (We still provide an interface for
clients using recurrence descriptors, but the implementation simply
translates to the intrinsic which each corresponds to.)
As a note, the getIntrinsicIdentity API does not support fminnum/fmaxnum
or fminimum/fmaximum which is why we still need manual logic (but at
least only one copy of manual logic) for those cases.
Diffstat (limited to 'llvm/lib/Transforms/Utils/LoopUtils.cpp')
-rw-r--r-- | llvm/lib/Transforms/Utils/LoopUtils.cpp | 59 |
1 files changed, 53 insertions, 6 deletions
diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp index 5591294..9a4289e 100644 --- a/llvm/lib/Transforms/Utils/LoopUtils.cpp +++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp @@ -1207,14 +1207,62 @@ Value *llvm::createAnyOfReduction(IRBuilderBase &Builder, Value *Src, return Builder.CreateSelect(AnyOf, NewVal, InitVal, "rdx.select"); } +Value *llvm::getReductionIdentity(Intrinsic::ID RdxID, Type *Ty, + FastMathFlags Flags) { + bool Negative = false; + switch (RdxID) { + default: + llvm_unreachable("Expecting a reduction intrinsic"); + case Intrinsic::vector_reduce_add: + case Intrinsic::vector_reduce_mul: + case Intrinsic::vector_reduce_or: + case Intrinsic::vector_reduce_xor: + case Intrinsic::vector_reduce_and: + case Intrinsic::vector_reduce_fadd: + case Intrinsic::vector_reduce_fmul: { + unsigned Opc = getArithmeticReductionInstruction(RdxID); + return ConstantExpr::getBinOpIdentity(Opc, Ty, false, + Flags.noSignedZeros()); + } + case Intrinsic::vector_reduce_umax: + case Intrinsic::vector_reduce_umin: + case Intrinsic::vector_reduce_smin: + case Intrinsic::vector_reduce_smax: { + Intrinsic::ID ScalarID = getMinMaxReductionIntrinsicOp(RdxID); + return ConstantExpr::getIntrinsicIdentity(ScalarID, Ty); + } + case Intrinsic::vector_reduce_fmax: + case Intrinsic::vector_reduce_fmaximum: + Negative = true; + [[fallthrough]]; + case Intrinsic::vector_reduce_fmin: + case Intrinsic::vector_reduce_fminimum: { + bool PropagatesNaN = RdxID == Intrinsic::vector_reduce_fminimum || + RdxID == Intrinsic::vector_reduce_fmaximum; + const fltSemantics &Semantics = Ty->getFltSemantics(); + return (!Flags.noNaNs() && !PropagatesNaN) + ? ConstantFP::getQNaN(Ty, Negative) + : !Flags.noInfs() + ? ConstantFP::getInfinity(Ty, Negative) + : ConstantFP::get(Ty, APFloat::getLargest(Semantics, Negative)); + } + } +} + +Value *llvm::getRecurrenceIdentity(RecurKind K, Type *Tp, FastMathFlags FMF) { + assert((!(K == RecurKind::FMin || K == RecurKind::FMax) || + (FMF.noNaNs() && FMF.noSignedZeros())) && + "nnan, nsz is expected to be set for FP min/max reduction."); + Intrinsic::ID RdxID = getReductionIntrinsicID(K); + return getReductionIdentity(RdxID, Tp, FMF); +} + Value *llvm::createSimpleReduction(IRBuilderBase &Builder, Value *Src, RecurKind RdxKind) { auto *SrcVecEltTy = cast<VectorType>(Src->getType())->getElementType(); auto getIdentity = [&]() { - Intrinsic::ID ID = getReductionIntrinsicID(RdxKind); - unsigned Opc = getArithmeticReductionInstruction(ID); - bool NSZ = Builder.getFastMathFlags().noSignedZeros(); - return ConstantExpr::getBinOpIdentity(Opc, SrcVecEltTy, false, NSZ); + return getRecurrenceIdentity(RdxKind, SrcVecEltTy, + Builder.getFastMathFlags()); }; switch (RdxKind) { case RecurKind::Add: @@ -1249,8 +1297,7 @@ Value *llvm::createSimpleReduction(VectorBuilder &VBuilder, Value *Src, Intrinsic::ID Id = getReductionIntrinsicID(Kind); auto *SrcTy = cast<VectorType>(Src->getType()); Type *SrcEltTy = SrcTy->getElementType(); - Value *Iden = - Desc.getRecurrenceIdentity(Kind, SrcEltTy, Desc.getFastMathFlags()); + Value *Iden = getRecurrenceIdentity(Kind, SrcEltTy, Desc.getFastMathFlags()); Value *Ops[] = {Iden, Src}; return VBuilder.createSimpleReduction(Id, SrcTy, Ops); } |