diff options
author | arthurqiu <arthurq@nvidia.com> | 2024-11-21 01:31:01 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-11-20 18:31:01 +0100 |
commit | 81055ff070e128bff78c8fa2d8ffe4c92ae692a6 (patch) | |
tree | e18635ab16808bee7350b42e53294f9c87241837 | |
parent | 0733f384142b02558b80b3e9a4633dc4d202a14b (diff) | |
download | llvm-81055ff070e128bff78c8fa2d8ffe4c92ae692a6.zip llvm-81055ff070e128bff78c8fa2d8ffe4c92ae692a6.tar.gz llvm-81055ff070e128bff78c8fa2d8ffe4c92ae692a6.tar.bz2 |
[mlir][nvvm] Add attributes for cluster dimension PTX directives (#116973)
PTX programming models provides cluster dimension directives, which are
leveraged by the downstream `ptxas` compiler. See
https://docs.nvidia.com/cuda/nvvm-ir-spec/#supported-properties and
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#cluster-dimension-directives
This PR introduces the cluster dimension directives to MLIR's NVVM
dialect as listed below:
```
cluster_dim_{x,y,z} -> exact number of CTAs per cluster
cluster_max_blocks -> max number of CTAs per cluster
```
-rw-r--r-- | mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 12 | ||||
-rw-r--r-- | mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 12 | ||||
-rw-r--r-- | mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp | 14 | ||||
-rw-r--r-- | mlir/test/Target/LLVMIR/nvvmir.mlir | 22 |
4 files changed, 56 insertions, 4 deletions
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index 6b462de..296a3c3 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -53,6 +53,18 @@ def NVVM_Dialect : Dialect { static StringRef getReqntidYName() { return "reqntidy"; } static StringRef getReqntidZName() { return "reqntidz"; } + /// Get the name of the attribute used to annotate exact CTAs required + /// per cluster for kernel functions. + static StringRef getClusterDimAttrName() { return "nvvm.cluster_dim"; } + /// Get the name of the metadata names for each dimension + static StringRef getClusterDimXName() { return "cluster_dim_x"; } + static StringRef getClusterDimYName() { return "cluster_dim_y"; } + static StringRef getClusterDimZName() { return "cluster_dim_z"; } + + /// Get the name of the attribute used to annotate maximum number of + /// CTAs per cluster for kernel functions. + static StringRef getClusterMaxBlocksAttrName() { return "nvvm.cluster_max_blocks"; } + /// Get the name of the attribute used to annotate min CTA required /// per SM for kernel functions. static StringRef getMinctasmAttrName() { return "nvvm.minctasm"; } diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index d28194d..ca04af0 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -1126,18 +1126,22 @@ LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op, << "' attribute attached to unexpected op"; } } - // If maxntid and reqntid exist, it must be an array with max 3 dim + // If maxntid / reqntid / cluster_dim exist, it must be an array with max 3 + // dim if (attrName == NVVMDialect::getMaxntidAttrName() || - attrName == NVVMDialect::getReqntidAttrName()) { + attrName == NVVMDialect::getReqntidAttrName() || + attrName == NVVMDialect::getClusterDimAttrName()) { auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.getValue()); if (!values || values.empty() || values.size() > 3) return op->emitError() << "'" << attrName << "' attribute must be integer array with maximum 3 index"; } - // If minctasm and maxnreg exist, it must be an integer attribute + // If minctasm / maxnreg / cluster_max_blocks exist, it must be an integer + // attribute if (attrName == NVVMDialect::getMinctasmAttrName() || - attrName == NVVMDialect::getMaxnregAttrName()) { + attrName == NVVMDialect::getMaxnregAttrName() || + attrName == NVVMDialect::getClusterMaxBlocksAttrName()) { if (!llvm::dyn_cast<IntegerAttr>(attr.getValue())) return op->emitError() << "'" << attrName << "' attribute must be integer constant"; diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp index 9cc6620..cf58bc5 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp @@ -215,6 +215,20 @@ public: if (values.size() > 2) generateMetadata(values[2], NVVM::NVVMDialect::getReqntidZName()); } else if (attribute.getName() == + NVVM::NVVMDialect::getClusterDimAttrName()) { + if (!dyn_cast<DenseI32ArrayAttr>(attribute.getValue())) + return failure(); + auto values = cast<DenseI32ArrayAttr>(attribute.getValue()); + generateMetadata(values[0], NVVM::NVVMDialect::getClusterDimXName()); + if (values.size() > 1) + generateMetadata(values[1], NVVM::NVVMDialect::getClusterDimYName()); + if (values.size() > 2) + generateMetadata(values[2], NVVM::NVVMDialect::getClusterDimZName()); + } else if (attribute.getName() == + NVVM::NVVMDialect::getClusterMaxBlocksAttrName()) { + auto value = dyn_cast<IntegerAttr>(attribute.getValue()); + generateMetadata(value.getInt(), "cluster_max_blocks"); + } else if (attribute.getName() == NVVM::NVVMDialect::getMinctasmAttrName()) { auto value = dyn_cast<IntegerAttr>(attribute.getValue()); generateMetadata(value.getInt(), "minctasm"); diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir index e5ea03f..a4a3581 100644 --- a/mlir/test/Target/LLVMIR/nvvmir.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir.mlir @@ -586,6 +586,28 @@ llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.reqntid = array<i32: 1, 2 // CHECK: {ptr @kernel_func, !"reqntidz", i32 32} // ----- +llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.cluster_dim = array<i32: 3, 5, 7>} { + llvm.return +} + +// CHECK: !nvvm.annotations = +// CHECK-NOT: {ptr @nvvm_special_regs, !"kernel", i32 1} +// CHECK: {ptr @kernel_func, !"cluster_dim_x", i32 3} +// CHECK: {ptr @kernel_func, !"cluster_dim_y", i32 5} +// CHECK: {ptr @kernel_func, !"cluster_dim_z", i32 7} +// CHECK: {ptr @kernel_func, !"kernel", i32 1} +// ----- + +llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.cluster_max_blocks = 8} { + llvm.return +} + +// CHECK: !nvvm.annotations = +// CHECK-NOT: {ptr @nvvm_special_regs, !"kernel", i32 1} +// CHECK: {ptr @kernel_func, !"cluster_max_blocks", i32 8} +// CHECK: {ptr @kernel_func, !"kernel", i32 1} +// ----- + llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.minctasm = 16} { llvm.return } |