aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGuray Ozen <guray.ozen@gmail.com>2025-05-15 15:12:09 +0200
committerGitHub <noreply@github.com>2025-05-15 15:12:09 +0200
commitd08b176edc82b8fbd417e0ead7b77abc5f45fce2 (patch)
tree2c0f18ae8abcb3b9587f0ef0257e6d9240ad3cba
parent6fc031291955d5d3c69f7df9b9f7460c473d114a (diff)
downloadllvm-d08b176edc82b8fbd417e0ead7b77abc5f45fce2.zip
llvm-d08b176edc82b8fbd417e0ead7b77abc5f45fce2.tar.gz
llvm-d08b176edc82b8fbd417e0ead7b77abc5f45fce2.tar.bz2
[MLIR][NVVM] Add `inline_ptx` op (#139923)
This op allows using PTX directly within the NVVM dialect, while greatly simplifying llvm.inline_asm generation. **Example 1: Read-only Parameters** Sets `"l,r"` automatically. ``` nvvm.inline_ptx "mbarrier.init.b64 [$0], $1;" (%barrier_gen, %count) : !llvm.ptr, i32 // Lowers to: llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.init.b64 [$0], $1;", "l,r" %arg0, %arg2 : (!llvm.ptr, i32) -> () ``` **Example 2: Read-only and Write-only Parameters** Sets `=f,f"` automatically. `=` is set because there is store. ``` %0 = nvvm.inline_ptx "ex2.approx.ftz.f32 $0, $1;" (%input) : f32 -> f32 // Lowers to: %0 = llvm.inline_asm has_side_effects asm_dialect = att "ex2.approx.ftz.f32 $0, $1;", "=f,f" %arg0 : (f32) -> f32 ``` **Example 3: Predicate Usage** Now `@$2` is set automatically for predication. ``` nvvm.inline_ptx "mbarrier.init.b64 [$0], $1;" (%barrier_gen, %count), predicate = %pred : !llvm.ptr, i32, i1 // Lowers to: llvm.inline_asm has_side_effects asm_dialect = att "@$2 mbarrier.init.b64 [$0], $1;", "l,r,b" %arg0, %arg2, %arg3 : (!llvm.ptr, i32, i1) -> () ``` --------- Co-authored-by: Mehdi Amini <joker.eph@gmail.com>
-rw-r--r--mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td70
-rw-r--r--mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir25
2 files changed, 95 insertions, 0 deletions
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 654aff7..a8e7dcb 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -237,6 +237,76 @@ foreach index = !range(0, 32) in {
}
//===----------------------------------------------------------------------===//
+// Inline PTX op definition
+//===----------------------------------------------------------------------===//
+
+def NVVM_InlinePtxOp : NVVM_Op<"inline_ptx",
+ [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>,
+ AttrSizedOperandSegments]>
+{
+ let summary = "Inline PTX Op";
+ let description = [{This op allows using PTX directly within the NVVM
+ dialect, while greatly simplifying llvm.inline_asm generation. It
+ automatically handles register size selection and sets the correct
+ read/write access for each operand. The operation leverages the
+ `BasicPtxBuilderInterface` to abstract away low-level details of
+ PTX assembly formatting.
+
+ The `predicate` attribute is used to specify a predicate for the
+ PTX instruction.
+
+ Example 1: Read-only Parameters
+ ```mlir
+ nvvm.inline_ptx "mbarrier.init.b64 [$0], $1;" (%barrier_gen, %count) : !llvm.ptr, i32
+
+ // Lowers to:
+ llvm.inline_asm has_side_effects asm_dialect = att
+ "mbarrier.init.b64 [$0], $1;", "l,r" %arg0, %arg2 : (!llvm.ptr, i32) -> ()
+ ```
+
+ Example 2: Read-only and Write-only Parameters
+ ```mlir
+ %0 = nvvm.inline_ptx "ex2.approx.ftz.f32 $0, $1;" (%input) : f32 -> f32
+
+ // Lowers to:
+ %0 = llvm.inline_asm has_side_effects asm_dialect = att
+ "ex2.approx.ftz.f32 $0, $1;", "=f,f" %arg0 : (f32) -> f32
+ ```
+
+ Example 3: Predicate Usage
+ ```mlir
+ nvvm.inline_ptx "mbarrier.init.b64 [$0], $1;" (%barrier_gen, %count),
+ predicate = %pred : !llvm.ptr, i32, i1
+
+ // Lowers to:
+ llvm.inline_asm has_side_effects asm_dialect = att
+ "@$2 mbarrier.init.b64 [$0], $1;", "l,r,b" %arg0, %arg2, %arg3
+ : (!llvm.ptr, i32, i1) -> ()
+ ```
+ }];
+
+ let arguments = (ins Variadic<AnyType>:$readOnlyArgs,
+ StrAttr:$ptxCode,
+ PtxPredicate:$predicate);
+
+ let results = (outs Variadic<AnyType>:$writeOnlyArgs);
+
+ let assemblyFormat = [{
+ $ptxCode `(` $readOnlyArgs `)`
+ (`,` `predicate` `=` $predicate^)? attr-dict
+ `:` type(operands)
+ (`->` type($writeOnlyArgs)^)?
+ }];
+
+ let extraClassDefinition = [{
+ std::string $cppClass::getPtx() {
+ StringRef ptxInstStr = getPtxCode();
+ return std::string(ptxInstStr.data());
+ }
+ }];
+}
+
+//===----------------------------------------------------------------------===//
// NVVM approximate op definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index c7a6eca..8d720ce 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -680,3 +680,28 @@ llvm.func @llvm_nvvm_barrier_arrive(%barID : i32, %numberOfThreads : i32) {
nvvm.barrier.arrive id = %barID number_of_threads = %numberOfThreads
llvm.return
}
+
+
+// -----
+
+llvm.func @init_mbarrier(
+ %barrier_gen : !llvm.ptr,
+ %barrier : !llvm.ptr<3>,
+ %count : i32,
+ %pred : i1) {
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.init.b64 [$0], $1;", "l,r"
+ nvvm.inline_ptx "mbarrier.init.b64 [$0], $1;" (%barrier_gen, %count) : !llvm.ptr, i32
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$2 mbarrier.init.b64 [$0], $1;", "l,r,b"
+ nvvm.inline_ptx "mbarrier.init.b64 [$0], $1;" (%barrier_gen, %count), predicate = %pred : !llvm.ptr, i32, i1
+ llvm.return
+}
+// -----
+
+llvm.func @ex2(%input : f32, %pred : i1) {
+ // CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att "ex2.approx.ftz.f32 $0, $1;", "=f,f" %{{.*}} : (f32) -> f32
+ %0 = nvvm.inline_ptx "ex2.approx.ftz.f32 $0, $1;" (%input) : f32 -> f32
+
+ // CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att "@$1 ex2.approx.ftz.f32 $0, $1;", "=f,f,b" %{{.*}}, %{{.*}} : (f32, i1) -> f32
+ %1 = nvvm.inline_ptx "ex2.approx.ftz.f32 $0, $1;" (%input), predicate = %pred : f32, i1 -> f32
+ llvm.return
+}