aboutsummaryrefslogtreecommitdiff
path: root/llvm/include/llvm/IR/IntrinsicsNVVM.td
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/include/llvm/IR/IntrinsicsNVVM.td')
-rw-r--r--llvm/include/llvm/IR/IntrinsicsNVVM.td143
1 files changed, 134 insertions, 9 deletions
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 0375f29..967d166 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -331,6 +331,11 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType> {
!eq(gf,"m8n16:x2") : !listsplat(llvm_i32_ty, 2),
!eq(gf,"m8n16:x4") : !listsplat(llvm_i32_ty, 4),
+ // stmatrix b8 -> s32 @ m16n8
+ !eq(gf,"m16n8:x1") : !listsplat(llvm_i32_ty, 1),
+ !eq(gf,"m16n8:x2") : !listsplat(llvm_i32_ty, 2),
+ !eq(gf,"m16n8:x4") : !listsplat(llvm_i32_ty, 4),
+
);
}
@@ -403,6 +408,17 @@ class LDMATRIX_NAME<WMMA_REGS Frag, int Trans> {
!subst("llvm.", "int_", intr));
}
+class STMATRIX_NAME<WMMA_REGS Frag, int Trans> {
+ string intr = "llvm.nvvm.stmatrix.sync.aligned"
+ # "." # Frag.geom
+ # "." # Frag.frag
+ # !if(Trans, ".trans", "")
+ # "." # Frag.ptx_elt_type
+ ;
+ string record = !subst(".", "_",
+ !subst("llvm.", "int_", intr));
+}
+
// Generates list of 4-tuples of WMMA_REGS representing a valid MMA op.
// Geom: list of supported geometries.
// TypeN: PTX type of the corresponding fragment's element.
@@ -443,6 +459,16 @@ class LDMATRIX_OPS<list<string> Geom, list<string> Frags, list<string> Types> {
list<string> ops = !foreach(x, ret, x.gft);
}
+class STMATRIX_OPS<list<string> Geom, list<string> Frags, list<string> Types> {
+ list<WMMA_REGS> ret =
+ !foldl([]<WMMA_REGS>, Geom, t1, geom, !listconcat(t1,
+ !foldl([]<WMMA_REGS>, Frags, t2, frag, !listconcat(t2,
+ !foldl([]<WMMA_REGS>, Types, t3, type, !listconcat(t3,
+ [WMMA_REGS<geom, frag, type>]))))));
+ // Debugging aid for readable representation of the list above.
+ list<string> ops = !foreach(x, ret, x.gft);
+}
+
// Creates list of valid combinations of fragments. This is the main list that
// drives generation of corresponding intrinsics and instructions.
class NVVM_MMA_OPS {
@@ -537,9 +563,18 @@ class NVVM_MMA_OPS {
list<WMMA_REGS> ldmatrix_geom_m8n16_ops = LDMATRIX_OPS<
["m8n16"], ["x1", "x2", "x4"], ["b8x16.b6x16_p32", "b8x16.b4x16_p64"]>.ret;
+ list<WMMA_REGS> stmatrix_b16_ops = STMATRIX_OPS<
+ ["m8n8"], ["x1", "x2", "x4"], ["b16"]>.ret;
+
+ list<WMMA_REGS> stmatrix_b8_ops = STMATRIX_OPS<
+ ["m16n8"], ["x1", "x2", "x4"], ["b8"]>.ret;
+
list<WMMA_REGS> all_ldmatrix_ops = !listconcat(ldmatrix_b16_ops,
ldmatrix_geom_m16n16_ops,
ldmatrix_geom_m8n16_ops);
+
+ list<WMMA_REGS> all_stmatrix_ops = !listconcat(stmatrix_b16_ops,
+ stmatrix_b8_ops);
}
def NVVM_MMA_OPS : NVVM_MMA_OPS;
@@ -680,6 +715,19 @@ class NVVM_LDMATRIX_SUPPORTED<WMMA_REGS frag, bit trans> {
);
}
+// Returns true if the fragment is valid for stmatrix ops is supported;
+// false otherwise.
+class NVVM_STMATRIX_SUPPORTED<WMMA_REGS frag, bit trans> {
+ string g = frag.geom;
+ string t = frag.ptx_elt_type;
+
+ bit ret = !cond(
+ !and(!eq(g, "m8n8"), !eq(t, "b16")): true,
+ !and(!eq(g, "m16n8"), !eq(t, "b8"), !eq(trans, 1)): true,
+ true: false
+ );
+}
+
class SHFL_INFO<bit sync, string mode, string type, bit return_pred> {
string Suffix = !if(sync, "sync_", "")
# mode # "_"
@@ -1969,6 +2017,23 @@ foreach transposed = [0, 1] in {
}
}
+// STMATRIX
+class NVVM_STMATRIX<WMMA_REGS Frag, int Transposed>
+ : Intrinsic<[],
+ !listconcat([llvm_anyptr_ty], Frag.regs),
+ [IntrWriteMem, IntrArgMemOnly, IntrNoCallback,
+ WriteOnly<ArgIndex<0>>, NoCapture<ArgIndex<0>>],
+ STMATRIX_NAME<Frag, Transposed>.intr>;
+
+foreach transposed = [0, 1] in {
+ foreach frag = NVVM_MMA_OPS.all_stmatrix_ops in {
+ if NVVM_STMATRIX_SUPPORTED<frag, transposed>.ret then {
+ def STMATRIX_NAME<frag, transposed>.record
+ : NVVM_STMATRIX<frag, transposed>;
+ }
+ }
+}
+
// MAPA
let IntrProperties = [IntrNoMem, IntrSpeculatable, NoCapture<ArgIndex<0>>] in {
def int_nvvm_mapa
@@ -2024,9 +2089,7 @@ foreach dim = 1...5 in {
tensor_dim_args, // actual tensor dims
[llvm_i64_ty]), // cache_hint
[llvm_i1_ty], // Flag for cache_hint
- [IntrConvergent,
- ReadOnly<ArgIndex<0>>, ReadOnly<ArgIndex<1>>,
- NoCapture<ArgIndex<0>>, NoCapture<ArgIndex<1>>]>;
+ [IntrConvergent, ReadOnly<ArgIndex<0>>, ReadOnly<ArgIndex<1>>]>;
// Intrinsics for TMA Copy with reduction
foreach red_op = ["add", "min", "max", "inc", "dec", "and", "or", "xor"] in
@@ -2037,18 +2100,31 @@ foreach dim = 1...5 in {
tensor_dim_args, // actual tensor dims
[llvm_i64_ty]), // cache_hint
[llvm_i1_ty], // Flag for cache_hint
- [IntrConvergent, ReadOnly<ArgIndex<0>>, ReadOnly<ArgIndex<1>>,
- NoCapture<ArgIndex<0>>, NoCapture<ArgIndex<1>>]>;
+ [IntrConvergent, ReadOnly<ArgIndex<0>>, ReadOnly<ArgIndex<1>>]>;
}
}
+// TMA S2G tile::scatter4
+def int_nvvm_cp_async_bulk_tensor_s2g_tile_scatter4_2d
+ : DefaultAttrsIntrinsicFlags<[],
+ !listconcat([llvm_shared_ptr_ty, // src_smem_ptr
+ llvm_ptr_ty], // tensormap_ptr
+ !listsplat(llvm_i32_ty, 5), // dims
+ [llvm_i64_ty]), // cache_hint
+ [llvm_i1_ty], // Flag for cache_hint
+ [IntrConvergent, ReadOnly<ArgIndex<0>>, ReadOnly<ArgIndex<1>>]>;
+
// TMA Tensor Copy Intrinsics: G2S -> From Global to Shared memory variants
foreach dim = 1...5 in {
defvar tensor_dim_args = !listsplat(llvm_i32_ty, dim);
- foreach mode = !if(!ge(dim, 3), ["tile", "im2col"], ["tile"]) in {
+ foreach mode = !if(!ge(dim, 3), ["tile", "im2col", "im2col_w", "im2col_w_128"], ["tile"]) in {
defvar is_im2col = !eq(mode, "im2col");
- defvar num_im2col_offsets = !if(is_im2col, !add(dim, -2), 0);
+ defvar is_im2colw = !or(!eq(mode, "im2col_w"), !eq(mode, "im2col_w_128"));
+
+ // For im2col_w/w128 modes, the num_offsets is always 2.
+ // For im2col mode, the num_offsets is (dim - 2).
+ defvar num_im2col_offsets = !if(is_im2colw, 2, !if(is_im2col, !add(dim, -2), 0));
defvar im2col_offsets_args = !listsplat(llvm_i16_ty, num_im2col_offsets);
defvar g2s_params = !listconcat(
@@ -2079,11 +2155,60 @@ foreach dim = 1...5 in {
im2col_offsets_args, // im2col offsets
[llvm_i64_ty]), // cache_hint
[llvm_i1_ty], // Flag for cache_hint
- [IntrConvergent,
- ReadOnly<ArgIndex<0>>, NoCapture<ArgIndex<0>>]>;
+ [IntrConvergent, ReadOnly<ArgIndex<0>>]>;
+
+ def int_nvvm_cp_async_bulk_tensor_g2s_cta_ # mode # _ # dim # d :
+ DefaultAttrsIntrinsicFlags<[],
+ !listconcat([llvm_shared_ptr_ty, // dst_ptr
+ llvm_shared_ptr_ty, // mbarrier_ptr
+ llvm_ptr_ty], // tensormap_ptr
+ tensor_dim_args, // actual tensor dims
+ im2col_offsets_args, // im2col offsets
+ [llvm_i64_ty]), // cache_hint
+ [llvm_i1_ty], // Flag for cache_hint
+ [IntrConvergent, WriteOnly<ArgIndex<0>>, ReadOnly<ArgIndex<2>>]>;
}
}
+// TMA copy for tile::gather4
+def int_nvvm_cp_async_bulk_tensor_g2s_tile_gather4_2d
+ : DefaultAttrsIntrinsicFlags<[],
+ !listconcat(
+ [llvm_shared_cluster_ptr_ty, // dst_shared_cluster_ptr
+ llvm_shared_ptr_ty, // mbarrier_ptr
+ llvm_ptr_ty], // tensormap_ptr
+ !listsplat(llvm_i32_ty, 5), // co-ordinates
+ [llvm_i16_ty, // cta_mask
+ llvm_i64_ty]), // cache_hint
+ [llvm_i1_ty, // Flag for cta_mask
+ llvm_i1_ty, // Flag for cache_hint
+ llvm_i32_ty], // Flag for cta_group
+ [IntrConvergent,
+ WriteOnly<ArgIndex<0>>, ReadOnly<ArgIndex<2>>,
+ // Allowed values for cta_group are {0,1,2} i.e [0, 3).
+ Range<ArgIndex<12>, 0, 3>]>;
+
+def int_nvvm_cp_async_bulk_tensor_g2s_cta_tile_gather4_2d
+ : DefaultAttrsIntrinsicFlags<[],
+ !listconcat(
+ [llvm_shared_ptr_ty, // dst_shared_ptr
+ llvm_shared_ptr_ty, // mbarrier_ptr
+ llvm_ptr_ty], // tensormap_ptr
+ !listsplat(llvm_i32_ty, 5), // co-ordinates
+ [llvm_i64_ty]), // cache_hint
+ [llvm_i1_ty], // Flag for cache_hint
+ [IntrConvergent,
+ WriteOnly<ArgIndex<0>>, ReadOnly<ArgIndex<2>>]>;
+
+// TMA prefetch for tile::gather4
+def int_nvvm_cp_async_bulk_tensor_prefetch_tile_gather4_2d
+ : DefaultAttrsIntrinsicFlags<[],
+ !listconcat([llvm_ptr_ty], // tensormap_ptr
+ !listsplat(llvm_i32_ty, 5), // co-ordinates
+ [llvm_i64_ty]), // cache_hint
+ [llvm_i1_ty], // Flag for cache_hint
+ [IntrConvergent, ReadOnly<ArgIndex<0>>]>;
+
// Intrinsics for Prefetch and Prefetchu
let IntrProperties = [IntrArgMemOnly, ReadOnly<ArgIndex<0>>, NoCapture<ArgIndex<0>>] in {
foreach level = ["L1", "L2"] in {