aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorOleksandr "Alex" Zinenko <zinenko@google.com>2023-12-19 14:18:16 +0100
committerGitHub <noreply@github.com>2023-12-19 14:18:16 +0100
commit9519e3ecbf6ed251c5ab7c74549fe86df1efc14c (patch)
treef1c49c69c3855a3e76692c4a9b67584a3ed5c4ef
parent133de6c1510b15108f729f0d981d45cb7e936b85 (diff)
downloadllvm-9519e3ecbf6ed251c5ab7c74549fe86df1efc14c.zip
llvm-9519e3ecbf6ed251c5ab7c74549fe86df1efc14c.tar.gz
llvm-9519e3ecbf6ed251c5ab7c74549fe86df1efc14c.tar.bz2
[mlir] support dialect attribute translation to LLVM IR (#75309)
Extend the `amendOperation` mechanism for translating dialect attributes attached to operations from another dialect when translating MLIR to LLVM IR. Previously, this mechanism would have no knowledge of the LLVM IR instructions created for the given operation, making it impossible for it to perform local modifications such as attaching operation-level metadata. Collect instructions inserted by the LLVM IR builder and pass them to `amendOperation`.
-rw-r--r--mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h10
-rw-r--r--mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h15
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp3
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp6
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp3
-rw-r--r--mlir/lib/Target/LLVMIR/ModuleTranslation.cpp141
-rw-r--r--mlir/test/Target/LLVMIR/test.mlir24
-rw-r--r--mlir/test/lib/Dialect/Test/TestToLLVMIRTranslation.cpp18
8 files changed, 194 insertions, 26 deletions
diff --git a/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h b/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h
index 0531c0ec..19991a6 100644
--- a/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h
+++ b/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h
@@ -18,6 +18,7 @@
#include "mlir/Support/LogicalResult.h"
namespace llvm {
+class Instruction;
class IRBuilderBase;
} // namespace llvm
@@ -52,7 +53,8 @@ public:
/// translation results and amend the corresponding IR constructs. Does
/// nothing and succeeds by default.
virtual LogicalResult
- amendOperation(Operation *op, NamedAttribute attribute,
+ amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
+ NamedAttribute attribute,
LLVM::ModuleTranslation &moduleTranslation) const {
return success();
}
@@ -78,11 +80,13 @@ public:
/// Acts on the given operation using the interface implemented by the dialect
/// of one of the operation's dialect attributes.
virtual LogicalResult
- amendOperation(Operation *op, NamedAttribute attribute,
+ amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
+ NamedAttribute attribute,
LLVM::ModuleTranslation &moduleTranslation) const {
if (const LLVMTranslationDialectInterface *iface =
getInterfaceFor(attribute.getNameDialect())) {
- return iface->amendOperation(op, attribute, moduleTranslation);
+ return iface->amendOperation(op, instructions, attribute,
+ moduleTranslation);
}
return success();
}
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index 3f797f4..d6b03ac 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -209,7 +209,10 @@ public:
/// PHI nodes are constructed for block arguments but are _not_ connected to
/// the predecessors that may not exist yet.
LogicalResult convertBlock(Block &bb, bool ignoreArguments,
- llvm::IRBuilderBase &builder);
+ llvm::IRBuilderBase &builder) {
+ return convertBlockImpl(bb, ignoreArguments, builder,
+ /*recordInsertions=*/false);
+ }
/// Gets the named metadata in the LLVM IR module being constructed, creating
/// it if it does not exist.
@@ -299,12 +302,16 @@ private:
~ModuleTranslation();
/// Converts individual components.
- LogicalResult convertOperation(Operation &op, llvm::IRBuilderBase &builder);
+ LogicalResult convertOperation(Operation &op, llvm::IRBuilderBase &builder,
+ bool recordInsertions = false);
LogicalResult convertFunctionSignatures();
LogicalResult convertFunctions();
LogicalResult convertComdats();
LogicalResult convertGlobals();
LogicalResult convertOneFunction(LLVMFuncOp func);
+ LogicalResult convertBlockImpl(Block &bb, bool ignoreArguments,
+ llvm::IRBuilderBase &builder,
+ bool recordInsertions);
/// Returns the LLVM metadata corresponding to the given mlir LLVM dialect
/// TBAATagAttr.
@@ -315,7 +322,9 @@ private:
LogicalResult createTBAAMetadata();
/// Translates dialect attributes attached to the given operation.
- LogicalResult convertDialectAttributes(Operation *op);
+ LogicalResult
+ convertDialectAttributes(Operation *op,
+ ArrayRef<llvm::Instruction *> instructions);
/// Translates parameter attributes and adds them to the returned AttrBuilder.
llvm::AttrBuilder convertParameterAttrs(DictionaryAttr paramAttrs);
diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
index 88e3a45..0d6bca5 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
@@ -142,7 +142,8 @@ public:
/// Attaches module-level metadata for functions marked as kernels.
LogicalResult
- amendOperation(Operation *op, NamedAttribute attribute,
+ amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
+ NamedAttribute attribute,
LLVM::ModuleTranslation &moduleTranslation) const final {
auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
if (!func)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 088e7ae..6295846 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -2572,14 +2572,16 @@ public:
/// Given an OpenMP MLIR attribute, create the corresponding LLVM-IR, runtime
/// calls, or operation amendments
LogicalResult
- amendOperation(Operation *op, NamedAttribute attribute,
+ amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
+ NamedAttribute attribute,
LLVM::ModuleTranslation &moduleTranslation) const final;
};
} // namespace
LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
- Operation *op, NamedAttribute attribute,
+ Operation *op, ArrayRef<llvm::Instruction *> instructions,
+ NamedAttribute attribute,
LLVM::ModuleTranslation &moduleTranslation) const {
return llvm::StringSwitch<llvm::function_ref<LogicalResult(Attribute)>>(
attribute.getName())
diff --git a/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp
index 5ab7028..55a6285 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp
@@ -81,7 +81,8 @@ public:
/// Attaches module-level metadata for functions marked as kernels.
LogicalResult
- amendOperation(Operation *op, NamedAttribute attribute,
+ amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
+ NamedAttribute attribute,
LLVM::ModuleTranslation &moduleTranslation) const final {
if (attribute.getName() == ROCDL::ROCDLDialect::getKernelFuncAttrName()) {
auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 9f0e1f3..1722d74 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -59,6 +59,113 @@ using namespace mlir::LLVM::detail;
#include "mlir/Dialect/LLVMIR/LLVMConversionEnumsToLLVM.inc"
+namespace {
+/// A customized inserter for LLVM's IRBuilder that captures all LLVM IR
+/// instructions that are created for future reference.
+///
+/// This is intended to be used with the `CollectionScope` RAII object:
+///
+/// llvm::IRBuilder<..., InstructionCapturingInserter> builder;
+/// {
+/// InstructionCapturingInserter::CollectionScope scope(builder);
+/// // Call IRBuilder methods as usual.
+///
+/// // This will return a list of all instructions created by the builder,
+/// // in order of creation.
+/// builder.getInserter().getCapturedInstructions();
+/// }
+/// // This will return an empty list.
+/// builder.getInserter().getCapturedInstructions();
+///
+/// The capturing functionality is _disabled_ by default for performance
+/// consideration. It needs to be explicitly enabled, which is achieved by
+/// creating a `CollectionScope`.
+class InstructionCapturingInserter : public llvm::IRBuilderCallbackInserter {
+public:
+ /// Constructs the inserter.
+ InstructionCapturingInserter()
+ : llvm::IRBuilderCallbackInserter([this](llvm::Instruction *instruction) {
+ if (LLVM_LIKELY(enabled))
+ capturedInstructions.push_back(instruction);
+ }) {}
+
+ /// Returns the list of LLVM IR instructions captured since the last cleanup.
+ ArrayRef<llvm::Instruction *> getCapturedInstructions() const {
+ return capturedInstructions;
+ }
+
+ /// Clears the list of captured LLVM IR instructions.
+ void clearCapturedInstructions() { capturedInstructions.clear(); }
+
+ /// RAII object enabling the capture of created LLVM IR instructions.
+ class CollectionScope {
+ public:
+ /// Creates the scope for the given inserter.
+ CollectionScope(llvm::IRBuilderBase &irBuilder, bool isBuilderCapturing);
+
+ /// Ends the scope.
+ ~CollectionScope();
+
+ ArrayRef<llvm::Instruction *> getCapturedInstructions() {
+ if (!inserter)
+ return {};
+ return inserter->getCapturedInstructions();
+ }
+
+ private:
+ /// Back reference to the inserter.
+ InstructionCapturingInserter *inserter = nullptr;
+
+ /// List of instructions in the inserter prior to this scope.
+ SmallVector<llvm::Instruction *> previouslyCollectedInstructions;
+
+ /// Whether the inserter was enabled prior to this scope.
+ bool wasEnabled;
+ };
+
+ /// Enable or disable the capturing mechanism.
+ void setEnabled(bool enabled = true) { this->enabled = enabled; }
+
+private:
+ /// List of captured instructions.
+ SmallVector<llvm::Instruction *> capturedInstructions;
+
+ /// Whether the collection is enabled.
+ bool enabled = false;
+};
+
+using CapturingIRBuilder =
+ llvm::IRBuilder<llvm::ConstantFolder, InstructionCapturingInserter>;
+} // namespace
+
+InstructionCapturingInserter::CollectionScope::CollectionScope(
+ llvm::IRBuilderBase &irBuilder, bool isBuilderCapturing) {
+
+ if (!isBuilderCapturing)
+ return;
+
+ auto &capturingIRBuilder = static_cast<CapturingIRBuilder &>(irBuilder);
+ inserter = &capturingIRBuilder.getInserter();
+ wasEnabled = inserter->enabled;
+ if (wasEnabled)
+ previouslyCollectedInstructions.swap(inserter->capturedInstructions);
+ inserter->setEnabled(true);
+}
+
+InstructionCapturingInserter::CollectionScope::~CollectionScope() {
+ if (!inserter)
+ return;
+
+ previouslyCollectedInstructions.swap(inserter->capturedInstructions);
+ // If collection was enabled (likely in another, surrounding scope), keep
+ // the instructions collected in this scope.
+ if (wasEnabled) {
+ llvm::append_range(inserter->capturedInstructions,
+ previouslyCollectedInstructions);
+ }
+ inserter->setEnabled(wasEnabled);
+}
+
/// Translates the given data layout spec attribute to the LLVM IR data layout.
/// Only integer, float, pointer and endianness entries are currently supported.
static FailureOr<llvm::DataLayout>
@@ -631,9 +738,9 @@ llvm::CallInst *mlir::LLVM::detail::createIntrinsicCall(
/// Given a single MLIR operation, create the corresponding LLVM IR operation
/// using the `builder`.
-LogicalResult
-ModuleTranslation::convertOperation(Operation &op,
- llvm::IRBuilderBase &builder) {
+LogicalResult ModuleTranslation::convertOperation(Operation &op,
+ llvm::IRBuilderBase &builder,
+ bool recordInsertions) {
const LLVMTranslationDialectInterface *opIface = iface.getInterfaceFor(&op);
if (!opIface)
return op.emitError("cannot be converted to LLVM IR: missing "
@@ -641,11 +748,13 @@ ModuleTranslation::convertOperation(Operation &op,
"dialect for op: ")
<< op.getName();
+ InstructionCapturingInserter::CollectionScope scope(builder,
+ recordInsertions);
if (failed(opIface->convertOperation(&op, builder, *this)))
return op.emitError("LLVM Translation failed for operation: ")
<< op.getName();
- return convertDialectAttributes(&op);
+ return convertDialectAttributes(&op, scope.getCapturedInstructions());
}
/// Convert block to LLVM IR. Unless `ignoreArguments` is set, emit PHI nodes
@@ -655,8 +764,10 @@ ModuleTranslation::convertOperation(Operation &op,
/// been created for `bb` and included in the block mapping. Inserts new
/// instructions at the end of the block and leaves `builder` in a state
/// suitable for further insertion into the end of the block.
-LogicalResult ModuleTranslation::convertBlock(Block &bb, bool ignoreArguments,
- llvm::IRBuilderBase &builder) {
+LogicalResult ModuleTranslation::convertBlockImpl(Block &bb,
+ bool ignoreArguments,
+ llvm::IRBuilderBase &builder,
+ bool recordInsertions) {
builder.SetInsertPoint(lookupBlock(&bb));
auto *subprogram = builder.GetInsertBlock()->getParent()->getSubprogram();
@@ -687,7 +798,7 @@ LogicalResult ModuleTranslation::convertBlock(Block &bb, bool ignoreArguments,
builder.SetCurrentDebugLocation(
debugTranslation->translateLoc(op.getLoc(), subprogram));
- if (failed(convertOperation(op, builder)))
+ if (failed(convertOperation(op, builder, recordInsertions)))
return failure();
// Set the branch weight metadata on the translated instruction.
@@ -844,7 +955,7 @@ LogicalResult ModuleTranslation::convertGlobals() {
}
for (auto op : getModuleBody(mlirModule).getOps<LLVM::GlobalOp>())
- if (failed(convertDialectAttributes(op)))
+ if (failed(convertDialectAttributes(op, {})))
return failure();
// Finally, update the compile units their respective sets of global variables
@@ -997,8 +1108,9 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
// converted before uses.
auto blocks = getTopologicallySortedBlocks(func.getBody());
for (Block *bb : blocks) {
- llvm::IRBuilder<> builder(llvmContext);
- if (failed(convertBlock(*bb, bb->isEntryBlock(), builder)))
+ CapturingIRBuilder builder(llvmContext);
+ if (failed(convertBlockImpl(*bb, bb->isEntryBlock(), builder,
+ /*recordInsertions=*/true)))
return failure();
}
@@ -1007,12 +1119,13 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
detail::connectPHINodes(func.getBody(), *this);
// Finally, convert dialect attributes attached to the function.
- return convertDialectAttributes(func);
+ return convertDialectAttributes(func, {});
}
-LogicalResult ModuleTranslation::convertDialectAttributes(Operation *op) {
+LogicalResult ModuleTranslation::convertDialectAttributes(
+ Operation *op, ArrayRef<llvm::Instruction *> instructions) {
for (NamedAttribute attribute : op->getDialectAttrs())
- if (failed(iface.amendOperation(op, attribute, *this)))
+ if (failed(iface.amendOperation(op, instructions, attribute, *this)))
return failure();
return success();
}
@@ -1134,7 +1247,7 @@ LogicalResult ModuleTranslation::convertFunctions() {
// Do not convert external functions, but do process dialect attributes
// attached to them.
if (function.isExternal()) {
- if (failed(convertDialectAttributes(function)))
+ if (failed(convertDialectAttributes(function, {})))
return failure();
continue;
}
diff --git a/mlir/test/Target/LLVMIR/test.mlir b/mlir/test/Target/LLVMIR/test.mlir
index f48738f..0ab1b72 100644
--- a/mlir/test/Target/LLVMIR/test.mlir
+++ b/mlir/test/Target/LLVMIR/test.mlir
@@ -16,3 +16,27 @@ module {
module attributes {test.discardable_mod_attr = true} {}
// CHECK: @sym_from_attr = external global i32
+
+// -----
+
+// CHECK-LABEL: @dialect_attr_translation
+llvm.func @dialect_attr_translation() {
+ // CHECK: ret void, !annotation ![[MD_ID:.+]]
+ llvm.return {test.add_annotation}
+}
+// CHECK: ![[MD_ID]] = !{!"annotation_from_test"}
+
+// -----
+
+// CHECK-LABEL: @dialect_attr_translation_multi
+llvm.func @dialect_attr_translation_multi(%a: i64, %b: i64, %c: i64) -> i64 {
+ // CHECK: add {{.*}}, !annotation ![[MD_ID_ADD:.+]]
+ // CHECK: mul {{.*}}, !annotation ![[MD_ID_MUL:.+]]
+ // CHECK: ret {{.*}}, !annotation ![[MD_ID_RET:.+]]
+ %ab = llvm.add %a, %b {test.add_annotation = "add"} : i64
+ %r = llvm.mul %ab, %c {test.add_annotation = "mul"} : i64
+ llvm.return {test.add_annotation = "ret"} %r : i64
+}
+// CHECK-DAG: ![[MD_ID_ADD]] = !{!"annotation_from_test: add"}
+// CHECK-DAG: ![[MD_ID_MUL]] = !{!"annotation_from_test: mul"}
+// CHECK-DAG: ![[MD_ID_RET]] = !{!"annotation_from_test: ret"}
diff --git a/mlir/test/lib/Dialect/Test/TestToLLVMIRTranslation.cpp b/mlir/test/lib/Dialect/Test/TestToLLVMIRTranslation.cpp
index 7110d999..2dd99c6 100644
--- a/mlir/test/lib/Dialect/Test/TestToLLVMIRTranslation.cpp
+++ b/mlir/test/lib/Dialect/Test/TestToLLVMIRTranslation.cpp
@@ -32,7 +32,8 @@ public:
using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
LogicalResult
- amendOperation(Operation *op, NamedAttribute attribute,
+ amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
+ NamedAttribute attribute,
LLVM::ModuleTranslation &moduleTranslation) const final;
LogicalResult
@@ -43,7 +44,8 @@ public:
} // namespace
LogicalResult TestDialectLLVMIRTranslationInterface::amendOperation(
- Operation *op, NamedAttribute attribute,
+ Operation *op, ArrayRef<llvm::Instruction *> instructions,
+ NamedAttribute attribute,
LLVM::ModuleTranslation &moduleTranslation) const {
return llvm::StringSwitch<llvm::function_ref<LogicalResult(Attribute)>>(
attribute.getName())
@@ -72,6 +74,18 @@ LogicalResult TestDialectLLVMIRTranslationInterface::amendOperation(
return success();
})
+ .Case("test.add_annotation",
+ [&](Attribute attr) {
+ for (llvm::Instruction *instruction : instructions) {
+ if (auto strAttr = dyn_cast<StringAttr>(attr)) {
+ instruction->addAnnotationMetadata("annotation_from_test: " +
+ strAttr.getValue().str());
+ } else {
+ instruction->addAnnotationMetadata("annotation_from_test");
+ }
+ }
+ return success();
+ })
.Default([](Attribute) {
// Skip other discardable dialect attributes.
return success();