diff options
author | Guray Ozen <guray.ozen@gmail.com> | 2023-11-10 16:53:43 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-11-10 16:53:43 +0100 |
commit | 51916f0c924f2ed4e970dd043a14d70b6b1d3f71 (patch) | |
tree | 9519555b47f418b1c6a0bab4d375c4f078d70dd4 /llvm/lib/Support/CommandLine.cpp | |
parent | a00caad6bf318a7497d477b434464ca75ecb41fc (diff) | |
download | llvm-51916f0c924f2ed4e970dd043a14d70b6b1d3f71.zip llvm-51916f0c924f2ed4e970dd043a14d70b6b1d3f71.tar.gz llvm-51916f0c924f2ed4e970dd043a14d70b6b1d3f71.tar.bz2 |
[mlir] Add sm_90a GEMM test 128x128x128 (F32 += F16 * F16) (#69913)
This PR adds a test that performs GEMM 128x128x128 (F32 += F16 * F16).
It uses `sm_90a` features in NVGPU dialect.
Simplified algorithm is as follows:
**Prologue**
```
mgroup = mbarriers.init x 2
tma.load ... shmem_buffer_lhs<0 x 128 x 64>
tma.load ... shmem_buffer_rhs<0 x 64 x 64>
tma.load ... shmem_buffer_rhs<0 x 64 x 64>
mbarrier.expect_tx 32768
tma.load ... shmem_buffer_lhs<1 x 128 x 64>
tma.load ... shmem_buffer_rhs<1 x 64 x 64>
tma.load ... shmem_buffer_rhs<1 x 64 x 64>
mbarrier.expect_tx 32768
```
**Mainloop**
```
matrixD =
for(i = 0;...2) {
mbarrier.try_wait [i]
lhs = shmem_buffer_lhs<pipe x 128 x 64>
rhs = shmem_buffer_rhs<pipe x 64 x 128>
yield nvgpu.warpgroup.mma (lhs, rhs)
// Expanded : nvgpu.warpgroup.mma [128][128]+=[128][64]*[64][128]
// wgmma.m64n128k16(A[0:64][0:16] * B[0:16][0:128])
// wgmma.m64n128k16(A[0:64][16:32] * B[16:32][0:128])
// wgmma.m64n128k16(A[0:64][32:48] * B[32:48][0:128])
// wgmma.m64n128k16(A[0:64][48:64] * B[48:64][0:128])
// wgmma.m64n128k16(A[64:128][0:16] * B[0:16][0:128])
// wgmma.m64n128k16(A[64:128][16:32] * B[16:32][0:128])
// wgmma.m64n128k16(A[64:128][32:48] * B[32:48][0:128])
// wgmma.m64n128k16(A[64:128][48:64] * B[48:64][0:128])
```
**Epilogue**
```
//reg->shmem
warpgroup.mma.store matrixD, shmem
//shmem->glbmem
parallel-for(i=0;...128)
parallel-for(j=0;...128)
store shmem, globalmem
```
Diffstat (limited to 'llvm/lib/Support/CommandLine.cpp')
0 files changed, 0 insertions, 0 deletions