diff options
Diffstat (limited to 'llvm/test/CodeGen/NVPTX/wmma.py')
-rw-r--r-- | llvm/test/CodeGen/NVPTX/wmma.py | 129 |
1 files changed, 129 insertions, 0 deletions
diff --git a/llvm/test/CodeGen/NVPTX/wmma.py b/llvm/test/CodeGen/NVPTX/wmma.py index 2ee4896..2eb3c3d 100644 --- a/llvm/test/CodeGen/NVPTX/wmma.py +++ b/llvm/test/CodeGen/NVPTX/wmma.py @@ -10,6 +10,7 @@ import argparse from itertools import product from string import Template + class MMAType: def __init__(self, ptx_type): self.ptx_type = ptx_type @@ -176,6 +177,13 @@ class MMAFrag: "m8n16:x1:b8x16.b4x16_p64": 1, "m8n16:x2:b8x16.b4x16_p64": 2, "m8n16:x4:b8x16.b4x16_p64": 4, + # stmatrix + "m8n8:x1:b16": 1, + "m8n8:x2:b16": 2, + "m8n8:x4:b16": 4, + "m16n8:x1:b8": 1, + "m16n8:x2:b8": 2, + "m16n8:x4:b8": 4, }.get( "%s:%s:%s" % (geom, frag, ptx_elt_type), { @@ -241,6 +249,13 @@ def make_ldmatrix_ops(geoms, frags, types): ] +def make_stmatrix_ops(geoms, frags, types): + return [ + MMAFrag(geom, frag, ptx_type) + for (geom, frag, ptx_type) in product(geoms, frags, types) + ] + + def get_wmma_ops(): return ( make_mma_ops(["m16n16k8"], ["tf32"], [], ["f32"], []) @@ -315,6 +330,12 @@ def get_ldmatrix_ops(): ) +def get_stmatrix_ops(): + return make_stmatrix_ops(["m8n8"], ["x1", "x2", "x4"], ["b16"]) + make_stmatrix_ops( + ["m16n8"], ["x1", "x2", "x4"], ["b8"] + ) + + def is_wmma_geom_supported(geom): # geometries for FP and ints. if geom in ["m8n32k16", "m32n8k16"]: @@ -360,6 +381,14 @@ def is_ldmatrix_geom_supported(geom): assert False # Unexpected geometry. +def is_stmatrix_geom_supported(geom): + if geom in ["m8n8"]: + return ptx_version >= 78 and gpu_arch >= 90 + elif geom in ["m16n8"]: + return ptx_version >= 86 and gpu_arch >= 100 and aa + assert False # Unexpected geometry. + + def is_ldmatrix_trans_supported(geom, trans): if geom in ["m8n8"]: return True @@ -369,6 +398,15 @@ def is_ldmatrix_trans_supported(geom, trans): return trans == "" assert False # Unexpected geometry. + +def is_stmatrix_trans_supported(geom, trans): + if geom in ["m8n8"]: + return True + elif geom in ["m16n8"]: + return trans == ".trans" + assert False # Unexpected geometry. + + def is_type_supported(ptx_type): if ptx_type in ["s8", "u8", "s32"]: return ptx_version >= 63 and gpu_arch >= 72 @@ -463,6 +501,16 @@ def is_ldmatrix_variant_supported(frag, trans): return frag.frag in ["x1", "x2", "x4"] +def is_stmatrix_variant_supported(frag, trans): + if not ( + is_type_supported(frag.mma_type.ptx_type) + and is_stmatrix_geom_supported(frag.geom) + and is_stmatrix_trans_supported(frag.geom, trans) + ): + return False + return frag.frag in ["x1", "x2", "x4"] + + def make_wmma_slice_ty(frag): return [frag.mma_type.llvm_type] * frag.nregs @@ -717,6 +765,65 @@ define ${ret_ty} @test_${function}_o(i8 ${as}* %src) { return generated_items +def gen_stmatrix_tests(): + stmatrix_template = """ +declare void @${intrinsic}(i8 ${as}* %dst, ${args}); + +; CHECK-LABEL: .func {{.*}}test_${function}( +define void @test_${function}(i8 ${as}* %dst, ${args}) { +; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}] +; CHECK: {${check_args}} + call void @${intrinsic}(i8${as}* %dst, ${args}); + ret void +} + +; CHECK-LABEL: .func{{.*}}test_${function}_o( +define void @test_${function}_o(i8 ${as}* %dst, ${args}) { +; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}+128], +; CHECK: {${check_args}} + %dst1 = getelementptr i8, i8 ${as}* %dst, i32 128; + call void @${intrinsic}(i8 ${as}* %dst1, ${args}); + ret void +} +""" + intrinsic_template = ( + "llvm.nvvm.stmatrix.sync.aligned.${geom}.${frag}${trans}.${itype}.${pspace}" + ) + instruction_template = ( + "stmatrix.sync.aligned.${geom}.${frag}${trans}${space}.${itype}" + ) + generated_items = [] + + for frag, space, trans in product( + get_stmatrix_ops(), + ["", ".shared"], + ["", ".trans"], + ): + if not is_stmatrix_variant_supported(frag, trans): + continue + + params = { + "frag": frag.frag, + "space": space, + "trans": trans, + "itype": frag.mma_type.ptx_type, + "pspace": get_pspace(space), + "as": "addrspace(%d)" % get_aspace(space), + "geom": frag.geom, + } + + test_params = params + test_params["intrinsic"] = Template(intrinsic_template).substitute(params) + test_params["function"] = test_params["intrinsic"].replace(".", "_") + test_params["instruction"] = Template(instruction_template).substitute(params) + test_params["args"] = make_wmma_slice_args(frag) + test_params["check_args"] = check_pattern(frag) + + print(Template(stmatrix_template).substitute(test_params)) + generated_items.append((test_params["intrinsic"], test_params["instruction"])) + + return generated_items + def mma_signature(op): if op.a.mma_type.ptx_type == "f16": # FP16 ops identified by accumulator & result type. @@ -893,6 +1000,7 @@ def gen_check_unsupported_ops(items): ; NOALTFLOAT-NOT: .{{bf16|tf32}} ; NODOUBLE-NOT: .f64 ; NOLDMATRIX-NOT: ldmatrix.sync.aligned +; NOSTMATRIX-NOT: stmatrix.sync.aligned ; M16N16-DAG: m16n16k16.load.{{[ab].*}}.f16.p ; M16N16-DAG: m16n16k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p @@ -994,6 +1102,26 @@ def gen_check_unsupported_ops(items): ; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x4.b8x16.b6x16_p32 ; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x4.b8x16.b4x16_p64 +; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.b16 +; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.b16 +; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.b16 +; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.trans.b16 +; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.trans.b16 +; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.trans.b16 +; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.shared.b16 +; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.shared.b16 +; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.shared.b16 +; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.trans.shared.b16 +; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.trans.shared.b16 +; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.trans.shared.b16 + +; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x1.trans.b8 +; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x2.trans.b8 +; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x4.trans.b8 +; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x1.trans.shared.b8 +; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x2.trans.shared.b8 +; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x4.trans.shared.b8 + ; PTX71MMA-DAG: mma.m8n8k4.row.col.f64 ; PTX71MMA-DAG: mma.m16n8k4.row.col.tf32 ; PTX71MMA-DAG: mma.m16n8k8.row.col.tf32 @@ -1039,6 +1167,7 @@ def gen_tests(): items = gen_wmma_load_tests() items += gen_wmma_store_tests() items += gen_ldmatrix_tests() + items += gen_stmatrix_tests() items += gen_wmma_mma_tests() items += gen_mma_tests() gen_check_unsupported_ops(items) |