diff options
Diffstat (limited to 'mlir/lib/Target/LLVMIR/ModuleImport.cpp')
-rw-r--r-- | mlir/lib/Target/LLVMIR/ModuleImport.cpp | 40 |
1 files changed, 39 insertions, 1 deletions
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp index 3f80002..d73c84a 100644 --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -519,6 +519,39 @@ void ModuleImport::addDebugIntrinsic(llvm::CallInst *intrinsic) { debugIntrinsics.insert(intrinsic); } +static Attribute convertCGProfileModuleFlagValue(ModuleOp mlirModule, + llvm::MDTuple *mdTuple) { + auto getFunctionSymbol = [&](const llvm::MDOperand &funcMDO) { + auto *f = cast<llvm::ValueAsMetadata>(funcMDO); + auto *llvmFn = cast<llvm::Function>(f->getValue()->stripPointerCasts()); + return FlatSymbolRefAttr::get(mlirModule->getContext(), llvmFn->getName()); + }; + + // Each tuple element becomes one ModuleFlagCGProfileEntryAttr. + SmallVector<Attribute> cgProfile; + for (unsigned i = 0; i < mdTuple->getNumOperands(); i++) { + const llvm::MDOperand &mdo = mdTuple->getOperand(i); + auto *cgEntry = cast<llvm::MDNode>(mdo); + llvm::Constant *llvmConstant = + cast<llvm::ConstantAsMetadata>(cgEntry->getOperand(2))->getValue(); + uint64_t count = cast<llvm::ConstantInt>(llvmConstant)->getZExtValue(); + cgProfile.push_back(ModuleFlagCGProfileEntryAttr::get( + mlirModule->getContext(), getFunctionSymbol(cgEntry->getOperand(0)), + getFunctionSymbol(cgEntry->getOperand(1)), count)); + } + return ArrayAttr::get(mlirModule->getContext(), cgProfile); +} + +/// Invoke specific handlers for each known module flag value, returns nullptr +/// if the key is unknown or unimplemented. +static Attribute convertModuleFlagValueFromMDTuple(ModuleOp mlirModule, + StringRef key, + llvm::MDTuple *mdTuple) { + if (key == LLVMDialect::getModuleFlagKeyCGProfileName()) + return convertCGProfileModuleFlagValue(mlirModule, mdTuple); + return nullptr; +} + LogicalResult ModuleImport::convertModuleFlagsMetadata() { SmallVector<llvm::Module::ModuleFlagEntry> llvmModuleFlags; llvmModule->getModuleFlagsMetadata(llvmModuleFlags); @@ -530,7 +563,12 @@ LogicalResult ModuleImport::convertModuleFlagsMetadata() { valAttr = builder.getI32IntegerAttr(constInt->getZExtValue()); } else if (auto *mdString = dyn_cast<llvm::MDString>(val)) { valAttr = builder.getStringAttr(mdString->getString()); - } else { + } else if (auto *mdTuple = dyn_cast<llvm::MDTuple>(val)) { + valAttr = convertModuleFlagValueFromMDTuple(mlirModule, key->getString(), + mdTuple); + } + + if (!valAttr) { emitWarning(mlirModule.getLoc()) << "unsupported module flag value: " << diagMD(val, llvmModule.get()); continue; |