aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/Utils/LoopUtils.cpp
diff options
context:
space:
mode:
authorPhilip Reames <preames@rivosinc.com>2024-08-30 17:13:51 -0700
committerPhilip Reames <listmail@philipreames.com>2024-09-04 08:23:21 -0700
commit3d9abfc9f841b13825e3d03cfba272f5eeab9a3b (patch)
treee8b7f6074af959d0f2591be744f39d63bbccf89f /llvm/lib/Transforms/Utils/LoopUtils.cpp
parentc1a8283fcc735b1567c49bb6cd485f9e71a12cc4 (diff)
downloadllvm-3d9abfc9f841b13825e3d03cfba272f5eeab9a3b.zip
llvm-3d9abfc9f841b13825e3d03cfba272f5eeab9a3b.tar.gz
llvm-3d9abfc9f841b13825e3d03cfba272f5eeab9a3b.tar.bz2
Consolidate all IR logic for getting the identity value of a reduction [nfc]
This change merges the three different places (at the IR layer) for finding the identity value of a reduction into a single copy. This depends on several prior commits which fix ommissions and bugs in the distinct copies, but this patch itself should be fully non-functional. As the new comments and naming try to make clear, the identity value is a property of the @llvm.vector.reduce.* intrinsic, not of e.g. the recurrence descriptor. (We still provide an interface for clients using recurrence descriptors, but the implementation simply translates to the intrinsic which each corresponds to.) As a note, the getIntrinsicIdentity API does not support fminnum/fmaxnum or fminimum/fmaximum which is why we still need manual logic (but at least only one copy of manual logic) for those cases.
Diffstat (limited to 'llvm/lib/Transforms/Utils/LoopUtils.cpp')
-rw-r--r--llvm/lib/Transforms/Utils/LoopUtils.cpp59
1 files changed, 53 insertions, 6 deletions
diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp
index 5591294..9a4289e 100644
--- a/llvm/lib/Transforms/Utils/LoopUtils.cpp
+++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp
@@ -1207,14 +1207,62 @@ Value *llvm::createAnyOfReduction(IRBuilderBase &Builder, Value *Src,
return Builder.CreateSelect(AnyOf, NewVal, InitVal, "rdx.select");
}
+Value *llvm::getReductionIdentity(Intrinsic::ID RdxID, Type *Ty,
+ FastMathFlags Flags) {
+ bool Negative = false;
+ switch (RdxID) {
+ default:
+ llvm_unreachable("Expecting a reduction intrinsic");
+ case Intrinsic::vector_reduce_add:
+ case Intrinsic::vector_reduce_mul:
+ case Intrinsic::vector_reduce_or:
+ case Intrinsic::vector_reduce_xor:
+ case Intrinsic::vector_reduce_and:
+ case Intrinsic::vector_reduce_fadd:
+ case Intrinsic::vector_reduce_fmul: {
+ unsigned Opc = getArithmeticReductionInstruction(RdxID);
+ return ConstantExpr::getBinOpIdentity(Opc, Ty, false,
+ Flags.noSignedZeros());
+ }
+ case Intrinsic::vector_reduce_umax:
+ case Intrinsic::vector_reduce_umin:
+ case Intrinsic::vector_reduce_smin:
+ case Intrinsic::vector_reduce_smax: {
+ Intrinsic::ID ScalarID = getMinMaxReductionIntrinsicOp(RdxID);
+ return ConstantExpr::getIntrinsicIdentity(ScalarID, Ty);
+ }
+ case Intrinsic::vector_reduce_fmax:
+ case Intrinsic::vector_reduce_fmaximum:
+ Negative = true;
+ [[fallthrough]];
+ case Intrinsic::vector_reduce_fmin:
+ case Intrinsic::vector_reduce_fminimum: {
+ bool PropagatesNaN = RdxID == Intrinsic::vector_reduce_fminimum ||
+ RdxID == Intrinsic::vector_reduce_fmaximum;
+ const fltSemantics &Semantics = Ty->getFltSemantics();
+ return (!Flags.noNaNs() && !PropagatesNaN)
+ ? ConstantFP::getQNaN(Ty, Negative)
+ : !Flags.noInfs()
+ ? ConstantFP::getInfinity(Ty, Negative)
+ : ConstantFP::get(Ty, APFloat::getLargest(Semantics, Negative));
+ }
+ }
+}
+
+Value *llvm::getRecurrenceIdentity(RecurKind K, Type *Tp, FastMathFlags FMF) {
+ assert((!(K == RecurKind::FMin || K == RecurKind::FMax) ||
+ (FMF.noNaNs() && FMF.noSignedZeros())) &&
+ "nnan, nsz is expected to be set for FP min/max reduction.");
+ Intrinsic::ID RdxID = getReductionIntrinsicID(K);
+ return getReductionIdentity(RdxID, Tp, FMF);
+}
+
Value *llvm::createSimpleReduction(IRBuilderBase &Builder, Value *Src,
RecurKind RdxKind) {
auto *SrcVecEltTy = cast<VectorType>(Src->getType())->getElementType();
auto getIdentity = [&]() {
- Intrinsic::ID ID = getReductionIntrinsicID(RdxKind);
- unsigned Opc = getArithmeticReductionInstruction(ID);
- bool NSZ = Builder.getFastMathFlags().noSignedZeros();
- return ConstantExpr::getBinOpIdentity(Opc, SrcVecEltTy, false, NSZ);
+ return getRecurrenceIdentity(RdxKind, SrcVecEltTy,
+ Builder.getFastMathFlags());
};
switch (RdxKind) {
case RecurKind::Add:
@@ -1249,8 +1297,7 @@ Value *llvm::createSimpleReduction(VectorBuilder &VBuilder, Value *Src,
Intrinsic::ID Id = getReductionIntrinsicID(Kind);
auto *SrcTy = cast<VectorType>(Src->getType());
Type *SrcEltTy = SrcTy->getElementType();
- Value *Iden =
- Desc.getRecurrenceIdentity(Kind, SrcEltTy, Desc.getFastMathFlags());
+ Value *Iden = getRecurrenceIdentity(Kind, SrcEltTy, Desc.getFastMathFlags());
Value *Ops[] = {Iden, Src};
return VBuilder.createSimpleReduction(Id, SrcTy, Ops);
}