diff options
Diffstat (limited to 'llvm')
-rw-r--r-- | llvm/include/llvm/IR/IntrinsicsNVVM.td | 11 | ||||
-rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXInstrInfo.td | 17 | ||||
-rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 14 | ||||
-rw-r--r-- | llvm/test/CodeGen/NVPTX/convert-sm100a.ll | 82 |
4 files changed, 123 insertions, 1 deletions
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td index 5be1a91..0b26bb9 100644 --- a/llvm/include/llvm/IR/IntrinsicsNVVM.td +++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td @@ -1293,10 +1293,19 @@ let TargetPrefix = "nvvm" in { } } + // FP4 conversions. + foreach relu = ["", "_relu"] in { + def int_nvvm_ff_to_e2m1x2_rn # relu # _satfinite : NVVMBuiltin, + DefaultAttrsIntrinsic<[llvm_i16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem, IntrNoCallback]>; + + def int_nvvm_e2m1x2_to_f16x2_rn # relu : NVVMBuiltin, + DefaultAttrsIntrinsic<[llvm_v2f16_ty], [llvm_i16_ty], [IntrNoMem, IntrNoCallback]>; + } + // UE8M0x2 conversions. foreach rmode = ["_rz", "_rp"] in { foreach satmode = ["", "_satfinite"] in { - defvar suffix = !strconcat(rmode, satmode); + defvar suffix = rmode # satmode; def int_nvvm_ff_to_ue8m0x2 # suffix : NVVMBuiltin, DefaultAttrsIntrinsic<[llvm_i16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem, IntrNoCallback]>; diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td index b6104a5..2c65ee6 100644 --- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td @@ -714,6 +714,23 @@ let hasSideEffects = false in { # type # " \t$dst, $src;", []>; } + // FP4 conversions. + def CVT_e2m1x2_f32_sf : NVPTXInst<(outs Int16Regs:$dst), + (ins Float32Regs:$src1, Float32Regs:$src2, CvtMode:$mode), + !strconcat("{{ \n\t", + ".reg .b8 \t%e2m1x2_out; \n\t", + "cvt${mode:base}.satfinite${mode:relu}.e2m1x2.f32 \t%e2m1x2_out, $src1, $src2; \n\t", + "cvt.u16.u8 \t$dst, %e2m1x2_out; \n\t", + "}}"), []>; + + def CVT_f16x2_e2m1x2 : NVPTXInst<(outs Int32Regs:$dst), + (ins Int16Regs:$src, CvtMode:$mode), + !strconcat("{{ \n\t", + ".reg .b8 \t%e2m1x2_in; \n\t", + "cvt.u8.u16 \t%e2m1x2_in, $src; \n\t", + "cvt${mode:base}${mode:relu}.f16x2.e2m1x2 \t$dst, %e2m1x2_in; \n\t", + "}}"), []>; + // UE8M0x2 conversions. class CVT_f32_to_ue8m0x2<string sat = ""> : NVPTXInst<(outs Int16Regs:$dst), diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td index 8110ba1..d3cfce7 100644 --- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -2003,6 +2003,20 @@ def : Pat<(int_nvvm_e3m2x2_to_f16x2_rn i16:$a), def : Pat<(int_nvvm_e3m2x2_to_f16x2_rn_relu i16:$a), (CVT_f16x2_e3m2x2 $a, CvtRN_RELU)>, Requires<[hasPTX<86>, hasSM<100>, hasArchAccelFeatures]>; + +def : Pat<(int_nvvm_ff_to_e2m1x2_rn_satfinite f32:$a, f32:$b), + (CVT_e2m1x2_f32_sf $a, $b, CvtRN)>, + Requires<[hasPTX<86>, hasSM<100>, hasArchAccelFeatures]>; +def : Pat<(int_nvvm_ff_to_e2m1x2_rn_relu_satfinite f32:$a, f32:$b), + (CVT_e2m1x2_f32_sf $a, $b, CvtRN_RELU)>, + Requires<[hasPTX<86>, hasSM<100>, hasArchAccelFeatures]>; + +def : Pat<(int_nvvm_e2m1x2_to_f16x2_rn Int16Regs:$a), + (CVT_f16x2_e2m1x2 $a, CvtRN)>, + Requires<[hasPTX<86>, hasSM<100>, hasArchAccelFeatures]>; +def : Pat<(int_nvvm_e2m1x2_to_f16x2_rn_relu Int16Regs:$a), + (CVT_f16x2_e2m1x2 $a, CvtRN_RELU)>, + Requires<[hasPTX<86>, hasSM<100>, hasArchAccelFeatures]>; def : Pat<(int_nvvm_ff_to_ue8m0x2_rz f32:$a, f32:$b), (CVT_ue8m0x2_f32 $a, $b, CvtRZ)>, diff --git a/llvm/test/CodeGen/NVPTX/convert-sm100a.ll b/llvm/test/CodeGen/NVPTX/convert-sm100a.ll index def2575..9acbb79 100644 --- a/llvm/test/CodeGen/NVPTX/convert-sm100a.ll +++ b/llvm/test/CodeGen/NVPTX/convert-sm100a.ll @@ -288,3 +288,85 @@ define <2 x bfloat> @cvt_bf16x2_ue8m0x2(i16 %in) { %val = call <2 x bfloat> @llvm.nvvm.ue8m0x2.to.bf16x2(i16 %in) ret <2 x bfloat> %val } + +define i16 @cvt_rn_sf_e2m1x2_f32(float %f1, float %f2) { +; CHECK-LABEL: cvt_rn_sf_e2m1x2_f32( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<2>; +; CHECK-NEXT: .reg .b32 %r<2>; +; CHECK-NEXT: .reg .b32 %f<3>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b32 %f1, [cvt_rn_sf_e2m1x2_f32_param_0]; +; CHECK-NEXT: ld.param.b32 %f2, [cvt_rn_sf_e2m1x2_f32_param_1]; +; CHECK-NEXT: { +; CHECK-NEXT: .reg .b8 %e2m1x2_out; +; CHECK-NEXT: cvt.rn.satfinite.e2m1x2.f32 %e2m1x2_out, %f1, %f2; +; CHECK-NEXT: cvt.u16.u8 %rs1, %e2m1x2_out; +; CHECK-NEXT: } +; CHECK-NEXT: cvt.u32.u16 %r1, %rs1; +; CHECK-NEXT: st.param.b32 [func_retval0], %r1; +; CHECK-NEXT: ret; + %val = call i16 @llvm.nvvm.ff.to.e2m1x2.rn.satfinite(float %f1, float %f2) + ret i16 %val +} + +define i16 @cvt_rn_relu_sf_e2m1x2_f32(float %f1, float %f2) { +; CHECK-LABEL: cvt_rn_relu_sf_e2m1x2_f32( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<2>; +; CHECK-NEXT: .reg .b32 %r<2>; +; CHECK-NEXT: .reg .b32 %f<3>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b32 %f1, [cvt_rn_relu_sf_e2m1x2_f32_param_0]; +; CHECK-NEXT: ld.param.b32 %f2, [cvt_rn_relu_sf_e2m1x2_f32_param_1]; +; CHECK-NEXT: { +; CHECK-NEXT: .reg .b8 %e2m1x2_out; +; CHECK-NEXT: cvt.rn.satfinite.relu.e2m1x2.f32 %e2m1x2_out, %f1, %f2; +; CHECK-NEXT: cvt.u16.u8 %rs1, %e2m1x2_out; +; CHECK-NEXT: } +; CHECK-NEXT: cvt.u32.u16 %r1, %rs1; +; CHECK-NEXT: st.param.b32 [func_retval0], %r1; +; CHECK-NEXT: ret; + %val = call i16 @llvm.nvvm.ff.to.e2m1x2.rn.relu.satfinite(float %f1, float %f2) + ret i16 %val +} + +define <2 x half> @cvt_rn_f16x2_e2m1x2(i16 %in) { +; CHECK-LABEL: cvt_rn_f16x2_e2m1x2( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<2>; +; CHECK-NEXT: .reg .b32 %r<2>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b16 %rs1, [cvt_rn_f16x2_e2m1x2_param_0]; +; CHECK-NEXT: { +; CHECK-NEXT: .reg .b8 %e2m1x2_in; +; CHECK-NEXT: cvt.u8.u16 %e2m1x2_in, %rs1; +; CHECK-NEXT: cvt.rn.f16x2.e2m1x2 %r1, %e2m1x2_in; +; CHECK-NEXT: } +; CHECK-NEXT: st.param.b32 [func_retval0], %r1; +; CHECK-NEXT: ret; + %val = call <2 x half> @llvm.nvvm.e2m1x2.to.f16x2.rn(i16 %in) + ret <2 x half> %val +} + +define <2 x half> @cvt_rn_relu_f16x2_e2m1x2(i16 %in) { +; CHECK-LABEL: cvt_rn_relu_f16x2_e2m1x2( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<2>; +; CHECK-NEXT: .reg .b32 %r<2>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b16 %rs1, [cvt_rn_relu_f16x2_e2m1x2_param_0]; +; CHECK-NEXT: { +; CHECK-NEXT: .reg .b8 %e2m1x2_in; +; CHECK-NEXT: cvt.u8.u16 %e2m1x2_in, %rs1; +; CHECK-NEXT: cvt.rn.relu.f16x2.e2m1x2 %r1, %e2m1x2_in; +; CHECK-NEXT: } +; CHECK-NEXT: st.param.b32 [func_retval0], %r1; +; CHECK-NEXT: ret; + %val = call <2 x half> @llvm.nvvm.e2m1x2.to.f16x2.rn.relu(i16 %in) + ret <2 x half> %val +} |