aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib')
-rw-r--r--llvm/lib/Target/AArch64/AArch64ISelLowering.cpp50
-rw-r--r--llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td2
-rw-r--r--llvm/lib/Target/AArch64/SVEInstrFormats.td6
3 files changed, 53 insertions, 5 deletions
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index b11ac81..4166d9b 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1664,6 +1664,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::BITCAST, VT, Custom);
setOperationAction(ISD::CONCAT_VECTORS, VT, Custom);
setOperationAction(ISD::FP_EXTEND, VT, Custom);
+ setOperationAction(ISD::FP_ROUND, VT, Custom);
setOperationAction(ISD::MLOAD, VT, Custom);
setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
setOperationAction(ISD::SPLAT_VECTOR, VT, Legal);
@@ -4334,14 +4335,57 @@ SDValue AArch64TargetLowering::LowerFP_EXTEND(SDValue Op,
SDValue AArch64TargetLowering::LowerFP_ROUND(SDValue Op,
SelectionDAG &DAG) const {
EVT VT = Op.getValueType();
- if (VT.isScalableVector())
- return LowerToPredicatedOp(Op, DAG, AArch64ISD::FP_ROUND_MERGE_PASSTHRU);
-
bool IsStrict = Op->isStrictFPOpcode();
SDValue SrcVal = Op.getOperand(IsStrict ? 1 : 0);
EVT SrcVT = SrcVal.getValueType();
bool Trunc = Op.getConstantOperandVal(IsStrict ? 2 : 1) == 1;
+ if (VT.isScalableVector()) {
+ if (VT.getScalarType() != MVT::bf16)
+ return LowerToPredicatedOp(Op, DAG, AArch64ISD::FP_ROUND_MERGE_PASSTHRU);
+
+ SDLoc DL(Op);
+ constexpr EVT I32 = MVT::nxv4i32;
+ auto ImmV = [&](int I) -> SDValue { return DAG.getConstant(I, DL, I32); };
+
+ SDValue NaN;
+ SDValue Narrow;
+
+ if (SrcVT == MVT::nxv2f32 || SrcVT == MVT::nxv4f32) {
+ if (Subtarget->hasBF16())
+ return LowerToPredicatedOp(Op, DAG,
+ AArch64ISD::FP_ROUND_MERGE_PASSTHRU);
+
+ Narrow = getSVESafeBitCast(I32, SrcVal, DAG);
+
+ // Set the quiet bit.
+ if (!DAG.isKnownNeverSNaN(SrcVal))
+ NaN = DAG.getNode(ISD::OR, DL, I32, Narrow, ImmV(0x400000));
+ } else
+ return SDValue();
+
+ if (!Trunc) {
+ SDValue Lsb = DAG.getNode(ISD::SRL, DL, I32, Narrow, ImmV(16));
+ Lsb = DAG.getNode(ISD::AND, DL, I32, Lsb, ImmV(1));
+ SDValue RoundingBias = DAG.getNode(ISD::ADD, DL, I32, Lsb, ImmV(0x7fff));
+ Narrow = DAG.getNode(ISD::ADD, DL, I32, Narrow, RoundingBias);
+ }
+
+ // Don't round if we had a NaN, we don't want to turn 0x7fffffff into
+ // 0x80000000.
+ if (NaN) {
+ EVT I1 = I32.changeElementType(MVT::i1);
+ EVT CondVT = VT.changeElementType(MVT::i1);
+ SDValue IsNaN = DAG.getSetCC(DL, CondVT, SrcVal, SrcVal, ISD::SETUO);
+ IsNaN = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, I1, IsNaN);
+ Narrow = DAG.getSelect(DL, I32, IsNaN, NaN, Narrow);
+ }
+
+ // Now that we have rounded, shift the bits into position.
+ Narrow = DAG.getNode(ISD::SRL, DL, I32, Narrow, ImmV(16));
+ return getSVESafeBitCast(VT, Narrow, DAG);
+ }
+
if (useSVEForFixedLengthVectorVT(SrcVT, !Subtarget->isNeonAvailable()))
return LowerFixedLengthFPRoundToSVE(Op, DAG);
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index 1f3d63a2..7240f6a 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -2425,7 +2425,7 @@ let Predicates = [HasBF16, HasSVEorSME] in {
defm BFMLALT_ZZZ : sve2_fp_mla_long<0b101, "bfmlalt", nxv4f32, nxv8bf16, int_aarch64_sve_bfmlalt>;
defm BFMLALB_ZZZI : sve2_fp_mla_long_by_indexed_elem<0b100, "bfmlalb", nxv4f32, nxv8bf16, int_aarch64_sve_bfmlalb_lane_v2>;
defm BFMLALT_ZZZI : sve2_fp_mla_long_by_indexed_elem<0b101, "bfmlalt", nxv4f32, nxv8bf16, int_aarch64_sve_bfmlalt_lane_v2>;
- defm BFCVT_ZPmZ : sve_bfloat_convert<0b1, "bfcvt", int_aarch64_sve_fcvt_bf16f32>;
+ defm BFCVT_ZPmZ : sve_bfloat_convert<0b1, "bfcvt", int_aarch64_sve_fcvt_bf16f32, AArch64fcvtr_mt>;
defm BFCVTNT_ZPmZ : sve_bfloat_convert<0b0, "bfcvtnt", int_aarch64_sve_fcvtnt_bf16f32>;
} // End HasBF16, HasSVEorSME
diff --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td
index 8119198..0bfac64 100644
--- a/llvm/lib/Target/AArch64/SVEInstrFormats.td
+++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td
@@ -8807,9 +8807,13 @@ class sve_bfloat_convert<bit N, string asm>
let mayRaiseFPException = 1;
}
-multiclass sve_bfloat_convert<bit N, string asm, SDPatternOperator op> {
+multiclass sve_bfloat_convert<bit N, string asm, SDPatternOperator op,
+ SDPatternOperator ir_op = null_frag> {
def NAME : sve_bfloat_convert<N, asm>;
+
def : SVE_3_Op_Pat<nxv8bf16, op, nxv8bf16, nxv8i1, nxv4f32, !cast<Instruction>(NAME)>;
+ def : SVE_1_Op_Passthru_Round_Pat<nxv4bf16, ir_op, nxv4i1, nxv4f32, !cast<Instruction>(NAME)>;
+ def : SVE_1_Op_Passthru_Round_Pat<nxv2bf16, ir_op, nxv2i1, nxv2f32, !cast<Instruction>(NAME)>;
}
//===----------------------------------------------------------------------===//