diff options
Diffstat (limited to 'llvm/lib/Target/NVPTX/NVPTXIntrinsics.td')
| -rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 314 |
1 files changed, 272 insertions, 42 deletions
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td index e8758aa..d18c7e2 100644 --- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -364,7 +364,42 @@ def INT_FENCE_SC_CLUSTER: NullaryInst<"fence.sc.cluster", int_nvvm_fence_sc_cluster>, Requires<[hasPTX<78>, hasSM<90>]>; +def INT_FENCE_MBARRIER_INIT_RELEASE_CLUSTER: + NullaryInst<"fence.mbarrier_init.release.cluster", + int_nvvm_fence_mbarrier_init_release_cluster>, + Requires<[hasPTX<80>, hasSM<90>]>; + +let Predicates = [hasPTX<86>, hasSM<90>] in { +def INT_FENCE_ACQUIRE_SYNC_RESTRICT_CLUSTER_CLUSTER: + NullaryInst<"fence.acquire.sync_restrict::shared::cluster.cluster", + int_nvvm_fence_acquire_sync_restrict_space_cluster_scope_cluster>; + +def INT_FENCE_RELEASE_SYNC_RESTRICT_CTA_CLUSTER: + NullaryInst<"fence.release.sync_restrict::shared::cta.cluster", + int_nvvm_fence_release_sync_restrict_space_cta_scope_cluster>; +} + // Proxy fence (uni-directional) +let Predicates = [hasPTX<86>, hasSM<90>] in { +def INT_NVVM_FENCE_PROXY_ASYNC_GENERIC_ACQUIRE_SYNC_RESTRICT_SPACE_CLUSTER_SCOPE_CLUSTER: + NullaryInst<"fence.proxy.async::generic.acquire.sync_restrict::shared::cluster.cluster", + int_nvvm_fence_proxy_async_generic_acquire_sync_restrict_space_cluster_scope_cluster>; + +def INT_NVVM_FENCE_PROXY_ASYNC_GENERIC_RELEASE_SYNC_RESTRICT_SPACE_CTA_SCOPE_CLUSTER: + NullaryInst<"fence.proxy.async::generic.release.sync_restrict::shared::cta.cluster", + int_nvvm_fence_proxy_async_generic_release_sync_restrict_space_cta_scope_cluster>; +} + +// Proxy fence (bi-directional) +foreach proxykind = ["alias", "async", "async.global", "async.shared_cta", + "async.shared_cluster"] in { + defvar Preds = !if(!eq(proxykind, "alias"), [hasPTX<75>, hasSM<70>], + [hasPTX<80>, hasSM<90>]); + defvar Intr = IntrinsicName<"llvm.nvvm.fence.proxy." # proxykind>; + def : NullaryInst<"fence.proxy." # !subst("_", "::", proxykind), + !cast<Intrinsic>(Intr.record_name)>, Requires<Preds>; +} + class FENCE_PROXY_TENSORMAP_GENERIC_RELEASE<string Scope, Intrinsic Intr> : NullaryInst<"fence.proxy.tensormap::generic.release." # Scope, Intr>, Requires<[hasPTX<83>, hasSM<90>]>; @@ -497,6 +532,10 @@ class CpAsyncBulkStr<bit mc, bit ch, bit mask = 0> { # !if(mc, ".multicast::cluster", "") # !if(ch, ".L2::cache_hint", ""); + // Global to Shared CTA memory + string G2S_CTA = "cp.async.bulk.shared::cta.global.mbarrier::complete_tx::bytes" + # !if(ch, ".L2::cache_hint", ""); + // Shared CTA to Cluster memory string C2C = "cp.async.bulk.shared::cluster.shared::cta.mbarrier::complete_tx::bytes"; } @@ -543,6 +582,21 @@ multiclass CP_ASYNC_BULK_G2S_INTR<bit has_ch> { defm CP_ASYNC_BULK_G2S : CP_ASYNC_BULK_G2S_INTR<has_ch = 0>; defm CP_ASYNC_BULK_G2S_CH : CP_ASYNC_BULK_G2S_INTR<has_ch = 1>; +multiclass CP_ASYNC_BULK_G2S_CTA_INTR<bit has_ch> { + defvar Intr = int_nvvm_cp_async_bulk_global_to_shared_cta; + + def "" : NVPTXInst<(outs), + (ins ADDR:$dst, ADDR:$mbar, ADDR:$src, + B32:$size, B64:$ch), + !if(has_ch, + CpAsyncBulkStr<0, 1>.G2S_CTA # " [$dst], [$src], $size, [$mbar], $ch;", + CpAsyncBulkStr<0, 0>.G2S_CTA # " [$dst], [$src], $size, [$mbar];"), + [(Intr addr:$dst, addr:$mbar, addr:$src, i32:$size, i64:$ch, !if(has_ch, -1, 0))]>, + Requires<[hasPTX<86>, hasSM<90>]>; +} +defm CP_ASYNC_BULK_G2S_CTA : CP_ASYNC_BULK_G2S_CTA_INTR<has_ch = 0>; +defm CP_ASYNC_BULK_G2S_CTA_CH : CP_ASYNC_BULK_G2S_CTA_INTR<has_ch = 1>; + def CP_ASYNC_BULK_CTA_TO_CLUSTER : NVPTXInst<(outs), (ins ADDR:$dst, ADDR:$mbar, ADDR:$src, B32:$size), CpAsyncBulkStr<0, 0>.C2C # " [$dst], [$src], $size, [$mbar];", @@ -1562,12 +1616,17 @@ def : Pat<(int_nvvm_saturate_d f64:$a), (CVT_f64_f64 $a, CvtSAT)>; // Exp2 Log2 // -def : Pat<(int_nvvm_ex2_approx_ftz_f f32:$a), (EX2_APPROX_f32 $a, FTZ)>; -def : Pat<(int_nvvm_ex2_approx_f f32:$a), (EX2_APPROX_f32 $a, NoFTZ)>; +def : Pat<(f32 (int_nvvm_ex2_approx_ftz f32:$a)), (EX2_APPROX_f32 $a, FTZ)>; +def : Pat<(f32 (int_nvvm_ex2_approx f32:$a)), (EX2_APPROX_f32 $a, NoFTZ)>; let Predicates = [hasPTX<70>, hasSM<75>] in { - def : Pat<(int_nvvm_ex2_approx_f16 f16:$a), (EX2_APPROX_f16 $a)>; - def : Pat<(int_nvvm_ex2_approx_f16x2 v2f16:$a), (EX2_APPROX_f16x2 $a)>; + def : Pat<(f16 (int_nvvm_ex2_approx f16:$a)), (EX2_APPROX_f16 $a)>; + def : Pat<(v2f16 (int_nvvm_ex2_approx v2f16:$a)), (EX2_APPROX_f16x2 $a)>; +} + +let Predicates = [hasPTX<78>, hasSM<90>] in { + def : Pat<(bf16 (int_nvvm_ex2_approx_ftz bf16:$a)), (EX2_APPROX_bf16 $a)>; + def : Pat<(v2bf16 (int_nvvm_ex2_approx_ftz v2bf16:$a)), (EX2_APPROX_bf16x2 $a)>; } def LG2_APPROX_f32 : @@ -1893,7 +1952,12 @@ def : Pat<(int_nvvm_ff2bf16x2_rn f32:$a, f32:$b), (CVT_bf16x2_f32 $a, $b, C def : Pat<(int_nvvm_ff2bf16x2_rn_relu f32:$a, f32:$b), (CVT_bf16x2_f32 $a, $b, CvtRN_RELU)>; def : Pat<(int_nvvm_ff2bf16x2_rz f32:$a, f32:$b), (CVT_bf16x2_f32 $a, $b, CvtRZ)>; def : Pat<(int_nvvm_ff2bf16x2_rz_relu f32:$a, f32:$b), (CVT_bf16x2_f32 $a, $b, CvtRZ_RELU)>; - +let Predicates = [hasPTX<81>, hasSM<80>] in { + def : Pat<(int_nvvm_ff2bf16x2_rn_satfinite f32:$a, f32:$b), (CVT_bf16x2_f32_sf $a, $b, CvtRN)>; + def : Pat<(int_nvvm_ff2bf16x2_rn_relu_satfinite f32:$a, f32:$b), (CVT_bf16x2_f32_sf $a, $b, CvtRN_RELU)>; + def : Pat<(int_nvvm_ff2bf16x2_rz_satfinite f32:$a, f32:$b), (CVT_bf16x2_f32_sf $a, $b, CvtRZ)>; + def : Pat<(int_nvvm_ff2bf16x2_rz_relu_satfinite f32:$a, f32:$b), (CVT_bf16x2_f32_sf $a, $b, CvtRZ_RELU)>; +} let Predicates = [hasPTX<87>, hasSM100aOrSM103a] in { def : Pat<(int_nvvm_ff2bf16x2_rs f32:$a, f32:$b, i32:$c), (CVT_bf16x2_f32_rs $a, $b, $c, CvtRS)>; @@ -1909,6 +1973,12 @@ def : Pat<(int_nvvm_ff2f16x2_rn f32:$a, f32:$b), (CVT_f16x2_f32 $a, $b, Cvt def : Pat<(int_nvvm_ff2f16x2_rn_relu f32:$a, f32:$b), (CVT_f16x2_f32 $a, $b, CvtRN_RELU)>; def : Pat<(int_nvvm_ff2f16x2_rz f32:$a, f32:$b), (CVT_f16x2_f32 $a, $b, CvtRZ)>; def : Pat<(int_nvvm_ff2f16x2_rz_relu f32:$a, f32:$b), (CVT_f16x2_f32 $a, $b, CvtRZ_RELU)>; +let Predicates = [hasPTX<81>, hasSM<80>] in { + def : Pat<(int_nvvm_ff2f16x2_rn_satfinite f32:$a, f32:$b), (CVT_f16x2_f32_sf $a, $b, CvtRN)>; + def : Pat<(int_nvvm_ff2f16x2_rn_relu_satfinite f32:$a, f32:$b), (CVT_f16x2_f32_sf $a, $b, CvtRN_RELU)>; + def : Pat<(int_nvvm_ff2f16x2_rz_satfinite f32:$a, f32:$b), (CVT_f16x2_f32_sf $a, $b, CvtRZ)>; + def : Pat<(int_nvvm_ff2f16x2_rz_relu_satfinite f32:$a, f32:$b), (CVT_f16x2_f32_sf $a, $b, CvtRZ_RELU)>; +} let Predicates = [hasPTX<87>, hasSM100aOrSM103a] in { def : Pat<(int_nvvm_ff2f16x2_rs f32:$a, f32:$b, i32:$c), @@ -1924,6 +1994,23 @@ def : Pat<(int_nvvm_f2bf16_rn f32:$a), (CVT_bf16_f32 $a, CvtRN)>; def : Pat<(int_nvvm_f2bf16_rn_relu f32:$a), (CVT_bf16_f32 $a, CvtRN_RELU)>; def : Pat<(int_nvvm_f2bf16_rz f32:$a), (CVT_bf16_f32 $a, CvtRZ)>; def : Pat<(int_nvvm_f2bf16_rz_relu f32:$a), (CVT_bf16_f32 $a, CvtRZ_RELU)>; +let Predicates = [hasPTX<81>, hasSM<80>] in { + def : Pat<(int_nvvm_f2bf16_rz_satfinite f32:$a), (CVT_bf16_f32_sf $a, CvtRZ)>; + def : Pat<(int_nvvm_f2bf16_rz_relu_satfinite f32:$a), (CVT_bf16_f32_sf $a, CvtRZ_RELU)>; + def : Pat<(int_nvvm_f2bf16_rn_satfinite f32:$a), (CVT_bf16_f32_sf $a, CvtRN)>; + def : Pat<(int_nvvm_f2bf16_rn_relu_satfinite f32:$a), (CVT_bf16_f32_sf $a, CvtRN_RELU)>; +} + +def : Pat<(int_nvvm_f2f16_rn f32:$a), (CVT_f16_f32 $a, CvtRN)>; +def : Pat<(int_nvvm_f2f16_rn_relu f32:$a), (CVT_f16_f32 $a, CvtRN_RELU)>; +def : Pat<(int_nvvm_f2f16_rz f32:$a), (CVT_f16_f32 $a, CvtRZ)>; +def : Pat<(int_nvvm_f2f16_rz_relu f32:$a), (CVT_f16_f32 $a, CvtRZ_RELU)>; +let Predicates = [hasPTX<81>, hasSM<80>] in { + def : Pat<(int_nvvm_f2f16_rz_satfinite f32:$a), (CVT_f16_f32_sf $a, CvtRZ)>; + def : Pat<(int_nvvm_f2f16_rz_relu_satfinite f32:$a), (CVT_f16_f32_sf $a, CvtRZ_RELU)>; + def : Pat<(int_nvvm_f2f16_rn_satfinite f32:$a), (CVT_f16_f32_sf $a, CvtRN)>; + def : Pat<(int_nvvm_f2f16_rn_relu_satfinite f32:$a), (CVT_f16_f32_sf $a, CvtRN_RELU)>; +} def : Pat<(int_nvvm_lohi_i2d i32:$a, i32:$b), (V2I32toI64 $a, $b)>; def : Pat<(int_nvvm_d2i_lo f64:$a), (I64toI32L $a)>; @@ -1984,34 +2071,36 @@ def : Pat<(int_nvvm_ull2d_rp i64:$a), (CVT_f64_u64 $a, CvtRP)>; def : Pat<(int_nvvm_f2h_rn_ftz f32:$a), (CVT_f16_f32 $a, CvtRN_FTZ)>; def : Pat<(int_nvvm_f2h_rn f32:$a), (CVT_f16_f32 $a, CvtRN)>; -def : Pat<(int_nvvm_ff_to_e4m3x2_rn f32:$a, f32:$b), - (CVT_e4m3x2_f32 $a, $b, CvtRN)>; -def : Pat<(int_nvvm_ff_to_e4m3x2_rn_relu f32:$a, f32:$b), - (CVT_e4m3x2_f32 $a, $b, CvtRN_RELU)>; -def : Pat<(int_nvvm_ff_to_e5m2x2_rn f32:$a, f32:$b), - (CVT_e5m2x2_f32 $a, $b, CvtRN)>; -def : Pat<(int_nvvm_ff_to_e5m2x2_rn_relu f32:$a, f32:$b), - (CVT_e5m2x2_f32 $a, $b, CvtRN_RELU)>; - -def : Pat<(int_nvvm_f16x2_to_e4m3x2_rn v2f16:$a), - (CVT_e4m3x2_f16x2 $a, CvtRN)>; -def : Pat<(int_nvvm_f16x2_to_e4m3x2_rn_relu v2f16:$a), - (CVT_e4m3x2_f16x2 $a, CvtRN_RELU)>; -def : Pat<(int_nvvm_f16x2_to_e5m2x2_rn v2f16:$a), - (CVT_e5m2x2_f16x2 $a, CvtRN)>; -def : Pat<(int_nvvm_f16x2_to_e5m2x2_rn_relu v2f16:$a), - (CVT_e5m2x2_f16x2 $a, CvtRN_RELU)>; - -def : Pat<(int_nvvm_e4m3x2_to_f16x2_rn i16:$a), - (CVT_f16x2_e4m3x2 $a, CvtRN)>; -def : Pat<(int_nvvm_e4m3x2_to_f16x2_rn_relu i16:$a), - (CVT_f16x2_e4m3x2 $a, CvtRN_RELU)>; -def : Pat<(int_nvvm_e5m2x2_to_f16x2_rn i16:$a), - (CVT_f16x2_e5m2x2 $a, CvtRN)>; -def : Pat<(int_nvvm_e5m2x2_to_f16x2_rn_relu i16:$a), - (CVT_f16x2_e5m2x2 $a, CvtRN_RELU)>; - -let Predicates = [hasPTX<86>, hasSM<100>, hasArchAccelFeatures] in { +let Predicates = [callSubtarget<"hasFP8ConversionSupport">] in { + def : Pat<(int_nvvm_ff_to_e4m3x2_rn f32:$a, f32:$b), + (CVT_e4m3x2_f32 $a, $b, CvtRN)>; + def : Pat<(int_nvvm_ff_to_e4m3x2_rn_relu f32:$a, f32:$b), + (CVT_e4m3x2_f32 $a, $b, CvtRN_RELU)>; + def : Pat<(int_nvvm_ff_to_e5m2x2_rn f32:$a, f32:$b), + (CVT_e5m2x2_f32 $a, $b, CvtRN)>; + def : Pat<(int_nvvm_ff_to_e5m2x2_rn_relu f32:$a, f32:$b), + (CVT_e5m2x2_f32 $a, $b, CvtRN_RELU)>; + + def : Pat<(int_nvvm_f16x2_to_e4m3x2_rn v2f16:$a), + (CVT_e4m3x2_f16x2 $a, CvtRN)>; + def : Pat<(int_nvvm_f16x2_to_e4m3x2_rn_relu v2f16:$a), + (CVT_e4m3x2_f16x2 $a, CvtRN_RELU)>; + def : Pat<(int_nvvm_f16x2_to_e5m2x2_rn v2f16:$a), + (CVT_e5m2x2_f16x2 $a, CvtRN)>; + def : Pat<(int_nvvm_f16x2_to_e5m2x2_rn_relu v2f16:$a), + (CVT_e5m2x2_f16x2 $a, CvtRN_RELU)>; + + def : Pat<(int_nvvm_e4m3x2_to_f16x2_rn i16:$a), + (CVT_f16x2_e4m3x2 $a, CvtRN)>; + def : Pat<(int_nvvm_e4m3x2_to_f16x2_rn_relu i16:$a), + (CVT_f16x2_e4m3x2 $a, CvtRN_RELU)>; + def : Pat<(int_nvvm_e5m2x2_to_f16x2_rn i16:$a), + (CVT_f16x2_e5m2x2 $a, CvtRN)>; + def : Pat<(int_nvvm_e5m2x2_to_f16x2_rn_relu i16:$a), + (CVT_f16x2_e5m2x2 $a, CvtRN_RELU)>; +} + +let Predicates = [callSubtarget<"hasNarrowFPConversionSupport">] in { def : Pat<(int_nvvm_ff_to_e2m3x2_rn_satfinite f32:$a, f32:$b), (CVT_e2m3x2_f32_sf $a, $b, CvtRN)>; def : Pat<(int_nvvm_ff_to_e2m3x2_rn_relu_satfinite f32:$a, f32:$b), @@ -2463,7 +2552,10 @@ def LDU_GLOBAL_v4i32 : VLDU_G_ELE_V4<B32>; // during the lifetime of the kernel. class LDG_G<NVPTXRegClass regclass> - : NVPTXInst<(outs regclass:$result), (ins AtomicCode:$Sign, i32imm:$fromWidth, ADDR:$src), + : NVPTXInst<(outs regclass:$result), + (ins AtomicCode:$Sign, i32imm:$fromWidth, + UsedBytesMask:$usedBytes, ADDR:$src), + "${usedBytes}" "ld.global.nc.${Sign:sign}$fromWidth \t$result, [$src];">; def LD_GLOBAL_NC_i16 : LDG_G<B16>; @@ -2475,19 +2567,25 @@ def LD_GLOBAL_NC_i64 : LDG_G<B64>; // Elementized vector ldg class VLDG_G_ELE_V2<NVPTXRegClass regclass> : NVPTXInst<(outs regclass:$dst1, regclass:$dst2), - (ins AtomicCode:$Sign, i32imm:$fromWidth, ADDR:$src), + (ins AtomicCode:$Sign, i32imm:$fromWidth, UsedBytesMask:$usedBytes, + ADDR:$src), + "${usedBytes}" "ld.global.nc.v2.${Sign:sign}$fromWidth \t{{$dst1, $dst2}}, [$src];">; class VLDG_G_ELE_V4<NVPTXRegClass regclass> : NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, regclass:$dst4), - (ins AtomicCode:$Sign, i32imm:$fromWidth, ADDR:$src), + (ins AtomicCode:$Sign, i32imm:$fromWidth, UsedBytesMask:$usedBytes, + ADDR:$src), + "${usedBytes}" "ld.global.nc.v4.${Sign:sign}$fromWidth \t{{$dst1, $dst2, $dst3, $dst4}}, [$src];">; class VLDG_G_ELE_V8<NVPTXRegClass regclass> : NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, regclass:$dst4, regclass:$dst5, regclass:$dst6, regclass:$dst7, regclass:$dst8), - (ins AtomicCode:$Sign, i32imm:$fromWidth, ADDR:$src), + (ins AtomicCode:$Sign, i32imm:$fromWidth, UsedBytesMask:$usedBytes, + ADDR:$src), + "${usedBytes}" "ld.global.nc.v8.${Sign:sign}$fromWidth \t{{$dst1, $dst2, $dst3, $dst4, $dst5, $dst6, $dst7, $dst8}}, [$src];">; // FIXME: 8-bit LDG should be fixed once LDG/LDU nodes are made into proper loads. @@ -4595,7 +4693,8 @@ def INT_PTX_SREG_WARPSIZE : // the fields commonly used to implement specific PTX instruction -- register // types and names, constraints, parts of assembly, etc. class WMMA_REGINFO<WMMA_REGS r, string op, string metadata = "", string kind = ""> - : WMMA_REGS<r.geom, r.frag, r.ptx_elt_type, !eq(op, "mma.sp")> { + : WMMA_REGS<r.geom, r.frag, r.ptx_elt_type, + !or(!eq(op, "mma.sp"), !eq(op, "mma.sp.block_scale"))> { // NVPTX register types used to carry fragment data. NVPTXRegClass regclass = !cond( !eq(ptx_elt_type, "e4m3") : B32, @@ -4635,6 +4734,9 @@ class WMMA_REGINFO<WMMA_REGS r, string op, string metadata = "", string kind = " // longer the case, we can concat all per-fragment predicates to enforce that // all fragments of the instruction are viable. list<Predicate> Predicates = !cond( + !or(!eq(op, "mma.block_scale"), + !eq(op, "mma.sp.block_scale")) : [hasSM120a, hasPTX<88>], + !or(!eq(ptx_elt_type, "e3m2"), !eq(ptx_elt_type, "e2m3"), !eq(ptx_elt_type, "e2m1"), @@ -4647,9 +4749,9 @@ class WMMA_REGINFO<WMMA_REGS r, string op, string metadata = "", string kind = " !or(!eq(ptx_elt_type, "e4m3"), !eq(ptx_elt_type, "e5m2")) : [hasSM<89>, hasPTX<84>], - !and(!eq(op, "mma.sp"), + !and(isSparse, !ne(metadata, "sp")) : [hasSM<80>, hasPTX<85>], - !eq(op, "mma.sp") : [hasSM<80>, hasPTX<71>], + isSparse : [hasSM<80>, hasPTX<71>], // fp16 -> fp16/fp32 @ m16n16k16 !and(!eq(geom, "m16n16k16"), @@ -4998,6 +5100,67 @@ defset list<WMMA_INSTR> MMAs = { } // defset } +// MMA.block_scale +class MMA_BLOCK_SCALE<WMMA_REGINFO FragA, WMMA_REGINFO FragB, + WMMA_REGINFO FragC, WMMA_REGINFO FragD, + string Kind, string SType, string ScaleVecSize> + : WMMA_INSTR<MMA_BLOCK_SCALE_NAME<Kind, SType, ScaleVecSize, + FragA, FragB, FragC, FragD>.record_name, + [FragA.Ins, FragB.Ins, FragC.Ins, + (ins B32:$scale_a, B16:$byte_id_a, + B16:$thread_id_a, B32:$scale_b, + B16:$byte_id_b, B16:$thread_id_b)]>, + // Requires does not seem to have effect on Instruction w/o Patterns. + // We set it here anyways and propagate to the Pat<> we construct below. + Requires<FragA.Predicates> { + let OutOperandList = FragD.Outs; + let InOperandList = !con(Args, (ins MmaCode:$ptx)); + string TypeList = !interleave([FragD.ptx_elt_type, + FragA.ptx_elt_type, + FragB.ptx_elt_type, + FragC.ptx_elt_type], "."); + string ScaleVecSizeStr = !cond( + !eq(ScaleVecSize, "") : "", + !eq(ScaleVecSize, ".scale_1x") : ".scale_vec::1X", + !eq(ScaleVecSize, ".scale_2x") : ".scale_vec::2X", + !eq(ScaleVecSize, ".scale_4x") : ".scale_vec::4X" + ); + let AsmString = "mma.sync.aligned." + # FragA.geom + # ".row.col" + # ".kind::" # Kind + # ".block_scale" + # ScaleVecSizeStr + # "." # TypeList + # "." # SType # " \n\t\t" + # FragD.regstring # ",\n\t\t" + # FragA.regstring # ",\n\t\t" + # FragB.regstring # ",\n\t\t" + # FragC.regstring # ",\n\t\t" + # "$scale_a, {{$byte_id_a, $thread_id_a}}" # ",\n\t\t" + # "$scale_b, {{$byte_id_b, $thread_id_b}};"; +} + +let isConvergent = true in { +defset list<WMMA_INSTR> MMA_BLOCK_SCALEs = { + foreach kind = ["mxf4", "mxf4nvf4", "mxf8f6f4"] in { + foreach scale_vec_size = ["", ".scale_1x", ".scale_2x", ".scale_4x"] in { + foreach stype = ["ue8m0", "ue4m3"] in { + foreach op = NVVM_MMA_OPS.all_mma_block_scale_ops in { + if NVVM_MMA_BLOCK_SCALE_SUPPORTED<op, kind, stype, scale_vec_size>.ret then { + def : MMA_BLOCK_SCALE<WMMA_REGINFO<op[0], "mma.block_scale", "", kind>, + WMMA_REGINFO<op[1], "mma.block_scale", "", kind>, + WMMA_REGINFO<op[2], "mma.block_scale", "", kind>, + WMMA_REGINFO<op[3], "mma.block_scale", "", kind>, + kind, stype, scale_vec_size>; + } + } // op + } // stype + } // scale_vec_size + } // kind +} // defset +} + // MMA SP class MMA_SP<WMMA_REGINFO FragA, WMMA_REGINFO FragB, WMMA_REGINFO FragC, WMMA_REGINFO FragD, @@ -5054,6 +5217,72 @@ defset list<WMMA_INSTR> MMA_SPs = { } // defset } +// MMA SP BLOCK SCALE +class MMA_SP_BLOCK_SCALE<WMMA_REGINFO FragA, WMMA_REGINFO FragB, + WMMA_REGINFO FragC, WMMA_REGINFO FragD, + string Kind, string SType, string ScaleVecSize> + : WMMA_INSTR<MMA_SP_BLOCK_SCALE_NAME<Kind, SType, ScaleVecSize, + FragA, FragB, FragC, FragD>.record_name, + [FragA.Ins, FragB.Ins, FragC.Ins, + (ins B32:$metadata, i32imm:$selector, + B32:$scale_a, B16:$byte_id_a, B16:$thread_id_a, + B32:$scale_b, B16:$byte_id_b, B16:$thread_id_b)]>, + // Requires does not seem to have effect on Instruction w/o Patterns. + // We set it here anyways and propagate to the Pat<> we construct below. + Requires<!listconcat(FragA.Predicates, + FragB.Predicates, + FragC.Predicates, + FragD.Predicates)> { + let OutOperandList = FragD.Outs; + let InOperandList = !con(Args, (ins MmaCode:$ptx)); + string TypeList = "." # FragD.ptx_elt_type + # "." # FragA.ptx_elt_type + # "." # FragB.ptx_elt_type + # "." # FragC.ptx_elt_type; + string ScaleVecSizeStr = !cond( + !eq(ScaleVecSize, "") : "", + !eq(ScaleVecSize, ".scale_1x") : ".scale_vec::1X", + !eq(ScaleVecSize, ".scale_2x") : ".scale_vec::2X", + !eq(ScaleVecSize, ".scale_4x") : ".scale_vec::4X" + ); + let AsmString = "mma.sp::ordered_metadata.sync.aligned." + # FragA.geom + # ".row.col" + # ".kind::" # Kind + # ".block_scale" + # ScaleVecSizeStr + # TypeList + # "." # SType # "\n\t\t" + # FragD.regstring # ",\n\t\t" + # FragA.regstring # ",\n\t\t" + # FragB.regstring # ",\n\t\t" + # FragC.regstring # ",\n\t\t" + # "$metadata" # ",\n\t\t" + # "$selector" # ",\n\t\t" + # "$scale_a, {{$byte_id_a, $thread_id_a}}" # ",\n\t\t" + # "$scale_b, {{$byte_id_b, $thread_id_b}};"; +} + +let isConvergent = true in { +defset list<WMMA_INSTR> MMA_SP_BLOCK_SCALEs = { + foreach kind = ["mxf4", "mxf4nvf4", "mxf8f6f4"] in { + foreach scale_vec_size = ["", ".scale_1x", ".scale_2x", ".scale_4x"] in { + foreach stype = ["ue8m0", "ue4m3"] in { + foreach op = NVVM_MMA_OPS.all_mma_sp_block_scale_ops in { + if NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED<op, kind, stype, scale_vec_size>.ret then { + def : MMA_SP_BLOCK_SCALE<WMMA_REGINFO<op[0], "mma.sp.block_scale", "sp::ordered_metadata", kind>, + WMMA_REGINFO<op[1], "mma.sp.block_scale", "sp::ordered_metadata", kind>, + WMMA_REGINFO<op[2], "mma.sp.block_scale", "sp::ordered_metadata", kind>, + WMMA_REGINFO<op[3], "mma.sp.block_scale", "sp::ordered_metadata", kind>, + kind, stype, scale_vec_size>; + } + } // op + } // stype + } // scale_vec_size + } // kind +} // defset +} + // // ldmatrix.sync.aligned.m8n8[|.trans][|.shared].b16 // @@ -5135,7 +5364,8 @@ class MMA_PAT<WMMA_INSTR wi> Requires<wi.Predicates>; // Build intrinsic->instruction patterns for all MMA instructions. -foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs, LDMATRIXs, STMATRIXs, MMA_SPs) in +foreach mma = !listconcat(MMAs, MMA_BLOCK_SCALEs, WMMAs, MMA_LDSTs, LDMATRIXs, + STMATRIXs, MMA_SPs, MMA_SP_BLOCK_SCALEs) in def : MMA_PAT<mma>; multiclass MAPA<string suffix, Intrinsic Intr> { @@ -5601,7 +5831,7 @@ class Tcgen05MMADisableOutputLaneSDNode<bit Sp, string ASpace, # "_DISABLE_OUTPUT_LANE_CG" # CtaGroup # !if(!eq(AShift, 1), "_ASHIFT", ""), Tcgen05MMADisableOutputLaneTypeProfile<Sp, ASpace, CtaGroup, ScaleInput>, - [SDNPHasChain, SDNPSideEffect]>; + [SDNPHasChain, SDNPSideEffect, SDNPMemOperand]>; class Tcgen05MMADisableOutputLaneInst<bit Sp, string ASpace, string Kind, int CtaGroup, string CollectorUsageStr, |
