aboutsummaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
authorDurgadoss R <durgadossr@nvidia.com>2024-01-22 13:09:30 +0530
committerGitHub <noreply@github.com>2024-01-22 08:39:30 +0100
commitaa4547fcc8eeb9bf4f3cf48cc926f62544e58767 (patch)
treee40ff4a16f3c9bd4b7d787aab377ee7be68ffa03 /mlir
parent12c241b3654800ab708607dbc1998975c893fc14 (diff)
downloadllvm-aa4547fcc8eeb9bf4f3cf48cc926f62544e58767.zip
llvm-aa4547fcc8eeb9bf4f3cf48cc926f62544e58767.tar.gz
llvm-aa4547fcc8eeb9bf4f3cf48cc926f62544e58767.tar.bz2
[MLIR][NVVM] Update cp.async.bulk Ops to use intrinsics (#78900)
This patch updates the cp.async.bulk.{commit/wait}_group Ops to use NVVM intrinsics. * Doc updated for the commit_group Op. * Tests are added to verify the lowering to the intrinsics. While we are there, fix the FileCheck directive on the 'nvvm.setmaxregister' test. Signed-off-by: Durgadoss R <durgadossr@nvidia.com>
Diffstat (limited to 'mlir')
-rw-r--r--mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td30
-rw-r--r--mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir18
-rw-r--r--mlir/test/Target/LLVMIR/nvvmir.mlir24
3 files changed, 47 insertions, 25 deletions
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index b1bd3a9..37e525a 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1591,19 +1591,26 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
// NVVM TMA Ops
//===----------------------------------------------------------------------===//
-def NVVM_CpAsyncBulkCommitGroupOp : NVVM_PTXBuilder_Op<"cp.async.bulk.commit.group">,
+def NVVM_CpAsyncBulkCommitGroupOp : NVVM_Op<"cp.async.bulk.commit.group">,
Arguments<(ins )> {
let assemblyFormat = "attr-dict";
- let extraClassDefinition = [{
- std::string $cppClass::getPtx() { return std::string("cp.async.bulk.commit_group;"); }
+ let description = [{
+ This Op commits all prior initiated but uncommitted cp.async.bulk
+ instructions into a cp.async.bulk-group.
+
+ [For more information, see PTX ISA]
+ (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-commit-group)
+ }];
+
+ string llvmBuilder = [{
+ createIntrinsicCall(builder, llvm::Intrinsic::nvvm_cp_async_bulk_commit_group);
}];
}
-def NVVM_CpAsyncBulkWaitGroupOp : NVVM_PTXBuilder_Op<"cp.async.bulk.wait_group">,
+def NVVM_CpAsyncBulkWaitGroupOp : NVVM_Op<"cp.async.bulk.wait_group">,
Arguments<(ins
ConfinedAttr<I32Attr, [IntMinValue<0>]>:$group,
- OptionalAttr<UnitAttr>:$read)>
-{
+ OptionalAttr<UnitAttr>:$read)> {
let assemblyFormat = "$group attr-dict";
let description = [{
Op waits for completion of the most recent bulk async-groups.
@@ -1620,15 +1627,14 @@ def NVVM_CpAsyncBulkWaitGroupOp : NVVM_PTXBuilder_Op<"cp.async.bulk.wait_group">
(https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group)
}];
- let extraClassDefinition = [{
- std::string $cppClass::getPtx() {
- auto ptx = std::string("cp.async.bulk.wait_group");
- if(getRead()) ptx += ".read";
- ptx += " %0;"; return ptx; }
+ string llvmBuilder = [{
+ auto intId = op.getRead() ?
+ llvm::Intrinsic::nvvm_cp_async_bulk_wait_group_read :
+ llvm::Intrinsic::nvvm_cp_async_bulk_wait_group;
+ createIntrinsicCall(builder, intId, builder.getInt32($group));
}];
}
-
def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp :
NVVM_Op<"cp.async.bulk.tensor.shared.cluster.global",
[DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>,
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 9c7c27c..0ac7331 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -638,23 +638,19 @@ func.func @set_max_register() {
// -----
-func.func @cp_bulk_commit() {
- //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.commit_group;"
+func.func @cp_async_bulk_commit() {
+ // CHECK: nvvm.cp.async.bulk.commit.group
nvvm.cp.async.bulk.commit.group
func.return
}
// -----
-func.func @cp_bulk_wait_group() {
- // CHECK: %[[S0:.+]] = llvm.mlir.constant(1 : i32) : i32
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.wait_group $0;", "n" %[[S0]] : (i32) -> ()
- // CHECK: %[[S1:.+]] = llvm.mlir.constant(0 : i32) : i32
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.wait_group $0;", "n" %[[S1]] : (i32) -> ()
- // CHECK: %[[S2:.+]] = llvm.mlir.constant(5 : i32) : i32
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.wait_group.read $0;", "n" %[[S2]] : (i32) -> ()
- // CHECK: %[[S3:.+]] = llvm.mlir.constant(0 : i32) : i32
- // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.wait_group.read $0;", "n" %[[S3]] : (i32) -> ()
+func.func @cp_async_bulk_wait_group() {
+ // CHECK: nvvm.cp.async.bulk.wait_group 1
+ // CHECK: nvvm.cp.async.bulk.wait_group 0
+ // CHECK: nvvm.cp.async.bulk.wait_group 5 {read}
+ // CHECK: nvvm.cp.async.bulk.wait_group 0 {read}
nvvm.cp.async.bulk.wait_group 1
nvvm.cp.async.bulk.wait_group 0
nvvm.cp.async.bulk.wait_group 5 {read}
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 8c5e352..49f9426 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -398,13 +398,33 @@ llvm.func @cp_async_mbarrier_arrive(%bar_shared: !llvm.ptr<3>, %bar_gen: !llvm.p
// CHECK-LABEL: @llvm_nvvm_setmaxregister
llvm.func @llvm_nvvm_setmaxregister() {
- // CHECK-LLVM: call void @llvm.nvvm.setmaxnreg.inc.sync.aligned.u32(i32 256)
+ // CHECK: call void @llvm.nvvm.setmaxnreg.inc.sync.aligned.u32(i32 256)
nvvm.setmaxregister increase 256
- // CHECK-LLVM: call void @llvm.nvvm.setmaxnreg.dec.sync.aligned.u32(i32 24)
+ // CHECK: call void @llvm.nvvm.setmaxnreg.dec.sync.aligned.u32(i32 24)
nvvm.setmaxregister decrease 24
llvm.return
}
+// CHECK-LABEL: @llvm_nvvm_cp_async_bulk_commit_group
+llvm.func @llvm_nvvm_cp_async_bulk_commit_group() {
+ // CHECK: call void @llvm.nvvm.cp.async.bulk.commit.group()
+ nvvm.cp.async.bulk.commit.group
+ llvm.return
+}
+
+// CHECK-LABEL: @llvm_nvvm_cp_async_bulk_wait_group
+llvm.func @llvm_nvvm_cp_async_bulk_wait_group() {
+ // CHECK: call void @llvm.nvvm.cp.async.bulk.wait.group(i32 0)
+ nvvm.cp.async.bulk.wait_group 0
+ // CHECK: call void @llvm.nvvm.cp.async.bulk.wait.group(i32 3)
+ nvvm.cp.async.bulk.wait_group 3
+ // CHECK: call void @llvm.nvvm.cp.async.bulk.wait.group.read(i32 0)
+ nvvm.cp.async.bulk.wait_group 0 {read}
+ // CHECK: call void @llvm.nvvm.cp.async.bulk.wait.group.read(i32 3)
+ nvvm.cp.async.bulk.wait_group 3 {read}
+ llvm.return
+}
+
// CHECK-LABEL: @ld_matrix
llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
// CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.b16.p3(ptr addrspace(3) %{{.*}})