aboutsummaryrefslogtreecommitdiff
path: root/llvm/include/llvm/IR/IntrinsicsNVVM.td
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/include/llvm/IR/IntrinsicsNVVM.td')
-rw-r--r--llvm/include/llvm/IR/IntrinsicsNVVM.td65
1 files changed, 65 insertions, 0 deletions
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 5ddc144..967d166 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -331,6 +331,11 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType> {
!eq(gf,"m8n16:x2") : !listsplat(llvm_i32_ty, 2),
!eq(gf,"m8n16:x4") : !listsplat(llvm_i32_ty, 4),
+ // stmatrix b8 -> s32 @ m16n8
+ !eq(gf,"m16n8:x1") : !listsplat(llvm_i32_ty, 1),
+ !eq(gf,"m16n8:x2") : !listsplat(llvm_i32_ty, 2),
+ !eq(gf,"m16n8:x4") : !listsplat(llvm_i32_ty, 4),
+
);
}
@@ -403,6 +408,17 @@ class LDMATRIX_NAME<WMMA_REGS Frag, int Trans> {
!subst("llvm.", "int_", intr));
}
+class STMATRIX_NAME<WMMA_REGS Frag, int Trans> {
+ string intr = "llvm.nvvm.stmatrix.sync.aligned"
+ # "." # Frag.geom
+ # "." # Frag.frag
+ # !if(Trans, ".trans", "")
+ # "." # Frag.ptx_elt_type
+ ;
+ string record = !subst(".", "_",
+ !subst("llvm.", "int_", intr));
+}
+
// Generates list of 4-tuples of WMMA_REGS representing a valid MMA op.
// Geom: list of supported geometries.
// TypeN: PTX type of the corresponding fragment's element.
@@ -443,6 +459,16 @@ class LDMATRIX_OPS<list<string> Geom, list<string> Frags, list<string> Types> {
list<string> ops = !foreach(x, ret, x.gft);
}
+class STMATRIX_OPS<list<string> Geom, list<string> Frags, list<string> Types> {
+ list<WMMA_REGS> ret =
+ !foldl([]<WMMA_REGS>, Geom, t1, geom, !listconcat(t1,
+ !foldl([]<WMMA_REGS>, Frags, t2, frag, !listconcat(t2,
+ !foldl([]<WMMA_REGS>, Types, t3, type, !listconcat(t3,
+ [WMMA_REGS<geom, frag, type>]))))));
+ // Debugging aid for readable representation of the list above.
+ list<string> ops = !foreach(x, ret, x.gft);
+}
+
// Creates list of valid combinations of fragments. This is the main list that
// drives generation of corresponding intrinsics and instructions.
class NVVM_MMA_OPS {
@@ -537,9 +563,18 @@ class NVVM_MMA_OPS {
list<WMMA_REGS> ldmatrix_geom_m8n16_ops = LDMATRIX_OPS<
["m8n16"], ["x1", "x2", "x4"], ["b8x16.b6x16_p32", "b8x16.b4x16_p64"]>.ret;
+ list<WMMA_REGS> stmatrix_b16_ops = STMATRIX_OPS<
+ ["m8n8"], ["x1", "x2", "x4"], ["b16"]>.ret;
+
+ list<WMMA_REGS> stmatrix_b8_ops = STMATRIX_OPS<
+ ["m16n8"], ["x1", "x2", "x4"], ["b8"]>.ret;
+
list<WMMA_REGS> all_ldmatrix_ops = !listconcat(ldmatrix_b16_ops,
ldmatrix_geom_m16n16_ops,
ldmatrix_geom_m8n16_ops);
+
+ list<WMMA_REGS> all_stmatrix_ops = !listconcat(stmatrix_b16_ops,
+ stmatrix_b8_ops);
}
def NVVM_MMA_OPS : NVVM_MMA_OPS;
@@ -680,6 +715,19 @@ class NVVM_LDMATRIX_SUPPORTED<WMMA_REGS frag, bit trans> {
);
}
+// Returns true if the fragment is valid for stmatrix ops is supported;
+// false otherwise.
+class NVVM_STMATRIX_SUPPORTED<WMMA_REGS frag, bit trans> {
+ string g = frag.geom;
+ string t = frag.ptx_elt_type;
+
+ bit ret = !cond(
+ !and(!eq(g, "m8n8"), !eq(t, "b16")): true,
+ !and(!eq(g, "m16n8"), !eq(t, "b8"), !eq(trans, 1)): true,
+ true: false
+ );
+}
+
class SHFL_INFO<bit sync, string mode, string type, bit return_pred> {
string Suffix = !if(sync, "sync_", "")
# mode # "_"
@@ -1969,6 +2017,23 @@ foreach transposed = [0, 1] in {
}
}
+// STMATRIX
+class NVVM_STMATRIX<WMMA_REGS Frag, int Transposed>
+ : Intrinsic<[],
+ !listconcat([llvm_anyptr_ty], Frag.regs),
+ [IntrWriteMem, IntrArgMemOnly, IntrNoCallback,
+ WriteOnly<ArgIndex<0>>, NoCapture<ArgIndex<0>>],
+ STMATRIX_NAME<Frag, Transposed>.intr>;
+
+foreach transposed = [0, 1] in {
+ foreach frag = NVVM_MMA_OPS.all_stmatrix_ops in {
+ if NVVM_STMATRIX_SUPPORTED<frag, transposed>.ret then {
+ def STMATRIX_NAME<frag, transposed>.record
+ : NVVM_STMATRIX<frag, transposed>;
+ }
+ }
+}
+
// MAPA
let IntrProperties = [IntrNoMem, IntrSpeculatable, NoCapture<ArgIndex<0>>] in {
def int_nvvm_mapa