aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/NVPTX/NVPTXIntrinsics.td')
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXIntrinsics.td842
1 files changed, 481 insertions, 361 deletions
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 70150bd..d337192 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -243,63 +243,82 @@ foreach sync = [false, true] in {
}
// vote.{all,any,uni,ballot}
-multiclass VOTE<NVPTXRegClass regclass, string mode, Intrinsic IntOp> {
- def : BasicNVPTXInst<(outs regclass:$dest), (ins B1:$pred),
- "vote." # mode,
- [(set regclass:$dest, (IntOp i1:$pred))]>,
- Requires<[hasPTX<60>, hasSM<30>]>;
-}
+let Predicates = [hasPTX<60>, hasSM<30>] in {
+ multiclass VOTE<string mode, RegTyInfo t, Intrinsic op> {
+ def : BasicNVPTXInst<(outs t.RC:$dest), (ins B1:$pred),
+ "vote." # mode # "." # t.PtxType,
+ [(set t.Ty:$dest, (op i1:$pred))]>;
+ }
-defm VOTE_ALL : VOTE<B1, "all.pred", int_nvvm_vote_all>;
-defm VOTE_ANY : VOTE<B1, "any.pred", int_nvvm_vote_any>;
-defm VOTE_UNI : VOTE<B1, "uni.pred", int_nvvm_vote_uni>;
-defm VOTE_BALLOT : VOTE<B32, "ballot.b32", int_nvvm_vote_ballot>;
+ defm VOTE_ALL : VOTE<"all", I1RT, int_nvvm_vote_all>;
+ defm VOTE_ANY : VOTE<"any", I1RT, int_nvvm_vote_any>;
+ defm VOTE_UNI : VOTE<"uni", I1RT, int_nvvm_vote_uni>;
+ defm VOTE_BALLOT : VOTE<"ballot", I32RT, int_nvvm_vote_ballot>;
+
+ // vote.sync.{all,any,uni,ballot}
+ multiclass VOTE_SYNC<string mode, RegTyInfo t, Intrinsic op> {
+ def i : BasicNVPTXInst<(outs t.RC:$dest), (ins B1:$pred, i32imm:$mask),
+ "vote.sync." # mode # "." # t.PtxType,
+ [(set t.Ty:$dest, (op imm:$mask, i1:$pred))]>;
+ def r : BasicNVPTXInst<(outs t.RC:$dest), (ins B1:$pred, B32:$mask),
+ "vote.sync." # mode # "." # t.PtxType,
+ [(set t.Ty:$dest, (op i32:$mask, i1:$pred))]>;
+ }
-// vote.sync.{all,any,uni,ballot}
-multiclass VOTE_SYNC<NVPTXRegClass regclass, string mode, Intrinsic IntOp> {
- def i : BasicNVPTXInst<(outs regclass:$dest), (ins B1:$pred, i32imm:$mask),
- "vote.sync." # mode,
- [(set regclass:$dest, (IntOp imm:$mask, i1:$pred))]>,
- Requires<[hasPTX<60>, hasSM<30>]>;
- def r : BasicNVPTXInst<(outs regclass:$dest), (ins B1:$pred, B32:$mask),
- "vote.sync." # mode,
- [(set regclass:$dest, (IntOp i32:$mask, i1:$pred))]>,
- Requires<[hasPTX<60>, hasSM<30>]>;
+ defm VOTE_SYNC_ALL : VOTE_SYNC<"all", I1RT, int_nvvm_vote_all_sync>;
+ defm VOTE_SYNC_ANY : VOTE_SYNC<"any", I1RT, int_nvvm_vote_any_sync>;
+ defm VOTE_SYNC_UNI : VOTE_SYNC<"uni", I1RT, int_nvvm_vote_uni_sync>;
+ defm VOTE_SYNC_BALLOT : VOTE_SYNC<"ballot", I32RT, int_nvvm_vote_ballot_sync>;
}
-
-defm VOTE_SYNC_ALL : VOTE_SYNC<B1, "all.pred", int_nvvm_vote_all_sync>;
-defm VOTE_SYNC_ANY : VOTE_SYNC<B1, "any.pred", int_nvvm_vote_any_sync>;
-defm VOTE_SYNC_UNI : VOTE_SYNC<B1, "uni.pred", int_nvvm_vote_uni_sync>;
-defm VOTE_SYNC_BALLOT : VOTE_SYNC<B32, "ballot.b32", int_nvvm_vote_ballot_sync>;
-
// elect.sync
+let Predicates = [hasPTX<80>, hasSM<90>] in {
def INT_ELECT_SYNC_I : BasicNVPTXInst<(outs B32:$dest, B1:$pred), (ins i32imm:$mask),
"elect.sync",
- [(set i32:$dest, i1:$pred, (int_nvvm_elect_sync imm:$mask))]>,
- Requires<[hasPTX<80>, hasSM<90>]>;
+ [(set i32:$dest, i1:$pred, (int_nvvm_elect_sync imm:$mask))]>;
def INT_ELECT_SYNC_R : BasicNVPTXInst<(outs B32:$dest, B1:$pred), (ins B32:$mask),
"elect.sync",
- [(set i32:$dest, i1:$pred, (int_nvvm_elect_sync i32:$mask))]>,
- Requires<[hasPTX<80>, hasSM<90>]>;
+ [(set i32:$dest, i1:$pred, (int_nvvm_elect_sync i32:$mask))]>;
+}
+
+let Predicates = [hasPTX<60>, hasSM<70>] in {
+ multiclass MATCH_ANY_SYNC<Intrinsic op, RegTyInfo t> {
+ def ii : BasicNVPTXInst<(outs B32:$dest), (ins t.Imm:$value, i32imm:$mask),
+ "match.any.sync." # t.PtxType,
+ [(set i32:$dest, (op imm:$mask, imm:$value))]>;
+ def ir : BasicNVPTXInst<(outs B32:$dest), (ins t.Imm:$value, B32:$mask),
+ "match.any.sync." # t.PtxType,
+ [(set i32:$dest, (op i32:$mask, imm:$value))]>;
+ def ri : BasicNVPTXInst<(outs B32:$dest), (ins t.RC:$value, i32imm:$mask),
+ "match.any.sync." # t.PtxType,
+ [(set i32:$dest, (op imm:$mask, t.Ty:$value))]>;
+ def rr : BasicNVPTXInst<(outs B32:$dest), (ins t.RC:$value, B32:$mask),
+ "match.any.sync." # t.PtxType,
+ [(set i32:$dest, (op i32:$mask, t.Ty:$value))]>;
+ }
-multiclass MATCH_ANY_SYNC<NVPTXRegClass regclass, string ptxtype, Intrinsic IntOp,
- Operand ImmOp> {
- def ii : BasicNVPTXInst<(outs B32:$dest), (ins ImmOp:$value, i32imm:$mask),
- "match.any.sync." # ptxtype,
- [(set i32:$dest, (IntOp imm:$mask, imm:$value))]>,
- Requires<[hasPTX<60>, hasSM<70>]>;
- def ir : BasicNVPTXInst<(outs B32:$dest), (ins ImmOp:$value, B32:$mask),
- "match.any.sync." # ptxtype,
- [(set i32:$dest, (IntOp i32:$mask, imm:$value))]>,
- Requires<[hasPTX<60>, hasSM<70>]>;
- def ri : BasicNVPTXInst<(outs B32:$dest), (ins regclass:$value, i32imm:$mask),
- "match.any.sync." # ptxtype,
- [(set i32:$dest, (IntOp imm:$mask, regclass:$value))]>,
- Requires<[hasPTX<60>, hasSM<70>]>;
- def rr : BasicNVPTXInst<(outs B32:$dest), (ins regclass:$value, B32:$mask),
- "match.any.sync." # ptxtype,
- [(set i32:$dest, (IntOp i32:$mask, regclass:$value))]>,
- Requires<[hasPTX<60>, hasSM<70>]>;
+ defm MATCH_ANY_SYNC_32 : MATCH_ANY_SYNC<int_nvvm_match_any_sync_i32, I32RT>;
+ defm MATCH_ANY_SYNC_64 : MATCH_ANY_SYNC<int_nvvm_match_any_sync_i64, I64RT>;
+
+ multiclass MATCH_ALLP_SYNC<RegTyInfo t, Intrinsic op> {
+ def ii : BasicNVPTXInst<(outs B32:$dest, B1:$pred),
+ (ins t.Imm:$value, i32imm:$mask),
+ "match.all.sync." # t.PtxType,
+ [(set i32:$dest, i1:$pred, (op imm:$mask, imm:$value))]>;
+ def ir : BasicNVPTXInst<(outs B32:$dest, B1:$pred),
+ (ins t.Imm:$value, B32:$mask),
+ "match.all.sync." # t.PtxType,
+ [(set i32:$dest, i1:$pred, (op i32:$mask, imm:$value))]>;
+ def ri : BasicNVPTXInst<(outs B32:$dest, B1:$pred),
+ (ins t.RC:$value, i32imm:$mask),
+ "match.all.sync." # t.PtxType,
+ [(set i32:$dest, i1:$pred, (op imm:$mask, t.Ty:$value))]>;
+ def rr : BasicNVPTXInst<(outs B32:$dest, B1:$pred),
+ (ins t.RC:$value, B32:$mask),
+ "match.all.sync." # t.PtxType,
+ [(set i32:$dest, i1:$pred, (op i32:$mask, t.Ty:$value))]>;
+ }
+ defm MATCH_ALLP_SYNC_32 : MATCH_ALLP_SYNC<I32RT, int_nvvm_match_all_sync_i32p>;
+ defm MATCH_ALLP_SYNC_64 : MATCH_ALLP_SYNC<I64RT, int_nvvm_match_all_sync_i64p>;
}
// activemask.b32
@@ -308,39 +327,6 @@ def ACTIVEMASK : BasicNVPTXInst<(outs B32:$dest), (ins),
[(set i32:$dest, (int_nvvm_activemask))]>,
Requires<[hasPTX<62>, hasSM<30>]>;
-defm MATCH_ANY_SYNC_32 : MATCH_ANY_SYNC<B32, "b32", int_nvvm_match_any_sync_i32,
- i32imm>;
-defm MATCH_ANY_SYNC_64 : MATCH_ANY_SYNC<B64, "b64", int_nvvm_match_any_sync_i64,
- i64imm>;
-
-multiclass MATCH_ALLP_SYNC<NVPTXRegClass regclass, string ptxtype, Intrinsic IntOp,
- Operand ImmOp> {
- def ii : BasicNVPTXInst<(outs B32:$dest, B1:$pred),
- (ins ImmOp:$value, i32imm:$mask),
- "match.all.sync." # ptxtype,
- [(set i32:$dest, i1:$pred, (IntOp imm:$mask, imm:$value))]>,
- Requires<[hasPTX<60>, hasSM<70>]>;
- def ir : BasicNVPTXInst<(outs B32:$dest, B1:$pred),
- (ins ImmOp:$value, B32:$mask),
- "match.all.sync." # ptxtype,
- [(set i32:$dest, i1:$pred, (IntOp i32:$mask, imm:$value))]>,
- Requires<[hasPTX<60>, hasSM<70>]>;
- def ri : BasicNVPTXInst<(outs B32:$dest, B1:$pred),
- (ins regclass:$value, i32imm:$mask),
- "match.all.sync." # ptxtype,
- [(set i32:$dest, i1:$pred, (IntOp imm:$mask, regclass:$value))]>,
- Requires<[hasPTX<60>, hasSM<70>]>;
- def rr : BasicNVPTXInst<(outs B32:$dest, B1:$pred),
- (ins regclass:$value, B32:$mask),
- "match.all.sync." # ptxtype,
- [(set i32:$dest, i1:$pred, (IntOp i32:$mask, regclass:$value))]>,
- Requires<[hasPTX<60>, hasSM<70>]>;
-}
-defm MATCH_ALLP_SYNC_32 : MATCH_ALLP_SYNC<B32, "b32", int_nvvm_match_all_sync_i32p,
- i32imm>;
-defm MATCH_ALLP_SYNC_64 : MATCH_ALLP_SYNC<B64, "b64", int_nvvm_match_all_sync_i64p,
- i64imm>;
-
multiclass REDUX_SYNC<string BinOp, string PTXType, Intrinsic Intrin> {
def : BasicNVPTXInst<(outs B32:$dst), (ins B32:$src, B32:$mask),
"redux.sync." # BinOp # "." # PTXType,
@@ -381,24 +367,20 @@ defm REDUX_SYNC_FMAX_ABS_NAN: REDUX_SYNC_F<"max", ".abs", ".NaN">;
//-----------------------------------
// Explicit Memory Fence Functions
//-----------------------------------
-class MEMBAR<string StrOp, Intrinsic IntOP> :
- BasicNVPTXInst<(outs), (ins),
- StrOp, [(IntOP)]>;
+class NullaryInst<string StrOp, Intrinsic IntOP> :
+ BasicNVPTXInst<(outs), (ins), StrOp, [(IntOP)]>;
-def INT_MEMBAR_CTA : MEMBAR<"membar.cta", int_nvvm_membar_cta>;
-def INT_MEMBAR_GL : MEMBAR<"membar.gl", int_nvvm_membar_gl>;
-def INT_MEMBAR_SYS : MEMBAR<"membar.sys", int_nvvm_membar_sys>;
+def INT_MEMBAR_CTA : NullaryInst<"membar.cta", int_nvvm_membar_cta>;
+def INT_MEMBAR_GL : NullaryInst<"membar.gl", int_nvvm_membar_gl>;
+def INT_MEMBAR_SYS : NullaryInst<"membar.sys", int_nvvm_membar_sys>;
def INT_FENCE_SC_CLUSTER:
- MEMBAR<"fence.sc.cluster", int_nvvm_fence_sc_cluster>,
+ NullaryInst<"fence.sc.cluster", int_nvvm_fence_sc_cluster>,
Requires<[hasPTX<78>, hasSM<90>]>;
// Proxy fence (uni-directional)
-// fence.proxy.tensormap.release variants
-
class FENCE_PROXY_TENSORMAP_GENERIC_RELEASE<string Scope, Intrinsic Intr> :
- BasicNVPTXInst<(outs), (ins),
- "fence.proxy.tensormap::generic.release." # Scope, [(Intr)]>,
+ NullaryInst<"fence.proxy.tensormap::generic.release." # Scope, Intr>,
Requires<[hasPTX<83>, hasSM<90>]>;
def INT_FENCE_PROXY_TENSORMAP_GENERIC_RELEASE_CTA:
@@ -488,35 +470,31 @@ defm CP_ASYNC_CG_SHARED_GLOBAL_16 :
CP_ASYNC_SHARED_GLOBAL_I<"cg", "16", int_nvvm_cp_async_cg_shared_global_16,
int_nvvm_cp_async_cg_shared_global_16_s>;
-def CP_ASYNC_COMMIT_GROUP :
- BasicNVPTXInst<(outs), (ins), "cp.async.commit_group", [(int_nvvm_cp_async_commit_group)]>,
- Requires<[hasPTX<70>, hasSM<80>]>;
+let Predicates = [hasPTX<70>, hasSM<80>] in {
+ def CP_ASYNC_COMMIT_GROUP :
+ NullaryInst<"cp.async.commit_group", int_nvvm_cp_async_commit_group>;
-def CP_ASYNC_WAIT_GROUP :
- BasicNVPTXInst<(outs), (ins i32imm:$n), "cp.async.wait_group",
- [(int_nvvm_cp_async_wait_group timm:$n)]>,
- Requires<[hasPTX<70>, hasSM<80>]>;
+ def CP_ASYNC_WAIT_GROUP :
+ BasicNVPTXInst<(outs), (ins i32imm:$n), "cp.async.wait_group",
+ [(int_nvvm_cp_async_wait_group timm:$n)]>;
-def CP_ASYNC_WAIT_ALL :
- BasicNVPTXInst<(outs), (ins), "cp.async.wait_all",
- [(int_nvvm_cp_async_wait_all)]>,
- Requires<[hasPTX<70>, hasSM<80>]>;
+ def CP_ASYNC_WAIT_ALL :
+ NullaryInst<"cp.async.wait_all", int_nvvm_cp_async_wait_all>;
+}
-// cp.async.bulk variants of the commit/wait group
-def CP_ASYNC_BULK_COMMIT_GROUP :
- BasicNVPTXInst<(outs), (ins), "cp.async.bulk.commit_group",
- [(int_nvvm_cp_async_bulk_commit_group)]>,
- Requires<[hasPTX<80>, hasSM<90>]>;
+let Predicates = [hasPTX<80>, hasSM<90>] in {
+ // cp.async.bulk variants of the commit/wait group
+ def CP_ASYNC_BULK_COMMIT_GROUP :
+ NullaryInst<"cp.async.bulk.commit_group", int_nvvm_cp_async_bulk_commit_group>;
-def CP_ASYNC_BULK_WAIT_GROUP :
- BasicNVPTXInst<(outs), (ins i32imm:$n), "cp.async.bulk.wait_group",
- [(int_nvvm_cp_async_bulk_wait_group timm:$n)]>,
- Requires<[hasPTX<80>, hasSM<90>]>;
+ def CP_ASYNC_BULK_WAIT_GROUP :
+ BasicNVPTXInst<(outs), (ins i32imm:$n), "cp.async.bulk.wait_group",
+ [(int_nvvm_cp_async_bulk_wait_group timm:$n)]>;
-def CP_ASYNC_BULK_WAIT_GROUP_READ :
- BasicNVPTXInst<(outs), (ins i32imm:$n), "cp.async.bulk.wait_group.read",
- [(int_nvvm_cp_async_bulk_wait_group_read timm:$n)]>,
- Requires<[hasPTX<80>, hasSM<90>]>;
+ def CP_ASYNC_BULK_WAIT_GROUP_READ :
+ BasicNVPTXInst<(outs), (ins i32imm:$n), "cp.async.bulk.wait_group.read",
+ [(int_nvvm_cp_async_bulk_wait_group_read timm:$n)]>;
+}
//------------------------------
// TMA Async Bulk Copy Functions
@@ -600,12 +578,23 @@ defm CP_ASYNC_BULK_PREFETCH_CH : CP_ASYNC_BULK_PREFETCH_INTR<has_ch = 1>;
// TMA Async Bulk Tensor Copy Functions
//-------------------------------------
-class TMA_DIMS_UTIL<int dim> {
+class TMA_DIMS_UTIL<int dim, string mode = ""> {
// For example, when 'dim' is 3, this generates:
// an ins_dag: B32:$d0, B32:$d1, B32:$d2
// with base_str: $d0, $d1, $d2
dag ins_dag = !dag(ins, !listsplat(B32, dim), !foreach(i, !range(dim), "d" # i));
string base_str = !interleave(!foreach(i, !range(dim), "$d" # i), ", ");
+
+ // Tile::Gather4/scatter4 actually operate on a 2D tensor,
+ // though they take 5 co-ordinates.
+ //
+ // The scatter-gather happens over 4 rows with a fixed
+ // column-index. The first co-ordinate represents the
+ // col-index followed by four row-indices.
+ int num_dims = !cond(
+ !eq(mode, "tile_scatter4") : 2,
+ !eq(mode, "tile_gather4") : 2,
+ true : dim); // for all other modes
}
class TMA_IM2COL_UTIL<int dim, string mode> {
@@ -692,14 +681,138 @@ foreach dim = [1, 2, 3, 4, 5] in {
}
}
+multiclass TMA_TENSOR_G2S_INTR<int dim, string mode, list<Predicate> pred = []> {
+ defvar dims_dag = TMA_DIMS_UTIL<dim>.ins_dag;
+ defvar dims_str = TMA_DIMS_UTIL<dim>.base_str;
+ defvar asm_str_base = "$cg [$dst], [$tmap, {{" # dims_str # "}}], [$mbar]";
+
+ defvar im2col_dag = TMA_IM2COL_UTIL<dim, mode>.ins_dag;
+ defvar im2col_str = TMA_IM2COL_UTIL<dim, mode>.base_str;
+ defvar asm_str = !if(!empty(im2col_str),
+ asm_str_base,
+ asm_str_base # ", {{" # im2col_str # "}}");
+
+ defvar dim_val = TMA_DIMS_UTIL<dim, mode>.num_dims;
+ defvar inst_name = "cp.async.bulk.tensor"
+ # "." # dim_val # "d"
+ # "." # "shared::cluster.global"
+ # "." # !subst("_", "::", mode)
+ # "." # "mbarrier::complete_tx::bytes";
+ defvar intr = !cast<Intrinsic>(
+ "int_nvvm_cp_async_bulk_tensor_g2s_" # mode # "_" # dim_val # "d");
+
+ defvar ins_dag = !con(
+ (ins ADDR:$dst, ADDR:$mbar, B64:$tmap),
+ dims_dag, im2col_dag,
+ (ins B16:$mc, B64:$ch, CTAGroupFlags:$cg));
+
+ defvar intr_dag_base = !con(
+ (intr addr:$dst, addr:$mbar, B64:$tmap),
+ !setdagop(dims_dag, intr),
+ !setdagop(im2col_dag, intr),
+ (intr B16:$mc, B64:$ch));
+ defvar intr_dag_no_hints = !con(intr_dag_base, (intr 0, 0, timm:$cg));
+ defvar intr_dag_with_mc = !con(intr_dag_base, (intr -1, 0, timm:$cg));
+ defvar intr_dag_with_ch = !con(intr_dag_base, (intr 0, -1, timm:$cg));
+ defvar intr_dag_with_mc_ch = !con(intr_dag_base, (intr -1, -1, timm:$cg));
+
+ def "" : NVPTXInst<(outs), ins_dag,
+ inst_name # asm_str # ";",
+ [intr_dag_no_hints]>,
+ Requires<pred>;
+ def _MC : NVPTXInst<(outs), ins_dag,
+ inst_name # ".multicast::cluster" # asm_str # ", $mc;",
+ [intr_dag_with_mc]>,
+ Requires<pred>;
+ def _CH : NVPTXInst<(outs), ins_dag,
+ inst_name # ".L2::cache_hint" # asm_str # ", $ch;",
+ [intr_dag_with_ch]>,
+ Requires<pred>;
+ def _MC_CH : NVPTXInst<(outs), ins_dag,
+ inst_name # ".multicast::cluster.L2::cache_hint" # asm_str # ", $mc, $ch;",
+ [intr_dag_with_mc_ch]>,
+ Requires<pred>;
+}
+foreach dim = 3...5 in {
+ foreach mode = ["im2col_w", "im2col_w_128"] in {
+ defm TMA_G2S_ # !toupper(mode) # "_" # dim # "D"
+ : TMA_TENSOR_G2S_INTR<dim, mode, [hasTMACTAGroupSupport]>;
+ }
+}
+defm TMA_G2S_TILE_GATHER4_2D : TMA_TENSOR_G2S_INTR<5, "tile_gather4",
+ [hasTMACTAGroupSupport]>;
+
+multiclass TMA_TENSOR_G2S_CTA_INTR<int dim, string mode, list<Predicate> pred = []> {
+ defvar dims_dag = TMA_DIMS_UTIL<dim>.ins_dag;
+ defvar dims_str = TMA_DIMS_UTIL<dim>.base_str;
+ defvar asm_str_base = " [$dst], [$tmap, {{" # dims_str # "}}], [$mbar]";
+
+ defvar im2col_dag = TMA_IM2COL_UTIL<dim, mode>.ins_dag;
+ defvar im2col_str = TMA_IM2COL_UTIL<dim, mode>.base_str;
+ defvar asm_str = !if(!empty(im2col_str),
+ asm_str_base,
+ asm_str_base # ", {{" # im2col_str # "}}");
+
+ defvar ins_dag = !con(
+ (ins ADDR:$dst, ADDR:$mbar, B64:$tmap),
+ dims_dag, im2col_dag,
+ (ins B64:$ch));
+
+ defvar dim_val = TMA_DIMS_UTIL<dim, mode>.num_dims;
+ defvar intr = !cast<Intrinsic>(
+ "int_nvvm_cp_async_bulk_tensor_g2s_cta_" # mode # "_" # dim_val # "d");
+ defvar intr_dag = !con(
+ (intr addr:$dst, addr:$mbar, B64:$tmap),
+ !setdagop(dims_dag, intr),
+ !setdagop(im2col_dag, intr),
+ (intr B64:$ch, 0));
+ defvar intr_dag_with_ch = !con(
+ (intr addr:$dst, addr:$mbar, B64:$tmap),
+ !setdagop(dims_dag, intr),
+ !setdagop(im2col_dag, intr),
+ (intr B64:$ch, -1));
+ defvar inst_name = "cp.async.bulk.tensor"
+ # "." # dim_val # "d"
+ # "." # "shared::cta.global"
+ # "." # !subst("_", "::", mode)
+ # "." # "mbarrier::complete_tx::bytes";
+
+ def "" : NVPTXInst<(outs), ins_dag,
+ inst_name # asm_str # ";",
+ [intr_dag]>,
+ Requires<pred>;
+ def _CH : NVPTXInst<(outs), ins_dag,
+ inst_name # ".L2::cache_hint" # asm_str # ", $ch;",
+ [intr_dag_with_ch]>,
+ Requires<pred>;
+}
+foreach dim = 1...5 in {
+ defm TMA_G2S_CTA_TILE_ # dim # "D"
+ : TMA_TENSOR_G2S_CTA_INTR<dim, "tile", [hasPTX<86>, hasSM<90>]>;
+}
+foreach dim = 3...5 in {
+ defm TMA_G2S_CTA_IM2COL_ # dim # "D"
+ : TMA_TENSOR_G2S_CTA_INTR<dim, "im2col", [hasPTX<86>, hasSM<90>]>;
+
+ defm TMA_G2S_CTA_IM2COL_W_ # dim # "D"
+ : TMA_TENSOR_G2S_CTA_INTR<dim, "im2col_w", [hasPTX<86>, hasSM<100>]>;
+
+ defm TMA_G2S_CTA_IM2COL_W_128_ # dim # "D"
+ : TMA_TENSOR_G2S_CTA_INTR<dim, "im2col_w_128", [hasTMACTAGroupSupport]>;
+}
+defm TMA_G2S_CTA_TILE_GATHER4_2D : TMA_TENSOR_G2S_CTA_INTR<5, "tile_gather4",
+ [hasPTX<86>, hasSM<100>]>;
+
multiclass TMA_TENSOR_S2G_INTR<int dim, string mode,
list<Predicate> pred = [hasPTX<80>, hasSM<90>]> {
defvar dims_dag = TMA_DIMS_UTIL<dim>.ins_dag;
defvar dims_str = TMA_DIMS_UTIL<dim>.base_str;
defvar asm_str = " [$tmap, {{" # dims_str # "}}], [$src]";
+ defvar dim_val = TMA_DIMS_UTIL<dim, mode>.num_dims;
defvar intr = !cast<Intrinsic>(
- "int_nvvm_cp_async_bulk_tensor_s2g_" # mode # "_" # dim # d);
+ "int_nvvm_cp_async_bulk_tensor_s2g_" # mode # "_" # dim_val # "d");
+
defvar intr_dag = !con((intr addr:$src, B64:$tmap),
!setdagop(dims_dag, intr),
(intr B64:$ch, 0));
@@ -707,11 +820,13 @@ multiclass TMA_TENSOR_S2G_INTR<int dim, string mode,
!setdagop(dims_dag, intr),
(intr B64:$ch, -1));
- // For im2col mode, the actual asm_str is "im2col_no_offs"
- defvar mode_asm_str = !if(!eq(mode, "im2col"),
- "im2col_no_offs", mode);
+ // Fix-up the asm_str when it is im2col/scatter4.
+ defvar mode_asm_str = !cond(
+ !eq(mode, "im2col") : "im2col_no_offs",
+ !eq(mode, "tile_scatter4") : "tile::scatter4",
+ true : mode);
defvar prefix = "cp.async.bulk.tensor"
- # "." # dim # "d"
+ # "." # dim_val # "d"
# ".global.shared::cta"
# "." # mode_asm_str
# ".bulk_group";
@@ -729,10 +844,12 @@ multiclass TMA_TENSOR_S2G_INTR<int dim, string mode,
}
foreach dim = 1...5 in {
foreach mode = !if(!ge(dim, 3), ["tile", "im2col"], ["tile"]) in {
- defvar suffix = !toupper(mode) # "_" # dim # D;
+ defvar suffix = !toupper(mode) # "_" # dim # "D";
defm TMA_TENSOR_S2G_ # suffix : TMA_TENSOR_S2G_INTR<dim, mode>;
}
}
+defm TMA_S2G_TILE_SCATTER4_2D : TMA_TENSOR_S2G_INTR<5, "tile_scatter4",
+ [hasTMACTAGroupSupport]>;
def TMAReductionFlags : Operand<i32> {
let PrintMethod = "printTmaReductionMode";
@@ -786,13 +903,14 @@ multiclass TMA_TENSOR_PREFETCH_INTR<int dim, string mode,
asm_str_base,
asm_str_base # ", {{" # im2col_str # "}}");
+ defvar dim_val = TMA_DIMS_UTIL<dim, mode>.num_dims;
defvar inst_name = "cp.async.bulk.prefetch.tensor"
- # "." # dim # "d"
+ # "." # dim_val # "d"
# "." # "L2.global"
- # "." # mode;
+ # "." # !subst("_", "::", mode);
defvar intr = !cast<Intrinsic>(
- "int_nvvm_cp_async_bulk_tensor_prefetch_" # mode # "_" # dim # d);
+ "int_nvvm_cp_async_bulk_tensor_prefetch_" # mode # "_" # dim_val # "d");
defvar ins_dag = !con((ins B64:$tmap),
dims_dag,
@@ -818,40 +936,46 @@ multiclass TMA_TENSOR_PREFETCH_INTR<int dim, string mode,
}
foreach dim = 1...5 in {
foreach mode = !if(!ge(dim, 3), ["tile", "im2col"], ["tile"]) in {
- defvar suffix = !toupper(mode) # "_" # dim # D;
+ defvar suffix = !toupper(mode) # "_" # dim # "D";
defm TMA_TENSOR_PF_ # suffix : TMA_TENSOR_PREFETCH_INTR<dim, mode>;
}
}
+foreach dim = 3...5 in {
+ foreach mode = ["im2col_w", "im2col_w_128"] in {
+ defvar suffix = !toupper(mode) # "_" # dim # "D";
+ defm TMA_TENSOR_PF_ # suffix : TMA_TENSOR_PREFETCH_INTR<dim, mode,
+ [hasTMACTAGroupSupport]>;
+ }
+}
+defm TMA_TENSOR_PF_TILE_GATHER4_2D : TMA_TENSOR_PREFETCH_INTR<5, "tile_gather4",
+ [hasTMACTAGroupSupport]>;
//Prefetch and Prefetchu
-class PREFETCH_INTRS<string InstName> :
- BasicNVPTXInst<(outs), (ins ADDR:$addr),
- InstName,
- [(!cast<Intrinsic>(!strconcat("int_nvvm_",
- !subst(".", "_", InstName))) addr:$addr)]>,
- Requires<[hasPTX<80>, hasSM<90>]>;
-
-
-def PREFETCH_L1 : PREFETCH_INTRS<"prefetch.L1">;
-def PREFETCH_L2 : PREFETCH_INTRS<"prefetch.L2">;
-def PREFETCH_GLOBAL_L1 : PREFETCH_INTRS<"prefetch.global.L1">;
-def PREFETCH_LOCAL_L1 : PREFETCH_INTRS<"prefetch.local.L1">;
-def PREFETCH_GLOBAL_L2 : PREFETCH_INTRS<"prefetch.global.L2">;
-def PREFETCH_LOCAL_L2 : PREFETCH_INTRS<"prefetch.local.L2">;
+let Predicates = [hasPTX<80>, hasSM<90>] in {
+ class PREFETCH_INTRS<string InstName> :
+ BasicNVPTXInst<(outs), (ins ADDR:$addr),
+ InstName,
+ [(!cast<Intrinsic>(!strconcat("int_nvvm_",
+ !subst(".", "_", InstName))) addr:$addr)]>;
-def PREFETCH_GLOBAL_L2_EVICT_NORMAL : BasicNVPTXInst<(outs), (ins ADDR:$addr),
- "prefetch.global.L2::evict_normal",
- [(int_nvvm_prefetch_global_L2_evict_normal addr:$addr)]>,
- Requires<[hasPTX<80>, hasSM<90>]>;
+ def PREFETCH_L1 : PREFETCH_INTRS<"prefetch.L1">;
+ def PREFETCH_L2 : PREFETCH_INTRS<"prefetch.L2">;
+ def PREFETCH_GLOBAL_L1 : PREFETCH_INTRS<"prefetch.global.L1">;
+ def PREFETCH_LOCAL_L1 : PREFETCH_INTRS<"prefetch.local.L1">;
+ def PREFETCH_GLOBAL_L2 : PREFETCH_INTRS<"prefetch.global.L2">;
+ def PREFETCH_LOCAL_L2 : PREFETCH_INTRS<"prefetch.local.L2">;
-def PREFETCH_GLOBAL_L2_EVICT_LAST : BasicNVPTXInst<(outs), (ins ADDR:$addr),
- "prefetch.global.L2::evict_last",
- [(int_nvvm_prefetch_global_L2_evict_last addr:$addr)]>,
- Requires<[hasPTX<80>, hasSM<90>]>;
+ def PREFETCH_GLOBAL_L2_EVICT_NORMAL : BasicNVPTXInst<(outs), (ins ADDR:$addr),
+ "prefetch.global.L2::evict_normal",
+ [(int_nvvm_prefetch_global_L2_evict_normal addr:$addr)]>;
+ def PREFETCH_GLOBAL_L2_EVICT_LAST : BasicNVPTXInst<(outs), (ins ADDR:$addr),
+ "prefetch.global.L2::evict_last",
+ [(int_nvvm_prefetch_global_L2_evict_last addr:$addr)]>;
-def PREFETCHU_L1 : PREFETCH_INTRS<"prefetchu.L1">;
+ def PREFETCHU_L1 : PREFETCH_INTRS<"prefetchu.L1">;
+}
//Applypriority intrinsics
class APPLYPRIORITY_L2_INTRS<string addrspace> :
@@ -882,99 +1006,82 @@ def DISCARD_GLOBAL_L2 : DISCARD_L2_INTRS<"global">;
// MBarrier Functions
//-----------------------------------
-multiclass MBARRIER_INIT<string AddrSpace, Intrinsic Intrin> {
- def "" : BasicNVPTXInst<(outs), (ins ADDR:$addr, B32:$count),
- "mbarrier.init" # AddrSpace # ".b64",
- [(Intrin addr:$addr, i32:$count)]>,
- Requires<[hasPTX<70>, hasSM<80>]>;
-}
-
-defm MBARRIER_INIT : MBARRIER_INIT<"", int_nvvm_mbarrier_init>;
-defm MBARRIER_INIT_SHARED : MBARRIER_INIT<".shared",
- int_nvvm_mbarrier_init_shared>;
-
-multiclass MBARRIER_INVAL<string AddrSpace, Intrinsic Intrin> {
- def "" : BasicNVPTXInst<(outs), (ins ADDR:$addr),
- "mbarrier.inval" # AddrSpace # ".b64",
- [(Intrin addr:$addr)]>,
- Requires<[hasPTX<70>, hasSM<80>]>;
-}
-
-defm MBARRIER_INVAL : MBARRIER_INVAL<"", int_nvvm_mbarrier_inval>;
-defm MBARRIER_INVAL_SHARED : MBARRIER_INVAL<".shared",
- int_nvvm_mbarrier_inval_shared>;
-
-multiclass MBARRIER_ARRIVE<string AddrSpace, Intrinsic Intrin> {
- def "" : BasicNVPTXInst<(outs B64:$state), (ins ADDR:$addr),
- "mbarrier.arrive" # AddrSpace # ".b64",
- [(set i64:$state, (Intrin addr:$addr))]>,
- Requires<[hasPTX<70>, hasSM<80>]>;
-}
-
-defm MBARRIER_ARRIVE : MBARRIER_ARRIVE<"", int_nvvm_mbarrier_arrive>;
-defm MBARRIER_ARRIVE_SHARED :
- MBARRIER_ARRIVE<".shared", int_nvvm_mbarrier_arrive_shared>;
-
-multiclass MBARRIER_ARRIVE_NOCOMPLETE<string AddrSpace, Intrinsic Intrin> {
- def "" : BasicNVPTXInst<(outs B64:$state),
- (ins ADDR:$addr, B32:$count),
- "mbarrier.arrive.noComplete" # AddrSpace # ".b64",
- [(set i64:$state, (Intrin addr:$addr, i32:$count))]>,
- Requires<[hasPTX<70>, hasSM<80>]>;
-}
-
-defm MBARRIER_ARRIVE_NOCOMPLETE :
- MBARRIER_ARRIVE_NOCOMPLETE<"", int_nvvm_mbarrier_arrive_noComplete>;
-defm MBARRIER_ARRIVE_NOCOMPLETE_SHARED :
- MBARRIER_ARRIVE_NOCOMPLETE<".shared", int_nvvm_mbarrier_arrive_noComplete_shared>;
-
-multiclass MBARRIER_ARRIVE_DROP<string AddrSpace, Intrinsic Intrin> {
- def "" : BasicNVPTXInst<(outs B64:$state), (ins ADDR:$addr),
- "mbarrier.arrive_drop" # AddrSpace # ".b64",
- [(set i64:$state, (Intrin addr:$addr))]>,
- Requires<[hasPTX<70>, hasSM<80>]>;
-}
-
-defm MBARRIER_ARRIVE_DROP :
- MBARRIER_ARRIVE_DROP<"", int_nvvm_mbarrier_arrive_drop>;
-defm MBARRIER_ARRIVE_DROP_SHARED :
- MBARRIER_ARRIVE_DROP<".shared", int_nvvm_mbarrier_arrive_drop_shared>;
-
-multiclass MBARRIER_ARRIVE_DROP_NOCOMPLETE<string AddrSpace, Intrinsic Intrin> {
- def "" : BasicNVPTXInst<(outs B64:$state),
- (ins ADDR:$addr, B32:$count),
- "mbarrier.arrive_drop.noComplete" # AddrSpace # ".b64",
- [(set i64:$state, (Intrin addr:$addr, i32:$count))]>,
- Requires<[hasPTX<70>, hasSM<80>]>;
-}
-
-defm MBARRIER_ARRIVE_DROP_NOCOMPLETE :
- MBARRIER_ARRIVE_DROP_NOCOMPLETE<"", int_nvvm_mbarrier_arrive_drop_noComplete>;
-defm MBARRIER_ARRIVE_DROP_NOCOMPLETE_SHARED :
- MBARRIER_ARRIVE_DROP_NOCOMPLETE<".shared",
- int_nvvm_mbarrier_arrive_drop_noComplete_shared>;
-
-multiclass MBARRIER_TEST_WAIT<string AddrSpace, Intrinsic Intrin> {
- def "" : BasicNVPTXInst<(outs B1:$res), (ins ADDR:$addr, B64:$state),
- "mbarrier.test_wait" # AddrSpace # ".b64",
- [(set i1:$res, (Intrin addr:$addr, i64:$state))]>,
- Requires<[hasPTX<70>, hasSM<80>]>;
+let Predicates = [hasPTX<70>, hasSM<80>] in {
+ class MBARRIER_INIT<string AddrSpace, Intrinsic Intrin> :
+ BasicNVPTXInst<(outs), (ins ADDR:$addr, B32:$count),
+ "mbarrier.init" # AddrSpace # ".b64",
+ [(Intrin addr:$addr, i32:$count)]>;
+
+ def MBARRIER_INIT : MBARRIER_INIT<"", int_nvvm_mbarrier_init>;
+ def MBARRIER_INIT_SHARED : MBARRIER_INIT<".shared",
+ int_nvvm_mbarrier_init_shared>;
+
+ class MBARRIER_INVAL<string AddrSpace, Intrinsic Intrin> :
+ BasicNVPTXInst<(outs), (ins ADDR:$addr),
+ "mbarrier.inval" # AddrSpace # ".b64",
+ [(Intrin addr:$addr)]>;
+
+ def MBARRIER_INVAL : MBARRIER_INVAL<"", int_nvvm_mbarrier_inval>;
+ def MBARRIER_INVAL_SHARED : MBARRIER_INVAL<".shared",
+ int_nvvm_mbarrier_inval_shared>;
+
+ class MBARRIER_ARRIVE<string AddrSpace, Intrinsic Intrin> :
+ BasicNVPTXInst<(outs B64:$state), (ins ADDR:$addr),
+ "mbarrier.arrive" # AddrSpace # ".b64",
+ [(set i64:$state, (Intrin addr:$addr))]>;
+
+ def MBARRIER_ARRIVE : MBARRIER_ARRIVE<"", int_nvvm_mbarrier_arrive>;
+ def MBARRIER_ARRIVE_SHARED :
+ MBARRIER_ARRIVE<".shared", int_nvvm_mbarrier_arrive_shared>;
+
+ class MBARRIER_ARRIVE_NOCOMPLETE<string AddrSpace, Intrinsic Intrin> :
+ BasicNVPTXInst<(outs B64:$state),
+ (ins ADDR:$addr, B32:$count),
+ "mbarrier.arrive.noComplete" # AddrSpace # ".b64",
+ [(set i64:$state, (Intrin addr:$addr, i32:$count))]>;
+
+ def MBARRIER_ARRIVE_NOCOMPLETE :
+ MBARRIER_ARRIVE_NOCOMPLETE<"", int_nvvm_mbarrier_arrive_noComplete>;
+ def MBARRIER_ARRIVE_NOCOMPLETE_SHARED :
+ MBARRIER_ARRIVE_NOCOMPLETE<".shared", int_nvvm_mbarrier_arrive_noComplete_shared>;
+
+ class MBARRIER_ARRIVE_DROP<string AddrSpace, Intrinsic Intrin> :
+ BasicNVPTXInst<(outs B64:$state), (ins ADDR:$addr),
+ "mbarrier.arrive_drop" # AddrSpace # ".b64",
+ [(set i64:$state, (Intrin addr:$addr))]>;
+
+ def MBARRIER_ARRIVE_DROP :
+ MBARRIER_ARRIVE_DROP<"", int_nvvm_mbarrier_arrive_drop>;
+ def MBARRIER_ARRIVE_DROP_SHARED :
+ MBARRIER_ARRIVE_DROP<".shared", int_nvvm_mbarrier_arrive_drop_shared>;
+
+ class MBARRIER_ARRIVE_DROP_NOCOMPLETE<string AddrSpace, Intrinsic Intrin> :
+ BasicNVPTXInst<(outs B64:$state),
+ (ins ADDR:$addr, B32:$count),
+ "mbarrier.arrive_drop.noComplete" # AddrSpace # ".b64",
+ [(set i64:$state, (Intrin addr:$addr, i32:$count))]>;
+
+ def MBARRIER_ARRIVE_DROP_NOCOMPLETE :
+ MBARRIER_ARRIVE_DROP_NOCOMPLETE<"", int_nvvm_mbarrier_arrive_drop_noComplete>;
+ def MBARRIER_ARRIVE_DROP_NOCOMPLETE_SHARED :
+ MBARRIER_ARRIVE_DROP_NOCOMPLETE<".shared",
+ int_nvvm_mbarrier_arrive_drop_noComplete_shared>;
+
+ class MBARRIER_TEST_WAIT<string AddrSpace, Intrinsic Intrin> :
+ BasicNVPTXInst<(outs B1:$res), (ins ADDR:$addr, B64:$state),
+ "mbarrier.test_wait" # AddrSpace # ".b64",
+ [(set i1:$res, (Intrin addr:$addr, i64:$state))]>;
+
+ def MBARRIER_TEST_WAIT :
+ MBARRIER_TEST_WAIT<"", int_nvvm_mbarrier_test_wait>;
+ def MBARRIER_TEST_WAIT_SHARED :
+ MBARRIER_TEST_WAIT<".shared", int_nvvm_mbarrier_test_wait_shared>;
+
+ def MBARRIER_PENDING_COUNT :
+ BasicNVPTXInst<(outs B32:$res), (ins B64:$state),
+ "mbarrier.pending_count.b64",
+ [(set i32:$res, (int_nvvm_mbarrier_pending_count i64:$state))]>;
}
-
-defm MBARRIER_TEST_WAIT :
- MBARRIER_TEST_WAIT<"", int_nvvm_mbarrier_test_wait>;
-defm MBARRIER_TEST_WAIT_SHARED :
- MBARRIER_TEST_WAIT<".shared", int_nvvm_mbarrier_test_wait_shared>;
-
-class MBARRIER_PENDING_COUNT<Intrinsic Intrin> :
- BasicNVPTXInst<(outs B32:$res), (ins B64:$state),
- "mbarrier.pending_count.b64",
- [(set i32:$res, (Intrin i64:$state))]>,
- Requires<[hasPTX<70>, hasSM<80>]>;
-
-def MBARRIER_PENDING_COUNT :
- MBARRIER_PENDING_COUNT<int_nvvm_mbarrier_pending_count>;
-
//-----------------------------------
// Math Functions
//-----------------------------------
@@ -1300,15 +1407,11 @@ defm ABS_F64 : F_ABS<"f64", F64RT, support_ftz = false>;
def fcopysign_nvptx : SDNode<"NVPTXISD::FCOPYSIGN", SDTFPBinOp>;
-def COPYSIGN_F :
- BasicNVPTXInst<(outs B32:$dst), (ins B32:$src0, B32:$src1),
- "copysign.f32",
- [(set f32:$dst, (fcopysign_nvptx f32:$src1, f32:$src0))]>;
-
-def COPYSIGN_D :
- BasicNVPTXInst<(outs B64:$dst), (ins B64:$src0, B64:$src1),
- "copysign.f64",
- [(set f64:$dst, (fcopysign_nvptx f64:$src1, f64:$src0))]>;
+foreach t = [F32RT, F64RT] in
+ def COPYSIGN_ # t :
+ BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$src0, t.RC:$src1),
+ "copysign." # t.PtxType,
+ [(set t.Ty:$dst, (fcopysign_nvptx t.Ty:$src1, t.Ty:$src0))]>;
//
// Neg bf16, bf16x2
@@ -2106,38 +2209,35 @@ defm INT_PTX_SATOM_XOR : ATOM2_bitwise_impl<"xor">;
// Scalar
-class LDU_G<string TyStr, NVPTXRegClass regclass>
- : NVPTXInst<(outs regclass:$result), (ins ADDR:$src),
- "ldu.global." # TyStr # " \t$result, [$src];", []>;
+class LDU_G<NVPTXRegClass regclass>
+ : NVPTXInst<(outs regclass:$result), (ins i32imm:$fromWidth, ADDR:$src),
+ "ldu.global.b$fromWidth \t$result, [$src];", []>;
-def LDU_GLOBAL_i8 : LDU_G<"b8", B16>;
-def LDU_GLOBAL_i16 : LDU_G<"b16", B16>;
-def LDU_GLOBAL_i32 : LDU_G<"b32", B32>;
-def LDU_GLOBAL_i64 : LDU_G<"b64", B64>;
+def LDU_GLOBAL_i16 : LDU_G<B16>;
+def LDU_GLOBAL_i32 : LDU_G<B32>;
+def LDU_GLOBAL_i64 : LDU_G<B64>;
// vector
// Elementized vector ldu
-class VLDU_G_ELE_V2<string TyStr, NVPTXRegClass regclass>
+class VLDU_G_ELE_V2<NVPTXRegClass regclass>
: NVPTXInst<(outs regclass:$dst1, regclass:$dst2),
- (ins ADDR:$src),
- "ldu.global.v2." # TyStr # " \t{{$dst1, $dst2}}, [$src];", []>;
+ (ins i32imm:$fromWidth, ADDR:$src),
+ "ldu.global.v2.b$fromWidth \t{{$dst1, $dst2}}, [$src];", []>;
-class VLDU_G_ELE_V4<string TyStr, NVPTXRegClass regclass>
- : NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3,
- regclass:$dst4), (ins ADDR:$src),
- "ldu.global.v4." # TyStr # " \t{{$dst1, $dst2, $dst3, $dst4}}, [$src];", []>;
+class VLDU_G_ELE_V4<NVPTXRegClass regclass>
+ : NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, regclass:$dst4),
+ (ins i32imm:$fromWidth, ADDR:$src),
+ "ldu.global.v4.b$fromWidth \t{{$dst1, $dst2, $dst3, $dst4}}, [$src];", []>;
-def LDU_GLOBAL_v2i8 : VLDU_G_ELE_V2<"b8", B16>;
-def LDU_GLOBAL_v2i16 : VLDU_G_ELE_V2<"b16", B16>;
-def LDU_GLOBAL_v2i32 : VLDU_G_ELE_V2<"b32", B32>;
-def LDU_GLOBAL_v2i64 : VLDU_G_ELE_V2<"b64", B64>;
+def LDU_GLOBAL_v2i16 : VLDU_G_ELE_V2<B16>;
+def LDU_GLOBAL_v2i32 : VLDU_G_ELE_V2<B32>;
+def LDU_GLOBAL_v2i64 : VLDU_G_ELE_V2<B64>;
-def LDU_GLOBAL_v4i8 : VLDU_G_ELE_V4<"b8", B16>;
-def LDU_GLOBAL_v4i16 : VLDU_G_ELE_V4<"b16", B16>;
-def LDU_GLOBAL_v4i32 : VLDU_G_ELE_V4<"b32", B32>;
+def LDU_GLOBAL_v4i16 : VLDU_G_ELE_V4<B16>;
+def LDU_GLOBAL_v4i32 : VLDU_G_ELE_V4<B32>;
//-----------------------------------
@@ -2178,12 +2278,10 @@ class VLDG_G_ELE_V8<NVPTXRegClass regclass> :
"ld.global.nc.v8.${Sign:sign}$fromWidth \t{{$dst1, $dst2, $dst3, $dst4, $dst5, $dst6, $dst7, $dst8}}, [$src];", []>;
// FIXME: 8-bit LDG should be fixed once LDG/LDU nodes are made into proper loads.
-def LD_GLOBAL_NC_v2i8 : VLDG_G_ELE_V2<B16>;
def LD_GLOBAL_NC_v2i16 : VLDG_G_ELE_V2<B16>;
def LD_GLOBAL_NC_v2i32 : VLDG_G_ELE_V2<B32>;
def LD_GLOBAL_NC_v2i64 : VLDG_G_ELE_V2<B64>;
-def LD_GLOBAL_NC_v4i8 : VLDG_G_ELE_V4<B16>;
def LD_GLOBAL_NC_v4i16 : VLDG_G_ELE_V4<B16>;
def LD_GLOBAL_NC_v4i32 : VLDG_G_ELE_V4<B32>;
@@ -2193,19 +2291,19 @@ def LD_GLOBAL_NC_v8i32 : VLDG_G_ELE_V8<B32>;
multiclass NG_TO_G<string Str, bit Supports32 = 1, list<Predicate> Preds = []> {
if Supports32 then
def "" : BasicNVPTXInst<(outs B32:$result), (ins B32:$src),
- "cvta." # Str # ".u32", []>, Requires<Preds>;
+ "cvta." # Str # ".u32">, Requires<Preds>;
def _64 : BasicNVPTXInst<(outs B64:$result), (ins B64:$src),
- "cvta." # Str # ".u64", []>, Requires<Preds>;
+ "cvta." # Str # ".u64">, Requires<Preds>;
}
multiclass G_TO_NG<string Str, bit Supports32 = 1, list<Predicate> Preds = []> {
if Supports32 then
def "" : BasicNVPTXInst<(outs B32:$result), (ins B32:$src),
- "cvta.to." # Str # ".u32", []>, Requires<Preds>;
+ "cvta.to." # Str # ".u32">, Requires<Preds>;
def _64 : BasicNVPTXInst<(outs B64:$result), (ins B64:$src),
- "cvta.to." # Str # ".u64", []>, Requires<Preds>;
+ "cvta.to." # Str # ".u64">, Requires<Preds>;
}
foreach space = ["local", "shared", "global", "const", "param"] in {
@@ -4465,9 +4563,9 @@ def INT_PTX_SREG_LANEMASK_GT :
PTX_READ_SREG_R32<"lanemask_gt", int_nvvm_read_ptx_sreg_lanemask_gt>;
let hasSideEffects = 1 in {
-def SREG_CLOCK : PTX_READ_SREG_R32<"clock", int_nvvm_read_ptx_sreg_clock>;
-def SREG_CLOCK64 : PTX_READ_SREG_R64<"clock64", int_nvvm_read_ptx_sreg_clock64>;
-def SREG_GLOBALTIMER : PTX_READ_SREG_R64<"globaltimer", int_nvvm_read_ptx_sreg_globaltimer>;
+ def SREG_CLOCK : PTX_READ_SREG_R32<"clock", int_nvvm_read_ptx_sreg_clock>;
+ def SREG_CLOCK64 : PTX_READ_SREG_R64<"clock64", int_nvvm_read_ptx_sreg_clock64>;
+ def SREG_GLOBALTIMER : PTX_READ_SREG_R64<"globaltimer", int_nvvm_read_ptx_sreg_globaltimer>;
}
def: Pat <(i64 (readcyclecounter)), (SREG_CLOCK64)>;
@@ -4609,7 +4707,14 @@ class WMMA_REGINFO<WMMA_REGS r, string op>
!and(!eq(op, "ldmatrix"),
!eq(ptx_elt_type, "b8x16.b4x16_p64"),
- !eq(geom, "m8n16")) : [hasSM<100>, hasArchAccelFeatures, hasPTX<86>]);
+ !eq(geom, "m8n16")) : [hasSM<100>, hasArchAccelFeatures, hasPTX<86>],
+
+ !and(!eq(op, "stmatrix"),!eq(ptx_elt_type, "b16"),
+ !eq(geom, "m8n8")) : [hasSM<90>, hasPTX<78>],
+
+ !and(!eq(op, "stmatrix"),
+ !eq(ptx_elt_type, "b8"),
+ !eq(geom, "m16n8")) : [hasSM<100>, hasArchAccelFeatures, hasPTX<86>]);
// template DAGs for instruction inputs/output.
dag Outs = !dag(outs, ptx_regs, reg_names);
@@ -4890,6 +4995,42 @@ defset list<WMMA_INSTR> LDMATRIXs = {
} // transposed
} // defset
+//
+// stmatrix.sync.aligned.m8n8[|.trans][|.shared].b16
+//
+class STMATRIX<WMMA_REGINFO Frag, bit Transposed, string Space>
+ : WMMA_INSTR<STMATRIX_NAME<Frag, Transposed>.record, [!con((ins ADDR:$dst), Frag.Ins)]>,
+ Requires<Frag.Predicates> {
+ // Build PatFrag that only matches particular address space.
+ dag PFOperands = !con((ops node:$dst),
+ !dag(ops, !listsplat(node, !size(Frag.regs)), Frag.reg_names));
+ PatFrag IntrFrag = PatFrag<PFOperands,
+ !foreach(tmp, PFOperands, !subst(ops, Intr, tmp)),
+ !cond(!eq(Space, ".shared"): AS_match.shared,
+ true: AS_match.generic)>;
+ // Build AS-constrained pattern.
+ let IntrinsicPattern = BuildPatternPF<IntrFrag, Args>.ret;
+ let OutOperandList = (outs);
+ let InOperandList = !con(Args, (ins MmaCode:$ptx));
+ let AsmString = "stmatrix.sync.aligned."
+ # Frag.geom
+ # "." # Frag.frag
+ # !if(Transposed, ".trans", "")
+ # Space
+ # "." # Frag.ptx_elt_type
+ # " [$dst], " # Frag.regstring # ";";
+}
+
+// Create all stmatrix variants
+defset list<WMMA_INSTR> STMATRIXs = {
+ foreach transposed = [false, true] in {foreach space = [".shared", ""] in {
+ foreach frag = NVVM_MMA_OPS.all_stmatrix_ops in
+ if NVVM_STMATRIX_SUPPORTED<frag, transposed>.ret then
+ def : STMATRIX<WMMA_REGINFO<frag, "stmatrix">, transposed, space>;
+ } // space
+ } // transposed
+} // defset
+
// Constructing non-flat DAGs is still a pain. I can't !subst a dag node with a
// dag, so the ptx.version must be appended *after* foreach replaces 'ins' with
// the instruction record.
@@ -4900,41 +5041,40 @@ class MMA_PAT<WMMA_INSTR wi>
Requires<wi.Predicates>;
// Build intrinsic->instruction patterns for all MMA instructions.
-foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs, LDMATRIXs) in
+foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs, LDMATRIXs, STMATRIXs) in
def : MMA_PAT<mma>;
multiclass MAPA<string suffix, Intrinsic Intr> {
- def _32: BasicNVPTXInst<(outs B32:$d), (ins B32:$a, B32:$b),
- "mapa" # suffix # ".u32",
- [(set i32:$d, (Intr i32:$a, i32:$b))]>,
- Requires<[hasSM<90>, hasPTX<78>]>;
- def _32i: BasicNVPTXInst<(outs B32:$d), (ins B32:$a, i32imm:$b),
- "mapa" # suffix # ".u32",
- [(set i32:$d, (Intr i32:$a, imm:$b))]>,
- Requires<[hasSM<90>, hasPTX<78>]>;
- def _64: BasicNVPTXInst<(outs B64:$d), (ins B64:$a, B32:$b),
- "mapa" # suffix # ".u64",
- [(set i64:$d, (Intr i64:$a, i32:$b))]>,
- Requires<[hasSM<90>, hasPTX<78>]>;
- def _64i: BasicNVPTXInst<(outs B64:$d), (ins B64:$a, i32imm:$b),
- "mapa" # suffix # ".u64",
- [(set i64:$d, (Intr i64:$a, imm:$b))]>,
- Requires<[hasSM<90>, hasPTX<78>]>;
+ let Predicates = [hasSM<90>, hasPTX<78>] in {
+ def _32: BasicNVPTXInst<(outs B32:$d), (ins B32:$a, B32:$b),
+ "mapa" # suffix # ".u32",
+ [(set i32:$d, (Intr i32:$a, i32:$b))]>;
+ def _32i: BasicNVPTXInst<(outs B32:$d), (ins B32:$a, i32imm:$b),
+ "mapa" # suffix # ".u32",
+ [(set i32:$d, (Intr i32:$a, imm:$b))]>;
+ def _64: BasicNVPTXInst<(outs B64:$d), (ins B64:$a, B32:$b),
+ "mapa" # suffix # ".u64",
+ [(set i64:$d, (Intr i64:$a, i32:$b))]>;
+ def _64i: BasicNVPTXInst<(outs B64:$d), (ins B64:$a, i32imm:$b),
+ "mapa" # suffix # ".u64",
+ [(set i64:$d, (Intr i64:$a, imm:$b))]>;
+ }
}
+
defm mapa : MAPA<"", int_nvvm_mapa>;
defm mapa_shared_cluster : MAPA<".shared::cluster", int_nvvm_mapa_shared_cluster>;
multiclass GETCTARANK<string suffix, Intrinsic Intr> {
- def _32: BasicNVPTXInst<(outs B32:$d), (ins B32:$a),
- "getctarank" # suffix # ".u32",
- [(set i32:$d, (Intr i32:$a))]>,
- Requires<[hasSM<90>, hasPTX<78>]>;
- def _64: BasicNVPTXInst<(outs B32:$d), (ins B64:$a),
- "getctarank" # suffix # ".u64",
- [(set i32:$d, (Intr i64:$a))]>,
- Requires<[hasSM<90>, hasPTX<78>]>;
+ let Predicates = [hasSM<90>, hasPTX<78>] in {
+ def _32: BasicNVPTXInst<(outs B32:$d), (ins B32:$a),
+ "getctarank" # suffix # ".u32",
+ [(set i32:$d, (Intr i32:$a))]>;
+ def _64: BasicNVPTXInst<(outs B32:$d), (ins B64:$a),
+ "getctarank" # suffix # ".u64",
+ [(set i32:$d, (Intr i64:$a))]>;
+ }
}
defm getctarank : GETCTARANK<"", int_nvvm_getctarank>;
@@ -4973,29 +5113,25 @@ def INT_NVVM_WGMMA_WAIT_GROUP_SYNC_ALIGNED : BasicNVPTXInst<(outs), (ins i64imm:
[(int_nvvm_wgmma_wait_group_sync_aligned timm:$n)]>, Requires<[hasSM90a, hasPTX<80>]>;
} // isConvergent = true
-def GRIDDEPCONTROL_LAUNCH_DEPENDENTS :
- BasicNVPTXInst<(outs), (ins),
- "griddepcontrol.launch_dependents",
- [(int_nvvm_griddepcontrol_launch_dependents)]>,
- Requires<[hasSM<90>, hasPTX<78>]>;
-
-def GRIDDEPCONTROL_WAIT :
- BasicNVPTXInst<(outs), (ins),
- "griddepcontrol.wait",
- [(int_nvvm_griddepcontrol_wait)]>,
- Requires<[hasSM<90>, hasPTX<78>]>;
+let Predicates = [hasSM<90>, hasPTX<78>] in {
+ def GRIDDEPCONTROL_LAUNCH_DEPENDENTS :
+ BasicNVPTXInst<(outs), (ins), "griddepcontrol.launch_dependents",
+ [(int_nvvm_griddepcontrol_launch_dependents)]>;
+ def GRIDDEPCONTROL_WAIT :
+ BasicNVPTXInst<(outs), (ins), "griddepcontrol.wait",
+ [(int_nvvm_griddepcontrol_wait)]>;
+}
def INT_EXIT : BasicNVPTXInst<(outs), (ins), "exit", [(int_nvvm_exit)]>;
// Tcgen05 intrinsics
-let isConvergent = true in {
+let isConvergent = true, Predicates = [hasTcgen05Instructions] in {
multiclass TCGEN05_ALLOC_INTR<string AS, string num, Intrinsic Intr> {
def "" : BasicNVPTXInst<(outs),
(ins ADDR:$dst, B32:$ncols),
"tcgen05.alloc.cta_group::" # num # ".sync.aligned" # AS # ".b32",
- [(Intr addr:$dst, B32:$ncols)]>,
- Requires<[hasTcgen05Instructions]>;
+ [(Intr addr:$dst, B32:$ncols)]>;
}
defm TCGEN05_ALLOC_CG1 : TCGEN05_ALLOC_INTR<"", "1", int_nvvm_tcgen05_alloc_cg1>;
@@ -5008,8 +5144,7 @@ multiclass TCGEN05_DEALLOC_INTR<string num, Intrinsic Intr> {
def "" : BasicNVPTXInst<(outs),
(ins B32:$tmem_addr, B32:$ncols),
"tcgen05.dealloc.cta_group::" # num # ".sync.aligned.b32",
- [(Intr B32:$tmem_addr, B32:$ncols)]>,
- Requires<[hasTcgen05Instructions]>;
+ [(Intr B32:$tmem_addr, B32:$ncols)]>;
}
defm TCGEN05_DEALLOC_CG1: TCGEN05_DEALLOC_INTR<"1", int_nvvm_tcgen05_dealloc_cg1>;
defm TCGEN05_DEALLOC_CG2: TCGEN05_DEALLOC_INTR<"2", int_nvvm_tcgen05_dealloc_cg2>;
@@ -5017,19 +5152,13 @@ defm TCGEN05_DEALLOC_CG2: TCGEN05_DEALLOC_INTR<"2", int_nvvm_tcgen05_dealloc_cg2
multiclass TCGEN05_RELINQ_PERMIT_INTR<string num, Intrinsic Intr> {
def "" : BasicNVPTXInst<(outs), (ins),
"tcgen05.relinquish_alloc_permit.cta_group::" # num # ".sync.aligned",
- [(Intr)]>,
- Requires<[hasTcgen05Instructions]>;
+ [(Intr)]>;
}
defm TCGEN05_RELINQ_CG1: TCGEN05_RELINQ_PERMIT_INTR<"1", int_nvvm_tcgen05_relinq_alloc_permit_cg1>;
defm TCGEN05_RELINQ_CG2: TCGEN05_RELINQ_PERMIT_INTR<"2", int_nvvm_tcgen05_relinq_alloc_permit_cg2>;
-def tcgen05_wait_ld: BasicNVPTXInst<(outs), (ins), "tcgen05.wait::ld.sync.aligned",
- [(int_nvvm_tcgen05_wait_ld)]>,
- Requires<[hasTcgen05Instructions]>;
-
-def tcgen05_wait_st: BasicNVPTXInst<(outs), (ins), "tcgen05.wait::st.sync.aligned",
- [(int_nvvm_tcgen05_wait_st)]>,
- Requires<[hasTcgen05Instructions]>;
+def tcgen05_wait_ld: NullaryInst<"tcgen05.wait::ld.sync.aligned", int_nvvm_tcgen05_wait_ld>;
+def tcgen05_wait_st: NullaryInst<"tcgen05.wait::st.sync.aligned", int_nvvm_tcgen05_wait_st>;
multiclass TCGEN05_COMMIT_INTR<string AS, string num> {
defvar prefix = "tcgen05.commit.cta_group::" # num #".mbarrier::arrive::one.shared::cluster";
@@ -5040,12 +5169,10 @@ multiclass TCGEN05_COMMIT_INTR<string AS, string num> {
def "" : BasicNVPTXInst<(outs), (ins ADDR:$mbar),
prefix # ".b64",
- [(Intr addr:$mbar)]>,
- Requires<[hasTcgen05Instructions]>;
+ [(Intr addr:$mbar)]>;
def _MC : BasicNVPTXInst<(outs), (ins ADDR:$mbar, B16:$mc),
prefix # ".multicast::cluster.b64",
- [(IntrMC addr:$mbar, B16:$mc)]>,
- Requires<[hasTcgen05Instructions]>;
+ [(IntrMC addr:$mbar, B16:$mc)]>;
}
defm TCGEN05_COMMIT_CG1 : TCGEN05_COMMIT_INTR<"", "1">;
@@ -5057,8 +5184,7 @@ multiclass TCGEN05_SHIFT_INTR<string num, Intrinsic Intr> {
def "" : BasicNVPTXInst<(outs),
(ins ADDR:$tmem_addr),
"tcgen05.shift.cta_group::" # num # ".down",
- [(Intr addr:$tmem_addr)]>,
- Requires<[hasTcgen05Instructions]>;
+ [(Intr addr:$tmem_addr)]>;
}
defm TCGEN05_SHIFT_CG1: TCGEN05_SHIFT_INTR<"1", int_nvvm_tcgen05_shift_down_cg1>;
defm TCGEN05_SHIFT_CG2: TCGEN05_SHIFT_INTR<"2", int_nvvm_tcgen05_shift_down_cg2>;
@@ -5078,13 +5204,11 @@ multiclass TCGEN05_CP_INTR<string shape, string src_fmt, string mc = ""> {
def _cg1 : BasicNVPTXInst<(outs),
(ins ADDR:$tmem_addr, B64:$sdesc),
"tcgen05.cp.cta_group::1." # shape_mc_asm # fmt_asm,
- [(IntrCG1 addr:$tmem_addr, B64:$sdesc)]>,
- Requires<[hasTcgen05Instructions]>;
+ [(IntrCG1 addr:$tmem_addr, B64:$sdesc)]>;
def _cg2 : BasicNVPTXInst<(outs),
(ins ADDR:$tmem_addr, B64:$sdesc),
"tcgen05.cp.cta_group::2." # shape_mc_asm # fmt_asm,
- [(IntrCG2 addr:$tmem_addr, B64:$sdesc)]>,
- Requires<[hasTcgen05Instructions]>;
+ [(IntrCG2 addr:$tmem_addr, B64:$sdesc)]>;
}
foreach src_fmt = ["", "b6x16_p32", "b4x16_p64"] in {
@@ -5097,17 +5221,13 @@ foreach src_fmt = ["", "b6x16_p32", "b4x16_p64"] in {
}
} // isConvergent
-let hasSideEffects = 1 in {
+let hasSideEffects = 1, Predicates = [hasTcgen05Instructions] in {
-def tcgen05_fence_before_thread_sync: BasicNVPTXInst<(outs), (ins),
- "tcgen05.fence::before_thread_sync",
- [(int_nvvm_tcgen05_fence_before_thread_sync)]>,
- Requires<[hasTcgen05Instructions]>;
+ def tcgen05_fence_before_thread_sync: NullaryInst<
+ "tcgen05.fence::before_thread_sync", int_nvvm_tcgen05_fence_before_thread_sync>;
-def tcgen05_fence_after_thread_sync: BasicNVPTXInst<(outs), (ins),
- "tcgen05.fence::after_thread_sync",
- [(int_nvvm_tcgen05_fence_after_thread_sync)]>,
- Requires<[hasTcgen05Instructions]>;
+ def tcgen05_fence_after_thread_sync: NullaryInst<
+ "tcgen05.fence::after_thread_sync", int_nvvm_tcgen05_fence_after_thread_sync>;
} // hasSideEffects
@@ -5200,17 +5320,17 @@ foreach shape = ["16x64b", "16x128b", "16x256b", "32x32b", "16x32bx2"] in {
// Bulk store instructions
def st_bulk_imm : TImmLeaf<i64, [{ return Imm == 0; }]>;
-def INT_NVVM_ST_BULK_GENERIC :
- BasicNVPTXInst<(outs), (ins ADDR:$dest_addr, B64:$size, i64imm:$value),
- "st.bulk",
- [(int_nvvm_st_bulk addr:$dest_addr, i64:$size, st_bulk_imm:$value)]>,
- Requires<[hasSM<100>, hasPTX<86>]>;
+let Predicates = [hasSM<100>, hasPTX<86>] in {
+ def INT_NVVM_ST_BULK_GENERIC :
+ BasicNVPTXInst<(outs), (ins ADDR:$dest_addr, B64:$size, i64imm:$value),
+ "st.bulk",
+ [(int_nvvm_st_bulk addr:$dest_addr, i64:$size, st_bulk_imm:$value)]>;
-def INT_NVVM_ST_BULK_SHARED_CTA:
- BasicNVPTXInst<(outs), (ins ADDR:$dest_addr, B64:$size, i64imm:$value),
- "st.bulk.shared::cta",
- [(int_nvvm_st_bulk_shared_cta addr:$dest_addr, i64:$size, st_bulk_imm:$value)]>,
- Requires<[hasSM<100>, hasPTX<86>]>;
+ def INT_NVVM_ST_BULK_SHARED_CTA:
+ BasicNVPTXInst<(outs), (ins ADDR:$dest_addr, B64:$size, i64imm:$value),
+ "st.bulk.shared::cta",
+ [(int_nvvm_st_bulk_shared_cta addr:$dest_addr, i64:$size, st_bulk_imm:$value)]>;
+}
//
// clusterlaunchcontorl Instructions