diff options
Diffstat (limited to 'llvm/lib/CodeGen/SelectionDAG')
7 files changed, 75 insertions, 25 deletions
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index 10daca5..f144f17 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -2042,6 +2042,7 @@ SDValue DAGCombiner::visit(SDNode *N) { case ISD::PARTIAL_REDUCE_SMLA: case ISD::PARTIAL_REDUCE_UMLA: case ISD::PARTIAL_REDUCE_SUMLA: + case ISD::PARTIAL_REDUCE_FMLA: return visitPARTIAL_REDUCE_MLA(N); case ISD::VECTOR_COMPRESS: return visitVECTOR_COMPRESS(N); case ISD::LIFETIME_END: return visitLIFETIME_END(N); @@ -13006,6 +13007,9 @@ SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) { // // partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1)) // -> partial_reduce_*mla(acc, x, C) +// +// partial_reduce_fmla(acc, fmul(fpext(a), fpext(b)), splat(1.0)) +// -> partial_reduce_fmla(acc, a, b) SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) { SDLoc DL(N); auto *Context = DAG.getContext(); @@ -13014,7 +13018,7 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) { SDValue Op2 = N->getOperand(2); unsigned Opc = Op1->getOpcode(); - if (Opc != ISD::MUL && Opc != ISD::SHL) + if (Opc != ISD::MUL && Opc != ISD::FMUL && Opc != ISD::SHL) return SDValue(); SDValue LHS = Op1->getOperand(0); @@ -13033,13 +13037,16 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) { Opc = ISD::MUL; } - APInt C; - if (Opc != ISD::MUL || !ISD::isConstantSplatVector(Op2.getNode(), C) || - !C.isOne()) + if (!(Opc == ISD::MUL && llvm::isOneOrOneSplat(Op2)) && + !(Opc == ISD::FMUL && llvm::isOneOrOneSplatFP(Op2))) return SDValue(); + auto IsIntOrFPExtOpcode = [](unsigned int Opcode) { + return (ISD::isExtOpcode(Opcode) || Opcode == ISD::FP_EXTEND); + }; + unsigned LHSOpcode = LHS->getOpcode(); - if (!ISD::isExtOpcode(LHSOpcode)) + if (!IsIntOrFPExtOpcode(LHSOpcode)) return SDValue(); SDValue LHSExtOp = LHS->getOperand(0); @@ -13047,6 +13054,7 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) { // partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1)) // -> partial_reduce_*mla(acc, x, C) + APInt C; if (ISD::isConstantSplatVector(RHS.getNode(), C)) { // TODO: Make use of partial_reduce_sumla here APInt CTrunc = C.trunc(LHSExtOpVT.getScalarSizeInBits()); @@ -13071,7 +13079,7 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) { } unsigned RHSOpcode = RHS->getOpcode(); - if (!ISD::isExtOpcode(RHSOpcode)) + if (!IsIntOrFPExtOpcode(RHSOpcode)) return SDValue(); SDValue RHSExtOp = RHS->getOperand(0); @@ -13088,6 +13096,8 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) { else if (LHSOpcode == ISD::ZERO_EXTEND && RHSOpcode == ISD::SIGN_EXTEND) { NewOpc = ISD::PARTIAL_REDUCE_SUMLA; std::swap(LHSExtOp, RHSExtOp); + } else if (LHSOpcode == ISD::FP_EXTEND && RHSOpcode == ISD::FP_EXTEND) { + NewOpc = ISD::PARTIAL_REDUCE_FMLA; } else return SDValue(); // For a 2-stage extend the signedness of both of the extends must match @@ -13115,30 +13125,33 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) { // -> partial.reduce.smla(acc, op, splat(trunc(1))) // partial.reduce.sumla(acc, sext(op), splat(1)) // -> partial.reduce.smla(acc, op, splat(trunc(1))) +// partial.reduce.fmla(acc, fpext(op), splat(1.0)) +// -> partial.reduce.fmla(acc, op, splat(1.0)) SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) { SDLoc DL(N); SDValue Acc = N->getOperand(0); SDValue Op1 = N->getOperand(1); SDValue Op2 = N->getOperand(2); - APInt ConstantOne; - if (!ISD::isConstantSplatVector(Op2.getNode(), ConstantOne) || - !ConstantOne.isOne()) + if (!llvm::isOneOrOneSplat(Op2) && !llvm::isOneOrOneSplatFP(Op2)) return SDValue(); unsigned Op1Opcode = Op1.getOpcode(); - if (!ISD::isExtOpcode(Op1Opcode)) + if (!ISD::isExtOpcode(Op1Opcode) && Op1Opcode != ISD::FP_EXTEND) return SDValue(); - bool Op1IsSigned = Op1Opcode == ISD::SIGN_EXTEND; + bool Op1IsSigned = + Op1Opcode == ISD::SIGN_EXTEND || Op1Opcode == ISD::FP_EXTEND; bool NodeIsSigned = N->getOpcode() != ISD::PARTIAL_REDUCE_UMLA; EVT AccElemVT = Acc.getValueType().getVectorElementType(); if (Op1IsSigned != NodeIsSigned && Op1.getValueType().getVectorElementType() != AccElemVT) return SDValue(); - unsigned NewOpcode = - Op1IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA; + unsigned NewOpcode = N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA + ? ISD::PARTIAL_REDUCE_FMLA + : Op1IsSigned ? ISD::PARTIAL_REDUCE_SMLA + : ISD::PARTIAL_REDUCE_UMLA; SDValue UnextOp1 = Op1.getOperand(0); EVT UnextOp1VT = UnextOp1.getValueType(); @@ -13148,8 +13161,12 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) { TLI.getTypeToTransformTo(*Context, UnextOp1VT))) return SDValue(); + SDValue Constant = N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA + ? DAG.getConstantFP(1, DL, UnextOp1VT) + : DAG.getConstant(1, DL, UnextOp1VT); + return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, UnextOp1, - DAG.getConstant(1, DL, UnextOp1VT)); + Constant); } SDValue DAGCombiner::visitVP_STRIDED_LOAD(SDNode *N) { diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp index 8e423c4..94751be5 100644 --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp @@ -534,6 +534,7 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) { case ISD::PARTIAL_REDUCE_UMLA: case ISD::PARTIAL_REDUCE_SMLA: case ISD::PARTIAL_REDUCE_SUMLA: + case ISD::PARTIAL_REDUCE_FMLA: Action = TLI.getPartialReduceMLAAction(Op.getOpcode(), Node->getValueType(0), Node->getOperand(1).getValueType()); @@ -1243,6 +1244,7 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) { case ISD::PARTIAL_REDUCE_UMLA: case ISD::PARTIAL_REDUCE_SMLA: case ISD::PARTIAL_REDUCE_SUMLA: + case ISD::PARTIAL_REDUCE_FMLA: Results.push_back(TLI.expandPartialReduceMLA(Node, DAG)); return; case ISD::VECREDUCE_SEQ_FADD: diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp index bb4a8d9..dd5c011 100644 --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp @@ -1474,6 +1474,7 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) { case ISD::PARTIAL_REDUCE_UMLA: case ISD::PARTIAL_REDUCE_SMLA: case ISD::PARTIAL_REDUCE_SUMLA: + case ISD::PARTIAL_REDUCE_FMLA: SplitVecRes_PARTIAL_REDUCE_MLA(N, Lo, Hi); break; case ISD::GET_ACTIVE_LANE_MASK: @@ -3689,6 +3690,7 @@ bool DAGTypeLegalizer::SplitVectorOperand(SDNode *N, unsigned OpNo) { case ISD::PARTIAL_REDUCE_UMLA: case ISD::PARTIAL_REDUCE_SMLA: case ISD::PARTIAL_REDUCE_SUMLA: + case ISD::PARTIAL_REDUCE_FMLA: Res = SplitVecOp_PARTIAL_REDUCE_MLA(N); break; } diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index 0a06752..bbc1d73 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -8404,7 +8404,8 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT, } case ISD::PARTIAL_REDUCE_UMLA: case ISD::PARTIAL_REDUCE_SMLA: - case ISD::PARTIAL_REDUCE_SUMLA: { + case ISD::PARTIAL_REDUCE_SUMLA: + case ISD::PARTIAL_REDUCE_FMLA: { [[maybe_unused]] EVT AccVT = N1.getValueType(); [[maybe_unused]] EVT Input1VT = N2.getValueType(); [[maybe_unused]] EVT Input2VT = N3.getValueType(); @@ -13064,6 +13065,11 @@ bool llvm::isOneOrOneSplat(SDValue N, bool AllowUndefs) { return C && C->isOne(); } +bool llvm::isOneOrOneSplatFP(SDValue N, bool AllowUndefs) { + ConstantFPSDNode *C = isConstOrConstSplatFP(N, AllowUndefs); + return C && C->isExactlyValue(1.0); +} + bool llvm::isAllOnesOrAllOnesSplat(SDValue N, bool AllowUndefs) { N = peekThroughBitcasts(N); unsigned BitWidth = N.getScalarValueSizeInBits(); diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp index 2f598b2..88b0809 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -8187,6 +8187,14 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I, Input, DAG.getConstant(1, sdl, Input.getValueType()))); return; } + case Intrinsic::vector_partial_reduce_fadd: { + SDValue Acc = getValue(I.getOperand(0)); + SDValue Input = getValue(I.getOperand(1)); + setValue(&I, DAG.getNode( + ISD::PARTIAL_REDUCE_FMLA, sdl, Acc.getValueType(), Acc, + Input, DAG.getConstantFP(1.0, sdl, Input.getValueType()))); + return; + } case Intrinsic::experimental_cttz_elts: { auto DL = getCurSDLoc(); SDValue Op = getValue(I.getOperand(0)); diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp index d3e1628..ec5edd5 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp @@ -590,6 +590,8 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const { return "partial_reduce_smla"; case ISD::PARTIAL_REDUCE_SUMLA: return "partial_reduce_sumla"; + case ISD::PARTIAL_REDUCE_FMLA: + return "partial_reduce_fmla"; case ISD::LOOP_DEPENDENCE_WAR_MASK: return "loop_dep_war"; case ISD::LOOP_DEPENDENCE_RAW_MASK: diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp index 9bdf822..b51d664 100644 --- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -12074,22 +12074,32 @@ SDValue TargetLowering::expandPartialReduceMLA(SDNode *N, EVT::getVectorVT(*DAG.getContext(), AccVT.getVectorElementType(), MulOpVT.getVectorElementCount()); - unsigned ExtOpcLHS = N->getOpcode() == ISD::PARTIAL_REDUCE_UMLA - ? ISD::ZERO_EXTEND - : ISD::SIGN_EXTEND; - unsigned ExtOpcRHS = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA - ? ISD::SIGN_EXTEND - : ISD::ZERO_EXTEND; + unsigned ExtOpcLHS, ExtOpcRHS; + switch (N->getOpcode()) { + default: + llvm_unreachable("Unexpected opcode"); + case ISD::PARTIAL_REDUCE_UMLA: + ExtOpcLHS = ExtOpcRHS = ISD::ZERO_EXTEND; + break; + case ISD::PARTIAL_REDUCE_SMLA: + ExtOpcLHS = ExtOpcRHS = ISD::SIGN_EXTEND; + break; + case ISD::PARTIAL_REDUCE_FMLA: + ExtOpcLHS = ExtOpcRHS = ISD::FP_EXTEND; + break; + } if (ExtMulOpVT != MulOpVT) { MulLHS = DAG.getNode(ExtOpcLHS, DL, ExtMulOpVT, MulLHS); MulRHS = DAG.getNode(ExtOpcRHS, DL, ExtMulOpVT, MulRHS); } SDValue Input = MulLHS; - APInt ConstantOne; - if (!ISD::isConstantSplatVector(MulRHS.getNode(), ConstantOne) || - !ConstantOne.isOne()) + if (N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA) { + if (!llvm::isOneOrOneSplatFP(MulRHS)) + Input = DAG.getNode(ISD::FMUL, DL, ExtMulOpVT, MulLHS, MulRHS); + } else if (!llvm::isOneOrOneSplat(MulRHS)) { Input = DAG.getNode(ISD::MUL, DL, ExtMulOpVT, MulLHS, MulRHS); + } unsigned Stride = AccVT.getVectorMinNumElements(); unsigned ScaleFactor = MulOpVT.getVectorMinNumElements() / Stride; @@ -12099,10 +12109,13 @@ SDValue TargetLowering::expandPartialReduceMLA(SDNode *N, for (unsigned I = 0; I < ScaleFactor; I++) Subvectors.push_back(DAG.getExtractSubvector(DL, AccVT, Input, I * Stride)); + unsigned FlatNode = + N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA ? ISD::FADD : ISD::ADD; + // Flatten the subvector tree while (Subvectors.size() > 1) { Subvectors.push_back( - DAG.getNode(ISD::ADD, DL, AccVT, {Subvectors[0], Subvectors[1]})); + DAG.getNode(FlatNode, DL, AccVT, {Subvectors[0], Subvectors[1]})); Subvectors.pop_front(); Subvectors.pop_front(); } |
