diff options
Diffstat (limited to 'llvm/include/llvm/IR/IntrinsicsNVVM.td')
-rw-r--r-- | llvm/include/llvm/IR/IntrinsicsNVVM.td | 65 |
1 files changed, 65 insertions, 0 deletions
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td index 5ddc144..967d166 100644 --- a/llvm/include/llvm/IR/IntrinsicsNVVM.td +++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td @@ -331,6 +331,11 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType> { !eq(gf,"m8n16:x2") : !listsplat(llvm_i32_ty, 2), !eq(gf,"m8n16:x4") : !listsplat(llvm_i32_ty, 4), + // stmatrix b8 -> s32 @ m16n8 + !eq(gf,"m16n8:x1") : !listsplat(llvm_i32_ty, 1), + !eq(gf,"m16n8:x2") : !listsplat(llvm_i32_ty, 2), + !eq(gf,"m16n8:x4") : !listsplat(llvm_i32_ty, 4), + ); } @@ -403,6 +408,17 @@ class LDMATRIX_NAME<WMMA_REGS Frag, int Trans> { !subst("llvm.", "int_", intr)); } +class STMATRIX_NAME<WMMA_REGS Frag, int Trans> { + string intr = "llvm.nvvm.stmatrix.sync.aligned" + # "." # Frag.geom + # "." # Frag.frag + # !if(Trans, ".trans", "") + # "." # Frag.ptx_elt_type + ; + string record = !subst(".", "_", + !subst("llvm.", "int_", intr)); +} + // Generates list of 4-tuples of WMMA_REGS representing a valid MMA op. // Geom: list of supported geometries. // TypeN: PTX type of the corresponding fragment's element. @@ -443,6 +459,16 @@ class LDMATRIX_OPS<list<string> Geom, list<string> Frags, list<string> Types> { list<string> ops = !foreach(x, ret, x.gft); } +class STMATRIX_OPS<list<string> Geom, list<string> Frags, list<string> Types> { + list<WMMA_REGS> ret = + !foldl([]<WMMA_REGS>, Geom, t1, geom, !listconcat(t1, + !foldl([]<WMMA_REGS>, Frags, t2, frag, !listconcat(t2, + !foldl([]<WMMA_REGS>, Types, t3, type, !listconcat(t3, + [WMMA_REGS<geom, frag, type>])))))); + // Debugging aid for readable representation of the list above. + list<string> ops = !foreach(x, ret, x.gft); +} + // Creates list of valid combinations of fragments. This is the main list that // drives generation of corresponding intrinsics and instructions. class NVVM_MMA_OPS { @@ -537,9 +563,18 @@ class NVVM_MMA_OPS { list<WMMA_REGS> ldmatrix_geom_m8n16_ops = LDMATRIX_OPS< ["m8n16"], ["x1", "x2", "x4"], ["b8x16.b6x16_p32", "b8x16.b4x16_p64"]>.ret; + list<WMMA_REGS> stmatrix_b16_ops = STMATRIX_OPS< + ["m8n8"], ["x1", "x2", "x4"], ["b16"]>.ret; + + list<WMMA_REGS> stmatrix_b8_ops = STMATRIX_OPS< + ["m16n8"], ["x1", "x2", "x4"], ["b8"]>.ret; + list<WMMA_REGS> all_ldmatrix_ops = !listconcat(ldmatrix_b16_ops, ldmatrix_geom_m16n16_ops, ldmatrix_geom_m8n16_ops); + + list<WMMA_REGS> all_stmatrix_ops = !listconcat(stmatrix_b16_ops, + stmatrix_b8_ops); } def NVVM_MMA_OPS : NVVM_MMA_OPS; @@ -680,6 +715,19 @@ class NVVM_LDMATRIX_SUPPORTED<WMMA_REGS frag, bit trans> { ); } +// Returns true if the fragment is valid for stmatrix ops is supported; +// false otherwise. +class NVVM_STMATRIX_SUPPORTED<WMMA_REGS frag, bit trans> { + string g = frag.geom; + string t = frag.ptx_elt_type; + + bit ret = !cond( + !and(!eq(g, "m8n8"), !eq(t, "b16")): true, + !and(!eq(g, "m16n8"), !eq(t, "b8"), !eq(trans, 1)): true, + true: false + ); +} + class SHFL_INFO<bit sync, string mode, string type, bit return_pred> { string Suffix = !if(sync, "sync_", "") # mode # "_" @@ -1969,6 +2017,23 @@ foreach transposed = [0, 1] in { } } +// STMATRIX +class NVVM_STMATRIX<WMMA_REGS Frag, int Transposed> + : Intrinsic<[], + !listconcat([llvm_anyptr_ty], Frag.regs), + [IntrWriteMem, IntrArgMemOnly, IntrNoCallback, + WriteOnly<ArgIndex<0>>, NoCapture<ArgIndex<0>>], + STMATRIX_NAME<Frag, Transposed>.intr>; + +foreach transposed = [0, 1] in { + foreach frag = NVVM_MMA_OPS.all_stmatrix_ops in { + if NVVM_STMATRIX_SUPPORTED<frag, transposed>.ret then { + def STMATRIX_NAME<frag, transposed>.record + : NVVM_STMATRIX<frag, transposed>; + } + } +} + // MAPA let IntrProperties = [IntrNoMem, IntrSpeculatable, NoCapture<ArgIndex<0>>] in { def int_nvvm_mapa |