diff options
| author | Sam Parker <sam.parker@arm.com> | 2025-10-13 16:50:53 +0100 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-10-13 16:50:53 +0100 |
| commit | 1820102167a5ace14a5c1f79d11d5eb4cce93001 (patch) | |
| tree | 4e5f2e07f3cb292261291a1714ce2f2a2e34f500 /llvm/lib | |
| parent | 095cad6add16df3f6273f5b24293e48a08e3230e (diff) | |
| download | llvm-1820102167a5ace14a5c1f79d11d5eb4cce93001.zip llvm-1820102167a5ace14a5c1f79d11d5eb4cce93001.tar.gz llvm-1820102167a5ace14a5c1f79d11d5eb4cce93001.tar.bz2 | |
Wasm fmuladd relaxed (#163177)
Reland #161355, after fixing up the cross-projects-tests for the wasm
simd intrinsics.
Original commit message:
Lower v4f32 and v2f64 fmuladd calls to relaxed_madd instructions.
If we have FP16, then lower v8f16 fmuladds to FMA.
I've introduced an ISD node for fmuladd to maintain the rounding
ambiguity through legalization / combine / isel.
Diffstat (limited to 'llvm/lib')
9 files changed, 94 insertions, 11 deletions
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index b1accdd..e153842 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -509,6 +509,7 @@ namespace { SDValue visitFMUL(SDNode *N); template <class MatchContextClass> SDValue visitFMA(SDNode *N); SDValue visitFMAD(SDNode *N); + SDValue visitFMULADD(SDNode *N); SDValue visitFDIV(SDNode *N); SDValue visitFREM(SDNode *N); SDValue visitFSQRT(SDNode *N); @@ -1991,6 +1992,7 @@ SDValue DAGCombiner::visit(SDNode *N) { case ISD::FMUL: return visitFMUL(N); case ISD::FMA: return visitFMA<EmptyMatchContext>(N); case ISD::FMAD: return visitFMAD(N); + case ISD::FMULADD: return visitFMULADD(N); case ISD::FDIV: return visitFDIV(N); case ISD::FREM: return visitFREM(N); case ISD::FSQRT: return visitFSQRT(N); @@ -18444,6 +18446,21 @@ SDValue DAGCombiner::visitFMAD(SDNode *N) { return SDValue(); } +SDValue DAGCombiner::visitFMULADD(SDNode *N) { + SDValue N0 = N->getOperand(0); + SDValue N1 = N->getOperand(1); + SDValue N2 = N->getOperand(2); + EVT VT = N->getValueType(0); + SDLoc DL(N); + + // Constant fold FMULADD. + if (SDValue C = + DAG.FoldConstantArithmetic(ISD::FMULADD, DL, VT, {N0, N1, N2})) + return C; + + return SDValue(); +} + // Combine multiple FDIVs with the same divisor into multiple FMULs by the // reciprocal. // E.g., (a / D; b / D;) -> (recip = 1.0 / D; a * recip; b * recip) diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index 2b8dd60..4512c5c 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -5786,6 +5786,7 @@ bool SelectionDAG::canCreateUndefOrPoison(SDValue Op, const APInt &DemandedElts, case ISD::FCOPYSIGN: case ISD::FMA: case ISD::FMAD: + case ISD::FMULADD: case ISD::FP_EXTEND: case ISD::FP_TO_SINT_SAT: case ISD::FP_TO_UINT_SAT: @@ -5904,6 +5905,7 @@ bool SelectionDAG::isKnownNeverNaN(SDValue Op, const APInt &DemandedElts, case ISD::FCOSH: case ISD::FTANH: case ISD::FMA: + case ISD::FMULADD: case ISD::FMAD: { if (SNaN) return true; @@ -7231,7 +7233,7 @@ SDValue SelectionDAG::FoldConstantArithmetic(unsigned Opcode, const SDLoc &DL, } // Handle fma/fmad special cases. - if (Opcode == ISD::FMA || Opcode == ISD::FMAD) { + if (Opcode == ISD::FMA || Opcode == ISD::FMAD || Opcode == ISD::FMULADD) { assert(VT.isFloatingPoint() && "This operator only applies to FP types!"); assert(Ops[0].getValueType() == VT && Ops[1].getValueType() == VT && Ops[2].getValueType() == VT && "FMA types must match!"); @@ -7242,7 +7244,7 @@ SDValue SelectionDAG::FoldConstantArithmetic(unsigned Opcode, const SDLoc &DL, APFloat V1 = C1->getValueAPF(); const APFloat &V2 = C2->getValueAPF(); const APFloat &V3 = C3->getValueAPF(); - if (Opcode == ISD::FMAD) { + if (Opcode == ISD::FMAD || Opcode == ISD::FMULADD) { V1.multiply(V2, APFloat::rmNearestTiesToEven); V1.add(V3, APFloat::rmNearestTiesToEven); } else diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp index c21890a..0f2b518 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -6996,6 +6996,13 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I, getValue(I.getArgOperand(0)), getValue(I.getArgOperand(1)), getValue(I.getArgOperand(2)), Flags)); + } else if (TLI.isOperationLegalOrCustom(ISD::FMULADD, VT)) { + // TODO: Support splitting the vector. + setValue(&I, DAG.getNode(ISD::FMULADD, sdl, + getValue(I.getArgOperand(0)).getValueType(), + getValue(I.getArgOperand(0)), + getValue(I.getArgOperand(1)), + getValue(I.getArgOperand(2)), Flags)); } else { // TODO: Intrinsic calls should have fast-math-flags. SDValue Mul = DAG.getNode( diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp index fcfbfe6..39cbfad 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp @@ -310,6 +310,7 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const { case ISD::FMA: return "fma"; case ISD::STRICT_FMA: return "strict_fma"; case ISD::FMAD: return "fmad"; + case ISD::FMULADD: return "fmuladd"; case ISD::FREM: return "frem"; case ISD::STRICT_FREM: return "strict_frem"; case ISD::FCOPYSIGN: return "fcopysign"; diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp index cc503d3..920dff9 100644 --- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -7676,6 +7676,7 @@ SDValue TargetLowering::getNegatedExpression(SDValue Op, SelectionDAG &DAG, break; } case ISD::FMA: + case ISD::FMULADD: case ISD::FMAD: { if (!Flags.hasNoSignedZeros()) break; diff --git a/llvm/lib/CodeGen/TargetLoweringBase.cpp b/llvm/lib/CodeGen/TargetLoweringBase.cpp index c23281a..060b1dd 100644 --- a/llvm/lib/CodeGen/TargetLoweringBase.cpp +++ b/llvm/lib/CodeGen/TargetLoweringBase.cpp @@ -815,7 +815,8 @@ void TargetLoweringBase::initActions() { ISD::FTAN, ISD::FACOS, ISD::FASIN, ISD::FATAN, ISD::FCOSH, ISD::FSINH, - ISD::FTANH, ISD::FATAN2}, + ISD::FTANH, ISD::FATAN2, + ISD::FMULADD}, VT, Expand); // Overflow operations default to expand diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp index 6472334..47c24fc 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp @@ -317,6 +317,15 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering( setOperationAction(ISD::ZERO_EXTEND_VECTOR_INREG, T, Custom); } + if (Subtarget->hasFP16()) { + setOperationAction(ISD::FMA, MVT::v8f16, Legal); + } + + if (Subtarget->hasRelaxedSIMD()) { + setOperationAction(ISD::FMULADD, MVT::v4f32, Legal); + setOperationAction(ISD::FMULADD, MVT::v2f64, Legal); + } + // Partial MLA reductions. for (auto Op : {ISD::PARTIAL_REDUCE_SMLA, ISD::PARTIAL_REDUCE_UMLA}) { setPartialReduceMLAAction(Op, MVT::v4i32, MVT::v16i8, Legal); @@ -1120,6 +1129,18 @@ WebAssemblyTargetLowering::getPreferredVectorAction(MVT VT) const { return TargetLoweringBase::getPreferredVectorAction(VT); } +bool WebAssemblyTargetLowering::isFMAFasterThanFMulAndFAdd( + const MachineFunction &MF, EVT VT) const { + if (!Subtarget->hasFP16() || !VT.isVector()) + return false; + + EVT ScalarVT = VT.getScalarType(); + if (!ScalarVT.isSimple()) + return false; + + return ScalarVT.getSimpleVT().SimpleTy == MVT::f16; +} + bool WebAssemblyTargetLowering::shouldSimplifyDemandedVectorElts( SDValue Op, const TargetLoweringOpt &TLO) const { // ISel process runs DAGCombiner after legalization; this step is called diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h index b33a853..472ec67 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h +++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h @@ -81,6 +81,8 @@ private: TargetLoweringBase::LegalizeTypeAction getPreferredVectorAction(MVT VT) const override; + bool isFMAFasterThanFMulAndFAdd(const MachineFunction &MF, + EVT VT) const override; SDValue LowerCall(CallLoweringInfo &CLI, SmallVectorImpl<SDValue> &InVals) const override; diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td index 49af78b..fe6fa07 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td +++ b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td @@ -1626,7 +1626,8 @@ defm "" : RelaxedConvert<I32x4, F64x2, int_wasm_relaxed_trunc_unsigned_zero, // Relaxed (Negative) Multiply-Add (madd/nmadd) //===----------------------------------------------------------------------===// -multiclass SIMDMADD<Vec vec, bits<32> simdopA, bits<32> simdopS, list<Predicate> reqs> { +multiclass RELAXED_SIMDMADD<Vec vec, bits<32> simdopA, bits<32> simdopS, + list<Predicate> reqs> { defm MADD_#vec : SIMD_I<(outs V128:$dst), (ins V128:$a, V128:$b, V128:$c), (outs), (ins), [(set (vec.vt V128:$dst), (int_wasm_relaxed_madd @@ -1640,16 +1641,46 @@ multiclass SIMDMADD<Vec vec, bits<32> simdopA, bits<32> simdopS, list<Predicate> vec.prefix#".relaxed_nmadd\t$dst, $a, $b, $c", vec.prefix#".relaxed_nmadd", simdopS, reqs>; - def : Pat<(fadd_contract (vec.vt V128:$a), (fmul_contract (vec.vt V128:$b), (vec.vt V128:$c))), - (!cast<Instruction>("MADD_"#vec) V128:$a, V128:$b, V128:$c)>, Requires<[HasRelaxedSIMD]>; + def : Pat<(fadd_contract (fmul_contract (vec.vt V128:$a), (vec.vt V128:$b)), (vec.vt V128:$c)), + (!cast<Instruction>("MADD_"#vec) V128:$a, V128:$b, V128:$c)>, Requires<reqs>; + def : Pat<(fmuladd (vec.vt V128:$a), (vec.vt V128:$b), (vec.vt V128:$c)), + (!cast<Instruction>("MADD_"#vec) V128:$a, V128:$b, V128:$c)>, Requires<reqs>; - def : Pat<(fsub_contract (vec.vt V128:$a), (fmul_contract (vec.vt V128:$b), (vec.vt V128:$c))), - (!cast<Instruction>("NMADD_"#vec) V128:$a, V128:$b, V128:$c)>, Requires<[HasRelaxedSIMD]>; + def : Pat<(fsub_contract (vec.vt V128:$c), (fmul_contract (vec.vt V128:$a), (vec.vt V128:$b))), + (!cast<Instruction>("NMADD_"#vec) V128:$a, V128:$b, V128:$c)>, Requires<reqs>; + def : Pat<(fmuladd (fneg (vec.vt V128:$a)), (vec.vt V128:$b), (vec.vt V128:$c)), + (!cast<Instruction>("NMADD_"#vec) V128:$a, V128:$b, V128:$c)>, Requires<reqs>; } -defm "" : SIMDMADD<F32x4, 0x105, 0x106, [HasRelaxedSIMD]>; -defm "" : SIMDMADD<F64x2, 0x107, 0x108, [HasRelaxedSIMD]>; -defm "" : SIMDMADD<F16x8, 0x14e, 0x14f, [HasFP16]>; +defm "" : RELAXED_SIMDMADD<F32x4, 0x105, 0x106, [HasRelaxedSIMD]>; +defm "" : RELAXED_SIMDMADD<F64x2, 0x107, 0x108, [HasRelaxedSIMD]>; + +//===----------------------------------------------------------------------===// +// FP16 (Negative) Multiply-Add (madd/nmadd) +//===----------------------------------------------------------------------===// + +multiclass HALF_PRECISION_SIMDMADD<Vec vec, bits<32> simdopA, bits<32> simdopS, + list<Predicate> reqs> { + defm MADD_#vec : + SIMD_I<(outs V128:$dst), (ins V128:$a, V128:$b, V128:$c), (outs), (ins), + [(set (vec.vt V128:$dst), (fma + (vec.vt V128:$a), (vec.vt V128:$b), (vec.vt V128:$c)))], + vec.prefix#".madd\t$dst, $a, $b, $c", + vec.prefix#".madd", simdopA, reqs>; + defm NMADD_#vec : + SIMD_I<(outs V128:$dst), (ins V128:$a, V128:$b, V128:$c), (outs), (ins), + [(set (vec.vt V128:$dst), (fma + (fneg (vec.vt V128:$a)), (vec.vt V128:$b), (vec.vt V128:$c)))], + vec.prefix#".nmadd\t$dst, $a, $b, $c", + vec.prefix#".nmadd", simdopS, reqs>; +} +defm "" : HALF_PRECISION_SIMDMADD<F16x8, 0x14e, 0x14f, [HasFP16]>; + +// TODO: I think separate intrinsics should be introduced for these FP16 operations. +def : Pat<(v8f16 (int_wasm_relaxed_madd (v8f16 V128:$a), (v8f16 V128:$b), (v8f16 V128:$c))), + (MADD_F16x8 V128:$a, V128:$b, V128:$c)>; +def : Pat<(v8f16 (int_wasm_relaxed_nmadd (v8f16 V128:$a), (v8f16 V128:$b), (v8f16 V128:$c))), + (NMADD_F16x8 V128:$a, V128:$b, V128:$c)>; //===----------------------------------------------------------------------===// // Laneselect |
