diff options
Diffstat (limited to 'llvm/tools/llvm-ir2vec')
-rw-r--r-- | llvm/tools/llvm-ir2vec/CMakeLists.txt | 15 | ||||
-rw-r--r-- | llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp | 305 |
2 files changed, 281 insertions, 39 deletions
diff --git a/llvm/tools/llvm-ir2vec/CMakeLists.txt b/llvm/tools/llvm-ir2vec/CMakeLists.txt index a4cf969..2bb6686 100644 --- a/llvm/tools/llvm-ir2vec/CMakeLists.txt +++ b/llvm/tools/llvm-ir2vec/CMakeLists.txt @@ -1,10 +1,25 @@ set(LLVM_LINK_COMPONENTS + # Core LLVM components for IR processing Analysis Core IRReader Support + + # Machine IR components (for -mode=mir) + CodeGen + MIRParser + + # Target initialization (required for MIR parsing) + AllTargetsAsmParsers + AllTargetsCodeGens + AllTargetsDescs + AllTargetsInfos + TargetParser ) add_llvm_tool(llvm-ir2vec llvm-ir2vec.cpp + + DEPENDS + intrinsics_gen ) diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp index 1031932..a723d37 100644 --- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp +++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp @@ -1,4 +1,4 @@ -//===- llvm-ir2vec.cpp - IR2Vec Embedding Generation Tool -----------------===// +//===- llvm-ir2vec.cpp - IR2Vec/MIR2Vec Embedding Generation Tool --------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -7,9 +7,13 @@ //===----------------------------------------------------------------------===// /// /// \file -/// This file implements the IR2Vec embedding generation tool. +/// This file implements the IR2Vec and MIR2Vec embedding generation tool. /// -/// This tool provides three main subcommands: +/// This tool supports two modes: +/// - LLVM IR mode (-mode=llvm): Process LLVM IR +/// - Machine IR mode (-mode=mir): Process Machine IR +/// +/// Available subcommands: /// /// 1. Triplet Generation (triplets): /// Generates numeric triplets (head, tail, relation) for vocabulary @@ -23,16 +27,24 @@ /// Usage: llvm-ir2vec entities input.bc -o entity2id.txt /// /// 3. Embedding Generation (embeddings): -/// Generates IR2Vec embeddings using a trained vocabulary. -/// Usage: llvm-ir2vec embeddings --ir2vec-vocab-path=vocab.json -/// --ir2vec-kind=<kind> --level=<level> input.bc -o embeddings.txt -/// Kind: --ir2vec-kind=symbolic (default), --ir2vec-kind=flow-aware +/// Generates IR2Vec/MIR2Vec embeddings using a trained vocabulary. +/// +/// For LLVM IR: +/// llvm-ir2vec embeddings --ir2vec-vocab-path=vocab.json +/// --ir2vec-kind=<kind> --level=<level> input.bc -o embeddings.txt +/// Kind: --ir2vec-kind=symbolic (default), --ir2vec-kind=flow-aware +/// +/// For Machine IR: +/// llvm-ir2vec embeddings -mode=mir --mir2vec-vocab-path=vocab.json +/// --level=<level> input.mir -o embeddings.txt +/// /// Levels: --level=inst (instructions), --level=bb (basic blocks), -/// --level=func (functions) (See IR2Vec.cpp for more embedding generation -/// options) +/// --level=func (functions) (See IR2Vec.cpp/MIR2Vec.cpp for more embedding +/// generation options) /// //===----------------------------------------------------------------------===// +#include "llvm/ADT/ArrayRef.h" #include "llvm/Analysis/IR2Vec.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Function.h" @@ -50,10 +62,38 @@ #include "llvm/Support/SourceMgr.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/CodeGen/CommandFlags.h" +#include "llvm/CodeGen/MIR2Vec.h" +#include "llvm/CodeGen/MIRParser/MIRParser.h" +#include "llvm/CodeGen/MachineFunction.h" +#include "llvm/CodeGen/MachineModuleInfo.h" +#include "llvm/MC/TargetRegistry.h" +#include "llvm/Support/TargetSelect.h" +#include "llvm/Support/WithColor.h" +#include "llvm/Target/TargetMachine.h" +#include "llvm/TargetParser/Host.h" + #define DEBUG_TYPE "ir2vec" namespace llvm { -namespace ir2vec { + +static const char *ToolName = "llvm-ir2vec"; + +// Common option category for options shared between IR2Vec and MIR2Vec +static cl::OptionCategory CommonCategory("Common Options", + "Options applicable to both IR2Vec " + "and MIR2Vec modes"); + +enum IRKind { + LLVMIR = 0, ///< LLVM IR + MIR ///< Machine IR +}; + +static cl::opt<IRKind> + IRMode("mode", cl::desc("Tool operation mode:"), + cl::values(clEnumValN(LLVMIR, "llvm", "Process LLVM IR"), + clEnumValN(MIR, "mir", "Process Machine IR")), + cl::init(LLVMIR), cl::cat(CommonCategory)); // Subcommands static cl::SubCommand @@ -70,18 +110,18 @@ static cl::opt<std::string> InputFilename(cl::Positional, cl::desc("<input bitcode file or '-' for stdin>"), cl::init("-"), cl::sub(TripletsSubCmd), - cl::sub(EmbeddingsSubCmd), cl::cat(ir2vec::IR2VecCategory)); + cl::sub(EmbeddingsSubCmd), cl::cat(CommonCategory)); static cl::opt<std::string> OutputFilename("o", cl::desc("Output filename"), cl::value_desc("filename"), cl::init("-"), - cl::cat(ir2vec::IR2VecCategory)); + cl::cat(CommonCategory)); // Embedding-specific options static cl::opt<std::string> FunctionName("function", cl::desc("Process specific function only"), cl::value_desc("name"), cl::Optional, cl::init(""), - cl::sub(EmbeddingsSubCmd), cl::cat(ir2vec::IR2VecCategory)); + cl::sub(EmbeddingsSubCmd), cl::cat(CommonCategory)); enum EmbeddingLevel { InstructionLevel, // Generate instruction-level embeddings @@ -98,9 +138,9 @@ static cl::opt<EmbeddingLevel> clEnumValN(FunctionLevel, "func", "Generate function-level embeddings")), cl::init(FunctionLevel), cl::sub(EmbeddingsSubCmd), - cl::cat(ir2vec::IR2VecCategory)); + cl::cat(CommonCategory)); -namespace { +namespace ir2vec { /// Relation types for triplet generation enum RelationType { @@ -220,7 +260,8 @@ public: /// Generate embeddings for the entire module void generateEmbeddings(raw_ostream &OS) const { if (!Vocab->isValid()) { - OS << "Error: Vocabulary is not valid. IR2VecTool not initialized.\n"; + WithColor::error(errs(), ToolName) + << "Vocabulary is not valid. IR2VecTool not initialized.\n"; return; } @@ -239,8 +280,8 @@ public: assert(Vocab->isValid() && "Vocabulary is not valid"); auto Emb = Embedder::create(IR2VecEmbeddingKind, F, *Vocab); if (!Emb) { - OS << "Error: Failed to create embedder for function " << F.getName() - << "\n"; + WithColor::error(errs(), ToolName) + << "Failed to create embedder for function " << F.getName() << "\n"; return; } @@ -300,20 +341,119 @@ Error processModule(Module &M, raw_ostream &OS) { } return Error::success(); } -} // namespace } // namespace ir2vec + +namespace mir2vec { + +/// Helper class for MIR2Vec embedding generation +class MIR2VecTool { +private: + MachineModuleInfo &MMI; + std::unique_ptr<MIRVocabulary> Vocab; + +public: + explicit MIR2VecTool(MachineModuleInfo &MMI) : MMI(MMI) {} + + /// Initialize MIR2Vec vocabulary + bool initializeVocabulary(const Module &M) { + MIR2VecVocabProvider Provider(MMI); + auto VocabOrErr = Provider.getVocabulary(M); + if (!VocabOrErr) { + WithColor::error(errs(), ToolName) + << "Failed to load MIR2Vec vocabulary - " + << toString(VocabOrErr.takeError()) << "\n"; + return false; + } + Vocab = std::make_unique<MIRVocabulary>(std::move(*VocabOrErr)); + return true; + } + + /// Generate embeddings for all machine functions in the module + void generateEmbeddings(const Module &M, raw_ostream &OS) const { + if (!Vocab) { + WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n"; + return; + } + + for (const Function &F : M) { + if (F.isDeclaration()) + continue; + + MachineFunction *MF = MMI.getMachineFunction(F); + if (!MF) { + WithColor::warning(errs(), ToolName) + << "No MachineFunction for " << F.getName() << "\n"; + continue; + } + + generateEmbeddings(*MF, OS); + } + } + + /// Generate embeddings for a specific machine function + void generateEmbeddings(MachineFunction &MF, raw_ostream &OS) const { + if (!Vocab) { + WithColor::error(errs(), ToolName) << "Vocabulary not initialized.\n"; + return; + } + + auto Emb = MIREmbedder::create(MIR2VecKind::Symbolic, MF, *Vocab); + if (!Emb) { + WithColor::error(errs(), ToolName) + << "Failed to create embedder for " << MF.getName() << "\n"; + return; + } + + OS << "MIR2Vec embeddings for machine function " << MF.getName() << ":\n"; + + // Generate embeddings based on the specified level + switch (Level) { + case FunctionLevel: { + OS << "Function vector: "; + Emb->getMFunctionVector().print(OS); + break; + } + case BasicBlockLevel: { + OS << "Basic block vectors:\n"; + for (const MachineBasicBlock &MBB : MF) { + OS << "MBB " << MBB.getName() << ": "; + Emb->getMBBVector(MBB).print(OS); + } + break; + } + case InstructionLevel: { + OS << "Instruction vectors:\n"; + for (const MachineBasicBlock &MBB : MF) { + for (const MachineInstr &MI : MBB) { + OS << MI << " -> "; + Emb->getMInstVector(MI).print(OS); + } + } + break; + } + } + } + + const MIRVocabulary *getVocabulary() const { return Vocab.get(); } +}; + +} // namespace mir2vec + } // namespace llvm int main(int argc, char **argv) { using namespace llvm; using namespace llvm::ir2vec; + using namespace llvm::mir2vec; InitLLVM X(argc, argv); - cl::HideUnrelatedOptions(ir2vec::IR2VecCategory); + // Show Common, IR2Vec and MIR2Vec option categories + cl::HideUnrelatedOptions(ArrayRef<const cl::OptionCategory *>{ + &CommonCategory, &ir2vec::IR2VecCategory, &mir2vec::MIR2VecCategory}); cl::ParseCommandLineOptions( argc, argv, - "IR2Vec - Embedding Generation Tool\n" - "Generates embeddings for a given LLVM IR and " + "IR2Vec/MIR2Vec - Embedding Generation Tool\n" + "Generates embeddings for a given LLVM IR or MIR and " "supports triplet generation for vocabulary " "training and embedding generation.\n\n" "See https://llvm.org/docs/CommandGuide/llvm-ir2vec.html for more " @@ -322,30 +462,117 @@ int main(int argc, char **argv) { std::error_code EC; raw_fd_ostream OS(OutputFilename, EC); if (EC) { - errs() << "Error opening output file: " << EC.message() << "\n"; + WithColor::error(errs(), ToolName) + << "opening output file: " << EC.message() << "\n"; return 1; } - if (EntitiesSubCmd) { - // Just dump entity mappings without processing any IR - IR2VecTool::generateEntityMappings(OS); + if (IRMode == IRKind::LLVMIR) { + if (EntitiesSubCmd) { + // Just dump entity mappings without processing any IR + IR2VecTool::generateEntityMappings(OS); + return 0; + } + + // Parse the input LLVM IR file or stdin + SMDiagnostic Err; + LLVMContext Context; + std::unique_ptr<Module> M = parseIRFile(InputFilename, Err, Context); + if (!M) { + Err.print(ToolName, errs()); + return 1; + } + + if (Error Err = processModule(*M, OS)) { + handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EIB) { + WithColor::error(errs(), ToolName) << EIB.message() << "\n"; + }); + return 1; + } return 0; } + if (IRMode == IRKind::MIR) { + // Initialize targets for Machine IR processing + InitializeAllTargets(); + InitializeAllTargetMCs(); + InitializeAllAsmParsers(); + InitializeAllAsmPrinters(); + static codegen::RegisterCodeGenFlags CGF; + + // Parse MIR input file + SMDiagnostic Err; + LLVMContext Context; + std::unique_ptr<TargetMachine> TM; + + auto MIR = createMIRParserFromFile(InputFilename, Err, Context); + if (!MIR) { + Err.print(ToolName, errs()); + return 1; + } - // Parse the input LLVM IR file or stdin - SMDiagnostic Err; - LLVMContext Context; - std::unique_ptr<Module> M = parseIRFile(InputFilename, Err, Context); - if (!M) { - Err.print(argv[0], errs()); - return 1; - } + auto SetDataLayout = [&](StringRef DataLayoutTargetTriple, + StringRef OldDLStr) -> std::optional<std::string> { + std::string IRTargetTriple = DataLayoutTargetTriple.str(); + Triple TheTriple = Triple(IRTargetTriple); + if (TheTriple.getTriple().empty()) + TheTriple.setTriple(sys::getDefaultTargetTriple()); + auto TMOrErr = codegen::createTargetMachineForTriple(TheTriple.str()); + if (!TMOrErr) { + Err.print(ToolName, errs()); + exit(1); + } + TM = std::move(*TMOrErr); + return TM->createDataLayout().getStringRepresentation(); + }; + + std::unique_ptr<Module> M = MIR->parseIRModule(SetDataLayout); + if (!M) { + Err.print(ToolName, errs()); + return 1; + } - if (Error Err = processModule(*M, OS)) { - handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EIB) { - errs() << "Error: " << EIB.message() << "\n"; - }); - return 1; + // Parse machine functions + auto MMI = std::make_unique<MachineModuleInfo>(TM.get()); + if (!MMI || MIR->parseMachineFunctions(*M, *MMI)) { + Err.print(ToolName, errs()); + return 1; + } + + // Create MIR2Vec tool and initialize vocabulary + MIR2VecTool Tool(*MMI); + if (!Tool.initializeVocabulary(*M)) + return 1; + + LLVM_DEBUG(dbgs() << "MIR2Vec vocabulary loaded successfully.\n" + << "Vocabulary dimension: " + << Tool.getVocabulary()->getDimension() << "\n" + << "Vocabulary size: " + << Tool.getVocabulary()->getCanonicalSize() << "\n"); + + // Generate embeddings based on subcommand + if (!FunctionName.empty()) { + // Process single function + Function *F = M->getFunction(FunctionName); + if (!F) { + WithColor::error(errs(), ToolName) + << "Function '" << FunctionName << "' not found\n"; + return 1; + } + + MachineFunction *MF = MMI->getMachineFunction(*F); + if (!MF) { + WithColor::error(errs(), ToolName) + << "No MachineFunction for " << FunctionName << "\n"; + return 1; + } + + Tool.generateEmbeddings(*MF, OS); + } else { + // Process all functions + Tool.generateEmbeddings(*M, OS); + } + + return 0; } return 0; |