diff options
-rw-r--r-- | mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp | 11 | ||||
-rw-r--r-- | mlir/test/Dialect/LLVMIR/inlining-rocdl.mlir | 14 |
2 files changed, 25 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp index 17371ec..6d54bb6 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp @@ -23,6 +23,7 @@ #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Operation.h" +#include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/TypeSwitch.h" using namespace mlir; @@ -180,6 +181,15 @@ void RawBufferAtomicUMinOp::print(mlir::OpAsmPrinter &p) { // ROCDLDialect initialization, type parsing, and registration. //===----------------------------------------------------------------------===// +namespace { +struct ROCDLInlinerInterface final : DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final { + return true; + } +}; +} // namespace + // TODO: This should be the llvm.rocdl dialect once this is supported. void ROCDLDialect::initialize() { addOperations< @@ -194,6 +204,7 @@ void ROCDLDialect::initialize() { // Support unknown operations because not all ROCDL operations are registered. allowUnknownOperations(); + addInterfaces<ROCDLInlinerInterface>(); declarePromisedInterface<gpu::TargetAttrInterface, ROCDLTargetAttr>(); } diff --git a/mlir/test/Dialect/LLVMIR/inlining-rocdl.mlir b/mlir/test/Dialect/LLVMIR/inlining-rocdl.mlir new file mode 100644 index 0000000..7fd97ef --- /dev/null +++ b/mlir/test/Dialect/LLVMIR/inlining-rocdl.mlir @@ -0,0 +1,14 @@ +// RUN: mlir-opt %s --inline | FileCheck %s + +llvm.func @threadidx() -> i32 { + %tid = rocdl.workitem.id.x : i32 + llvm.return %tid : i32 +} + +// CHECK-LABEL: func @caller +llvm.func @caller() -> i32 { + // CHECK-NOT: llvm.call @threadidx + // CHECK: rocdl.workitem.id.x + %z = llvm.call @threadidx() : () -> (i32) + llvm.return %z : i32 +} |