diff options
Diffstat (limited to 'llvm/lib/Target/NVPTX/NVPTXIntrinsics.td')
-rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 842 |
1 files changed, 481 insertions, 361 deletions
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td index 70150bd..d337192 100644 --- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -243,63 +243,82 @@ foreach sync = [false, true] in { } // vote.{all,any,uni,ballot} -multiclass VOTE<NVPTXRegClass regclass, string mode, Intrinsic IntOp> { - def : BasicNVPTXInst<(outs regclass:$dest), (ins B1:$pred), - "vote." # mode, - [(set regclass:$dest, (IntOp i1:$pred))]>, - Requires<[hasPTX<60>, hasSM<30>]>; -} +let Predicates = [hasPTX<60>, hasSM<30>] in { + multiclass VOTE<string mode, RegTyInfo t, Intrinsic op> { + def : BasicNVPTXInst<(outs t.RC:$dest), (ins B1:$pred), + "vote." # mode # "." # t.PtxType, + [(set t.Ty:$dest, (op i1:$pred))]>; + } -defm VOTE_ALL : VOTE<B1, "all.pred", int_nvvm_vote_all>; -defm VOTE_ANY : VOTE<B1, "any.pred", int_nvvm_vote_any>; -defm VOTE_UNI : VOTE<B1, "uni.pred", int_nvvm_vote_uni>; -defm VOTE_BALLOT : VOTE<B32, "ballot.b32", int_nvvm_vote_ballot>; + defm VOTE_ALL : VOTE<"all", I1RT, int_nvvm_vote_all>; + defm VOTE_ANY : VOTE<"any", I1RT, int_nvvm_vote_any>; + defm VOTE_UNI : VOTE<"uni", I1RT, int_nvvm_vote_uni>; + defm VOTE_BALLOT : VOTE<"ballot", I32RT, int_nvvm_vote_ballot>; + + // vote.sync.{all,any,uni,ballot} + multiclass VOTE_SYNC<string mode, RegTyInfo t, Intrinsic op> { + def i : BasicNVPTXInst<(outs t.RC:$dest), (ins B1:$pred, i32imm:$mask), + "vote.sync." # mode # "." # t.PtxType, + [(set t.Ty:$dest, (op imm:$mask, i1:$pred))]>; + def r : BasicNVPTXInst<(outs t.RC:$dest), (ins B1:$pred, B32:$mask), + "vote.sync." # mode # "." # t.PtxType, + [(set t.Ty:$dest, (op i32:$mask, i1:$pred))]>; + } -// vote.sync.{all,any,uni,ballot} -multiclass VOTE_SYNC<NVPTXRegClass regclass, string mode, Intrinsic IntOp> { - def i : BasicNVPTXInst<(outs regclass:$dest), (ins B1:$pred, i32imm:$mask), - "vote.sync." # mode, - [(set regclass:$dest, (IntOp imm:$mask, i1:$pred))]>, - Requires<[hasPTX<60>, hasSM<30>]>; - def r : BasicNVPTXInst<(outs regclass:$dest), (ins B1:$pred, B32:$mask), - "vote.sync." # mode, - [(set regclass:$dest, (IntOp i32:$mask, i1:$pred))]>, - Requires<[hasPTX<60>, hasSM<30>]>; + defm VOTE_SYNC_ALL : VOTE_SYNC<"all", I1RT, int_nvvm_vote_all_sync>; + defm VOTE_SYNC_ANY : VOTE_SYNC<"any", I1RT, int_nvvm_vote_any_sync>; + defm VOTE_SYNC_UNI : VOTE_SYNC<"uni", I1RT, int_nvvm_vote_uni_sync>; + defm VOTE_SYNC_BALLOT : VOTE_SYNC<"ballot", I32RT, int_nvvm_vote_ballot_sync>; } - -defm VOTE_SYNC_ALL : VOTE_SYNC<B1, "all.pred", int_nvvm_vote_all_sync>; -defm VOTE_SYNC_ANY : VOTE_SYNC<B1, "any.pred", int_nvvm_vote_any_sync>; -defm VOTE_SYNC_UNI : VOTE_SYNC<B1, "uni.pred", int_nvvm_vote_uni_sync>; -defm VOTE_SYNC_BALLOT : VOTE_SYNC<B32, "ballot.b32", int_nvvm_vote_ballot_sync>; - // elect.sync +let Predicates = [hasPTX<80>, hasSM<90>] in { def INT_ELECT_SYNC_I : BasicNVPTXInst<(outs B32:$dest, B1:$pred), (ins i32imm:$mask), "elect.sync", - [(set i32:$dest, i1:$pred, (int_nvvm_elect_sync imm:$mask))]>, - Requires<[hasPTX<80>, hasSM<90>]>; + [(set i32:$dest, i1:$pred, (int_nvvm_elect_sync imm:$mask))]>; def INT_ELECT_SYNC_R : BasicNVPTXInst<(outs B32:$dest, B1:$pred), (ins B32:$mask), "elect.sync", - [(set i32:$dest, i1:$pred, (int_nvvm_elect_sync i32:$mask))]>, - Requires<[hasPTX<80>, hasSM<90>]>; + [(set i32:$dest, i1:$pred, (int_nvvm_elect_sync i32:$mask))]>; +} + +let Predicates = [hasPTX<60>, hasSM<70>] in { + multiclass MATCH_ANY_SYNC<Intrinsic op, RegTyInfo t> { + def ii : BasicNVPTXInst<(outs B32:$dest), (ins t.Imm:$value, i32imm:$mask), + "match.any.sync." # t.PtxType, + [(set i32:$dest, (op imm:$mask, imm:$value))]>; + def ir : BasicNVPTXInst<(outs B32:$dest), (ins t.Imm:$value, B32:$mask), + "match.any.sync." # t.PtxType, + [(set i32:$dest, (op i32:$mask, imm:$value))]>; + def ri : BasicNVPTXInst<(outs B32:$dest), (ins t.RC:$value, i32imm:$mask), + "match.any.sync." # t.PtxType, + [(set i32:$dest, (op imm:$mask, t.Ty:$value))]>; + def rr : BasicNVPTXInst<(outs B32:$dest), (ins t.RC:$value, B32:$mask), + "match.any.sync." # t.PtxType, + [(set i32:$dest, (op i32:$mask, t.Ty:$value))]>; + } -multiclass MATCH_ANY_SYNC<NVPTXRegClass regclass, string ptxtype, Intrinsic IntOp, - Operand ImmOp> { - def ii : BasicNVPTXInst<(outs B32:$dest), (ins ImmOp:$value, i32imm:$mask), - "match.any.sync." # ptxtype, - [(set i32:$dest, (IntOp imm:$mask, imm:$value))]>, - Requires<[hasPTX<60>, hasSM<70>]>; - def ir : BasicNVPTXInst<(outs B32:$dest), (ins ImmOp:$value, B32:$mask), - "match.any.sync." # ptxtype, - [(set i32:$dest, (IntOp i32:$mask, imm:$value))]>, - Requires<[hasPTX<60>, hasSM<70>]>; - def ri : BasicNVPTXInst<(outs B32:$dest), (ins regclass:$value, i32imm:$mask), - "match.any.sync." # ptxtype, - [(set i32:$dest, (IntOp imm:$mask, regclass:$value))]>, - Requires<[hasPTX<60>, hasSM<70>]>; - def rr : BasicNVPTXInst<(outs B32:$dest), (ins regclass:$value, B32:$mask), - "match.any.sync." # ptxtype, - [(set i32:$dest, (IntOp i32:$mask, regclass:$value))]>, - Requires<[hasPTX<60>, hasSM<70>]>; + defm MATCH_ANY_SYNC_32 : MATCH_ANY_SYNC<int_nvvm_match_any_sync_i32, I32RT>; + defm MATCH_ANY_SYNC_64 : MATCH_ANY_SYNC<int_nvvm_match_any_sync_i64, I64RT>; + + multiclass MATCH_ALLP_SYNC<RegTyInfo t, Intrinsic op> { + def ii : BasicNVPTXInst<(outs B32:$dest, B1:$pred), + (ins t.Imm:$value, i32imm:$mask), + "match.all.sync." # t.PtxType, + [(set i32:$dest, i1:$pred, (op imm:$mask, imm:$value))]>; + def ir : BasicNVPTXInst<(outs B32:$dest, B1:$pred), + (ins t.Imm:$value, B32:$mask), + "match.all.sync." # t.PtxType, + [(set i32:$dest, i1:$pred, (op i32:$mask, imm:$value))]>; + def ri : BasicNVPTXInst<(outs B32:$dest, B1:$pred), + (ins t.RC:$value, i32imm:$mask), + "match.all.sync." # t.PtxType, + [(set i32:$dest, i1:$pred, (op imm:$mask, t.Ty:$value))]>; + def rr : BasicNVPTXInst<(outs B32:$dest, B1:$pred), + (ins t.RC:$value, B32:$mask), + "match.all.sync." # t.PtxType, + [(set i32:$dest, i1:$pred, (op i32:$mask, t.Ty:$value))]>; + } + defm MATCH_ALLP_SYNC_32 : MATCH_ALLP_SYNC<I32RT, int_nvvm_match_all_sync_i32p>; + defm MATCH_ALLP_SYNC_64 : MATCH_ALLP_SYNC<I64RT, int_nvvm_match_all_sync_i64p>; } // activemask.b32 @@ -308,39 +327,6 @@ def ACTIVEMASK : BasicNVPTXInst<(outs B32:$dest), (ins), [(set i32:$dest, (int_nvvm_activemask))]>, Requires<[hasPTX<62>, hasSM<30>]>; -defm MATCH_ANY_SYNC_32 : MATCH_ANY_SYNC<B32, "b32", int_nvvm_match_any_sync_i32, - i32imm>; -defm MATCH_ANY_SYNC_64 : MATCH_ANY_SYNC<B64, "b64", int_nvvm_match_any_sync_i64, - i64imm>; - -multiclass MATCH_ALLP_SYNC<NVPTXRegClass regclass, string ptxtype, Intrinsic IntOp, - Operand ImmOp> { - def ii : BasicNVPTXInst<(outs B32:$dest, B1:$pred), - (ins ImmOp:$value, i32imm:$mask), - "match.all.sync." # ptxtype, - [(set i32:$dest, i1:$pred, (IntOp imm:$mask, imm:$value))]>, - Requires<[hasPTX<60>, hasSM<70>]>; - def ir : BasicNVPTXInst<(outs B32:$dest, B1:$pred), - (ins ImmOp:$value, B32:$mask), - "match.all.sync." # ptxtype, - [(set i32:$dest, i1:$pred, (IntOp i32:$mask, imm:$value))]>, - Requires<[hasPTX<60>, hasSM<70>]>; - def ri : BasicNVPTXInst<(outs B32:$dest, B1:$pred), - (ins regclass:$value, i32imm:$mask), - "match.all.sync." # ptxtype, - [(set i32:$dest, i1:$pred, (IntOp imm:$mask, regclass:$value))]>, - Requires<[hasPTX<60>, hasSM<70>]>; - def rr : BasicNVPTXInst<(outs B32:$dest, B1:$pred), - (ins regclass:$value, B32:$mask), - "match.all.sync." # ptxtype, - [(set i32:$dest, i1:$pred, (IntOp i32:$mask, regclass:$value))]>, - Requires<[hasPTX<60>, hasSM<70>]>; -} -defm MATCH_ALLP_SYNC_32 : MATCH_ALLP_SYNC<B32, "b32", int_nvvm_match_all_sync_i32p, - i32imm>; -defm MATCH_ALLP_SYNC_64 : MATCH_ALLP_SYNC<B64, "b64", int_nvvm_match_all_sync_i64p, - i64imm>; - multiclass REDUX_SYNC<string BinOp, string PTXType, Intrinsic Intrin> { def : BasicNVPTXInst<(outs B32:$dst), (ins B32:$src, B32:$mask), "redux.sync." # BinOp # "." # PTXType, @@ -381,24 +367,20 @@ defm REDUX_SYNC_FMAX_ABS_NAN: REDUX_SYNC_F<"max", ".abs", ".NaN">; //----------------------------------- // Explicit Memory Fence Functions //----------------------------------- -class MEMBAR<string StrOp, Intrinsic IntOP> : - BasicNVPTXInst<(outs), (ins), - StrOp, [(IntOP)]>; +class NullaryInst<string StrOp, Intrinsic IntOP> : + BasicNVPTXInst<(outs), (ins), StrOp, [(IntOP)]>; -def INT_MEMBAR_CTA : MEMBAR<"membar.cta", int_nvvm_membar_cta>; -def INT_MEMBAR_GL : MEMBAR<"membar.gl", int_nvvm_membar_gl>; -def INT_MEMBAR_SYS : MEMBAR<"membar.sys", int_nvvm_membar_sys>; +def INT_MEMBAR_CTA : NullaryInst<"membar.cta", int_nvvm_membar_cta>; +def INT_MEMBAR_GL : NullaryInst<"membar.gl", int_nvvm_membar_gl>; +def INT_MEMBAR_SYS : NullaryInst<"membar.sys", int_nvvm_membar_sys>; def INT_FENCE_SC_CLUSTER: - MEMBAR<"fence.sc.cluster", int_nvvm_fence_sc_cluster>, + NullaryInst<"fence.sc.cluster", int_nvvm_fence_sc_cluster>, Requires<[hasPTX<78>, hasSM<90>]>; // Proxy fence (uni-directional) -// fence.proxy.tensormap.release variants - class FENCE_PROXY_TENSORMAP_GENERIC_RELEASE<string Scope, Intrinsic Intr> : - BasicNVPTXInst<(outs), (ins), - "fence.proxy.tensormap::generic.release." # Scope, [(Intr)]>, + NullaryInst<"fence.proxy.tensormap::generic.release." # Scope, Intr>, Requires<[hasPTX<83>, hasSM<90>]>; def INT_FENCE_PROXY_TENSORMAP_GENERIC_RELEASE_CTA: @@ -488,35 +470,31 @@ defm CP_ASYNC_CG_SHARED_GLOBAL_16 : CP_ASYNC_SHARED_GLOBAL_I<"cg", "16", int_nvvm_cp_async_cg_shared_global_16, int_nvvm_cp_async_cg_shared_global_16_s>; -def CP_ASYNC_COMMIT_GROUP : - BasicNVPTXInst<(outs), (ins), "cp.async.commit_group", [(int_nvvm_cp_async_commit_group)]>, - Requires<[hasPTX<70>, hasSM<80>]>; +let Predicates = [hasPTX<70>, hasSM<80>] in { + def CP_ASYNC_COMMIT_GROUP : + NullaryInst<"cp.async.commit_group", int_nvvm_cp_async_commit_group>; -def CP_ASYNC_WAIT_GROUP : - BasicNVPTXInst<(outs), (ins i32imm:$n), "cp.async.wait_group", - [(int_nvvm_cp_async_wait_group timm:$n)]>, - Requires<[hasPTX<70>, hasSM<80>]>; + def CP_ASYNC_WAIT_GROUP : + BasicNVPTXInst<(outs), (ins i32imm:$n), "cp.async.wait_group", + [(int_nvvm_cp_async_wait_group timm:$n)]>; -def CP_ASYNC_WAIT_ALL : - BasicNVPTXInst<(outs), (ins), "cp.async.wait_all", - [(int_nvvm_cp_async_wait_all)]>, - Requires<[hasPTX<70>, hasSM<80>]>; + def CP_ASYNC_WAIT_ALL : + NullaryInst<"cp.async.wait_all", int_nvvm_cp_async_wait_all>; +} -// cp.async.bulk variants of the commit/wait group -def CP_ASYNC_BULK_COMMIT_GROUP : - BasicNVPTXInst<(outs), (ins), "cp.async.bulk.commit_group", - [(int_nvvm_cp_async_bulk_commit_group)]>, - Requires<[hasPTX<80>, hasSM<90>]>; +let Predicates = [hasPTX<80>, hasSM<90>] in { + // cp.async.bulk variants of the commit/wait group + def CP_ASYNC_BULK_COMMIT_GROUP : + NullaryInst<"cp.async.bulk.commit_group", int_nvvm_cp_async_bulk_commit_group>; -def CP_ASYNC_BULK_WAIT_GROUP : - BasicNVPTXInst<(outs), (ins i32imm:$n), "cp.async.bulk.wait_group", - [(int_nvvm_cp_async_bulk_wait_group timm:$n)]>, - Requires<[hasPTX<80>, hasSM<90>]>; + def CP_ASYNC_BULK_WAIT_GROUP : + BasicNVPTXInst<(outs), (ins i32imm:$n), "cp.async.bulk.wait_group", + [(int_nvvm_cp_async_bulk_wait_group timm:$n)]>; -def CP_ASYNC_BULK_WAIT_GROUP_READ : - BasicNVPTXInst<(outs), (ins i32imm:$n), "cp.async.bulk.wait_group.read", - [(int_nvvm_cp_async_bulk_wait_group_read timm:$n)]>, - Requires<[hasPTX<80>, hasSM<90>]>; + def CP_ASYNC_BULK_WAIT_GROUP_READ : + BasicNVPTXInst<(outs), (ins i32imm:$n), "cp.async.bulk.wait_group.read", + [(int_nvvm_cp_async_bulk_wait_group_read timm:$n)]>; +} //------------------------------ // TMA Async Bulk Copy Functions @@ -600,12 +578,23 @@ defm CP_ASYNC_BULK_PREFETCH_CH : CP_ASYNC_BULK_PREFETCH_INTR<has_ch = 1>; // TMA Async Bulk Tensor Copy Functions //------------------------------------- -class TMA_DIMS_UTIL<int dim> { +class TMA_DIMS_UTIL<int dim, string mode = ""> { // For example, when 'dim' is 3, this generates: // an ins_dag: B32:$d0, B32:$d1, B32:$d2 // with base_str: $d0, $d1, $d2 dag ins_dag = !dag(ins, !listsplat(B32, dim), !foreach(i, !range(dim), "d" # i)); string base_str = !interleave(!foreach(i, !range(dim), "$d" # i), ", "); + + // Tile::Gather4/scatter4 actually operate on a 2D tensor, + // though they take 5 co-ordinates. + // + // The scatter-gather happens over 4 rows with a fixed + // column-index. The first co-ordinate represents the + // col-index followed by four row-indices. + int num_dims = !cond( + !eq(mode, "tile_scatter4") : 2, + !eq(mode, "tile_gather4") : 2, + true : dim); // for all other modes } class TMA_IM2COL_UTIL<int dim, string mode> { @@ -692,14 +681,138 @@ foreach dim = [1, 2, 3, 4, 5] in { } } +multiclass TMA_TENSOR_G2S_INTR<int dim, string mode, list<Predicate> pred = []> { + defvar dims_dag = TMA_DIMS_UTIL<dim>.ins_dag; + defvar dims_str = TMA_DIMS_UTIL<dim>.base_str; + defvar asm_str_base = "$cg [$dst], [$tmap, {{" # dims_str # "}}], [$mbar]"; + + defvar im2col_dag = TMA_IM2COL_UTIL<dim, mode>.ins_dag; + defvar im2col_str = TMA_IM2COL_UTIL<dim, mode>.base_str; + defvar asm_str = !if(!empty(im2col_str), + asm_str_base, + asm_str_base # ", {{" # im2col_str # "}}"); + + defvar dim_val = TMA_DIMS_UTIL<dim, mode>.num_dims; + defvar inst_name = "cp.async.bulk.tensor" + # "." # dim_val # "d" + # "." # "shared::cluster.global" + # "." # !subst("_", "::", mode) + # "." # "mbarrier::complete_tx::bytes"; + defvar intr = !cast<Intrinsic>( + "int_nvvm_cp_async_bulk_tensor_g2s_" # mode # "_" # dim_val # "d"); + + defvar ins_dag = !con( + (ins ADDR:$dst, ADDR:$mbar, B64:$tmap), + dims_dag, im2col_dag, + (ins B16:$mc, B64:$ch, CTAGroupFlags:$cg)); + + defvar intr_dag_base = !con( + (intr addr:$dst, addr:$mbar, B64:$tmap), + !setdagop(dims_dag, intr), + !setdagop(im2col_dag, intr), + (intr B16:$mc, B64:$ch)); + defvar intr_dag_no_hints = !con(intr_dag_base, (intr 0, 0, timm:$cg)); + defvar intr_dag_with_mc = !con(intr_dag_base, (intr -1, 0, timm:$cg)); + defvar intr_dag_with_ch = !con(intr_dag_base, (intr 0, -1, timm:$cg)); + defvar intr_dag_with_mc_ch = !con(intr_dag_base, (intr -1, -1, timm:$cg)); + + def "" : NVPTXInst<(outs), ins_dag, + inst_name # asm_str # ";", + [intr_dag_no_hints]>, + Requires<pred>; + def _MC : NVPTXInst<(outs), ins_dag, + inst_name # ".multicast::cluster" # asm_str # ", $mc;", + [intr_dag_with_mc]>, + Requires<pred>; + def _CH : NVPTXInst<(outs), ins_dag, + inst_name # ".L2::cache_hint" # asm_str # ", $ch;", + [intr_dag_with_ch]>, + Requires<pred>; + def _MC_CH : NVPTXInst<(outs), ins_dag, + inst_name # ".multicast::cluster.L2::cache_hint" # asm_str # ", $mc, $ch;", + [intr_dag_with_mc_ch]>, + Requires<pred>; +} +foreach dim = 3...5 in { + foreach mode = ["im2col_w", "im2col_w_128"] in { + defm TMA_G2S_ # !toupper(mode) # "_" # dim # "D" + : TMA_TENSOR_G2S_INTR<dim, mode, [hasTMACTAGroupSupport]>; + } +} +defm TMA_G2S_TILE_GATHER4_2D : TMA_TENSOR_G2S_INTR<5, "tile_gather4", + [hasTMACTAGroupSupport]>; + +multiclass TMA_TENSOR_G2S_CTA_INTR<int dim, string mode, list<Predicate> pred = []> { + defvar dims_dag = TMA_DIMS_UTIL<dim>.ins_dag; + defvar dims_str = TMA_DIMS_UTIL<dim>.base_str; + defvar asm_str_base = " [$dst], [$tmap, {{" # dims_str # "}}], [$mbar]"; + + defvar im2col_dag = TMA_IM2COL_UTIL<dim, mode>.ins_dag; + defvar im2col_str = TMA_IM2COL_UTIL<dim, mode>.base_str; + defvar asm_str = !if(!empty(im2col_str), + asm_str_base, + asm_str_base # ", {{" # im2col_str # "}}"); + + defvar ins_dag = !con( + (ins ADDR:$dst, ADDR:$mbar, B64:$tmap), + dims_dag, im2col_dag, + (ins B64:$ch)); + + defvar dim_val = TMA_DIMS_UTIL<dim, mode>.num_dims; + defvar intr = !cast<Intrinsic>( + "int_nvvm_cp_async_bulk_tensor_g2s_cta_" # mode # "_" # dim_val # "d"); + defvar intr_dag = !con( + (intr addr:$dst, addr:$mbar, B64:$tmap), + !setdagop(dims_dag, intr), + !setdagop(im2col_dag, intr), + (intr B64:$ch, 0)); + defvar intr_dag_with_ch = !con( + (intr addr:$dst, addr:$mbar, B64:$tmap), + !setdagop(dims_dag, intr), + !setdagop(im2col_dag, intr), + (intr B64:$ch, -1)); + defvar inst_name = "cp.async.bulk.tensor" + # "." # dim_val # "d" + # "." # "shared::cta.global" + # "." # !subst("_", "::", mode) + # "." # "mbarrier::complete_tx::bytes"; + + def "" : NVPTXInst<(outs), ins_dag, + inst_name # asm_str # ";", + [intr_dag]>, + Requires<pred>; + def _CH : NVPTXInst<(outs), ins_dag, + inst_name # ".L2::cache_hint" # asm_str # ", $ch;", + [intr_dag_with_ch]>, + Requires<pred>; +} +foreach dim = 1...5 in { + defm TMA_G2S_CTA_TILE_ # dim # "D" + : TMA_TENSOR_G2S_CTA_INTR<dim, "tile", [hasPTX<86>, hasSM<90>]>; +} +foreach dim = 3...5 in { + defm TMA_G2S_CTA_IM2COL_ # dim # "D" + : TMA_TENSOR_G2S_CTA_INTR<dim, "im2col", [hasPTX<86>, hasSM<90>]>; + + defm TMA_G2S_CTA_IM2COL_W_ # dim # "D" + : TMA_TENSOR_G2S_CTA_INTR<dim, "im2col_w", [hasPTX<86>, hasSM<100>]>; + + defm TMA_G2S_CTA_IM2COL_W_128_ # dim # "D" + : TMA_TENSOR_G2S_CTA_INTR<dim, "im2col_w_128", [hasTMACTAGroupSupport]>; +} +defm TMA_G2S_CTA_TILE_GATHER4_2D : TMA_TENSOR_G2S_CTA_INTR<5, "tile_gather4", + [hasPTX<86>, hasSM<100>]>; + multiclass TMA_TENSOR_S2G_INTR<int dim, string mode, list<Predicate> pred = [hasPTX<80>, hasSM<90>]> { defvar dims_dag = TMA_DIMS_UTIL<dim>.ins_dag; defvar dims_str = TMA_DIMS_UTIL<dim>.base_str; defvar asm_str = " [$tmap, {{" # dims_str # "}}], [$src]"; + defvar dim_val = TMA_DIMS_UTIL<dim, mode>.num_dims; defvar intr = !cast<Intrinsic>( - "int_nvvm_cp_async_bulk_tensor_s2g_" # mode # "_" # dim # d); + "int_nvvm_cp_async_bulk_tensor_s2g_" # mode # "_" # dim_val # "d"); + defvar intr_dag = !con((intr addr:$src, B64:$tmap), !setdagop(dims_dag, intr), (intr B64:$ch, 0)); @@ -707,11 +820,13 @@ multiclass TMA_TENSOR_S2G_INTR<int dim, string mode, !setdagop(dims_dag, intr), (intr B64:$ch, -1)); - // For im2col mode, the actual asm_str is "im2col_no_offs" - defvar mode_asm_str = !if(!eq(mode, "im2col"), - "im2col_no_offs", mode); + // Fix-up the asm_str when it is im2col/scatter4. + defvar mode_asm_str = !cond( + !eq(mode, "im2col") : "im2col_no_offs", + !eq(mode, "tile_scatter4") : "tile::scatter4", + true : mode); defvar prefix = "cp.async.bulk.tensor" - # "." # dim # "d" + # "." # dim_val # "d" # ".global.shared::cta" # "." # mode_asm_str # ".bulk_group"; @@ -729,10 +844,12 @@ multiclass TMA_TENSOR_S2G_INTR<int dim, string mode, } foreach dim = 1...5 in { foreach mode = !if(!ge(dim, 3), ["tile", "im2col"], ["tile"]) in { - defvar suffix = !toupper(mode) # "_" # dim # D; + defvar suffix = !toupper(mode) # "_" # dim # "D"; defm TMA_TENSOR_S2G_ # suffix : TMA_TENSOR_S2G_INTR<dim, mode>; } } +defm TMA_S2G_TILE_SCATTER4_2D : TMA_TENSOR_S2G_INTR<5, "tile_scatter4", + [hasTMACTAGroupSupport]>; def TMAReductionFlags : Operand<i32> { let PrintMethod = "printTmaReductionMode"; @@ -786,13 +903,14 @@ multiclass TMA_TENSOR_PREFETCH_INTR<int dim, string mode, asm_str_base, asm_str_base # ", {{" # im2col_str # "}}"); + defvar dim_val = TMA_DIMS_UTIL<dim, mode>.num_dims; defvar inst_name = "cp.async.bulk.prefetch.tensor" - # "." # dim # "d" + # "." # dim_val # "d" # "." # "L2.global" - # "." # mode; + # "." # !subst("_", "::", mode); defvar intr = !cast<Intrinsic>( - "int_nvvm_cp_async_bulk_tensor_prefetch_" # mode # "_" # dim # d); + "int_nvvm_cp_async_bulk_tensor_prefetch_" # mode # "_" # dim_val # "d"); defvar ins_dag = !con((ins B64:$tmap), dims_dag, @@ -818,40 +936,46 @@ multiclass TMA_TENSOR_PREFETCH_INTR<int dim, string mode, } foreach dim = 1...5 in { foreach mode = !if(!ge(dim, 3), ["tile", "im2col"], ["tile"]) in { - defvar suffix = !toupper(mode) # "_" # dim # D; + defvar suffix = !toupper(mode) # "_" # dim # "D"; defm TMA_TENSOR_PF_ # suffix : TMA_TENSOR_PREFETCH_INTR<dim, mode>; } } +foreach dim = 3...5 in { + foreach mode = ["im2col_w", "im2col_w_128"] in { + defvar suffix = !toupper(mode) # "_" # dim # "D"; + defm TMA_TENSOR_PF_ # suffix : TMA_TENSOR_PREFETCH_INTR<dim, mode, + [hasTMACTAGroupSupport]>; + } +} +defm TMA_TENSOR_PF_TILE_GATHER4_2D : TMA_TENSOR_PREFETCH_INTR<5, "tile_gather4", + [hasTMACTAGroupSupport]>; //Prefetch and Prefetchu -class PREFETCH_INTRS<string InstName> : - BasicNVPTXInst<(outs), (ins ADDR:$addr), - InstName, - [(!cast<Intrinsic>(!strconcat("int_nvvm_", - !subst(".", "_", InstName))) addr:$addr)]>, - Requires<[hasPTX<80>, hasSM<90>]>; - - -def PREFETCH_L1 : PREFETCH_INTRS<"prefetch.L1">; -def PREFETCH_L2 : PREFETCH_INTRS<"prefetch.L2">; -def PREFETCH_GLOBAL_L1 : PREFETCH_INTRS<"prefetch.global.L1">; -def PREFETCH_LOCAL_L1 : PREFETCH_INTRS<"prefetch.local.L1">; -def PREFETCH_GLOBAL_L2 : PREFETCH_INTRS<"prefetch.global.L2">; -def PREFETCH_LOCAL_L2 : PREFETCH_INTRS<"prefetch.local.L2">; +let Predicates = [hasPTX<80>, hasSM<90>] in { + class PREFETCH_INTRS<string InstName> : + BasicNVPTXInst<(outs), (ins ADDR:$addr), + InstName, + [(!cast<Intrinsic>(!strconcat("int_nvvm_", + !subst(".", "_", InstName))) addr:$addr)]>; -def PREFETCH_GLOBAL_L2_EVICT_NORMAL : BasicNVPTXInst<(outs), (ins ADDR:$addr), - "prefetch.global.L2::evict_normal", - [(int_nvvm_prefetch_global_L2_evict_normal addr:$addr)]>, - Requires<[hasPTX<80>, hasSM<90>]>; + def PREFETCH_L1 : PREFETCH_INTRS<"prefetch.L1">; + def PREFETCH_L2 : PREFETCH_INTRS<"prefetch.L2">; + def PREFETCH_GLOBAL_L1 : PREFETCH_INTRS<"prefetch.global.L1">; + def PREFETCH_LOCAL_L1 : PREFETCH_INTRS<"prefetch.local.L1">; + def PREFETCH_GLOBAL_L2 : PREFETCH_INTRS<"prefetch.global.L2">; + def PREFETCH_LOCAL_L2 : PREFETCH_INTRS<"prefetch.local.L2">; -def PREFETCH_GLOBAL_L2_EVICT_LAST : BasicNVPTXInst<(outs), (ins ADDR:$addr), - "prefetch.global.L2::evict_last", - [(int_nvvm_prefetch_global_L2_evict_last addr:$addr)]>, - Requires<[hasPTX<80>, hasSM<90>]>; + def PREFETCH_GLOBAL_L2_EVICT_NORMAL : BasicNVPTXInst<(outs), (ins ADDR:$addr), + "prefetch.global.L2::evict_normal", + [(int_nvvm_prefetch_global_L2_evict_normal addr:$addr)]>; + def PREFETCH_GLOBAL_L2_EVICT_LAST : BasicNVPTXInst<(outs), (ins ADDR:$addr), + "prefetch.global.L2::evict_last", + [(int_nvvm_prefetch_global_L2_evict_last addr:$addr)]>; -def PREFETCHU_L1 : PREFETCH_INTRS<"prefetchu.L1">; + def PREFETCHU_L1 : PREFETCH_INTRS<"prefetchu.L1">; +} //Applypriority intrinsics class APPLYPRIORITY_L2_INTRS<string addrspace> : @@ -882,99 +1006,82 @@ def DISCARD_GLOBAL_L2 : DISCARD_L2_INTRS<"global">; // MBarrier Functions //----------------------------------- -multiclass MBARRIER_INIT<string AddrSpace, Intrinsic Intrin> { - def "" : BasicNVPTXInst<(outs), (ins ADDR:$addr, B32:$count), - "mbarrier.init" # AddrSpace # ".b64", - [(Intrin addr:$addr, i32:$count)]>, - Requires<[hasPTX<70>, hasSM<80>]>; -} - -defm MBARRIER_INIT : MBARRIER_INIT<"", int_nvvm_mbarrier_init>; -defm MBARRIER_INIT_SHARED : MBARRIER_INIT<".shared", - int_nvvm_mbarrier_init_shared>; - -multiclass MBARRIER_INVAL<string AddrSpace, Intrinsic Intrin> { - def "" : BasicNVPTXInst<(outs), (ins ADDR:$addr), - "mbarrier.inval" # AddrSpace # ".b64", - [(Intrin addr:$addr)]>, - Requires<[hasPTX<70>, hasSM<80>]>; -} - -defm MBARRIER_INVAL : MBARRIER_INVAL<"", int_nvvm_mbarrier_inval>; -defm MBARRIER_INVAL_SHARED : MBARRIER_INVAL<".shared", - int_nvvm_mbarrier_inval_shared>; - -multiclass MBARRIER_ARRIVE<string AddrSpace, Intrinsic Intrin> { - def "" : BasicNVPTXInst<(outs B64:$state), (ins ADDR:$addr), - "mbarrier.arrive" # AddrSpace # ".b64", - [(set i64:$state, (Intrin addr:$addr))]>, - Requires<[hasPTX<70>, hasSM<80>]>; -} - -defm MBARRIER_ARRIVE : MBARRIER_ARRIVE<"", int_nvvm_mbarrier_arrive>; -defm MBARRIER_ARRIVE_SHARED : - MBARRIER_ARRIVE<".shared", int_nvvm_mbarrier_arrive_shared>; - -multiclass MBARRIER_ARRIVE_NOCOMPLETE<string AddrSpace, Intrinsic Intrin> { - def "" : BasicNVPTXInst<(outs B64:$state), - (ins ADDR:$addr, B32:$count), - "mbarrier.arrive.noComplete" # AddrSpace # ".b64", - [(set i64:$state, (Intrin addr:$addr, i32:$count))]>, - Requires<[hasPTX<70>, hasSM<80>]>; -} - -defm MBARRIER_ARRIVE_NOCOMPLETE : - MBARRIER_ARRIVE_NOCOMPLETE<"", int_nvvm_mbarrier_arrive_noComplete>; -defm MBARRIER_ARRIVE_NOCOMPLETE_SHARED : - MBARRIER_ARRIVE_NOCOMPLETE<".shared", int_nvvm_mbarrier_arrive_noComplete_shared>; - -multiclass MBARRIER_ARRIVE_DROP<string AddrSpace, Intrinsic Intrin> { - def "" : BasicNVPTXInst<(outs B64:$state), (ins ADDR:$addr), - "mbarrier.arrive_drop" # AddrSpace # ".b64", - [(set i64:$state, (Intrin addr:$addr))]>, - Requires<[hasPTX<70>, hasSM<80>]>; -} - -defm MBARRIER_ARRIVE_DROP : - MBARRIER_ARRIVE_DROP<"", int_nvvm_mbarrier_arrive_drop>; -defm MBARRIER_ARRIVE_DROP_SHARED : - MBARRIER_ARRIVE_DROP<".shared", int_nvvm_mbarrier_arrive_drop_shared>; - -multiclass MBARRIER_ARRIVE_DROP_NOCOMPLETE<string AddrSpace, Intrinsic Intrin> { - def "" : BasicNVPTXInst<(outs B64:$state), - (ins ADDR:$addr, B32:$count), - "mbarrier.arrive_drop.noComplete" # AddrSpace # ".b64", - [(set i64:$state, (Intrin addr:$addr, i32:$count))]>, - Requires<[hasPTX<70>, hasSM<80>]>; -} - -defm MBARRIER_ARRIVE_DROP_NOCOMPLETE : - MBARRIER_ARRIVE_DROP_NOCOMPLETE<"", int_nvvm_mbarrier_arrive_drop_noComplete>; -defm MBARRIER_ARRIVE_DROP_NOCOMPLETE_SHARED : - MBARRIER_ARRIVE_DROP_NOCOMPLETE<".shared", - int_nvvm_mbarrier_arrive_drop_noComplete_shared>; - -multiclass MBARRIER_TEST_WAIT<string AddrSpace, Intrinsic Intrin> { - def "" : BasicNVPTXInst<(outs B1:$res), (ins ADDR:$addr, B64:$state), - "mbarrier.test_wait" # AddrSpace # ".b64", - [(set i1:$res, (Intrin addr:$addr, i64:$state))]>, - Requires<[hasPTX<70>, hasSM<80>]>; +let Predicates = [hasPTX<70>, hasSM<80>] in { + class MBARRIER_INIT<string AddrSpace, Intrinsic Intrin> : + BasicNVPTXInst<(outs), (ins ADDR:$addr, B32:$count), + "mbarrier.init" # AddrSpace # ".b64", + [(Intrin addr:$addr, i32:$count)]>; + + def MBARRIER_INIT : MBARRIER_INIT<"", int_nvvm_mbarrier_init>; + def MBARRIER_INIT_SHARED : MBARRIER_INIT<".shared", + int_nvvm_mbarrier_init_shared>; + + class MBARRIER_INVAL<string AddrSpace, Intrinsic Intrin> : + BasicNVPTXInst<(outs), (ins ADDR:$addr), + "mbarrier.inval" # AddrSpace # ".b64", + [(Intrin addr:$addr)]>; + + def MBARRIER_INVAL : MBARRIER_INVAL<"", int_nvvm_mbarrier_inval>; + def MBARRIER_INVAL_SHARED : MBARRIER_INVAL<".shared", + int_nvvm_mbarrier_inval_shared>; + + class MBARRIER_ARRIVE<string AddrSpace, Intrinsic Intrin> : + BasicNVPTXInst<(outs B64:$state), (ins ADDR:$addr), + "mbarrier.arrive" # AddrSpace # ".b64", + [(set i64:$state, (Intrin addr:$addr))]>; + + def MBARRIER_ARRIVE : MBARRIER_ARRIVE<"", int_nvvm_mbarrier_arrive>; + def MBARRIER_ARRIVE_SHARED : + MBARRIER_ARRIVE<".shared", int_nvvm_mbarrier_arrive_shared>; + + class MBARRIER_ARRIVE_NOCOMPLETE<string AddrSpace, Intrinsic Intrin> : + BasicNVPTXInst<(outs B64:$state), + (ins ADDR:$addr, B32:$count), + "mbarrier.arrive.noComplete" # AddrSpace # ".b64", + [(set i64:$state, (Intrin addr:$addr, i32:$count))]>; + + def MBARRIER_ARRIVE_NOCOMPLETE : + MBARRIER_ARRIVE_NOCOMPLETE<"", int_nvvm_mbarrier_arrive_noComplete>; + def MBARRIER_ARRIVE_NOCOMPLETE_SHARED : + MBARRIER_ARRIVE_NOCOMPLETE<".shared", int_nvvm_mbarrier_arrive_noComplete_shared>; + + class MBARRIER_ARRIVE_DROP<string AddrSpace, Intrinsic Intrin> : + BasicNVPTXInst<(outs B64:$state), (ins ADDR:$addr), + "mbarrier.arrive_drop" # AddrSpace # ".b64", + [(set i64:$state, (Intrin addr:$addr))]>; + + def MBARRIER_ARRIVE_DROP : + MBARRIER_ARRIVE_DROP<"", int_nvvm_mbarrier_arrive_drop>; + def MBARRIER_ARRIVE_DROP_SHARED : + MBARRIER_ARRIVE_DROP<".shared", int_nvvm_mbarrier_arrive_drop_shared>; + + class MBARRIER_ARRIVE_DROP_NOCOMPLETE<string AddrSpace, Intrinsic Intrin> : + BasicNVPTXInst<(outs B64:$state), + (ins ADDR:$addr, B32:$count), + "mbarrier.arrive_drop.noComplete" # AddrSpace # ".b64", + [(set i64:$state, (Intrin addr:$addr, i32:$count))]>; + + def MBARRIER_ARRIVE_DROP_NOCOMPLETE : + MBARRIER_ARRIVE_DROP_NOCOMPLETE<"", int_nvvm_mbarrier_arrive_drop_noComplete>; + def MBARRIER_ARRIVE_DROP_NOCOMPLETE_SHARED : + MBARRIER_ARRIVE_DROP_NOCOMPLETE<".shared", + int_nvvm_mbarrier_arrive_drop_noComplete_shared>; + + class MBARRIER_TEST_WAIT<string AddrSpace, Intrinsic Intrin> : + BasicNVPTXInst<(outs B1:$res), (ins ADDR:$addr, B64:$state), + "mbarrier.test_wait" # AddrSpace # ".b64", + [(set i1:$res, (Intrin addr:$addr, i64:$state))]>; + + def MBARRIER_TEST_WAIT : + MBARRIER_TEST_WAIT<"", int_nvvm_mbarrier_test_wait>; + def MBARRIER_TEST_WAIT_SHARED : + MBARRIER_TEST_WAIT<".shared", int_nvvm_mbarrier_test_wait_shared>; + + def MBARRIER_PENDING_COUNT : + BasicNVPTXInst<(outs B32:$res), (ins B64:$state), + "mbarrier.pending_count.b64", + [(set i32:$res, (int_nvvm_mbarrier_pending_count i64:$state))]>; } - -defm MBARRIER_TEST_WAIT : - MBARRIER_TEST_WAIT<"", int_nvvm_mbarrier_test_wait>; -defm MBARRIER_TEST_WAIT_SHARED : - MBARRIER_TEST_WAIT<".shared", int_nvvm_mbarrier_test_wait_shared>; - -class MBARRIER_PENDING_COUNT<Intrinsic Intrin> : - BasicNVPTXInst<(outs B32:$res), (ins B64:$state), - "mbarrier.pending_count.b64", - [(set i32:$res, (Intrin i64:$state))]>, - Requires<[hasPTX<70>, hasSM<80>]>; - -def MBARRIER_PENDING_COUNT : - MBARRIER_PENDING_COUNT<int_nvvm_mbarrier_pending_count>; - //----------------------------------- // Math Functions //----------------------------------- @@ -1300,15 +1407,11 @@ defm ABS_F64 : F_ABS<"f64", F64RT, support_ftz = false>; def fcopysign_nvptx : SDNode<"NVPTXISD::FCOPYSIGN", SDTFPBinOp>; -def COPYSIGN_F : - BasicNVPTXInst<(outs B32:$dst), (ins B32:$src0, B32:$src1), - "copysign.f32", - [(set f32:$dst, (fcopysign_nvptx f32:$src1, f32:$src0))]>; - -def COPYSIGN_D : - BasicNVPTXInst<(outs B64:$dst), (ins B64:$src0, B64:$src1), - "copysign.f64", - [(set f64:$dst, (fcopysign_nvptx f64:$src1, f64:$src0))]>; +foreach t = [F32RT, F64RT] in + def COPYSIGN_ # t : + BasicNVPTXInst<(outs t.RC:$dst), (ins t.RC:$src0, t.RC:$src1), + "copysign." # t.PtxType, + [(set t.Ty:$dst, (fcopysign_nvptx t.Ty:$src1, t.Ty:$src0))]>; // // Neg bf16, bf16x2 @@ -2106,38 +2209,35 @@ defm INT_PTX_SATOM_XOR : ATOM2_bitwise_impl<"xor">; // Scalar -class LDU_G<string TyStr, NVPTXRegClass regclass> - : NVPTXInst<(outs regclass:$result), (ins ADDR:$src), - "ldu.global." # TyStr # " \t$result, [$src];", []>; +class LDU_G<NVPTXRegClass regclass> + : NVPTXInst<(outs regclass:$result), (ins i32imm:$fromWidth, ADDR:$src), + "ldu.global.b$fromWidth \t$result, [$src];", []>; -def LDU_GLOBAL_i8 : LDU_G<"b8", B16>; -def LDU_GLOBAL_i16 : LDU_G<"b16", B16>; -def LDU_GLOBAL_i32 : LDU_G<"b32", B32>; -def LDU_GLOBAL_i64 : LDU_G<"b64", B64>; +def LDU_GLOBAL_i16 : LDU_G<B16>; +def LDU_GLOBAL_i32 : LDU_G<B32>; +def LDU_GLOBAL_i64 : LDU_G<B64>; // vector // Elementized vector ldu -class VLDU_G_ELE_V2<string TyStr, NVPTXRegClass regclass> +class VLDU_G_ELE_V2<NVPTXRegClass regclass> : NVPTXInst<(outs regclass:$dst1, regclass:$dst2), - (ins ADDR:$src), - "ldu.global.v2." # TyStr # " \t{{$dst1, $dst2}}, [$src];", []>; + (ins i32imm:$fromWidth, ADDR:$src), + "ldu.global.v2.b$fromWidth \t{{$dst1, $dst2}}, [$src];", []>; -class VLDU_G_ELE_V4<string TyStr, NVPTXRegClass regclass> - : NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, - regclass:$dst4), (ins ADDR:$src), - "ldu.global.v4." # TyStr # " \t{{$dst1, $dst2, $dst3, $dst4}}, [$src];", []>; +class VLDU_G_ELE_V4<NVPTXRegClass regclass> + : NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, regclass:$dst4), + (ins i32imm:$fromWidth, ADDR:$src), + "ldu.global.v4.b$fromWidth \t{{$dst1, $dst2, $dst3, $dst4}}, [$src];", []>; -def LDU_GLOBAL_v2i8 : VLDU_G_ELE_V2<"b8", B16>; -def LDU_GLOBAL_v2i16 : VLDU_G_ELE_V2<"b16", B16>; -def LDU_GLOBAL_v2i32 : VLDU_G_ELE_V2<"b32", B32>; -def LDU_GLOBAL_v2i64 : VLDU_G_ELE_V2<"b64", B64>; +def LDU_GLOBAL_v2i16 : VLDU_G_ELE_V2<B16>; +def LDU_GLOBAL_v2i32 : VLDU_G_ELE_V2<B32>; +def LDU_GLOBAL_v2i64 : VLDU_G_ELE_V2<B64>; -def LDU_GLOBAL_v4i8 : VLDU_G_ELE_V4<"b8", B16>; -def LDU_GLOBAL_v4i16 : VLDU_G_ELE_V4<"b16", B16>; -def LDU_GLOBAL_v4i32 : VLDU_G_ELE_V4<"b32", B32>; +def LDU_GLOBAL_v4i16 : VLDU_G_ELE_V4<B16>; +def LDU_GLOBAL_v4i32 : VLDU_G_ELE_V4<B32>; //----------------------------------- @@ -2178,12 +2278,10 @@ class VLDG_G_ELE_V8<NVPTXRegClass regclass> : "ld.global.nc.v8.${Sign:sign}$fromWidth \t{{$dst1, $dst2, $dst3, $dst4, $dst5, $dst6, $dst7, $dst8}}, [$src];", []>; // FIXME: 8-bit LDG should be fixed once LDG/LDU nodes are made into proper loads. -def LD_GLOBAL_NC_v2i8 : VLDG_G_ELE_V2<B16>; def LD_GLOBAL_NC_v2i16 : VLDG_G_ELE_V2<B16>; def LD_GLOBAL_NC_v2i32 : VLDG_G_ELE_V2<B32>; def LD_GLOBAL_NC_v2i64 : VLDG_G_ELE_V2<B64>; -def LD_GLOBAL_NC_v4i8 : VLDG_G_ELE_V4<B16>; def LD_GLOBAL_NC_v4i16 : VLDG_G_ELE_V4<B16>; def LD_GLOBAL_NC_v4i32 : VLDG_G_ELE_V4<B32>; @@ -2193,19 +2291,19 @@ def LD_GLOBAL_NC_v8i32 : VLDG_G_ELE_V8<B32>; multiclass NG_TO_G<string Str, bit Supports32 = 1, list<Predicate> Preds = []> { if Supports32 then def "" : BasicNVPTXInst<(outs B32:$result), (ins B32:$src), - "cvta." # Str # ".u32", []>, Requires<Preds>; + "cvta." # Str # ".u32">, Requires<Preds>; def _64 : BasicNVPTXInst<(outs B64:$result), (ins B64:$src), - "cvta." # Str # ".u64", []>, Requires<Preds>; + "cvta." # Str # ".u64">, Requires<Preds>; } multiclass G_TO_NG<string Str, bit Supports32 = 1, list<Predicate> Preds = []> { if Supports32 then def "" : BasicNVPTXInst<(outs B32:$result), (ins B32:$src), - "cvta.to." # Str # ".u32", []>, Requires<Preds>; + "cvta.to." # Str # ".u32">, Requires<Preds>; def _64 : BasicNVPTXInst<(outs B64:$result), (ins B64:$src), - "cvta.to." # Str # ".u64", []>, Requires<Preds>; + "cvta.to." # Str # ".u64">, Requires<Preds>; } foreach space = ["local", "shared", "global", "const", "param"] in { @@ -4465,9 +4563,9 @@ def INT_PTX_SREG_LANEMASK_GT : PTX_READ_SREG_R32<"lanemask_gt", int_nvvm_read_ptx_sreg_lanemask_gt>; let hasSideEffects = 1 in { -def SREG_CLOCK : PTX_READ_SREG_R32<"clock", int_nvvm_read_ptx_sreg_clock>; -def SREG_CLOCK64 : PTX_READ_SREG_R64<"clock64", int_nvvm_read_ptx_sreg_clock64>; -def SREG_GLOBALTIMER : PTX_READ_SREG_R64<"globaltimer", int_nvvm_read_ptx_sreg_globaltimer>; + def SREG_CLOCK : PTX_READ_SREG_R32<"clock", int_nvvm_read_ptx_sreg_clock>; + def SREG_CLOCK64 : PTX_READ_SREG_R64<"clock64", int_nvvm_read_ptx_sreg_clock64>; + def SREG_GLOBALTIMER : PTX_READ_SREG_R64<"globaltimer", int_nvvm_read_ptx_sreg_globaltimer>; } def: Pat <(i64 (readcyclecounter)), (SREG_CLOCK64)>; @@ -4609,7 +4707,14 @@ class WMMA_REGINFO<WMMA_REGS r, string op> !and(!eq(op, "ldmatrix"), !eq(ptx_elt_type, "b8x16.b4x16_p64"), - !eq(geom, "m8n16")) : [hasSM<100>, hasArchAccelFeatures, hasPTX<86>]); + !eq(geom, "m8n16")) : [hasSM<100>, hasArchAccelFeatures, hasPTX<86>], + + !and(!eq(op, "stmatrix"),!eq(ptx_elt_type, "b16"), + !eq(geom, "m8n8")) : [hasSM<90>, hasPTX<78>], + + !and(!eq(op, "stmatrix"), + !eq(ptx_elt_type, "b8"), + !eq(geom, "m16n8")) : [hasSM<100>, hasArchAccelFeatures, hasPTX<86>]); // template DAGs for instruction inputs/output. dag Outs = !dag(outs, ptx_regs, reg_names); @@ -4890,6 +4995,42 @@ defset list<WMMA_INSTR> LDMATRIXs = { } // transposed } // defset +// +// stmatrix.sync.aligned.m8n8[|.trans][|.shared].b16 +// +class STMATRIX<WMMA_REGINFO Frag, bit Transposed, string Space> + : WMMA_INSTR<STMATRIX_NAME<Frag, Transposed>.record, [!con((ins ADDR:$dst), Frag.Ins)]>, + Requires<Frag.Predicates> { + // Build PatFrag that only matches particular address space. + dag PFOperands = !con((ops node:$dst), + !dag(ops, !listsplat(node, !size(Frag.regs)), Frag.reg_names)); + PatFrag IntrFrag = PatFrag<PFOperands, + !foreach(tmp, PFOperands, !subst(ops, Intr, tmp)), + !cond(!eq(Space, ".shared"): AS_match.shared, + true: AS_match.generic)>; + // Build AS-constrained pattern. + let IntrinsicPattern = BuildPatternPF<IntrFrag, Args>.ret; + let OutOperandList = (outs); + let InOperandList = !con(Args, (ins MmaCode:$ptx)); + let AsmString = "stmatrix.sync.aligned." + # Frag.geom + # "." # Frag.frag + # !if(Transposed, ".trans", "") + # Space + # "." # Frag.ptx_elt_type + # " [$dst], " # Frag.regstring # ";"; +} + +// Create all stmatrix variants +defset list<WMMA_INSTR> STMATRIXs = { + foreach transposed = [false, true] in {foreach space = [".shared", ""] in { + foreach frag = NVVM_MMA_OPS.all_stmatrix_ops in + if NVVM_STMATRIX_SUPPORTED<frag, transposed>.ret then + def : STMATRIX<WMMA_REGINFO<frag, "stmatrix">, transposed, space>; + } // space + } // transposed +} // defset + // Constructing non-flat DAGs is still a pain. I can't !subst a dag node with a // dag, so the ptx.version must be appended *after* foreach replaces 'ins' with // the instruction record. @@ -4900,41 +5041,40 @@ class MMA_PAT<WMMA_INSTR wi> Requires<wi.Predicates>; // Build intrinsic->instruction patterns for all MMA instructions. -foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs, LDMATRIXs) in +foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs, LDMATRIXs, STMATRIXs) in def : MMA_PAT<mma>; multiclass MAPA<string suffix, Intrinsic Intr> { - def _32: BasicNVPTXInst<(outs B32:$d), (ins B32:$a, B32:$b), - "mapa" # suffix # ".u32", - [(set i32:$d, (Intr i32:$a, i32:$b))]>, - Requires<[hasSM<90>, hasPTX<78>]>; - def _32i: BasicNVPTXInst<(outs B32:$d), (ins B32:$a, i32imm:$b), - "mapa" # suffix # ".u32", - [(set i32:$d, (Intr i32:$a, imm:$b))]>, - Requires<[hasSM<90>, hasPTX<78>]>; - def _64: BasicNVPTXInst<(outs B64:$d), (ins B64:$a, B32:$b), - "mapa" # suffix # ".u64", - [(set i64:$d, (Intr i64:$a, i32:$b))]>, - Requires<[hasSM<90>, hasPTX<78>]>; - def _64i: BasicNVPTXInst<(outs B64:$d), (ins B64:$a, i32imm:$b), - "mapa" # suffix # ".u64", - [(set i64:$d, (Intr i64:$a, imm:$b))]>, - Requires<[hasSM<90>, hasPTX<78>]>; + let Predicates = [hasSM<90>, hasPTX<78>] in { + def _32: BasicNVPTXInst<(outs B32:$d), (ins B32:$a, B32:$b), + "mapa" # suffix # ".u32", + [(set i32:$d, (Intr i32:$a, i32:$b))]>; + def _32i: BasicNVPTXInst<(outs B32:$d), (ins B32:$a, i32imm:$b), + "mapa" # suffix # ".u32", + [(set i32:$d, (Intr i32:$a, imm:$b))]>; + def _64: BasicNVPTXInst<(outs B64:$d), (ins B64:$a, B32:$b), + "mapa" # suffix # ".u64", + [(set i64:$d, (Intr i64:$a, i32:$b))]>; + def _64i: BasicNVPTXInst<(outs B64:$d), (ins B64:$a, i32imm:$b), + "mapa" # suffix # ".u64", + [(set i64:$d, (Intr i64:$a, imm:$b))]>; + } } + defm mapa : MAPA<"", int_nvvm_mapa>; defm mapa_shared_cluster : MAPA<".shared::cluster", int_nvvm_mapa_shared_cluster>; multiclass GETCTARANK<string suffix, Intrinsic Intr> { - def _32: BasicNVPTXInst<(outs B32:$d), (ins B32:$a), - "getctarank" # suffix # ".u32", - [(set i32:$d, (Intr i32:$a))]>, - Requires<[hasSM<90>, hasPTX<78>]>; - def _64: BasicNVPTXInst<(outs B32:$d), (ins B64:$a), - "getctarank" # suffix # ".u64", - [(set i32:$d, (Intr i64:$a))]>, - Requires<[hasSM<90>, hasPTX<78>]>; + let Predicates = [hasSM<90>, hasPTX<78>] in { + def _32: BasicNVPTXInst<(outs B32:$d), (ins B32:$a), + "getctarank" # suffix # ".u32", + [(set i32:$d, (Intr i32:$a))]>; + def _64: BasicNVPTXInst<(outs B32:$d), (ins B64:$a), + "getctarank" # suffix # ".u64", + [(set i32:$d, (Intr i64:$a))]>; + } } defm getctarank : GETCTARANK<"", int_nvvm_getctarank>; @@ -4973,29 +5113,25 @@ def INT_NVVM_WGMMA_WAIT_GROUP_SYNC_ALIGNED : BasicNVPTXInst<(outs), (ins i64imm: [(int_nvvm_wgmma_wait_group_sync_aligned timm:$n)]>, Requires<[hasSM90a, hasPTX<80>]>; } // isConvergent = true -def GRIDDEPCONTROL_LAUNCH_DEPENDENTS : - BasicNVPTXInst<(outs), (ins), - "griddepcontrol.launch_dependents", - [(int_nvvm_griddepcontrol_launch_dependents)]>, - Requires<[hasSM<90>, hasPTX<78>]>; - -def GRIDDEPCONTROL_WAIT : - BasicNVPTXInst<(outs), (ins), - "griddepcontrol.wait", - [(int_nvvm_griddepcontrol_wait)]>, - Requires<[hasSM<90>, hasPTX<78>]>; +let Predicates = [hasSM<90>, hasPTX<78>] in { + def GRIDDEPCONTROL_LAUNCH_DEPENDENTS : + BasicNVPTXInst<(outs), (ins), "griddepcontrol.launch_dependents", + [(int_nvvm_griddepcontrol_launch_dependents)]>; + def GRIDDEPCONTROL_WAIT : + BasicNVPTXInst<(outs), (ins), "griddepcontrol.wait", + [(int_nvvm_griddepcontrol_wait)]>; +} def INT_EXIT : BasicNVPTXInst<(outs), (ins), "exit", [(int_nvvm_exit)]>; // Tcgen05 intrinsics -let isConvergent = true in { +let isConvergent = true, Predicates = [hasTcgen05Instructions] in { multiclass TCGEN05_ALLOC_INTR<string AS, string num, Intrinsic Intr> { def "" : BasicNVPTXInst<(outs), (ins ADDR:$dst, B32:$ncols), "tcgen05.alloc.cta_group::" # num # ".sync.aligned" # AS # ".b32", - [(Intr addr:$dst, B32:$ncols)]>, - Requires<[hasTcgen05Instructions]>; + [(Intr addr:$dst, B32:$ncols)]>; } defm TCGEN05_ALLOC_CG1 : TCGEN05_ALLOC_INTR<"", "1", int_nvvm_tcgen05_alloc_cg1>; @@ -5008,8 +5144,7 @@ multiclass TCGEN05_DEALLOC_INTR<string num, Intrinsic Intr> { def "" : BasicNVPTXInst<(outs), (ins B32:$tmem_addr, B32:$ncols), "tcgen05.dealloc.cta_group::" # num # ".sync.aligned.b32", - [(Intr B32:$tmem_addr, B32:$ncols)]>, - Requires<[hasTcgen05Instructions]>; + [(Intr B32:$tmem_addr, B32:$ncols)]>; } defm TCGEN05_DEALLOC_CG1: TCGEN05_DEALLOC_INTR<"1", int_nvvm_tcgen05_dealloc_cg1>; defm TCGEN05_DEALLOC_CG2: TCGEN05_DEALLOC_INTR<"2", int_nvvm_tcgen05_dealloc_cg2>; @@ -5017,19 +5152,13 @@ defm TCGEN05_DEALLOC_CG2: TCGEN05_DEALLOC_INTR<"2", int_nvvm_tcgen05_dealloc_cg2 multiclass TCGEN05_RELINQ_PERMIT_INTR<string num, Intrinsic Intr> { def "" : BasicNVPTXInst<(outs), (ins), "tcgen05.relinquish_alloc_permit.cta_group::" # num # ".sync.aligned", - [(Intr)]>, - Requires<[hasTcgen05Instructions]>; + [(Intr)]>; } defm TCGEN05_RELINQ_CG1: TCGEN05_RELINQ_PERMIT_INTR<"1", int_nvvm_tcgen05_relinq_alloc_permit_cg1>; defm TCGEN05_RELINQ_CG2: TCGEN05_RELINQ_PERMIT_INTR<"2", int_nvvm_tcgen05_relinq_alloc_permit_cg2>; -def tcgen05_wait_ld: BasicNVPTXInst<(outs), (ins), "tcgen05.wait::ld.sync.aligned", - [(int_nvvm_tcgen05_wait_ld)]>, - Requires<[hasTcgen05Instructions]>; - -def tcgen05_wait_st: BasicNVPTXInst<(outs), (ins), "tcgen05.wait::st.sync.aligned", - [(int_nvvm_tcgen05_wait_st)]>, - Requires<[hasTcgen05Instructions]>; +def tcgen05_wait_ld: NullaryInst<"tcgen05.wait::ld.sync.aligned", int_nvvm_tcgen05_wait_ld>; +def tcgen05_wait_st: NullaryInst<"tcgen05.wait::st.sync.aligned", int_nvvm_tcgen05_wait_st>; multiclass TCGEN05_COMMIT_INTR<string AS, string num> { defvar prefix = "tcgen05.commit.cta_group::" # num #".mbarrier::arrive::one.shared::cluster"; @@ -5040,12 +5169,10 @@ multiclass TCGEN05_COMMIT_INTR<string AS, string num> { def "" : BasicNVPTXInst<(outs), (ins ADDR:$mbar), prefix # ".b64", - [(Intr addr:$mbar)]>, - Requires<[hasTcgen05Instructions]>; + [(Intr addr:$mbar)]>; def _MC : BasicNVPTXInst<(outs), (ins ADDR:$mbar, B16:$mc), prefix # ".multicast::cluster.b64", - [(IntrMC addr:$mbar, B16:$mc)]>, - Requires<[hasTcgen05Instructions]>; + [(IntrMC addr:$mbar, B16:$mc)]>; } defm TCGEN05_COMMIT_CG1 : TCGEN05_COMMIT_INTR<"", "1">; @@ -5057,8 +5184,7 @@ multiclass TCGEN05_SHIFT_INTR<string num, Intrinsic Intr> { def "" : BasicNVPTXInst<(outs), (ins ADDR:$tmem_addr), "tcgen05.shift.cta_group::" # num # ".down", - [(Intr addr:$tmem_addr)]>, - Requires<[hasTcgen05Instructions]>; + [(Intr addr:$tmem_addr)]>; } defm TCGEN05_SHIFT_CG1: TCGEN05_SHIFT_INTR<"1", int_nvvm_tcgen05_shift_down_cg1>; defm TCGEN05_SHIFT_CG2: TCGEN05_SHIFT_INTR<"2", int_nvvm_tcgen05_shift_down_cg2>; @@ -5078,13 +5204,11 @@ multiclass TCGEN05_CP_INTR<string shape, string src_fmt, string mc = ""> { def _cg1 : BasicNVPTXInst<(outs), (ins ADDR:$tmem_addr, B64:$sdesc), "tcgen05.cp.cta_group::1." # shape_mc_asm # fmt_asm, - [(IntrCG1 addr:$tmem_addr, B64:$sdesc)]>, - Requires<[hasTcgen05Instructions]>; + [(IntrCG1 addr:$tmem_addr, B64:$sdesc)]>; def _cg2 : BasicNVPTXInst<(outs), (ins ADDR:$tmem_addr, B64:$sdesc), "tcgen05.cp.cta_group::2." # shape_mc_asm # fmt_asm, - [(IntrCG2 addr:$tmem_addr, B64:$sdesc)]>, - Requires<[hasTcgen05Instructions]>; + [(IntrCG2 addr:$tmem_addr, B64:$sdesc)]>; } foreach src_fmt = ["", "b6x16_p32", "b4x16_p64"] in { @@ -5097,17 +5221,13 @@ foreach src_fmt = ["", "b6x16_p32", "b4x16_p64"] in { } } // isConvergent -let hasSideEffects = 1 in { +let hasSideEffects = 1, Predicates = [hasTcgen05Instructions] in { -def tcgen05_fence_before_thread_sync: BasicNVPTXInst<(outs), (ins), - "tcgen05.fence::before_thread_sync", - [(int_nvvm_tcgen05_fence_before_thread_sync)]>, - Requires<[hasTcgen05Instructions]>; + def tcgen05_fence_before_thread_sync: NullaryInst< + "tcgen05.fence::before_thread_sync", int_nvvm_tcgen05_fence_before_thread_sync>; -def tcgen05_fence_after_thread_sync: BasicNVPTXInst<(outs), (ins), - "tcgen05.fence::after_thread_sync", - [(int_nvvm_tcgen05_fence_after_thread_sync)]>, - Requires<[hasTcgen05Instructions]>; + def tcgen05_fence_after_thread_sync: NullaryInst< + "tcgen05.fence::after_thread_sync", int_nvvm_tcgen05_fence_after_thread_sync>; } // hasSideEffects @@ -5200,17 +5320,17 @@ foreach shape = ["16x64b", "16x128b", "16x256b", "32x32b", "16x32bx2"] in { // Bulk store instructions def st_bulk_imm : TImmLeaf<i64, [{ return Imm == 0; }]>; -def INT_NVVM_ST_BULK_GENERIC : - BasicNVPTXInst<(outs), (ins ADDR:$dest_addr, B64:$size, i64imm:$value), - "st.bulk", - [(int_nvvm_st_bulk addr:$dest_addr, i64:$size, st_bulk_imm:$value)]>, - Requires<[hasSM<100>, hasPTX<86>]>; +let Predicates = [hasSM<100>, hasPTX<86>] in { + def INT_NVVM_ST_BULK_GENERIC : + BasicNVPTXInst<(outs), (ins ADDR:$dest_addr, B64:$size, i64imm:$value), + "st.bulk", + [(int_nvvm_st_bulk addr:$dest_addr, i64:$size, st_bulk_imm:$value)]>; -def INT_NVVM_ST_BULK_SHARED_CTA: - BasicNVPTXInst<(outs), (ins ADDR:$dest_addr, B64:$size, i64imm:$value), - "st.bulk.shared::cta", - [(int_nvvm_st_bulk_shared_cta addr:$dest_addr, i64:$size, st_bulk_imm:$value)]>, - Requires<[hasSM<100>, hasPTX<86>]>; + def INT_NVVM_ST_BULK_SHARED_CTA: + BasicNVPTXInst<(outs), (ins ADDR:$dest_addr, B64:$size, i64imm:$value), + "st.bulk.shared::cta", + [(int_nvvm_st_bulk_shared_cta addr:$dest_addr, i64:$size, st_bulk_imm:$value)]>; +} // // clusterlaunchcontorl Instructions |