aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorarthurqiu <arthurq@nvidia.com>2024-11-21 01:31:01 +0800
committerGitHub <noreply@github.com>2024-11-20 18:31:01 +0100
commit81055ff070e128bff78c8fa2d8ffe4c92ae692a6 (patch)
treee18635ab16808bee7350b42e53294f9c87241837
parent0733f384142b02558b80b3e9a4633dc4d202a14b (diff)
downloadllvm-81055ff070e128bff78c8fa2d8ffe4c92ae692a6.zip
llvm-81055ff070e128bff78c8fa2d8ffe4c92ae692a6.tar.gz
llvm-81055ff070e128bff78c8fa2d8ffe4c92ae692a6.tar.bz2
[mlir][nvvm] Add attributes for cluster dimension PTX directives (#116973)
PTX programming models provides cluster dimension directives, which are leveraged by the downstream `ptxas` compiler. See https://docs.nvidia.com/cuda/nvvm-ir-spec/#supported-properties and https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#cluster-dimension-directives This PR introduces the cluster dimension directives to MLIR's NVVM dialect as listed below: ``` cluster_dim_{x,y,z} -> exact number of CTAs per cluster cluster_max_blocks -> max number of CTAs per cluster ```
-rw-r--r--mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td12
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp12
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp14
-rw-r--r--mlir/test/Target/LLVMIR/nvvmir.mlir22
4 files changed, 56 insertions, 4 deletions
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 6b462de..296a3c3 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -53,6 +53,18 @@ def NVVM_Dialect : Dialect {
static StringRef getReqntidYName() { return "reqntidy"; }
static StringRef getReqntidZName() { return "reqntidz"; }
+ /// Get the name of the attribute used to annotate exact CTAs required
+ /// per cluster for kernel functions.
+ static StringRef getClusterDimAttrName() { return "nvvm.cluster_dim"; }
+ /// Get the name of the metadata names for each dimension
+ static StringRef getClusterDimXName() { return "cluster_dim_x"; }
+ static StringRef getClusterDimYName() { return "cluster_dim_y"; }
+ static StringRef getClusterDimZName() { return "cluster_dim_z"; }
+
+ /// Get the name of the attribute used to annotate maximum number of
+ /// CTAs per cluster for kernel functions.
+ static StringRef getClusterMaxBlocksAttrName() { return "nvvm.cluster_max_blocks"; }
+
/// Get the name of the attribute used to annotate min CTA required
/// per SM for kernel functions.
static StringRef getMinctasmAttrName() { return "nvvm.minctasm"; }
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index d28194d..ca04af0 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1126,18 +1126,22 @@ LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,
<< "' attribute attached to unexpected op";
}
}
- // If maxntid and reqntid exist, it must be an array with max 3 dim
+ // If maxntid / reqntid / cluster_dim exist, it must be an array with max 3
+ // dim
if (attrName == NVVMDialect::getMaxntidAttrName() ||
- attrName == NVVMDialect::getReqntidAttrName()) {
+ attrName == NVVMDialect::getReqntidAttrName() ||
+ attrName == NVVMDialect::getClusterDimAttrName()) {
auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.getValue());
if (!values || values.empty() || values.size() > 3)
return op->emitError()
<< "'" << attrName
<< "' attribute must be integer array with maximum 3 index";
}
- // If minctasm and maxnreg exist, it must be an integer attribute
+ // If minctasm / maxnreg / cluster_max_blocks exist, it must be an integer
+ // attribute
if (attrName == NVVMDialect::getMinctasmAttrName() ||
- attrName == NVVMDialect::getMaxnregAttrName()) {
+ attrName == NVVMDialect::getMaxnregAttrName() ||
+ attrName == NVVMDialect::getClusterMaxBlocksAttrName()) {
if (!llvm::dyn_cast<IntegerAttr>(attr.getValue()))
return op->emitError()
<< "'" << attrName << "' attribute must be integer constant";
diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
index 9cc6620..cf58bc5 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
@@ -215,6 +215,20 @@ public:
if (values.size() > 2)
generateMetadata(values[2], NVVM::NVVMDialect::getReqntidZName());
} else if (attribute.getName() ==
+ NVVM::NVVMDialect::getClusterDimAttrName()) {
+ if (!dyn_cast<DenseI32ArrayAttr>(attribute.getValue()))
+ return failure();
+ auto values = cast<DenseI32ArrayAttr>(attribute.getValue());
+ generateMetadata(values[0], NVVM::NVVMDialect::getClusterDimXName());
+ if (values.size() > 1)
+ generateMetadata(values[1], NVVM::NVVMDialect::getClusterDimYName());
+ if (values.size() > 2)
+ generateMetadata(values[2], NVVM::NVVMDialect::getClusterDimZName());
+ } else if (attribute.getName() ==
+ NVVM::NVVMDialect::getClusterMaxBlocksAttrName()) {
+ auto value = dyn_cast<IntegerAttr>(attribute.getValue());
+ generateMetadata(value.getInt(), "cluster_max_blocks");
+ } else if (attribute.getName() ==
NVVM::NVVMDialect::getMinctasmAttrName()) {
auto value = dyn_cast<IntegerAttr>(attribute.getValue());
generateMetadata(value.getInt(), "minctasm");
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index e5ea03f..a4a3581 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -586,6 +586,28 @@ llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.reqntid = array<i32: 1, 2
// CHECK: {ptr @kernel_func, !"reqntidz", i32 32}
// -----
+llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.cluster_dim = array<i32: 3, 5, 7>} {
+ llvm.return
+}
+
+// CHECK: !nvvm.annotations =
+// CHECK-NOT: {ptr @nvvm_special_regs, !"kernel", i32 1}
+// CHECK: {ptr @kernel_func, !"cluster_dim_x", i32 3}
+// CHECK: {ptr @kernel_func, !"cluster_dim_y", i32 5}
+// CHECK: {ptr @kernel_func, !"cluster_dim_z", i32 7}
+// CHECK: {ptr @kernel_func, !"kernel", i32 1}
+// -----
+
+llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.cluster_max_blocks = 8} {
+ llvm.return
+}
+
+// CHECK: !nvvm.annotations =
+// CHECK-NOT: {ptr @nvvm_special_regs, !"kernel", i32 1}
+// CHECK: {ptr @kernel_func, !"cluster_max_blocks", i32 8}
+// CHECK: {ptr @kernel_func, !"kernel", i32 1}
+// -----
+
llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.minctasm = 16} {
llvm.return
}