diff options
Diffstat (limited to 'llvm/lib')
| -rw-r--r-- | llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 50 | ||||
| -rw-r--r-- | llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td | 2 | ||||
| -rw-r--r-- | llvm/lib/Target/AArch64/SVEInstrFormats.td | 6 |
3 files changed, 53 insertions, 5 deletions
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index b11ac81..4166d9b 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1664,6 +1664,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(ISD::BITCAST, VT, Custom); setOperationAction(ISD::CONCAT_VECTORS, VT, Custom); setOperationAction(ISD::FP_EXTEND, VT, Custom); + setOperationAction(ISD::FP_ROUND, VT, Custom); setOperationAction(ISD::MLOAD, VT, Custom); setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom); setOperationAction(ISD::SPLAT_VECTOR, VT, Legal); @@ -4334,14 +4335,57 @@ SDValue AArch64TargetLowering::LowerFP_EXTEND(SDValue Op, SDValue AArch64TargetLowering::LowerFP_ROUND(SDValue Op, SelectionDAG &DAG) const { EVT VT = Op.getValueType(); - if (VT.isScalableVector()) - return LowerToPredicatedOp(Op, DAG, AArch64ISD::FP_ROUND_MERGE_PASSTHRU); - bool IsStrict = Op->isStrictFPOpcode(); SDValue SrcVal = Op.getOperand(IsStrict ? 1 : 0); EVT SrcVT = SrcVal.getValueType(); bool Trunc = Op.getConstantOperandVal(IsStrict ? 2 : 1) == 1; + if (VT.isScalableVector()) { + if (VT.getScalarType() != MVT::bf16) + return LowerToPredicatedOp(Op, DAG, AArch64ISD::FP_ROUND_MERGE_PASSTHRU); + + SDLoc DL(Op); + constexpr EVT I32 = MVT::nxv4i32; + auto ImmV = [&](int I) -> SDValue { return DAG.getConstant(I, DL, I32); }; + + SDValue NaN; + SDValue Narrow; + + if (SrcVT == MVT::nxv2f32 || SrcVT == MVT::nxv4f32) { + if (Subtarget->hasBF16()) + return LowerToPredicatedOp(Op, DAG, + AArch64ISD::FP_ROUND_MERGE_PASSTHRU); + + Narrow = getSVESafeBitCast(I32, SrcVal, DAG); + + // Set the quiet bit. + if (!DAG.isKnownNeverSNaN(SrcVal)) + NaN = DAG.getNode(ISD::OR, DL, I32, Narrow, ImmV(0x400000)); + } else + return SDValue(); + + if (!Trunc) { + SDValue Lsb = DAG.getNode(ISD::SRL, DL, I32, Narrow, ImmV(16)); + Lsb = DAG.getNode(ISD::AND, DL, I32, Lsb, ImmV(1)); + SDValue RoundingBias = DAG.getNode(ISD::ADD, DL, I32, Lsb, ImmV(0x7fff)); + Narrow = DAG.getNode(ISD::ADD, DL, I32, Narrow, RoundingBias); + } + + // Don't round if we had a NaN, we don't want to turn 0x7fffffff into + // 0x80000000. + if (NaN) { + EVT I1 = I32.changeElementType(MVT::i1); + EVT CondVT = VT.changeElementType(MVT::i1); + SDValue IsNaN = DAG.getSetCC(DL, CondVT, SrcVal, SrcVal, ISD::SETUO); + IsNaN = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, I1, IsNaN); + Narrow = DAG.getSelect(DL, I32, IsNaN, NaN, Narrow); + } + + // Now that we have rounded, shift the bits into position. + Narrow = DAG.getNode(ISD::SRL, DL, I32, Narrow, ImmV(16)); + return getSVESafeBitCast(VT, Narrow, DAG); + } + if (useSVEForFixedLengthVectorVT(SrcVT, !Subtarget->isNeonAvailable())) return LowerFixedLengthFPRoundToSVE(Op, DAG); diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td index 1f3d63a2..7240f6a 100644 --- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -2425,7 +2425,7 @@ let Predicates = [HasBF16, HasSVEorSME] in { defm BFMLALT_ZZZ : sve2_fp_mla_long<0b101, "bfmlalt", nxv4f32, nxv8bf16, int_aarch64_sve_bfmlalt>; defm BFMLALB_ZZZI : sve2_fp_mla_long_by_indexed_elem<0b100, "bfmlalb", nxv4f32, nxv8bf16, int_aarch64_sve_bfmlalb_lane_v2>; defm BFMLALT_ZZZI : sve2_fp_mla_long_by_indexed_elem<0b101, "bfmlalt", nxv4f32, nxv8bf16, int_aarch64_sve_bfmlalt_lane_v2>; - defm BFCVT_ZPmZ : sve_bfloat_convert<0b1, "bfcvt", int_aarch64_sve_fcvt_bf16f32>; + defm BFCVT_ZPmZ : sve_bfloat_convert<0b1, "bfcvt", int_aarch64_sve_fcvt_bf16f32, AArch64fcvtr_mt>; defm BFCVTNT_ZPmZ : sve_bfloat_convert<0b0, "bfcvtnt", int_aarch64_sve_fcvtnt_bf16f32>; } // End HasBF16, HasSVEorSME diff --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td index 8119198..0bfac64 100644 --- a/llvm/lib/Target/AArch64/SVEInstrFormats.td +++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td @@ -8807,9 +8807,13 @@ class sve_bfloat_convert<bit N, string asm> let mayRaiseFPException = 1; } -multiclass sve_bfloat_convert<bit N, string asm, SDPatternOperator op> { +multiclass sve_bfloat_convert<bit N, string asm, SDPatternOperator op, + SDPatternOperator ir_op = null_frag> { def NAME : sve_bfloat_convert<N, asm>; + def : SVE_3_Op_Pat<nxv8bf16, op, nxv8bf16, nxv8i1, nxv4f32, !cast<Instruction>(NAME)>; + def : SVE_1_Op_Passthru_Round_Pat<nxv4bf16, ir_op, nxv4i1, nxv4f32, !cast<Instruction>(NAME)>; + def : SVE_1_Op_Passthru_Round_Pat<nxv2bf16, ir_op, nxv2i1, nxv2f32, !cast<Instruction>(NAME)>; } //===----------------------------------------------------------------------===// |
