aboutsummaryrefslogtreecommitdiff
path: root/llvm
diff options
context:
space:
mode:
Diffstat (limited to 'llvm')
-rw-r--r--llvm/include/llvm/IR/IntrinsicsNVVM.td11
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXInstrInfo.td17
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXIntrinsics.td14
-rw-r--r--llvm/test/CodeGen/NVPTX/convert-sm100a.ll82
4 files changed, 123 insertions, 1 deletions
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 5be1a91..0b26bb9 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -1293,10 +1293,19 @@ let TargetPrefix = "nvvm" in {
}
}
+ // FP4 conversions.
+ foreach relu = ["", "_relu"] in {
+ def int_nvvm_ff_to_e2m1x2_rn # relu # _satfinite : NVVMBuiltin,
+ DefaultAttrsIntrinsic<[llvm_i16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem, IntrNoCallback]>;
+
+ def int_nvvm_e2m1x2_to_f16x2_rn # relu : NVVMBuiltin,
+ DefaultAttrsIntrinsic<[llvm_v2f16_ty], [llvm_i16_ty], [IntrNoMem, IntrNoCallback]>;
+ }
+
// UE8M0x2 conversions.
foreach rmode = ["_rz", "_rp"] in {
foreach satmode = ["", "_satfinite"] in {
- defvar suffix = !strconcat(rmode, satmode);
+ defvar suffix = rmode # satmode;
def int_nvvm_ff_to_ue8m0x2 # suffix : NVVMBuiltin,
DefaultAttrsIntrinsic<[llvm_i16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem, IntrNoCallback]>;
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index b6104a5..2c65ee6 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -714,6 +714,23 @@ let hasSideEffects = false in {
# type # " \t$dst, $src;", []>;
}
+ // FP4 conversions.
+ def CVT_e2m1x2_f32_sf : NVPTXInst<(outs Int16Regs:$dst),
+ (ins Float32Regs:$src1, Float32Regs:$src2, CvtMode:$mode),
+ !strconcat("{{ \n\t",
+ ".reg .b8 \t%e2m1x2_out; \n\t",
+ "cvt${mode:base}.satfinite${mode:relu}.e2m1x2.f32 \t%e2m1x2_out, $src1, $src2; \n\t",
+ "cvt.u16.u8 \t$dst, %e2m1x2_out; \n\t",
+ "}}"), []>;
+
+ def CVT_f16x2_e2m1x2 : NVPTXInst<(outs Int32Regs:$dst),
+ (ins Int16Regs:$src, CvtMode:$mode),
+ !strconcat("{{ \n\t",
+ ".reg .b8 \t%e2m1x2_in; \n\t",
+ "cvt.u8.u16 \t%e2m1x2_in, $src; \n\t",
+ "cvt${mode:base}${mode:relu}.f16x2.e2m1x2 \t$dst, %e2m1x2_in; \n\t",
+ "}}"), []>;
+
// UE8M0x2 conversions.
class CVT_f32_to_ue8m0x2<string sat = ""> :
NVPTXInst<(outs Int16Regs:$dst),
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 8110ba1..d3cfce7 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -2003,6 +2003,20 @@ def : Pat<(int_nvvm_e3m2x2_to_f16x2_rn i16:$a),
def : Pat<(int_nvvm_e3m2x2_to_f16x2_rn_relu i16:$a),
(CVT_f16x2_e3m2x2 $a, CvtRN_RELU)>,
Requires<[hasPTX<86>, hasSM<100>, hasArchAccelFeatures]>;
+
+def : Pat<(int_nvvm_ff_to_e2m1x2_rn_satfinite f32:$a, f32:$b),
+ (CVT_e2m1x2_f32_sf $a, $b, CvtRN)>,
+ Requires<[hasPTX<86>, hasSM<100>, hasArchAccelFeatures]>;
+def : Pat<(int_nvvm_ff_to_e2m1x2_rn_relu_satfinite f32:$a, f32:$b),
+ (CVT_e2m1x2_f32_sf $a, $b, CvtRN_RELU)>,
+ Requires<[hasPTX<86>, hasSM<100>, hasArchAccelFeatures]>;
+
+def : Pat<(int_nvvm_e2m1x2_to_f16x2_rn Int16Regs:$a),
+ (CVT_f16x2_e2m1x2 $a, CvtRN)>,
+ Requires<[hasPTX<86>, hasSM<100>, hasArchAccelFeatures]>;
+def : Pat<(int_nvvm_e2m1x2_to_f16x2_rn_relu Int16Regs:$a),
+ (CVT_f16x2_e2m1x2 $a, CvtRN_RELU)>,
+ Requires<[hasPTX<86>, hasSM<100>, hasArchAccelFeatures]>;
def : Pat<(int_nvvm_ff_to_ue8m0x2_rz f32:$a, f32:$b),
(CVT_ue8m0x2_f32 $a, $b, CvtRZ)>,
diff --git a/llvm/test/CodeGen/NVPTX/convert-sm100a.ll b/llvm/test/CodeGen/NVPTX/convert-sm100a.ll
index def2575..9acbb79 100644
--- a/llvm/test/CodeGen/NVPTX/convert-sm100a.ll
+++ b/llvm/test/CodeGen/NVPTX/convert-sm100a.ll
@@ -288,3 +288,85 @@ define <2 x bfloat> @cvt_bf16x2_ue8m0x2(i16 %in) {
%val = call <2 x bfloat> @llvm.nvvm.ue8m0x2.to.bf16x2(i16 %in)
ret <2 x bfloat> %val
}
+
+define i16 @cvt_rn_sf_e2m1x2_f32(float %f1, float %f2) {
+; CHECK-LABEL: cvt_rn_sf_e2m1x2_f32(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<2>;
+; CHECK-NEXT: .reg .b32 %r<2>;
+; CHECK-NEXT: .reg .b32 %f<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b32 %f1, [cvt_rn_sf_e2m1x2_f32_param_0];
+; CHECK-NEXT: ld.param.b32 %f2, [cvt_rn_sf_e2m1x2_f32_param_1];
+; CHECK-NEXT: {
+; CHECK-NEXT: .reg .b8 %e2m1x2_out;
+; CHECK-NEXT: cvt.rn.satfinite.e2m1x2.f32 %e2m1x2_out, %f1, %f2;
+; CHECK-NEXT: cvt.u16.u8 %rs1, %e2m1x2_out;
+; CHECK-NEXT: }
+; CHECK-NEXT: cvt.u32.u16 %r1, %rs1;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r1;
+; CHECK-NEXT: ret;
+ %val = call i16 @llvm.nvvm.ff.to.e2m1x2.rn.satfinite(float %f1, float %f2)
+ ret i16 %val
+}
+
+define i16 @cvt_rn_relu_sf_e2m1x2_f32(float %f1, float %f2) {
+; CHECK-LABEL: cvt_rn_relu_sf_e2m1x2_f32(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<2>;
+; CHECK-NEXT: .reg .b32 %r<2>;
+; CHECK-NEXT: .reg .b32 %f<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b32 %f1, [cvt_rn_relu_sf_e2m1x2_f32_param_0];
+; CHECK-NEXT: ld.param.b32 %f2, [cvt_rn_relu_sf_e2m1x2_f32_param_1];
+; CHECK-NEXT: {
+; CHECK-NEXT: .reg .b8 %e2m1x2_out;
+; CHECK-NEXT: cvt.rn.satfinite.relu.e2m1x2.f32 %e2m1x2_out, %f1, %f2;
+; CHECK-NEXT: cvt.u16.u8 %rs1, %e2m1x2_out;
+; CHECK-NEXT: }
+; CHECK-NEXT: cvt.u32.u16 %r1, %rs1;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r1;
+; CHECK-NEXT: ret;
+ %val = call i16 @llvm.nvvm.ff.to.e2m1x2.rn.relu.satfinite(float %f1, float %f2)
+ ret i16 %val
+}
+
+define <2 x half> @cvt_rn_f16x2_e2m1x2(i16 %in) {
+; CHECK-LABEL: cvt_rn_f16x2_e2m1x2(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<2>;
+; CHECK-NEXT: .reg .b32 %r<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b16 %rs1, [cvt_rn_f16x2_e2m1x2_param_0];
+; CHECK-NEXT: {
+; CHECK-NEXT: .reg .b8 %e2m1x2_in;
+; CHECK-NEXT: cvt.u8.u16 %e2m1x2_in, %rs1;
+; CHECK-NEXT: cvt.rn.f16x2.e2m1x2 %r1, %e2m1x2_in;
+; CHECK-NEXT: }
+; CHECK-NEXT: st.param.b32 [func_retval0], %r1;
+; CHECK-NEXT: ret;
+ %val = call <2 x half> @llvm.nvvm.e2m1x2.to.f16x2.rn(i16 %in)
+ ret <2 x half> %val
+}
+
+define <2 x half> @cvt_rn_relu_f16x2_e2m1x2(i16 %in) {
+; CHECK-LABEL: cvt_rn_relu_f16x2_e2m1x2(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<2>;
+; CHECK-NEXT: .reg .b32 %r<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b16 %rs1, [cvt_rn_relu_f16x2_e2m1x2_param_0];
+; CHECK-NEXT: {
+; CHECK-NEXT: .reg .b8 %e2m1x2_in;
+; CHECK-NEXT: cvt.u8.u16 %e2m1x2_in, %rs1;
+; CHECK-NEXT: cvt.rn.relu.f16x2.e2m1x2 %r1, %e2m1x2_in;
+; CHECK-NEXT: }
+; CHECK-NEXT: st.param.b32 [func_retval0], %r1;
+; CHECK-NEXT: ret;
+ %val = call <2 x half> @llvm.nvvm.e2m1x2.to.f16x2.rn.relu(i16 %in)
+ ret <2 x half> %val
+}