aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp11
-rw-r--r--mlir/test/Dialect/LLVMIR/inlining-rocdl.mlir14
2 files changed, 25 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
index 17371ec..6d54bb6 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp
@@ -23,6 +23,7 @@
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
+#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
@@ -180,6 +181,15 @@ void RawBufferAtomicUMinOp::print(mlir::OpAsmPrinter &p) {
// ROCDLDialect initialization, type parsing, and registration.
//===----------------------------------------------------------------------===//
+namespace {
+struct ROCDLInlinerInterface final : DialectInlinerInterface {
+ using DialectInlinerInterface::DialectInlinerInterface;
+ bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
+ return true;
+ }
+};
+} // namespace
+
// TODO: This should be the llvm.rocdl dialect once this is supported.
void ROCDLDialect::initialize() {
addOperations<
@@ -194,6 +204,7 @@ void ROCDLDialect::initialize() {
// Support unknown operations because not all ROCDL operations are registered.
allowUnknownOperations();
+ addInterfaces<ROCDLInlinerInterface>();
declarePromisedInterface<gpu::TargetAttrInterface, ROCDLTargetAttr>();
}
diff --git a/mlir/test/Dialect/LLVMIR/inlining-rocdl.mlir b/mlir/test/Dialect/LLVMIR/inlining-rocdl.mlir
new file mode 100644
index 0000000..7fd97ef
--- /dev/null
+++ b/mlir/test/Dialect/LLVMIR/inlining-rocdl.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-opt %s --inline | FileCheck %s
+
+llvm.func @threadidx() -> i32 {
+ %tid = rocdl.workitem.id.x : i32
+ llvm.return %tid : i32
+}
+
+// CHECK-LABEL: func @caller
+llvm.func @caller() -> i32 {
+ // CHECK-NOT: llvm.call @threadidx
+ // CHECK: rocdl.workitem.id.x
+ %z = llvm.call @threadidx() : () -> (i32)
+ llvm.return %z : i32
+}