aboutsummaryrefslogtreecommitdiff
path: root/llvm/test/CodeGen/NVPTX/wmma.py
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/test/CodeGen/NVPTX/wmma.py')
-rw-r--r--llvm/test/CodeGen/NVPTX/wmma.py129
1 files changed, 129 insertions, 0 deletions
diff --git a/llvm/test/CodeGen/NVPTX/wmma.py b/llvm/test/CodeGen/NVPTX/wmma.py
index 2ee4896..2eb3c3d 100644
--- a/llvm/test/CodeGen/NVPTX/wmma.py
+++ b/llvm/test/CodeGen/NVPTX/wmma.py
@@ -10,6 +10,7 @@ import argparse
from itertools import product
from string import Template
+
class MMAType:
def __init__(self, ptx_type):
self.ptx_type = ptx_type
@@ -176,6 +177,13 @@ class MMAFrag:
"m8n16:x1:b8x16.b4x16_p64": 1,
"m8n16:x2:b8x16.b4x16_p64": 2,
"m8n16:x4:b8x16.b4x16_p64": 4,
+ # stmatrix
+ "m8n8:x1:b16": 1,
+ "m8n8:x2:b16": 2,
+ "m8n8:x4:b16": 4,
+ "m16n8:x1:b8": 1,
+ "m16n8:x2:b8": 2,
+ "m16n8:x4:b8": 4,
}.get(
"%s:%s:%s" % (geom, frag, ptx_elt_type),
{
@@ -241,6 +249,13 @@ def make_ldmatrix_ops(geoms, frags, types):
]
+def make_stmatrix_ops(geoms, frags, types):
+ return [
+ MMAFrag(geom, frag, ptx_type)
+ for (geom, frag, ptx_type) in product(geoms, frags, types)
+ ]
+
+
def get_wmma_ops():
return (
make_mma_ops(["m16n16k8"], ["tf32"], [], ["f32"], [])
@@ -315,6 +330,12 @@ def get_ldmatrix_ops():
)
+def get_stmatrix_ops():
+ return make_stmatrix_ops(["m8n8"], ["x1", "x2", "x4"], ["b16"]) + make_stmatrix_ops(
+ ["m16n8"], ["x1", "x2", "x4"], ["b8"]
+ )
+
+
def is_wmma_geom_supported(geom):
# geometries for FP and ints.
if geom in ["m8n32k16", "m32n8k16"]:
@@ -360,6 +381,14 @@ def is_ldmatrix_geom_supported(geom):
assert False # Unexpected geometry.
+def is_stmatrix_geom_supported(geom):
+ if geom in ["m8n8"]:
+ return ptx_version >= 78 and gpu_arch >= 90
+ elif geom in ["m16n8"]:
+ return ptx_version >= 86 and gpu_arch >= 100 and aa
+ assert False # Unexpected geometry.
+
+
def is_ldmatrix_trans_supported(geom, trans):
if geom in ["m8n8"]:
return True
@@ -369,6 +398,15 @@ def is_ldmatrix_trans_supported(geom, trans):
return trans == ""
assert False # Unexpected geometry.
+
+def is_stmatrix_trans_supported(geom, trans):
+ if geom in ["m8n8"]:
+ return True
+ elif geom in ["m16n8"]:
+ return trans == ".trans"
+ assert False # Unexpected geometry.
+
+
def is_type_supported(ptx_type):
if ptx_type in ["s8", "u8", "s32"]:
return ptx_version >= 63 and gpu_arch >= 72
@@ -463,6 +501,16 @@ def is_ldmatrix_variant_supported(frag, trans):
return frag.frag in ["x1", "x2", "x4"]
+def is_stmatrix_variant_supported(frag, trans):
+ if not (
+ is_type_supported(frag.mma_type.ptx_type)
+ and is_stmatrix_geom_supported(frag.geom)
+ and is_stmatrix_trans_supported(frag.geom, trans)
+ ):
+ return False
+ return frag.frag in ["x1", "x2", "x4"]
+
+
def make_wmma_slice_ty(frag):
return [frag.mma_type.llvm_type] * frag.nregs
@@ -717,6 +765,65 @@ define ${ret_ty} @test_${function}_o(i8 ${as}* %src) {
return generated_items
+def gen_stmatrix_tests():
+ stmatrix_template = """
+declare void @${intrinsic}(i8 ${as}* %dst, ${args});
+
+; CHECK-LABEL: .func {{.*}}test_${function}(
+define void @test_${function}(i8 ${as}* %dst, ${args}) {
+; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}]
+; CHECK: {${check_args}}
+ call void @${intrinsic}(i8${as}* %dst, ${args});
+ ret void
+}
+
+; CHECK-LABEL: .func{{.*}}test_${function}_o(
+define void @test_${function}_o(i8 ${as}* %dst, ${args}) {
+; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}+128],
+; CHECK: {${check_args}}
+ %dst1 = getelementptr i8, i8 ${as}* %dst, i32 128;
+ call void @${intrinsic}(i8 ${as}* %dst1, ${args});
+ ret void
+}
+"""
+ intrinsic_template = (
+ "llvm.nvvm.stmatrix.sync.aligned.${geom}.${frag}${trans}.${itype}.${pspace}"
+ )
+ instruction_template = (
+ "stmatrix.sync.aligned.${geom}.${frag}${trans}${space}.${itype}"
+ )
+ generated_items = []
+
+ for frag, space, trans in product(
+ get_stmatrix_ops(),
+ ["", ".shared"],
+ ["", ".trans"],
+ ):
+ if not is_stmatrix_variant_supported(frag, trans):
+ continue
+
+ params = {
+ "frag": frag.frag,
+ "space": space,
+ "trans": trans,
+ "itype": frag.mma_type.ptx_type,
+ "pspace": get_pspace(space),
+ "as": "addrspace(%d)" % get_aspace(space),
+ "geom": frag.geom,
+ }
+
+ test_params = params
+ test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
+ test_params["function"] = test_params["intrinsic"].replace(".", "_")
+ test_params["instruction"] = Template(instruction_template).substitute(params)
+ test_params["args"] = make_wmma_slice_args(frag)
+ test_params["check_args"] = check_pattern(frag)
+
+ print(Template(stmatrix_template).substitute(test_params))
+ generated_items.append((test_params["intrinsic"], test_params["instruction"]))
+
+ return generated_items
+
def mma_signature(op):
if op.a.mma_type.ptx_type == "f16":
# FP16 ops identified by accumulator & result type.
@@ -893,6 +1000,7 @@ def gen_check_unsupported_ops(items):
; NOALTFLOAT-NOT: .{{bf16|tf32}}
; NODOUBLE-NOT: .f64
; NOLDMATRIX-NOT: ldmatrix.sync.aligned
+; NOSTMATRIX-NOT: stmatrix.sync.aligned
; M16N16-DAG: m16n16k16.load.{{[ab].*}}.f16.p
; M16N16-DAG: m16n16k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p
@@ -994,6 +1102,26 @@ def gen_check_unsupported_ops(items):
; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x4.b8x16.b6x16_p32
; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x4.b8x16.b4x16_p64
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.trans.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.trans.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.trans.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.shared.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.shared.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.shared.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.trans.shared.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.trans.shared.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.trans.shared.b16
+
+; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x1.trans.b8
+; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x2.trans.b8
+; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x4.trans.b8
+; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x1.trans.shared.b8
+; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x2.trans.shared.b8
+; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x4.trans.shared.b8
+
; PTX71MMA-DAG: mma.m8n8k4.row.col.f64
; PTX71MMA-DAG: mma.m16n8k4.row.col.tf32
; PTX71MMA-DAG: mma.m16n8k8.row.col.tf32
@@ -1039,6 +1167,7 @@ def gen_tests():
items = gen_wmma_load_tests()
items += gen_wmma_store_tests()
items += gen_ldmatrix_tests()
+ items += gen_stmatrix_tests()
items += gen_wmma_mma_tests()
items += gen_mma_tests()
gen_check_unsupported_ops(items)