diff options
Diffstat (limited to 'llvm/test/CodeGen/NVPTX/wmma.py')
-rw-r--r-- | llvm/test/CodeGen/NVPTX/wmma.py | 115 |
1 files changed, 104 insertions, 11 deletions
diff --git a/llvm/test/CodeGen/NVPTX/wmma.py b/llvm/test/CodeGen/NVPTX/wmma.py index 6d73bce..8427ae4 100644 --- a/llvm/test/CodeGen/NVPTX/wmma.py +++ b/llvm/test/CodeGen/NVPTX/wmma.py @@ -90,6 +90,21 @@ class MMAFrag: "m16n8k32:b:s8": 2, "m16n8k32:c:s32": 4, "m16n8k32:d:s32": 4, + # e4m3/e5m2/e3m2/e2m3/e2m1 -> f16/f32 @ m16n8k16/m16n8k32 + "m16n8k16:a:e4m3": 2, + "m16n8k16:a:e5m2": 2, + "m16n8k32:a:e4m3": 4, + "m16n8k32:a:e5m2": 4, + "m16n8k32:a:e3m2": 4, + "m16n8k32:a:e2m3": 4, + "m16n8k32:a:e2m1": 4, + "m16n8k16:b:e4m3": 1, + "m16n8k16:b:e5m2": 1, + "m16n8k32:b:e4m3": 2, + "m16n8k32:b:e5m2": 2, + "m16n8k32:b:e3m2": 2, + "m16n8k32:b:e2m3": 2, + "m16n8k32:b:e2m1": 2, # mma sp "m16n8k32:a:bf16": 4, "m16n8k32:a:f16": 4, @@ -182,6 +197,18 @@ class MMAFrag: "m8n8k4:b:f64": 1, "m8n8k4:c:f64": 2, "m8n8k4:d:f64": 2, + "m16n8k4:a:f64": 2, + "m16n8k4:b:f64": 1, + "m16n8k4:c:f64": 4, + "m16n8k4:d:f64": 4, + "m16n8k8:a:f64": 4, + "m16n8k8:b:f64": 2, + "m16n8k8:c:f64": 4, + "m16n8k8:d:f64": 4, + "m16n8k16:a:f64": 8, + "m16n8k16:b:f64": 4, + "m16n8k16:c:f64": 4, + "m16n8k16:d:f64": 4, # tf32 -> s32 @ m16n16k8 "m16n16k8:a:tf32": 4, "m16n16k8:b:tf32": 4, @@ -324,7 +351,9 @@ def get_wmma_ops(): def get_mma_ops(): return ( - make_mma_ops(["m8n8k4"], ["f64"], [], ["f64"], []) + make_mma_ops( + ["m8n8k4", "m16n8k4", "m16n8k8", "m16n8k16"], ["f64"], [], ["f64"], [] + ) + make_mma_ops(["m16n8k4", "m16n8k8"], ["tf32"], [], ["f32"], []) + make_mma_ops(["m16n8k16", "m16n8k8"], ["bf16"], [], ["f32"], []) + make_mma_ops( @@ -341,6 +370,20 @@ def get_mma_ops(): ["m8n8k32", "m16n8k32", "m16n8k64"], ["s4", "u4"], ["s4", "u4"], ["s32"], [] ) + make_mma_ops(["m8n8k128", "m16n8k128", "m16n8k256"], ["b1"], [], ["s32"], []) + + make_mma_ops( + ["m16n8k16"], + ["e4m3", "e5m2"], + ["e4m3", "e5m2"], + ["f16", "f32"], + ["f16", "f32"], + ) + + make_mma_ops( + ["m16n8k32"], + ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"], + ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"], + ["f16", "f32"], + ["f16", "f32"], + ) ) @@ -492,7 +535,7 @@ def is_wmma_variant_supported(op, layout_a, layout_b, rnd, satf): return True -def is_mma_variant_supported(op, layout_a, layout_b, satf): +def is_mma_variant_supported(op, layout_a, layout_b, kind, satf): if not ( is_type_supported(op.a.mma_type.ptx_type) and is_mma_geom_supported(op.a.geom) ): @@ -516,13 +559,53 @@ def is_mma_variant_supported(op, layout_a, layout_b, satf): ): return False + if ( + op.a.geom != "m8n8k4" + and op.a.mma_type.ptx_type == "f64" + and (ptx_version < 78 or gpu_arch < 90) + ): + return False + # C and D type must be the same - if op.a.geom == "m16n8k16" and op.c.mma_type.ptx_type != op.d.mma_type.ptx_type: + if ( + op.a.geom in ["m16n8k16", "m16n8k32"] + and op.c.mma_type.ptx_type != op.d.mma_type.ptx_type + ): + return False + + if ( + op.a.geom in ["m16n8k16", "m16n8k32"] + and any( + x in ["e4m3", "e5m2"] + for x in (op.a.mma_type.ptx_type, op.b.mma_type.ptx_type) + ) + and ptx_version < 87 + ): + return False + + if kind != "" and not (ptx_version >= 87 and gpu_arch >= 120 and aa): + return False + + if kind != "" and ( + op.a.geom != "m16n8k32" + or op.a.mma_type.ptx_type not in ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"] + ): + return False + + if ( + kind == "" + and op.a.geom in ["m16n8k16", "m16n8k32"] + and any( + x in ["e3m2", "e2m3", "e2m1"] + for x in (op.a.mma_type.ptx_type, op.b.mma_type.ptx_type) + ) + ): return False # Require row/col layout for all MMA except m8n8k4 on FP16 if not (op.a.geom == "m8n8k4" and op.a.mma_type.ptx_type == "f16"): return layout_a == "row" and layout_b == "col" + return True @@ -937,7 +1020,12 @@ define ${ret_ty} @test_${function}( """ test_params = params - test_params["intrinsic"] = Template(intrinsic_template).substitute(params) + test_params["intrinsic"] = ( + Template(intrinsic_template) + .substitute(params) + .replace("::", ".") + .replace("_", ".") + ) test_params["function"] = test_params["intrinsic"].replace(".", "_") test_params["instruction"] = Template(instruction_template).substitute(params) test_params["ret_ty"] = make_wmma_ld_ret_ty(op.d) @@ -1002,16 +1090,20 @@ def gen_wmma_mma_tests(): def gen_mma_tests(): - mma_intrinsic_template = "llvm.nvvm.mma${b1op}.${geom}.${alayout}.${blayout}${satf}.${intrinsic_signature}" - mma_instruction_template = "mma.sync${aligned}.${geom}.${alayout}.${blayout}${satf}.${ptx_signature}${b1op}" + mma_intrinsic_template = "llvm.nvvm.mma${b1op}.${geom}.${alayout}.${blayout}${kind}${satf}.${intrinsic_signature}" + mma_instruction_template = "mma.sync${aligned}.${geom}.${alayout}.${blayout}${kind}${satf}.${ptx_signature}${b1op}" generated_items = [] - for op, alayout, blayout, satf in product( - get_mma_ops(), ["row", "col"], ["row", "col"], [".satfinite", ""] + for op, alayout, blayout, kind, satf in product( + get_mma_ops(), + ["row", "col"], + ["row", "col"], + ["", ".kind::f8f6f4"], + [".satfinite", ""], ): - if not is_mma_variant_supported(op, alayout, blayout, satf): + if not is_mma_variant_supported(op, alayout, blayout, kind, satf): continue for b1op in get_b1_ops(op.a.mma_type.ptx_type): @@ -1024,6 +1116,7 @@ def gen_mma_tests(): "satf": satf, "geom": op.a.geom, "b1op": b1op, + "kind": kind, } intrinsic_template = mma_intrinsic_template @@ -1105,9 +1198,9 @@ def is_mma_sp_variant_supported(op, metadata, kind, satf): ): return False - # C and D type must be the same for m16n8k16/m16n8k32 + # C and D type must be the same for m16n8k16/m16n8k32/m16n8k64 if ( - op.a.geom in ["m16n8k16", "m16n8k32"] + op.a.geom in ["m16n8k16", "m16n8k32", "m16n8k64"] and op.c.mma_type.ptx_type != op.d.mma_type.ptx_type ): return False |