aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/NVPTX/NVPTXIntrinsics.td')
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXIntrinsics.td314
1 files changed, 272 insertions, 42 deletions
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index e8758aa..d18c7e2 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -364,7 +364,42 @@ def INT_FENCE_SC_CLUSTER:
NullaryInst<"fence.sc.cluster", int_nvvm_fence_sc_cluster>,
Requires<[hasPTX<78>, hasSM<90>]>;
+def INT_FENCE_MBARRIER_INIT_RELEASE_CLUSTER:
+ NullaryInst<"fence.mbarrier_init.release.cluster",
+ int_nvvm_fence_mbarrier_init_release_cluster>,
+ Requires<[hasPTX<80>, hasSM<90>]>;
+
+let Predicates = [hasPTX<86>, hasSM<90>] in {
+def INT_FENCE_ACQUIRE_SYNC_RESTRICT_CLUSTER_CLUSTER:
+ NullaryInst<"fence.acquire.sync_restrict::shared::cluster.cluster",
+ int_nvvm_fence_acquire_sync_restrict_space_cluster_scope_cluster>;
+
+def INT_FENCE_RELEASE_SYNC_RESTRICT_CTA_CLUSTER:
+ NullaryInst<"fence.release.sync_restrict::shared::cta.cluster",
+ int_nvvm_fence_release_sync_restrict_space_cta_scope_cluster>;
+}
+
// Proxy fence (uni-directional)
+let Predicates = [hasPTX<86>, hasSM<90>] in {
+def INT_NVVM_FENCE_PROXY_ASYNC_GENERIC_ACQUIRE_SYNC_RESTRICT_SPACE_CLUSTER_SCOPE_CLUSTER:
+ NullaryInst<"fence.proxy.async::generic.acquire.sync_restrict::shared::cluster.cluster",
+ int_nvvm_fence_proxy_async_generic_acquire_sync_restrict_space_cluster_scope_cluster>;
+
+def INT_NVVM_FENCE_PROXY_ASYNC_GENERIC_RELEASE_SYNC_RESTRICT_SPACE_CTA_SCOPE_CLUSTER:
+ NullaryInst<"fence.proxy.async::generic.release.sync_restrict::shared::cta.cluster",
+ int_nvvm_fence_proxy_async_generic_release_sync_restrict_space_cta_scope_cluster>;
+}
+
+// Proxy fence (bi-directional)
+foreach proxykind = ["alias", "async", "async.global", "async.shared_cta",
+ "async.shared_cluster"] in {
+ defvar Preds = !if(!eq(proxykind, "alias"), [hasPTX<75>, hasSM<70>],
+ [hasPTX<80>, hasSM<90>]);
+ defvar Intr = IntrinsicName<"llvm.nvvm.fence.proxy." # proxykind>;
+ def : NullaryInst<"fence.proxy." # !subst("_", "::", proxykind),
+ !cast<Intrinsic>(Intr.record_name)>, Requires<Preds>;
+}
+
class FENCE_PROXY_TENSORMAP_GENERIC_RELEASE<string Scope, Intrinsic Intr> :
NullaryInst<"fence.proxy.tensormap::generic.release." # Scope, Intr>,
Requires<[hasPTX<83>, hasSM<90>]>;
@@ -497,6 +532,10 @@ class CpAsyncBulkStr<bit mc, bit ch, bit mask = 0> {
# !if(mc, ".multicast::cluster", "")
# !if(ch, ".L2::cache_hint", "");
+ // Global to Shared CTA memory
+ string G2S_CTA = "cp.async.bulk.shared::cta.global.mbarrier::complete_tx::bytes"
+ # !if(ch, ".L2::cache_hint", "");
+
// Shared CTA to Cluster memory
string C2C = "cp.async.bulk.shared::cluster.shared::cta.mbarrier::complete_tx::bytes";
}
@@ -543,6 +582,21 @@ multiclass CP_ASYNC_BULK_G2S_INTR<bit has_ch> {
defm CP_ASYNC_BULK_G2S : CP_ASYNC_BULK_G2S_INTR<has_ch = 0>;
defm CP_ASYNC_BULK_G2S_CH : CP_ASYNC_BULK_G2S_INTR<has_ch = 1>;
+multiclass CP_ASYNC_BULK_G2S_CTA_INTR<bit has_ch> {
+ defvar Intr = int_nvvm_cp_async_bulk_global_to_shared_cta;
+
+ def "" : NVPTXInst<(outs),
+ (ins ADDR:$dst, ADDR:$mbar, ADDR:$src,
+ B32:$size, B64:$ch),
+ !if(has_ch,
+ CpAsyncBulkStr<0, 1>.G2S_CTA # " [$dst], [$src], $size, [$mbar], $ch;",
+ CpAsyncBulkStr<0, 0>.G2S_CTA # " [$dst], [$src], $size, [$mbar];"),
+ [(Intr addr:$dst, addr:$mbar, addr:$src, i32:$size, i64:$ch, !if(has_ch, -1, 0))]>,
+ Requires<[hasPTX<86>, hasSM<90>]>;
+}
+defm CP_ASYNC_BULK_G2S_CTA : CP_ASYNC_BULK_G2S_CTA_INTR<has_ch = 0>;
+defm CP_ASYNC_BULK_G2S_CTA_CH : CP_ASYNC_BULK_G2S_CTA_INTR<has_ch = 1>;
+
def CP_ASYNC_BULK_CTA_TO_CLUSTER : NVPTXInst<(outs),
(ins ADDR:$dst, ADDR:$mbar, ADDR:$src, B32:$size),
CpAsyncBulkStr<0, 0>.C2C # " [$dst], [$src], $size, [$mbar];",
@@ -1562,12 +1616,17 @@ def : Pat<(int_nvvm_saturate_d f64:$a), (CVT_f64_f64 $a, CvtSAT)>;
// Exp2 Log2
//
-def : Pat<(int_nvvm_ex2_approx_ftz_f f32:$a), (EX2_APPROX_f32 $a, FTZ)>;
-def : Pat<(int_nvvm_ex2_approx_f f32:$a), (EX2_APPROX_f32 $a, NoFTZ)>;
+def : Pat<(f32 (int_nvvm_ex2_approx_ftz f32:$a)), (EX2_APPROX_f32 $a, FTZ)>;
+def : Pat<(f32 (int_nvvm_ex2_approx f32:$a)), (EX2_APPROX_f32 $a, NoFTZ)>;
let Predicates = [hasPTX<70>, hasSM<75>] in {
- def : Pat<(int_nvvm_ex2_approx_f16 f16:$a), (EX2_APPROX_f16 $a)>;
- def : Pat<(int_nvvm_ex2_approx_f16x2 v2f16:$a), (EX2_APPROX_f16x2 $a)>;
+ def : Pat<(f16 (int_nvvm_ex2_approx f16:$a)), (EX2_APPROX_f16 $a)>;
+ def : Pat<(v2f16 (int_nvvm_ex2_approx v2f16:$a)), (EX2_APPROX_f16x2 $a)>;
+}
+
+let Predicates = [hasPTX<78>, hasSM<90>] in {
+ def : Pat<(bf16 (int_nvvm_ex2_approx_ftz bf16:$a)), (EX2_APPROX_bf16 $a)>;
+ def : Pat<(v2bf16 (int_nvvm_ex2_approx_ftz v2bf16:$a)), (EX2_APPROX_bf16x2 $a)>;
}
def LG2_APPROX_f32 :
@@ -1893,7 +1952,12 @@ def : Pat<(int_nvvm_ff2bf16x2_rn f32:$a, f32:$b), (CVT_bf16x2_f32 $a, $b, C
def : Pat<(int_nvvm_ff2bf16x2_rn_relu f32:$a, f32:$b), (CVT_bf16x2_f32 $a, $b, CvtRN_RELU)>;
def : Pat<(int_nvvm_ff2bf16x2_rz f32:$a, f32:$b), (CVT_bf16x2_f32 $a, $b, CvtRZ)>;
def : Pat<(int_nvvm_ff2bf16x2_rz_relu f32:$a, f32:$b), (CVT_bf16x2_f32 $a, $b, CvtRZ_RELU)>;
-
+let Predicates = [hasPTX<81>, hasSM<80>] in {
+ def : Pat<(int_nvvm_ff2bf16x2_rn_satfinite f32:$a, f32:$b), (CVT_bf16x2_f32_sf $a, $b, CvtRN)>;
+ def : Pat<(int_nvvm_ff2bf16x2_rn_relu_satfinite f32:$a, f32:$b), (CVT_bf16x2_f32_sf $a, $b, CvtRN_RELU)>;
+ def : Pat<(int_nvvm_ff2bf16x2_rz_satfinite f32:$a, f32:$b), (CVT_bf16x2_f32_sf $a, $b, CvtRZ)>;
+ def : Pat<(int_nvvm_ff2bf16x2_rz_relu_satfinite f32:$a, f32:$b), (CVT_bf16x2_f32_sf $a, $b, CvtRZ_RELU)>;
+}
let Predicates = [hasPTX<87>, hasSM100aOrSM103a] in {
def : Pat<(int_nvvm_ff2bf16x2_rs f32:$a, f32:$b, i32:$c),
(CVT_bf16x2_f32_rs $a, $b, $c, CvtRS)>;
@@ -1909,6 +1973,12 @@ def : Pat<(int_nvvm_ff2f16x2_rn f32:$a, f32:$b), (CVT_f16x2_f32 $a, $b, Cvt
def : Pat<(int_nvvm_ff2f16x2_rn_relu f32:$a, f32:$b), (CVT_f16x2_f32 $a, $b, CvtRN_RELU)>;
def : Pat<(int_nvvm_ff2f16x2_rz f32:$a, f32:$b), (CVT_f16x2_f32 $a, $b, CvtRZ)>;
def : Pat<(int_nvvm_ff2f16x2_rz_relu f32:$a, f32:$b), (CVT_f16x2_f32 $a, $b, CvtRZ_RELU)>;
+let Predicates = [hasPTX<81>, hasSM<80>] in {
+ def : Pat<(int_nvvm_ff2f16x2_rn_satfinite f32:$a, f32:$b), (CVT_f16x2_f32_sf $a, $b, CvtRN)>;
+ def : Pat<(int_nvvm_ff2f16x2_rn_relu_satfinite f32:$a, f32:$b), (CVT_f16x2_f32_sf $a, $b, CvtRN_RELU)>;
+ def : Pat<(int_nvvm_ff2f16x2_rz_satfinite f32:$a, f32:$b), (CVT_f16x2_f32_sf $a, $b, CvtRZ)>;
+ def : Pat<(int_nvvm_ff2f16x2_rz_relu_satfinite f32:$a, f32:$b), (CVT_f16x2_f32_sf $a, $b, CvtRZ_RELU)>;
+}
let Predicates = [hasPTX<87>, hasSM100aOrSM103a] in {
def : Pat<(int_nvvm_ff2f16x2_rs f32:$a, f32:$b, i32:$c),
@@ -1924,6 +1994,23 @@ def : Pat<(int_nvvm_f2bf16_rn f32:$a), (CVT_bf16_f32 $a, CvtRN)>;
def : Pat<(int_nvvm_f2bf16_rn_relu f32:$a), (CVT_bf16_f32 $a, CvtRN_RELU)>;
def : Pat<(int_nvvm_f2bf16_rz f32:$a), (CVT_bf16_f32 $a, CvtRZ)>;
def : Pat<(int_nvvm_f2bf16_rz_relu f32:$a), (CVT_bf16_f32 $a, CvtRZ_RELU)>;
+let Predicates = [hasPTX<81>, hasSM<80>] in {
+ def : Pat<(int_nvvm_f2bf16_rz_satfinite f32:$a), (CVT_bf16_f32_sf $a, CvtRZ)>;
+ def : Pat<(int_nvvm_f2bf16_rz_relu_satfinite f32:$a), (CVT_bf16_f32_sf $a, CvtRZ_RELU)>;
+ def : Pat<(int_nvvm_f2bf16_rn_satfinite f32:$a), (CVT_bf16_f32_sf $a, CvtRN)>;
+ def : Pat<(int_nvvm_f2bf16_rn_relu_satfinite f32:$a), (CVT_bf16_f32_sf $a, CvtRN_RELU)>;
+}
+
+def : Pat<(int_nvvm_f2f16_rn f32:$a), (CVT_f16_f32 $a, CvtRN)>;
+def : Pat<(int_nvvm_f2f16_rn_relu f32:$a), (CVT_f16_f32 $a, CvtRN_RELU)>;
+def : Pat<(int_nvvm_f2f16_rz f32:$a), (CVT_f16_f32 $a, CvtRZ)>;
+def : Pat<(int_nvvm_f2f16_rz_relu f32:$a), (CVT_f16_f32 $a, CvtRZ_RELU)>;
+let Predicates = [hasPTX<81>, hasSM<80>] in {
+ def : Pat<(int_nvvm_f2f16_rz_satfinite f32:$a), (CVT_f16_f32_sf $a, CvtRZ)>;
+ def : Pat<(int_nvvm_f2f16_rz_relu_satfinite f32:$a), (CVT_f16_f32_sf $a, CvtRZ_RELU)>;
+ def : Pat<(int_nvvm_f2f16_rn_satfinite f32:$a), (CVT_f16_f32_sf $a, CvtRN)>;
+ def : Pat<(int_nvvm_f2f16_rn_relu_satfinite f32:$a), (CVT_f16_f32_sf $a, CvtRN_RELU)>;
+}
def : Pat<(int_nvvm_lohi_i2d i32:$a, i32:$b), (V2I32toI64 $a, $b)>;
def : Pat<(int_nvvm_d2i_lo f64:$a), (I64toI32L $a)>;
@@ -1984,34 +2071,36 @@ def : Pat<(int_nvvm_ull2d_rp i64:$a), (CVT_f64_u64 $a, CvtRP)>;
def : Pat<(int_nvvm_f2h_rn_ftz f32:$a), (CVT_f16_f32 $a, CvtRN_FTZ)>;
def : Pat<(int_nvvm_f2h_rn f32:$a), (CVT_f16_f32 $a, CvtRN)>;
-def : Pat<(int_nvvm_ff_to_e4m3x2_rn f32:$a, f32:$b),
- (CVT_e4m3x2_f32 $a, $b, CvtRN)>;
-def : Pat<(int_nvvm_ff_to_e4m3x2_rn_relu f32:$a, f32:$b),
- (CVT_e4m3x2_f32 $a, $b, CvtRN_RELU)>;
-def : Pat<(int_nvvm_ff_to_e5m2x2_rn f32:$a, f32:$b),
- (CVT_e5m2x2_f32 $a, $b, CvtRN)>;
-def : Pat<(int_nvvm_ff_to_e5m2x2_rn_relu f32:$a, f32:$b),
- (CVT_e5m2x2_f32 $a, $b, CvtRN_RELU)>;
-
-def : Pat<(int_nvvm_f16x2_to_e4m3x2_rn v2f16:$a),
- (CVT_e4m3x2_f16x2 $a, CvtRN)>;
-def : Pat<(int_nvvm_f16x2_to_e4m3x2_rn_relu v2f16:$a),
- (CVT_e4m3x2_f16x2 $a, CvtRN_RELU)>;
-def : Pat<(int_nvvm_f16x2_to_e5m2x2_rn v2f16:$a),
- (CVT_e5m2x2_f16x2 $a, CvtRN)>;
-def : Pat<(int_nvvm_f16x2_to_e5m2x2_rn_relu v2f16:$a),
- (CVT_e5m2x2_f16x2 $a, CvtRN_RELU)>;
-
-def : Pat<(int_nvvm_e4m3x2_to_f16x2_rn i16:$a),
- (CVT_f16x2_e4m3x2 $a, CvtRN)>;
-def : Pat<(int_nvvm_e4m3x2_to_f16x2_rn_relu i16:$a),
- (CVT_f16x2_e4m3x2 $a, CvtRN_RELU)>;
-def : Pat<(int_nvvm_e5m2x2_to_f16x2_rn i16:$a),
- (CVT_f16x2_e5m2x2 $a, CvtRN)>;
-def : Pat<(int_nvvm_e5m2x2_to_f16x2_rn_relu i16:$a),
- (CVT_f16x2_e5m2x2 $a, CvtRN_RELU)>;
-
-let Predicates = [hasPTX<86>, hasSM<100>, hasArchAccelFeatures] in {
+let Predicates = [callSubtarget<"hasFP8ConversionSupport">] in {
+ def : Pat<(int_nvvm_ff_to_e4m3x2_rn f32:$a, f32:$b),
+ (CVT_e4m3x2_f32 $a, $b, CvtRN)>;
+ def : Pat<(int_nvvm_ff_to_e4m3x2_rn_relu f32:$a, f32:$b),
+ (CVT_e4m3x2_f32 $a, $b, CvtRN_RELU)>;
+ def : Pat<(int_nvvm_ff_to_e5m2x2_rn f32:$a, f32:$b),
+ (CVT_e5m2x2_f32 $a, $b, CvtRN)>;
+ def : Pat<(int_nvvm_ff_to_e5m2x2_rn_relu f32:$a, f32:$b),
+ (CVT_e5m2x2_f32 $a, $b, CvtRN_RELU)>;
+
+ def : Pat<(int_nvvm_f16x2_to_e4m3x2_rn v2f16:$a),
+ (CVT_e4m3x2_f16x2 $a, CvtRN)>;
+ def : Pat<(int_nvvm_f16x2_to_e4m3x2_rn_relu v2f16:$a),
+ (CVT_e4m3x2_f16x2 $a, CvtRN_RELU)>;
+ def : Pat<(int_nvvm_f16x2_to_e5m2x2_rn v2f16:$a),
+ (CVT_e5m2x2_f16x2 $a, CvtRN)>;
+ def : Pat<(int_nvvm_f16x2_to_e5m2x2_rn_relu v2f16:$a),
+ (CVT_e5m2x2_f16x2 $a, CvtRN_RELU)>;
+
+ def : Pat<(int_nvvm_e4m3x2_to_f16x2_rn i16:$a),
+ (CVT_f16x2_e4m3x2 $a, CvtRN)>;
+ def : Pat<(int_nvvm_e4m3x2_to_f16x2_rn_relu i16:$a),
+ (CVT_f16x2_e4m3x2 $a, CvtRN_RELU)>;
+ def : Pat<(int_nvvm_e5m2x2_to_f16x2_rn i16:$a),
+ (CVT_f16x2_e5m2x2 $a, CvtRN)>;
+ def : Pat<(int_nvvm_e5m2x2_to_f16x2_rn_relu i16:$a),
+ (CVT_f16x2_e5m2x2 $a, CvtRN_RELU)>;
+}
+
+let Predicates = [callSubtarget<"hasNarrowFPConversionSupport">] in {
def : Pat<(int_nvvm_ff_to_e2m3x2_rn_satfinite f32:$a, f32:$b),
(CVT_e2m3x2_f32_sf $a, $b, CvtRN)>;
def : Pat<(int_nvvm_ff_to_e2m3x2_rn_relu_satfinite f32:$a, f32:$b),
@@ -2463,7 +2552,10 @@ def LDU_GLOBAL_v4i32 : VLDU_G_ELE_V4<B32>;
// during the lifetime of the kernel.
class LDG_G<NVPTXRegClass regclass>
- : NVPTXInst<(outs regclass:$result), (ins AtomicCode:$Sign, i32imm:$fromWidth, ADDR:$src),
+ : NVPTXInst<(outs regclass:$result),
+ (ins AtomicCode:$Sign, i32imm:$fromWidth,
+ UsedBytesMask:$usedBytes, ADDR:$src),
+ "${usedBytes}"
"ld.global.nc.${Sign:sign}$fromWidth \t$result, [$src];">;
def LD_GLOBAL_NC_i16 : LDG_G<B16>;
@@ -2475,19 +2567,25 @@ def LD_GLOBAL_NC_i64 : LDG_G<B64>;
// Elementized vector ldg
class VLDG_G_ELE_V2<NVPTXRegClass regclass> :
NVPTXInst<(outs regclass:$dst1, regclass:$dst2),
- (ins AtomicCode:$Sign, i32imm:$fromWidth, ADDR:$src),
+ (ins AtomicCode:$Sign, i32imm:$fromWidth, UsedBytesMask:$usedBytes,
+ ADDR:$src),
+ "${usedBytes}"
"ld.global.nc.v2.${Sign:sign}$fromWidth \t{{$dst1, $dst2}}, [$src];">;
class VLDG_G_ELE_V4<NVPTXRegClass regclass> :
NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, regclass:$dst4),
- (ins AtomicCode:$Sign, i32imm:$fromWidth, ADDR:$src),
+ (ins AtomicCode:$Sign, i32imm:$fromWidth, UsedBytesMask:$usedBytes,
+ ADDR:$src),
+ "${usedBytes}"
"ld.global.nc.v4.${Sign:sign}$fromWidth \t{{$dst1, $dst2, $dst3, $dst4}}, [$src];">;
class VLDG_G_ELE_V8<NVPTXRegClass regclass> :
NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, regclass:$dst4,
regclass:$dst5, regclass:$dst6, regclass:$dst7, regclass:$dst8),
- (ins AtomicCode:$Sign, i32imm:$fromWidth, ADDR:$src),
+ (ins AtomicCode:$Sign, i32imm:$fromWidth, UsedBytesMask:$usedBytes,
+ ADDR:$src),
+ "${usedBytes}"
"ld.global.nc.v8.${Sign:sign}$fromWidth \t{{$dst1, $dst2, $dst3, $dst4, $dst5, $dst6, $dst7, $dst8}}, [$src];">;
// FIXME: 8-bit LDG should be fixed once LDG/LDU nodes are made into proper loads.
@@ -4595,7 +4693,8 @@ def INT_PTX_SREG_WARPSIZE :
// the fields commonly used to implement specific PTX instruction -- register
// types and names, constraints, parts of assembly, etc.
class WMMA_REGINFO<WMMA_REGS r, string op, string metadata = "", string kind = "">
- : WMMA_REGS<r.geom, r.frag, r.ptx_elt_type, !eq(op, "mma.sp")> {
+ : WMMA_REGS<r.geom, r.frag, r.ptx_elt_type,
+ !or(!eq(op, "mma.sp"), !eq(op, "mma.sp.block_scale"))> {
// NVPTX register types used to carry fragment data.
NVPTXRegClass regclass = !cond(
!eq(ptx_elt_type, "e4m3") : B32,
@@ -4635,6 +4734,9 @@ class WMMA_REGINFO<WMMA_REGS r, string op, string metadata = "", string kind = "
// longer the case, we can concat all per-fragment predicates to enforce that
// all fragments of the instruction are viable.
list<Predicate> Predicates = !cond(
+ !or(!eq(op, "mma.block_scale"),
+ !eq(op, "mma.sp.block_scale")) : [hasSM120a, hasPTX<88>],
+
!or(!eq(ptx_elt_type, "e3m2"),
!eq(ptx_elt_type, "e2m3"),
!eq(ptx_elt_type, "e2m1"),
@@ -4647,9 +4749,9 @@ class WMMA_REGINFO<WMMA_REGS r, string op, string metadata = "", string kind = "
!or(!eq(ptx_elt_type, "e4m3"),
!eq(ptx_elt_type, "e5m2")) : [hasSM<89>, hasPTX<84>],
- !and(!eq(op, "mma.sp"),
+ !and(isSparse,
!ne(metadata, "sp")) : [hasSM<80>, hasPTX<85>],
- !eq(op, "mma.sp") : [hasSM<80>, hasPTX<71>],
+ isSparse : [hasSM<80>, hasPTX<71>],
// fp16 -> fp16/fp32 @ m16n16k16
!and(!eq(geom, "m16n16k16"),
@@ -4998,6 +5100,67 @@ defset list<WMMA_INSTR> MMAs = {
} // defset
}
+// MMA.block_scale
+class MMA_BLOCK_SCALE<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
+ WMMA_REGINFO FragC, WMMA_REGINFO FragD,
+ string Kind, string SType, string ScaleVecSize>
+ : WMMA_INSTR<MMA_BLOCK_SCALE_NAME<Kind, SType, ScaleVecSize,
+ FragA, FragB, FragC, FragD>.record_name,
+ [FragA.Ins, FragB.Ins, FragC.Ins,
+ (ins B32:$scale_a, B16:$byte_id_a,
+ B16:$thread_id_a, B32:$scale_b,
+ B16:$byte_id_b, B16:$thread_id_b)]>,
+ // Requires does not seem to have effect on Instruction w/o Patterns.
+ // We set it here anyways and propagate to the Pat<> we construct below.
+ Requires<FragA.Predicates> {
+ let OutOperandList = FragD.Outs;
+ let InOperandList = !con(Args, (ins MmaCode:$ptx));
+ string TypeList = !interleave([FragD.ptx_elt_type,
+ FragA.ptx_elt_type,
+ FragB.ptx_elt_type,
+ FragC.ptx_elt_type], ".");
+ string ScaleVecSizeStr = !cond(
+ !eq(ScaleVecSize, "") : "",
+ !eq(ScaleVecSize, ".scale_1x") : ".scale_vec::1X",
+ !eq(ScaleVecSize, ".scale_2x") : ".scale_vec::2X",
+ !eq(ScaleVecSize, ".scale_4x") : ".scale_vec::4X"
+ );
+ let AsmString = "mma.sync.aligned."
+ # FragA.geom
+ # ".row.col"
+ # ".kind::" # Kind
+ # ".block_scale"
+ # ScaleVecSizeStr
+ # "." # TypeList
+ # "." # SType # " \n\t\t"
+ # FragD.regstring # ",\n\t\t"
+ # FragA.regstring # ",\n\t\t"
+ # FragB.regstring # ",\n\t\t"
+ # FragC.regstring # ",\n\t\t"
+ # "$scale_a, {{$byte_id_a, $thread_id_a}}" # ",\n\t\t"
+ # "$scale_b, {{$byte_id_b, $thread_id_b}};";
+}
+
+let isConvergent = true in {
+defset list<WMMA_INSTR> MMA_BLOCK_SCALEs = {
+ foreach kind = ["mxf4", "mxf4nvf4", "mxf8f6f4"] in {
+ foreach scale_vec_size = ["", ".scale_1x", ".scale_2x", ".scale_4x"] in {
+ foreach stype = ["ue8m0", "ue4m3"] in {
+ foreach op = NVVM_MMA_OPS.all_mma_block_scale_ops in {
+ if NVVM_MMA_BLOCK_SCALE_SUPPORTED<op, kind, stype, scale_vec_size>.ret then {
+ def : MMA_BLOCK_SCALE<WMMA_REGINFO<op[0], "mma.block_scale", "", kind>,
+ WMMA_REGINFO<op[1], "mma.block_scale", "", kind>,
+ WMMA_REGINFO<op[2], "mma.block_scale", "", kind>,
+ WMMA_REGINFO<op[3], "mma.block_scale", "", kind>,
+ kind, stype, scale_vec_size>;
+ }
+ } // op
+ } // stype
+ } // scale_vec_size
+ } // kind
+} // defset
+}
+
// MMA SP
class MMA_SP<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
WMMA_REGINFO FragC, WMMA_REGINFO FragD,
@@ -5054,6 +5217,72 @@ defset list<WMMA_INSTR> MMA_SPs = {
} // defset
}
+// MMA SP BLOCK SCALE
+class MMA_SP_BLOCK_SCALE<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
+ WMMA_REGINFO FragC, WMMA_REGINFO FragD,
+ string Kind, string SType, string ScaleVecSize>
+ : WMMA_INSTR<MMA_SP_BLOCK_SCALE_NAME<Kind, SType, ScaleVecSize,
+ FragA, FragB, FragC, FragD>.record_name,
+ [FragA.Ins, FragB.Ins, FragC.Ins,
+ (ins B32:$metadata, i32imm:$selector,
+ B32:$scale_a, B16:$byte_id_a, B16:$thread_id_a,
+ B32:$scale_b, B16:$byte_id_b, B16:$thread_id_b)]>,
+ // Requires does not seem to have effect on Instruction w/o Patterns.
+ // We set it here anyways and propagate to the Pat<> we construct below.
+ Requires<!listconcat(FragA.Predicates,
+ FragB.Predicates,
+ FragC.Predicates,
+ FragD.Predicates)> {
+ let OutOperandList = FragD.Outs;
+ let InOperandList = !con(Args, (ins MmaCode:$ptx));
+ string TypeList = "." # FragD.ptx_elt_type
+ # "." # FragA.ptx_elt_type
+ # "." # FragB.ptx_elt_type
+ # "." # FragC.ptx_elt_type;
+ string ScaleVecSizeStr = !cond(
+ !eq(ScaleVecSize, "") : "",
+ !eq(ScaleVecSize, ".scale_1x") : ".scale_vec::1X",
+ !eq(ScaleVecSize, ".scale_2x") : ".scale_vec::2X",
+ !eq(ScaleVecSize, ".scale_4x") : ".scale_vec::4X"
+ );
+ let AsmString = "mma.sp::ordered_metadata.sync.aligned."
+ # FragA.geom
+ # ".row.col"
+ # ".kind::" # Kind
+ # ".block_scale"
+ # ScaleVecSizeStr
+ # TypeList
+ # "." # SType # "\n\t\t"
+ # FragD.regstring # ",\n\t\t"
+ # FragA.regstring # ",\n\t\t"
+ # FragB.regstring # ",\n\t\t"
+ # FragC.regstring # ",\n\t\t"
+ # "$metadata" # ",\n\t\t"
+ # "$selector" # ",\n\t\t"
+ # "$scale_a, {{$byte_id_a, $thread_id_a}}" # ",\n\t\t"
+ # "$scale_b, {{$byte_id_b, $thread_id_b}};";
+}
+
+let isConvergent = true in {
+defset list<WMMA_INSTR> MMA_SP_BLOCK_SCALEs = {
+ foreach kind = ["mxf4", "mxf4nvf4", "mxf8f6f4"] in {
+ foreach scale_vec_size = ["", ".scale_1x", ".scale_2x", ".scale_4x"] in {
+ foreach stype = ["ue8m0", "ue4m3"] in {
+ foreach op = NVVM_MMA_OPS.all_mma_sp_block_scale_ops in {
+ if NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED<op, kind, stype, scale_vec_size>.ret then {
+ def : MMA_SP_BLOCK_SCALE<WMMA_REGINFO<op[0], "mma.sp.block_scale", "sp::ordered_metadata", kind>,
+ WMMA_REGINFO<op[1], "mma.sp.block_scale", "sp::ordered_metadata", kind>,
+ WMMA_REGINFO<op[2], "mma.sp.block_scale", "sp::ordered_metadata", kind>,
+ WMMA_REGINFO<op[3], "mma.sp.block_scale", "sp::ordered_metadata", kind>,
+ kind, stype, scale_vec_size>;
+ }
+ } // op
+ } // stype
+ } // scale_vec_size
+ } // kind
+} // defset
+}
+
//
// ldmatrix.sync.aligned.m8n8[|.trans][|.shared].b16
//
@@ -5135,7 +5364,8 @@ class MMA_PAT<WMMA_INSTR wi>
Requires<wi.Predicates>;
// Build intrinsic->instruction patterns for all MMA instructions.
-foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs, LDMATRIXs, STMATRIXs, MMA_SPs) in
+foreach mma = !listconcat(MMAs, MMA_BLOCK_SCALEs, WMMAs, MMA_LDSTs, LDMATRIXs,
+ STMATRIXs, MMA_SPs, MMA_SP_BLOCK_SCALEs) in
def : MMA_PAT<mma>;
multiclass MAPA<string suffix, Intrinsic Intr> {
@@ -5601,7 +5831,7 @@ class Tcgen05MMADisableOutputLaneSDNode<bit Sp, string ASpace,
# "_DISABLE_OUTPUT_LANE_CG" # CtaGroup
# !if(!eq(AShift, 1), "_ASHIFT", ""),
Tcgen05MMADisableOutputLaneTypeProfile<Sp, ASpace, CtaGroup, ScaleInput>,
- [SDNPHasChain, SDNPSideEffect]>;
+ [SDNPHasChain, SDNPSideEffect, SDNPMemOperand]>;
class Tcgen05MMADisableOutputLaneInst<bit Sp, string ASpace,
string Kind, int CtaGroup, string CollectorUsageStr,