diff options
Diffstat (limited to 'llvm/lib/Transforms/Utils/LoopUtils.cpp')
-rw-r--r-- | llvm/lib/Transforms/Utils/LoopUtils.cpp | 55 |
1 files changed, 23 insertions, 32 deletions
diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp index 3245f5f..f2b94d9 100644 --- a/llvm/lib/Transforms/Utils/LoopUtils.cpp +++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp @@ -979,9 +979,9 @@ Value *llvm::getShuffleReduction(IRBuilderBase &Builder, Value *Src, Value *llvm::createSimpleTargetReduction(IRBuilderBase &Builder, const TargetTransformInfo *TTI, - unsigned Opcode, Value *Src, - RecurKind RdxKind, + Value *Src, RecurKind RdxKind, ArrayRef<Value *> RedOps) { + unsigned Opcode = RecurrenceDescriptor::getOpcode(RdxKind); TargetTransformInfo::ReductionFlags RdxFlags; RdxFlags.IsMaxOp = RdxKind == RecurKind::SMax || RdxKind == RecurKind::UMax || RdxKind == RecurKind::FMax; @@ -991,42 +991,34 @@ Value *llvm::createSimpleTargetReduction(IRBuilderBase &Builder, return getShuffleReduction(Builder, Src, Opcode, RdxKind, RedOps); auto *SrcVecEltTy = cast<VectorType>(Src->getType())->getElementType(); - switch (Opcode) { - case Instruction::Add: + switch (RdxKind) { + case RecurKind::Add: return Builder.CreateAddReduce(Src); - case Instruction::Mul: + case RecurKind::Mul: return Builder.CreateMulReduce(Src); - case Instruction::And: + case RecurKind::And: return Builder.CreateAndReduce(Src); - case Instruction::Or: + case RecurKind::Or: return Builder.CreateOrReduce(Src); - case Instruction::Xor: + case RecurKind::Xor: return Builder.CreateXorReduce(Src); - case Instruction::FAdd: + case RecurKind::FAdd: return Builder.CreateFAddReduce(ConstantFP::getNegativeZero(SrcVecEltTy), Src); - case Instruction::FMul: + case RecurKind::FMul: return Builder.CreateFMulReduce(ConstantFP::get(SrcVecEltTy, 1.0), Src); - case Instruction::ICmp: - switch (RdxKind) { - case RecurKind::SMax: - return Builder.CreateIntMaxReduce(Src, true); - case RecurKind::SMin: - return Builder.CreateIntMinReduce(Src, true); - case RecurKind::UMax: - return Builder.CreateIntMaxReduce(Src, false); - case RecurKind::UMin: - return Builder.CreateIntMinReduce(Src, false); - default: - llvm_unreachable("Unexpected min/max reduction type"); - } - case Instruction::FCmp: - assert((RdxKind == RecurKind::FMax || RdxKind == RecurKind::FMin) && - "Unexpected min/max reduction type"); - if (RdxKind == RecurKind::FMax) - return Builder.CreateFPMaxReduce(Src); - else - return Builder.CreateFPMinReduce(Src); + case RecurKind::SMax: + return Builder.CreateIntMaxReduce(Src, true); + case RecurKind::SMin: + return Builder.CreateIntMinReduce(Src, true); + case RecurKind::UMax: + return Builder.CreateIntMaxReduce(Src, false); + case RecurKind::UMin: + return Builder.CreateIntMinReduce(Src, false); + case RecurKind::FMax: + return Builder.CreateFPMaxReduce(Src); + case RecurKind::FMin: + return Builder.CreateFPMinReduce(Src); default: llvm_unreachable("Unhandled opcode"); } @@ -1040,8 +1032,7 @@ Value *llvm::createTargetReduction(IRBuilderBase &B, // descriptor. IRBuilderBase::FastMathFlagGuard FMFGuard(B); B.setFastMathFlags(Desc.getFastMathFlags()); - return createSimpleTargetReduction(B, TTI, Desc.getRecurrenceBinOp(), Src, - Desc.getRecurrenceKind()); + return createSimpleTargetReduction(B, TTI, Src, Desc.getRecurrenceKind()); } void llvm::propagateIRFlags(Value *I, ArrayRef<Value *> VL, Value *OpValue) { |