aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target/AMDGPU/SISchedule.td
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/AMDGPU/SISchedule.td')
-rw-r--r--llvm/lib/Target/AMDGPU/SISchedule.td15
1 files changed, 15 insertions, 0 deletions
diff --git a/llvm/lib/Target/AMDGPU/SISchedule.td b/llvm/lib/Target/AMDGPU/SISchedule.td
index ef8faff..8eecb1c 100644
--- a/llvm/lib/Target/AMDGPU/SISchedule.td
+++ b/llvm/lib/Target/AMDGPU/SISchedule.td
@@ -464,6 +464,20 @@ def : InstRW<[WriteCopy], (instrs COPY)>;
} // End SchedModel = GFX12SpeedModel
+// Check if any matrix inputs are interpreted as f8 in an f8f6f4
+// wmma instruction.
+def PredIsF8_WMMA_SCALE : SchedPredicate<[{
+ TII->getNamedOperand(*MI, AMDGPU::OpName::matrix_a_fmt)->getImm() <= AMDGPU::WMMA::MATRIX_FMT_BF8 ||
+ TII->getNamedOperand(*MI, AMDGPU::OpName::matrix_b_fmt)->getImm() <= AMDGPU::WMMA::MATRIX_FMT_BF8
+}]>;
+
+// If either matrix format is f8, the instruction takes 2x as many
+// cycles. TODO: This isn't reflected in MCA.
+def WriteWMMAScale_16X16X128_F8F6F4 : SchedWriteVariant<[
+ SchedVar<PredIsF8_WMMA_SCALE, [WriteXDL4PassWMMA]>,
+ SchedVar<NoSchedPred, [WriteXDL2PassWMMA]>
+]>;
+
multiclass GFX125xCommonWriteRes {
let ReleaseAtCycles = [8] in
@@ -495,6 +509,7 @@ def : InstRW<[WriteCopy], (instrs COPY)>;
def : InstRW<[WriteXDL2PassWMMA], (instregex "^V_[S]*WMMA[C]*_.*_(FP8|BF8|BF16|F16)_w32")>;
def : InstRW<[WriteXDL4PassWMMA], (instregex "^V_[S]*WMMA[C]*_.*_(IU8|IU4)_w32")>;
+def : InstRW<[WriteWMMAScale_16X16X128_F8F6F4], (instregex "^V_WMMA_.*_16X16X128_F8F6F4.*_w32")>;
def : InstRW<[Write4PassWMMA], (instregex "^V_WMMA_F32_16X16X4_F32_w32")>;
def : InstRW<[WriteXDL2PassWMMA], (instregex "^V_WMMA.*_F32_32X16X128_F4")>;
} // End GFX125xCommonWriteRes