//===- llvm-ir2vec.cpp - IR2Vec Embedding Generation Tool -----------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// /// /// \file /// This file implements the IR2Vec embedding generation tool. /// /// This tool provides two main functionalities: /// /// 1. Triplet Generation Mode (--mode=triplets): /// Generates triplets (opcode, type, operands) for vocabulary training. /// Usage: llvm-ir2vec --mode=triplets input.bc -o triplets.txt /// /// 2. Embedding Generation Mode (--mode=embeddings): /// Generates IR2Vec embeddings using a trained vocabulary. /// Usage: llvm-ir2vec --mode=embeddings --ir2vec-vocab-path=vocab.json /// --level=func input.bc -o embeddings.txt Levels: --level=inst /// (instructions), --level=bb (basic blocks), --level=func (functions) /// (See IR2Vec.cpp for more embedding generation options) /// //===----------------------------------------------------------------------===// #include "llvm/Analysis/IR2Vec.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Function.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/IR/PassInstrumentation.h" #include "llvm/IR/PassManager.h" #include "llvm/IR/Type.h" #include "llvm/IRReader/IRReader.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/Errc.h" #include "llvm/Support/InitLLVM.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/raw_ostream.h" #define DEBUG_TYPE "ir2vec" namespace llvm { namespace ir2vec { static cl::OptionCategory IR2VecToolCategory("IR2Vec Tool Options"); static cl::opt InputFilename(cl::Positional, cl::desc(""), cl::init("-"), cl::cat(IR2VecToolCategory)); static cl::opt OutputFilename("o", cl::desc("Output filename"), cl::value_desc("filename"), cl::init("-"), cl::cat(IR2VecToolCategory)); enum ToolMode { TripletMode, // Generate triplets for vocabulary training EmbeddingMode // Generate embeddings using trained vocabulary }; static cl::opt Mode("mode", cl::desc("Tool operation mode:"), cl::values(clEnumValN(TripletMode, "triplets", "Generate triplets for vocabulary training"), clEnumValN(EmbeddingMode, "embeddings", "Generate embeddings using trained vocabulary")), cl::init(EmbeddingMode), cl::cat(IR2VecToolCategory)); static cl::opt FunctionName("function", cl::desc("Process specific function only"), cl::value_desc("name"), cl::Optional, cl::init(""), cl::cat(IR2VecToolCategory)); enum EmbeddingLevel { InstructionLevel, // Generate instruction-level embeddings BasicBlockLevel, // Generate basic block-level embeddings FunctionLevel // Generate function-level embeddings }; static cl::opt Level("level", cl::desc("Embedding generation level (for embedding mode):"), cl::values(clEnumValN(InstructionLevel, "inst", "Generate instruction-level embeddings"), clEnumValN(BasicBlockLevel, "bb", "Generate basic block-level embeddings"), clEnumValN(FunctionLevel, "func", "Generate function-level embeddings")), cl::init(FunctionLevel), cl::cat(IR2VecToolCategory)); namespace { /// Helper class for collecting IR triplets and generating embeddings class IR2VecTool { private: Module &M; ModuleAnalysisManager MAM; const Vocabulary *Vocab = nullptr; public: explicit IR2VecTool(Module &M) : M(M) {} /// Initialize the IR2Vec vocabulary analysis bool initializeVocabulary() { // Register and run the IR2Vec vocabulary analysis // The vocabulary file path is specified via --ir2vec-vocab-path global // option MAM.registerPass([&] { return PassInstrumentationAnalysis(); }); MAM.registerPass([&] { return IR2VecVocabAnalysis(); }); Vocab = &MAM.getResult(M); return Vocab->isValid(); } /// Generate triplets for the entire module void generateTriplets(raw_ostream &OS) const { for (const Function &F : M) generateTriplets(F, OS); } /// Generate triplets for a single function void generateTriplets(const Function &F, raw_ostream &OS) const { if (F.isDeclaration()) return; std::string LocalOutput; raw_string_ostream LocalOS(LocalOutput); for (const BasicBlock &BB : F) traverseBasicBlock(BB, LocalOS); LocalOS.flush(); OS << LocalOutput; } /// Generate embeddings for the entire module void generateEmbeddings(raw_ostream &OS) const { if (!Vocab->isValid()) { OS << "Error: Vocabulary is not valid. IR2VecTool not initialized.\n"; return; } for (const Function &F : M) generateEmbeddings(F, OS); } /// Generate embeddings for a single function void generateEmbeddings(const Function &F, raw_ostream &OS) const { if (F.isDeclaration()) { OS << "Function " << F.getName() << " is a declaration, skipping.\n"; return; } // Create embedder for this function assert(Vocab->isValid() && "Vocabulary is not valid"); auto Emb = Embedder::create(IR2VecKind::Symbolic, F, *Vocab); if (!Emb) { OS << "Error: Failed to create embedder for function " << F.getName() << "\n"; return; } OS << "Function: " << F.getName() << "\n"; // Generate embeddings based on the specified level switch (Level) { case FunctionLevel: { Emb->getFunctionVector().print(OS); break; } case BasicBlockLevel: { const auto &BBVecMap = Emb->getBBVecMap(); for (const BasicBlock &BB : F) { auto It = BBVecMap.find(&BB); if (It != BBVecMap.end()) { OS << BB.getName() << ":"; It->second.print(OS); } } break; } case InstructionLevel: { const auto &InstMap = Emb->getInstVecMap(); for (const BasicBlock &BB : F) { for (const Instruction &I : BB) { auto It = InstMap.find(&I); if (It != InstMap.end()) { I.print(OS); It->second.print(OS); } } } break; } } } private: /// Process a single basic block for triplet generation void traverseBasicBlock(const BasicBlock &BB, raw_string_ostream &OS) const { // Consider only non-debug and non-pseudo instructions for (const auto &I : BB.instructionsWithoutDebug()) { StringRef OpcStr = Vocabulary::getVocabKeyForOpcode(I.getOpcode()); StringRef TypeStr = Vocabulary::getVocabKeyForTypeID(I.getType()->getTypeID()); OS << '\n' << OpcStr << ' ' << TypeStr << ' '; LLVM_DEBUG({ I.print(dbgs()); dbgs() << "\n"; I.getType()->print(dbgs()); dbgs() << " Type\n"; }); for (const Use &U : I.operands()) OS << Vocabulary::getVocabKeyForOperandKind( Vocabulary::getOperandKind(U.get())) << ' '; } } }; Error processModule(Module &M, raw_ostream &OS) { IR2VecTool Tool(M); if (Mode == EmbeddingMode) { // Initialize vocabulary for embedding generation // Note: Requires --ir2vec-vocab-path option to be set if (!Tool.initializeVocabulary()) return createStringError( errc::invalid_argument, "Failed to initialize IR2Vec vocabulary. " "Make sure to specify --ir2vec-vocab-path for embedding mode."); if (!FunctionName.empty()) { // Process single function if (const Function *F = M.getFunction(FunctionName)) Tool.generateEmbeddings(*F, OS); else return createStringError(errc::invalid_argument, "Function '%s' not found", FunctionName.c_str()); } else { // Process all functions Tool.generateEmbeddings(OS); } } else { // Triplet generation mode - no vocabulary needed if (!FunctionName.empty()) // Process single function if (const Function *F = M.getFunction(FunctionName)) Tool.generateTriplets(*F, OS); else return createStringError(errc::invalid_argument, "Function '%s' not found", FunctionName.c_str()); else // Process all functions Tool.generateTriplets(OS); } return Error::success(); } } // namespace } // namespace ir2vec } // namespace llvm int main(int argc, char **argv) { using namespace llvm; using namespace llvm::ir2vec; InitLLVM X(argc, argv); cl::HideUnrelatedOptions(IR2VecToolCategory); cl::ParseCommandLineOptions( argc, argv, "IR2Vec - Embedding Generation Tool\n" "Generates embeddings for a given LLVM IR and " "supports triplet generation for vocabulary " "training and embedding generation.\n\n" "See https://llvm.org/docs/CommandGuide/llvm-ir2vec.html for more " "information.\n"); // Validate command line options if (Mode == TripletMode && Level.getNumOccurrences() > 0) errs() << "Warning: --level option is ignored in triplet mode\n"; // Parse the input LLVM IR file or stdin SMDiagnostic Err; LLVMContext Context; std::unique_ptr M = parseIRFile(InputFilename, Err, Context); if (!M) { Err.print(argv[0], errs()); return 1; } std::error_code EC; raw_fd_ostream OS(OutputFilename, EC); if (EC) { errs() << "Error opening output file: " << EC.message() << "\n"; return 1; } if (Error Err = processModule(*M, OS)) { handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EIB) { errs() << "Error: " << EIB.message() << "\n"; }); return 1; } return 0; }