aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorweiwei chen <weiwei.chen@modular.com>2024-06-24 22:15:58 -0400
committerGitHub <noreply@github.com>2024-06-24 22:15:58 -0400
commitb0e9b00ce7d623175c5e60e82afe24e7f8a200be (patch)
tree0661e08ecd34a9c2adb6b0b48adb24f57e3d49ca
parent7ea63b9db4198688873036f3b0b81f9124076f7a (diff)
downloadllvm-b0e9b00ce7d623175c5e60e82afe24e7f8a200be.zip
llvm-b0e9b00ce7d623175c5e60e82afe24e7f8a200be.tar.gz
llvm-b0e9b00ce7d623175c5e60e82afe24e7f8a200be.tar.bz2
[NVPTX] Make nvptx mma instructions convergent. (#96521)
We are running into NVPTX backend generating wrong code for an input: ``` %0 = llvm.nvvm.mma.m?n?k?.row.col.??? (...) if laneid == 0: ret else: store %0 ``` The backend reorder the instruction (as an effect of `MachineSink` pass) to ``` if laneid == 0: ret else: %0 = llvm.nvvm.mma.m?n?k?.row.col.??? (...) store %0 ``` This is incorrect because `mma` is a warp instruction which needs all threads to sync before performing the operation instead of being guarded by a specific thread id. It should be similar as the shuffle instruction `shfl` in terms of warp level sync, and `shfl` is marked as `isConvergent = true`. Apply `isConvergent = true` to `mma` instructions.
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXIntrinsics.td4
-rw-r--r--llvm/test/CodeGen/NVPTX/mma-no-sink-after-laneid-check.ll26
2 files changed, 30 insertions, 0 deletions
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index c050905..c81dfa6 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -6725,6 +6725,7 @@ class WMMA_MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
# FragC.regstring # ";";
}
+let isConvergent = true in {
defset list<WMMA_INSTR> WMMAs = {
foreach layout_a = ["row", "col"] in {
foreach layout_b = ["row", "col"] in {
@@ -6746,6 +6747,7 @@ defset list<WMMA_INSTR> WMMAs = {
} // layout_b
} // layout_a
} // defset
+}
// MMA
class MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
@@ -6775,6 +6777,7 @@ class MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
# FragC.regstring # ";";
}
+let isConvergent = true in {
defset list<WMMA_INSTR> MMAs = {
foreach layout_a = ["row", "col"] in {
foreach layout_b = ["row", "col"] in {
@@ -6794,6 +6797,7 @@ defset list<WMMA_INSTR> MMAs = {
} // layout_b
} // layout_a
} // defset
+}
//
// ldmatrix.sync.aligned.m8n8[|.trans][|.shared].b16
diff --git a/llvm/test/CodeGen/NVPTX/mma-no-sink-after-laneid-check.ll b/llvm/test/CodeGen/NVPTX/mma-no-sink-after-laneid-check.ll
new file mode 100644
index 0000000..a88bc4d
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/mma-no-sink-after-laneid-check.ll
@@ -0,0 +1,26 @@
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_80 -mattr=+ptx81 | FileCheck %s
+
+declare { float, float, float, float } @llvm.nvvm.mma.m16n8k4.row.col.tf32(i32, i32, i32, float, float, float, float) #1
+
+declare noundef i32 @llvm.nvvm.read.ptx.sreg.laneid() #0
+
+; COM: llvm.nvvm.mma should not sink to the next block and gets reordered to be after laneid check.
+; CHECK-LABEL: no_reorder_mma_and_laneid_check
+define dso_local void @no_reorder_mma_and_laneid_check(ptr %arg, ptr %arg1) {
+bb:
+ ; CHECK: mma.sync.aligned.m16n8k4.row.col.f32.tf32.tf32.f32
+ ; CHECK: laneid
+ %i = tail call { float, float, float, float } @llvm.nvvm.mma.m16n8k4.row.col.tf32(i32 10, i32 10, i32 8, float 0.0, float 0.0, float 0.0, float 0.0)
+ %i3 = tail call i32 @llvm.nvvm.read.ptx.sreg.laneid()
+ %i4 = icmp eq i32 %i3, 0
+ br i1 %i4, label %bb5, label %bb8
+
+bb5: ; preds = %bb
+ %i6 = extractvalue { float, float, float, float } %i, 0
+ %i7 = getelementptr float, ptr %arg, i64 0
+ store float %i6, ptr %i7, align 4
+ br label %bb8
+
+bb8: ; preds = %bb5, %bb
+ ret void
+}