diff options
Diffstat (limited to 'llvm/include/llvm/IR/IntrinsicsNVVM.td')
-rw-r--r-- | llvm/include/llvm/IR/IntrinsicsNVVM.td | 143 |
1 files changed, 134 insertions, 9 deletions
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td index 0375f29..967d166 100644 --- a/llvm/include/llvm/IR/IntrinsicsNVVM.td +++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td @@ -331,6 +331,11 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType> { !eq(gf,"m8n16:x2") : !listsplat(llvm_i32_ty, 2), !eq(gf,"m8n16:x4") : !listsplat(llvm_i32_ty, 4), + // stmatrix b8 -> s32 @ m16n8 + !eq(gf,"m16n8:x1") : !listsplat(llvm_i32_ty, 1), + !eq(gf,"m16n8:x2") : !listsplat(llvm_i32_ty, 2), + !eq(gf,"m16n8:x4") : !listsplat(llvm_i32_ty, 4), + ); } @@ -403,6 +408,17 @@ class LDMATRIX_NAME<WMMA_REGS Frag, int Trans> { !subst("llvm.", "int_", intr)); } +class STMATRIX_NAME<WMMA_REGS Frag, int Trans> { + string intr = "llvm.nvvm.stmatrix.sync.aligned" + # "." # Frag.geom + # "." # Frag.frag + # !if(Trans, ".trans", "") + # "." # Frag.ptx_elt_type + ; + string record = !subst(".", "_", + !subst("llvm.", "int_", intr)); +} + // Generates list of 4-tuples of WMMA_REGS representing a valid MMA op. // Geom: list of supported geometries. // TypeN: PTX type of the corresponding fragment's element. @@ -443,6 +459,16 @@ class LDMATRIX_OPS<list<string> Geom, list<string> Frags, list<string> Types> { list<string> ops = !foreach(x, ret, x.gft); } +class STMATRIX_OPS<list<string> Geom, list<string> Frags, list<string> Types> { + list<WMMA_REGS> ret = + !foldl([]<WMMA_REGS>, Geom, t1, geom, !listconcat(t1, + !foldl([]<WMMA_REGS>, Frags, t2, frag, !listconcat(t2, + !foldl([]<WMMA_REGS>, Types, t3, type, !listconcat(t3, + [WMMA_REGS<geom, frag, type>])))))); + // Debugging aid for readable representation of the list above. + list<string> ops = !foreach(x, ret, x.gft); +} + // Creates list of valid combinations of fragments. This is the main list that // drives generation of corresponding intrinsics and instructions. class NVVM_MMA_OPS { @@ -537,9 +563,18 @@ class NVVM_MMA_OPS { list<WMMA_REGS> ldmatrix_geom_m8n16_ops = LDMATRIX_OPS< ["m8n16"], ["x1", "x2", "x4"], ["b8x16.b6x16_p32", "b8x16.b4x16_p64"]>.ret; + list<WMMA_REGS> stmatrix_b16_ops = STMATRIX_OPS< + ["m8n8"], ["x1", "x2", "x4"], ["b16"]>.ret; + + list<WMMA_REGS> stmatrix_b8_ops = STMATRIX_OPS< + ["m16n8"], ["x1", "x2", "x4"], ["b8"]>.ret; + list<WMMA_REGS> all_ldmatrix_ops = !listconcat(ldmatrix_b16_ops, ldmatrix_geom_m16n16_ops, ldmatrix_geom_m8n16_ops); + + list<WMMA_REGS> all_stmatrix_ops = !listconcat(stmatrix_b16_ops, + stmatrix_b8_ops); } def NVVM_MMA_OPS : NVVM_MMA_OPS; @@ -680,6 +715,19 @@ class NVVM_LDMATRIX_SUPPORTED<WMMA_REGS frag, bit trans> { ); } +// Returns true if the fragment is valid for stmatrix ops is supported; +// false otherwise. +class NVVM_STMATRIX_SUPPORTED<WMMA_REGS frag, bit trans> { + string g = frag.geom; + string t = frag.ptx_elt_type; + + bit ret = !cond( + !and(!eq(g, "m8n8"), !eq(t, "b16")): true, + !and(!eq(g, "m16n8"), !eq(t, "b8"), !eq(trans, 1)): true, + true: false + ); +} + class SHFL_INFO<bit sync, string mode, string type, bit return_pred> { string Suffix = !if(sync, "sync_", "") # mode # "_" @@ -1969,6 +2017,23 @@ foreach transposed = [0, 1] in { } } +// STMATRIX +class NVVM_STMATRIX<WMMA_REGS Frag, int Transposed> + : Intrinsic<[], + !listconcat([llvm_anyptr_ty], Frag.regs), + [IntrWriteMem, IntrArgMemOnly, IntrNoCallback, + WriteOnly<ArgIndex<0>>, NoCapture<ArgIndex<0>>], + STMATRIX_NAME<Frag, Transposed>.intr>; + +foreach transposed = [0, 1] in { + foreach frag = NVVM_MMA_OPS.all_stmatrix_ops in { + if NVVM_STMATRIX_SUPPORTED<frag, transposed>.ret then { + def STMATRIX_NAME<frag, transposed>.record + : NVVM_STMATRIX<frag, transposed>; + } + } +} + // MAPA let IntrProperties = [IntrNoMem, IntrSpeculatable, NoCapture<ArgIndex<0>>] in { def int_nvvm_mapa @@ -2024,9 +2089,7 @@ foreach dim = 1...5 in { tensor_dim_args, // actual tensor dims [llvm_i64_ty]), // cache_hint [llvm_i1_ty], // Flag for cache_hint - [IntrConvergent, - ReadOnly<ArgIndex<0>>, ReadOnly<ArgIndex<1>>, - NoCapture<ArgIndex<0>>, NoCapture<ArgIndex<1>>]>; + [IntrConvergent, ReadOnly<ArgIndex<0>>, ReadOnly<ArgIndex<1>>]>; // Intrinsics for TMA Copy with reduction foreach red_op = ["add", "min", "max", "inc", "dec", "and", "or", "xor"] in @@ -2037,18 +2100,31 @@ foreach dim = 1...5 in { tensor_dim_args, // actual tensor dims [llvm_i64_ty]), // cache_hint [llvm_i1_ty], // Flag for cache_hint - [IntrConvergent, ReadOnly<ArgIndex<0>>, ReadOnly<ArgIndex<1>>, - NoCapture<ArgIndex<0>>, NoCapture<ArgIndex<1>>]>; + [IntrConvergent, ReadOnly<ArgIndex<0>>, ReadOnly<ArgIndex<1>>]>; } } +// TMA S2G tile::scatter4 +def int_nvvm_cp_async_bulk_tensor_s2g_tile_scatter4_2d + : DefaultAttrsIntrinsicFlags<[], + !listconcat([llvm_shared_ptr_ty, // src_smem_ptr + llvm_ptr_ty], // tensormap_ptr + !listsplat(llvm_i32_ty, 5), // dims + [llvm_i64_ty]), // cache_hint + [llvm_i1_ty], // Flag for cache_hint + [IntrConvergent, ReadOnly<ArgIndex<0>>, ReadOnly<ArgIndex<1>>]>; + // TMA Tensor Copy Intrinsics: G2S -> From Global to Shared memory variants foreach dim = 1...5 in { defvar tensor_dim_args = !listsplat(llvm_i32_ty, dim); - foreach mode = !if(!ge(dim, 3), ["tile", "im2col"], ["tile"]) in { + foreach mode = !if(!ge(dim, 3), ["tile", "im2col", "im2col_w", "im2col_w_128"], ["tile"]) in { defvar is_im2col = !eq(mode, "im2col"); - defvar num_im2col_offsets = !if(is_im2col, !add(dim, -2), 0); + defvar is_im2colw = !or(!eq(mode, "im2col_w"), !eq(mode, "im2col_w_128")); + + // For im2col_w/w128 modes, the num_offsets is always 2. + // For im2col mode, the num_offsets is (dim - 2). + defvar num_im2col_offsets = !if(is_im2colw, 2, !if(is_im2col, !add(dim, -2), 0)); defvar im2col_offsets_args = !listsplat(llvm_i16_ty, num_im2col_offsets); defvar g2s_params = !listconcat( @@ -2079,11 +2155,60 @@ foreach dim = 1...5 in { im2col_offsets_args, // im2col offsets [llvm_i64_ty]), // cache_hint [llvm_i1_ty], // Flag for cache_hint - [IntrConvergent, - ReadOnly<ArgIndex<0>>, NoCapture<ArgIndex<0>>]>; + [IntrConvergent, ReadOnly<ArgIndex<0>>]>; + + def int_nvvm_cp_async_bulk_tensor_g2s_cta_ # mode # _ # dim # d : + DefaultAttrsIntrinsicFlags<[], + !listconcat([llvm_shared_ptr_ty, // dst_ptr + llvm_shared_ptr_ty, // mbarrier_ptr + llvm_ptr_ty], // tensormap_ptr + tensor_dim_args, // actual tensor dims + im2col_offsets_args, // im2col offsets + [llvm_i64_ty]), // cache_hint + [llvm_i1_ty], // Flag for cache_hint + [IntrConvergent, WriteOnly<ArgIndex<0>>, ReadOnly<ArgIndex<2>>]>; } } +// TMA copy for tile::gather4 +def int_nvvm_cp_async_bulk_tensor_g2s_tile_gather4_2d + : DefaultAttrsIntrinsicFlags<[], + !listconcat( + [llvm_shared_cluster_ptr_ty, // dst_shared_cluster_ptr + llvm_shared_ptr_ty, // mbarrier_ptr + llvm_ptr_ty], // tensormap_ptr + !listsplat(llvm_i32_ty, 5), // co-ordinates + [llvm_i16_ty, // cta_mask + llvm_i64_ty]), // cache_hint + [llvm_i1_ty, // Flag for cta_mask + llvm_i1_ty, // Flag for cache_hint + llvm_i32_ty], // Flag for cta_group + [IntrConvergent, + WriteOnly<ArgIndex<0>>, ReadOnly<ArgIndex<2>>, + // Allowed values for cta_group are {0,1,2} i.e [0, 3). + Range<ArgIndex<12>, 0, 3>]>; + +def int_nvvm_cp_async_bulk_tensor_g2s_cta_tile_gather4_2d + : DefaultAttrsIntrinsicFlags<[], + !listconcat( + [llvm_shared_ptr_ty, // dst_shared_ptr + llvm_shared_ptr_ty, // mbarrier_ptr + llvm_ptr_ty], // tensormap_ptr + !listsplat(llvm_i32_ty, 5), // co-ordinates + [llvm_i64_ty]), // cache_hint + [llvm_i1_ty], // Flag for cache_hint + [IntrConvergent, + WriteOnly<ArgIndex<0>>, ReadOnly<ArgIndex<2>>]>; + +// TMA prefetch for tile::gather4 +def int_nvvm_cp_async_bulk_tensor_prefetch_tile_gather4_2d + : DefaultAttrsIntrinsicFlags<[], + !listconcat([llvm_ptr_ty], // tensormap_ptr + !listsplat(llvm_i32_ty, 5), // co-ordinates + [llvm_i64_ty]), // cache_hint + [llvm_i1_ty], // Flag for cache_hint + [IntrConvergent, ReadOnly<ArgIndex<0>>]>; + // Intrinsics for Prefetch and Prefetchu let IntrProperties = [IntrArgMemOnly, ReadOnly<ArgIndex<0>>, NoCapture<ArgIndex<0>>] in { foreach level = ["L1", "L2"] in { |