aboutsummaryrefslogtreecommitdiff
path: root/llvm/test/CodeGen/NVPTX/wmma.py
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/test/CodeGen/NVPTX/wmma.py')
-rw-r--r--llvm/test/CodeGen/NVPTX/wmma.py115
1 files changed, 104 insertions, 11 deletions
diff --git a/llvm/test/CodeGen/NVPTX/wmma.py b/llvm/test/CodeGen/NVPTX/wmma.py
index 6d73bce..8427ae4 100644
--- a/llvm/test/CodeGen/NVPTX/wmma.py
+++ b/llvm/test/CodeGen/NVPTX/wmma.py
@@ -90,6 +90,21 @@ class MMAFrag:
"m16n8k32:b:s8": 2,
"m16n8k32:c:s32": 4,
"m16n8k32:d:s32": 4,
+ # e4m3/e5m2/e3m2/e2m3/e2m1 -> f16/f32 @ m16n8k16/m16n8k32
+ "m16n8k16:a:e4m3": 2,
+ "m16n8k16:a:e5m2": 2,
+ "m16n8k32:a:e4m3": 4,
+ "m16n8k32:a:e5m2": 4,
+ "m16n8k32:a:e3m2": 4,
+ "m16n8k32:a:e2m3": 4,
+ "m16n8k32:a:e2m1": 4,
+ "m16n8k16:b:e4m3": 1,
+ "m16n8k16:b:e5m2": 1,
+ "m16n8k32:b:e4m3": 2,
+ "m16n8k32:b:e5m2": 2,
+ "m16n8k32:b:e3m2": 2,
+ "m16n8k32:b:e2m3": 2,
+ "m16n8k32:b:e2m1": 2,
# mma sp
"m16n8k32:a:bf16": 4,
"m16n8k32:a:f16": 4,
@@ -182,6 +197,18 @@ class MMAFrag:
"m8n8k4:b:f64": 1,
"m8n8k4:c:f64": 2,
"m8n8k4:d:f64": 2,
+ "m16n8k4:a:f64": 2,
+ "m16n8k4:b:f64": 1,
+ "m16n8k4:c:f64": 4,
+ "m16n8k4:d:f64": 4,
+ "m16n8k8:a:f64": 4,
+ "m16n8k8:b:f64": 2,
+ "m16n8k8:c:f64": 4,
+ "m16n8k8:d:f64": 4,
+ "m16n8k16:a:f64": 8,
+ "m16n8k16:b:f64": 4,
+ "m16n8k16:c:f64": 4,
+ "m16n8k16:d:f64": 4,
# tf32 -> s32 @ m16n16k8
"m16n16k8:a:tf32": 4,
"m16n16k8:b:tf32": 4,
@@ -324,7 +351,9 @@ def get_wmma_ops():
def get_mma_ops():
return (
- make_mma_ops(["m8n8k4"], ["f64"], [], ["f64"], [])
+ make_mma_ops(
+ ["m8n8k4", "m16n8k4", "m16n8k8", "m16n8k16"], ["f64"], [], ["f64"], []
+ )
+ make_mma_ops(["m16n8k4", "m16n8k8"], ["tf32"], [], ["f32"], [])
+ make_mma_ops(["m16n8k16", "m16n8k8"], ["bf16"], [], ["f32"], [])
+ make_mma_ops(
@@ -341,6 +370,20 @@ def get_mma_ops():
["m8n8k32", "m16n8k32", "m16n8k64"], ["s4", "u4"], ["s4", "u4"], ["s32"], []
)
+ make_mma_ops(["m8n8k128", "m16n8k128", "m16n8k256"], ["b1"], [], ["s32"], [])
+ + make_mma_ops(
+ ["m16n8k16"],
+ ["e4m3", "e5m2"],
+ ["e4m3", "e5m2"],
+ ["f16", "f32"],
+ ["f16", "f32"],
+ )
+ + make_mma_ops(
+ ["m16n8k32"],
+ ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+ ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+ ["f16", "f32"],
+ ["f16", "f32"],
+ )
)
@@ -492,7 +535,7 @@ def is_wmma_variant_supported(op, layout_a, layout_b, rnd, satf):
return True
-def is_mma_variant_supported(op, layout_a, layout_b, satf):
+def is_mma_variant_supported(op, layout_a, layout_b, kind, satf):
if not (
is_type_supported(op.a.mma_type.ptx_type) and is_mma_geom_supported(op.a.geom)
):
@@ -516,13 +559,53 @@ def is_mma_variant_supported(op, layout_a, layout_b, satf):
):
return False
+ if (
+ op.a.geom != "m8n8k4"
+ and op.a.mma_type.ptx_type == "f64"
+ and (ptx_version < 78 or gpu_arch < 90)
+ ):
+ return False
+
# C and D type must be the same
- if op.a.geom == "m16n8k16" and op.c.mma_type.ptx_type != op.d.mma_type.ptx_type:
+ if (
+ op.a.geom in ["m16n8k16", "m16n8k32"]
+ and op.c.mma_type.ptx_type != op.d.mma_type.ptx_type
+ ):
+ return False
+
+ if (
+ op.a.geom in ["m16n8k16", "m16n8k32"]
+ and any(
+ x in ["e4m3", "e5m2"]
+ for x in (op.a.mma_type.ptx_type, op.b.mma_type.ptx_type)
+ )
+ and ptx_version < 87
+ ):
+ return False
+
+ if kind != "" and not (ptx_version >= 87 and gpu_arch >= 120 and aa):
+ return False
+
+ if kind != "" and (
+ op.a.geom != "m16n8k32"
+ or op.a.mma_type.ptx_type not in ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"]
+ ):
+ return False
+
+ if (
+ kind == ""
+ and op.a.geom in ["m16n8k16", "m16n8k32"]
+ and any(
+ x in ["e3m2", "e2m3", "e2m1"]
+ for x in (op.a.mma_type.ptx_type, op.b.mma_type.ptx_type)
+ )
+ ):
return False
# Require row/col layout for all MMA except m8n8k4 on FP16
if not (op.a.geom == "m8n8k4" and op.a.mma_type.ptx_type == "f16"):
return layout_a == "row" and layout_b == "col"
+
return True
@@ -937,7 +1020,12 @@ define ${ret_ty} @test_${function}(
"""
test_params = params
- test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
+ test_params["intrinsic"] = (
+ Template(intrinsic_template)
+ .substitute(params)
+ .replace("::", ".")
+ .replace("_", ".")
+ )
test_params["function"] = test_params["intrinsic"].replace(".", "_")
test_params["instruction"] = Template(instruction_template).substitute(params)
test_params["ret_ty"] = make_wmma_ld_ret_ty(op.d)
@@ -1002,16 +1090,20 @@ def gen_wmma_mma_tests():
def gen_mma_tests():
- mma_intrinsic_template = "llvm.nvvm.mma${b1op}.${geom}.${alayout}.${blayout}${satf}.${intrinsic_signature}"
- mma_instruction_template = "mma.sync${aligned}.${geom}.${alayout}.${blayout}${satf}.${ptx_signature}${b1op}"
+ mma_intrinsic_template = "llvm.nvvm.mma${b1op}.${geom}.${alayout}.${blayout}${kind}${satf}.${intrinsic_signature}"
+ mma_instruction_template = "mma.sync${aligned}.${geom}.${alayout}.${blayout}${kind}${satf}.${ptx_signature}${b1op}"
generated_items = []
- for op, alayout, blayout, satf in product(
- get_mma_ops(), ["row", "col"], ["row", "col"], [".satfinite", ""]
+ for op, alayout, blayout, kind, satf in product(
+ get_mma_ops(),
+ ["row", "col"],
+ ["row", "col"],
+ ["", ".kind::f8f6f4"],
+ [".satfinite", ""],
):
- if not is_mma_variant_supported(op, alayout, blayout, satf):
+ if not is_mma_variant_supported(op, alayout, blayout, kind, satf):
continue
for b1op in get_b1_ops(op.a.mma_type.ptx_type):
@@ -1024,6 +1116,7 @@ def gen_mma_tests():
"satf": satf,
"geom": op.a.geom,
"b1op": b1op,
+ "kind": kind,
}
intrinsic_template = mma_intrinsic_template
@@ -1105,9 +1198,9 @@ def is_mma_sp_variant_supported(op, metadata, kind, satf):
):
return False
- # C and D type must be the same for m16n8k16/m16n8k32
+ # C and D type must be the same for m16n8k16/m16n8k32/m16n8k64
if (
- op.a.geom in ["m16n8k16", "m16n8k32"]
+ op.a.geom in ["m16n8k16", "m16n8k32", "m16n8k64"]
and op.c.mma_type.ptx_type != op.d.mma_type.ptx_type
):
return False