diff options
Diffstat (limited to 'llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp')
| -rw-r--r-- | llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp | 244 |
1 files changed, 215 insertions, 29 deletions
diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp index a723d37..7402782 100644 --- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp +++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp @@ -19,12 +19,22 @@ /// Generates numeric triplets (head, tail, relation) for vocabulary /// training. Output format: MAX_RELATION=N header followed by /// head\ttail\trelation lines. Relations: 0=Type, 1=Next, 2+=Arg0,Arg1,... -/// Usage: llvm-ir2vec triplets input.bc -o train2id.txt +/// +/// For LLVM IR: +/// llvm-ir2vec triplets input.bc -o train2id.txt +/// +/// For Machine IR: +/// llvm-ir2vec triplets -mode=mir input.mir -o train2id.txt /// /// 2. Entity Mappings (entities): /// Generates entity mappings for vocabulary training. /// Output format: <total_entities> header followed by entity\tid lines. -/// Usage: llvm-ir2vec entities input.bc -o entity2id.txt +/// +/// For LLVM IR: +/// llvm-ir2vec entities input.bc -o entity2id.txt +/// +/// For Machine IR: +/// llvm-ir2vec entities -mode=mir input.mir -o entity2id.txt /// /// 3. Embedding Generation (embeddings): /// Generates IR2Vec/MIR2Vec embeddings using a trained vocabulary. @@ -67,6 +77,8 @@ #include "llvm/CodeGen/MIRParser/MIRParser.h" #include "llvm/CodeGen/MachineFunction.h" #include "llvm/CodeGen/MachineModuleInfo.h" +#include "llvm/CodeGen/TargetInstrInfo.h" +#include "llvm/CodeGen/TargetRegisterInfo.h" #include "llvm/MC/TargetRegistry.h" #include "llvm/Support/TargetSelect.h" #include "llvm/Support/WithColor.h" @@ -106,11 +118,10 @@ static cl::SubCommand "Generate embeddings using trained vocabulary"); // Common options -static cl::opt<std::string> - InputFilename(cl::Positional, - cl::desc("<input bitcode file or '-' for stdin>"), - cl::init("-"), cl::sub(TripletsSubCmd), - cl::sub(EmbeddingsSubCmd), cl::cat(CommonCategory)); +static cl::opt<std::string> InputFilename( + cl::Positional, cl::desc("<input bitcode/MIR file or '-' for stdin>"), + cl::init("-"), cl::sub(TripletsSubCmd), cl::sub(EntitiesSubCmd), + cl::sub(EmbeddingsSubCmd), cl::cat(CommonCategory)); static cl::opt<std::string> OutputFilename("o", cl::desc("Output filename"), cl::value_desc("filename"), @@ -345,6 +356,12 @@ Error processModule(Module &M, raw_ostream &OS) { namespace mir2vec { +/// Relation types for MIR2Vec triplet generation +enum MIRRelationType { + MIRNextRelation = 0, ///< Sequential instruction relationship + MIRArgRelation = 1 ///< Instruction to operand relationship (ArgRelation + N) +}; + /// Helper class for MIR2Vec embedding generation class MIR2VecTool { private: @@ -354,7 +371,7 @@ private: public: explicit MIR2VecTool(MachineModuleInfo &MMI) : MMI(MMI) {} - /// Initialize MIR2Vec vocabulary + /// Initialize MIR2Vec vocabulary from file (for embeddings generation) bool initializeVocabulary(const Module &M) { MIR2VecVocabProvider Provider(MMI); auto VocabOrErr = Provider.getVocabulary(M); @@ -368,6 +385,146 @@ public: return true; } + /// Initialize vocabulary with layout information only. + /// This creates a minimal vocabulary with correct layout but no actual + /// embeddings. Sufficient for generating training data and entity mappings. + /// + /// Note: Requires target-specific information from the first machine function + /// to determine the vocabulary layout (number of opcodes, register classes). + /// + /// FIXME: Use --target option to get target info directly, avoiding the need + /// to parse machine functions for pre-training operations. + bool initializeVocabularyForLayout(const Module &M) { + for (const Function &F : M) { + if (F.isDeclaration()) + continue; + + MachineFunction *MF = MMI.getMachineFunction(F); + if (!MF) + continue; + + const TargetInstrInfo &TII = *MF->getSubtarget().getInstrInfo(); + const TargetRegisterInfo &TRI = *MF->getSubtarget().getRegisterInfo(); + const MachineRegisterInfo &MRI = MF->getRegInfo(); + + auto VocabOrErr = + MIRVocabulary::createDummyVocabForTest(TII, TRI, MRI, 1); + if (!VocabOrErr) { + WithColor::error(errs(), ToolName) + << "Failed to create dummy vocabulary - " + << toString(VocabOrErr.takeError()) << "\n"; + return false; + } + Vocab = std::make_unique<MIRVocabulary>(std::move(*VocabOrErr)); + return true; + } + + WithColor::error(errs(), ToolName) + << "No machine functions found to initialize vocabulary\n"; + return false; + } + + /// Generate triplets for the module + /// Output format: MAX_RELATION=N header followed by relationships + void generateTriplets(const Module &M, raw_ostream &OS) const { + unsigned MaxRelation = MIRNextRelation; // Track maximum relation ID + std::string Relationships; + raw_string_ostream RelOS(Relationships); + + for (const Function &F : M) { + if (F.isDeclaration()) + continue; + + MachineFunction *MF = MMI.getMachineFunction(F); + if (!MF) { + WithColor::warning(errs(), ToolName) + << "No MachineFunction for " << F.getName() << "\n"; + continue; + } + + unsigned FuncMaxRelation = generateTriplets(*MF, RelOS); + MaxRelation = std::max(MaxRelation, FuncMaxRelation); + } + + RelOS.flush(); + + // Write metadata header followed by relationships + OS << "MAX_RELATION=" << MaxRelation << '\n'; + OS << Relationships; + } + + /// Generate triplets for a single machine function + /// Returns the maximum relation ID used in this function + unsigned generateTriplets(const MachineFunction &MF, raw_ostream &OS) const { + unsigned MaxRelation = MIRNextRelation; + unsigned PrevOpcode = 0; + bool HasPrevOpcode = false; + + if (!Vocab) { + WithColor::error(errs(), ToolName) + << "MIR Vocabulary must be initialized for triplet generation.\n"; + return MaxRelation; + } + + for (const MachineBasicBlock &MBB : MF) { + for (const MachineInstr &MI : MBB) { + // Skip debug instructions + if (MI.isDebugInstr()) + continue; + + // Get opcode entity ID + unsigned OpcodeID = Vocab->getEntityIDForOpcode(MI.getOpcode()); + + // Add "Next" relationship with previous instruction + if (HasPrevOpcode) { + OS << PrevOpcode << '\t' << OpcodeID << '\t' << MIRNextRelation + << '\n'; + LLVM_DEBUG(dbgs() + << Vocab->getStringKey(PrevOpcode) << '\t' + << Vocab->getStringKey(OpcodeID) << '\t' << "Next\n"); + } + + // Add "Arg" relationships for operands + unsigned ArgIndex = 0; + for (const MachineOperand &MO : MI.operands()) { + auto OperandID = Vocab->getEntityIDForMachineOperand(MO); + unsigned RelationID = MIRArgRelation + ArgIndex; + OS << OpcodeID << '\t' << OperandID << '\t' << RelationID << '\n'; + LLVM_DEBUG({ + std::string OperandStr = Vocab->getStringKey(OperandID); + dbgs() << Vocab->getStringKey(OpcodeID) << '\t' << OperandStr + << '\t' << "Arg" << ArgIndex << '\n'; + }); + + ++ArgIndex; + } + + // Update MaxRelation if there were operands + if (ArgIndex > 0) + MaxRelation = std::max(MaxRelation, MIRArgRelation + ArgIndex - 1); + + PrevOpcode = OpcodeID; + HasPrevOpcode = true; + } + } + + return MaxRelation; + } + + /// Generate entity mappings with vocabulary + void generateEntityMappings(raw_ostream &OS) const { + if (!Vocab) { + WithColor::error(errs(), ToolName) + << "Vocabulary must be initialized for entity mappings.\n"; + return; + } + + const unsigned EntityCount = Vocab->getCanonicalSize(); + OS << EntityCount << "\n"; + for (unsigned EntityID = 0; EntityID < EntityCount; ++EntityID) + OS << Vocab->getStringKey(EntityID) << '\t' << EntityID << '\n'; + } + /// Generate embeddings for all machine functions in the module void generateEmbeddings(const Module &M, raw_ostream &OS) const { if (!Vocab) { @@ -538,38 +695,67 @@ int main(int argc, char **argv) { return 1; } - // Create MIR2Vec tool and initialize vocabulary + // Create MIR2Vec tool MIR2VecTool Tool(*MMI); - if (!Tool.initializeVocabulary(*M)) - return 1; + // Initialize vocabulary. For triplet/entity generation, only layout is + // needed For embedding generation, the full vocabulary is needed. + // + // Note: Unlike IR2Vec, MIR2Vec vocabulary initialization requires + // target-specific information for generating the vocabulary layout. So, we + // always initialize the vocabulary in this case. + if (TripletsSubCmd || EntitiesSubCmd) { + if (!Tool.initializeVocabularyForLayout(*M)) { + WithColor::error(errs(), ToolName) + << "Failed to initialize MIR2Vec vocabulary for layout.\n"; + return 1; + } + } else { + if (!Tool.initializeVocabulary(*M)) { + WithColor::error(errs(), ToolName) + << "Failed to initialize MIR2Vec vocabulary.\n"; + return 1; + } + } + assert(Tool.getVocabulary() && + "MIR2Vec vocabulary should be initialized at this point"); LLVM_DEBUG(dbgs() << "MIR2Vec vocabulary loaded successfully.\n" << "Vocabulary dimension: " << Tool.getVocabulary()->getDimension() << "\n" << "Vocabulary size: " << Tool.getVocabulary()->getCanonicalSize() << "\n"); - // Generate embeddings based on subcommand - if (!FunctionName.empty()) { - // Process single function - Function *F = M->getFunction(FunctionName); - if (!F) { - WithColor::error(errs(), ToolName) - << "Function '" << FunctionName << "' not found\n"; - return 1; - } + // Handle subcommands + if (TripletsSubCmd) { + Tool.generateTriplets(*M, OS); + } else if (EntitiesSubCmd) { + Tool.generateEntityMappings(OS); + } else if (EmbeddingsSubCmd) { + if (!FunctionName.empty()) { + // Process single function + Function *F = M->getFunction(FunctionName); + if (!F) { + WithColor::error(errs(), ToolName) + << "Function '" << FunctionName << "' not found\n"; + return 1; + } - MachineFunction *MF = MMI->getMachineFunction(*F); - if (!MF) { - WithColor::error(errs(), ToolName) - << "No MachineFunction for " << FunctionName << "\n"; - return 1; - } + MachineFunction *MF = MMI->getMachineFunction(*F); + if (!MF) { + WithColor::error(errs(), ToolName) + << "No MachineFunction for " << FunctionName << "\n"; + return 1; + } - Tool.generateEmbeddings(*MF, OS); + Tool.generateEmbeddings(*MF, OS); + } else { + // Process all functions + Tool.generateEmbeddings(*M, OS); + } } else { - // Process all functions - Tool.generateEmbeddings(*M, OS); + WithColor::error(errs(), ToolName) + << "Please specify a subcommand: triplets, entities, or embeddings\n"; + return 1; } return 0; |
