aboutsummaryrefslogtreecommitdiff
path: root/llvm/tools/llvm-ir2vec
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/tools/llvm-ir2vec')
-rw-r--r--llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp244
1 files changed, 215 insertions, 29 deletions
diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
index a723d37..7402782 100644
--- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
+++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
@@ -19,12 +19,22 @@
/// Generates numeric triplets (head, tail, relation) for vocabulary
/// training. Output format: MAX_RELATION=N header followed by
/// head\ttail\trelation lines. Relations: 0=Type, 1=Next, 2+=Arg0,Arg1,...
-/// Usage: llvm-ir2vec triplets input.bc -o train2id.txt
+///
+/// For LLVM IR:
+/// llvm-ir2vec triplets input.bc -o train2id.txt
+///
+/// For Machine IR:
+/// llvm-ir2vec triplets -mode=mir input.mir -o train2id.txt
///
/// 2. Entity Mappings (entities):
/// Generates entity mappings for vocabulary training.
/// Output format: <total_entities> header followed by entity\tid lines.
-/// Usage: llvm-ir2vec entities input.bc -o entity2id.txt
+///
+/// For LLVM IR:
+/// llvm-ir2vec entities input.bc -o entity2id.txt
+///
+/// For Machine IR:
+/// llvm-ir2vec entities -mode=mir input.mir -o entity2id.txt
///
/// 3. Embedding Generation (embeddings):
/// Generates IR2Vec/MIR2Vec embeddings using a trained vocabulary.
@@ -67,6 +77,8 @@
#include "llvm/CodeGen/MIRParser/MIRParser.h"
#include "llvm/CodeGen/MachineFunction.h"
#include "llvm/CodeGen/MachineModuleInfo.h"
+#include "llvm/CodeGen/TargetInstrInfo.h"
+#include "llvm/CodeGen/TargetRegisterInfo.h"
#include "llvm/MC/TargetRegistry.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Support/WithColor.h"
@@ -106,11 +118,10 @@ static cl::SubCommand
"Generate embeddings using trained vocabulary");
// Common options
-static cl::opt<std::string>
- InputFilename(cl::Positional,
- cl::desc("<input bitcode file or '-' for stdin>"),
- cl::init("-"), cl::sub(TripletsSubCmd),
- cl::sub(EmbeddingsSubCmd), cl::cat(CommonCategory));
+static cl::opt<std::string> InputFilename(
+ cl::Positional, cl::desc("<input bitcode/MIR file or '-' for stdin>"),
+ cl::init("-"), cl::sub(TripletsSubCmd), cl::sub(EntitiesSubCmd),
+ cl::sub(EmbeddingsSubCmd), cl::cat(CommonCategory));
static cl::opt<std::string> OutputFilename("o", cl::desc("Output filename"),
cl::value_desc("filename"),
@@ -345,6 +356,12 @@ Error processModule(Module &M, raw_ostream &OS) {
namespace mir2vec {
+/// Relation types for MIR2Vec triplet generation
+enum MIRRelationType {
+ MIRNextRelation = 0, ///< Sequential instruction relationship
+ MIRArgRelation = 1 ///< Instruction to operand relationship (ArgRelation + N)
+};
+
/// Helper class for MIR2Vec embedding generation
class MIR2VecTool {
private:
@@ -354,7 +371,7 @@ private:
public:
explicit MIR2VecTool(MachineModuleInfo &MMI) : MMI(MMI) {}
- /// Initialize MIR2Vec vocabulary
+ /// Initialize MIR2Vec vocabulary from file (for embeddings generation)
bool initializeVocabulary(const Module &M) {
MIR2VecVocabProvider Provider(MMI);
auto VocabOrErr = Provider.getVocabulary(M);
@@ -368,6 +385,146 @@ public:
return true;
}
+ /// Initialize vocabulary with layout information only.
+ /// This creates a minimal vocabulary with correct layout but no actual
+ /// embeddings. Sufficient for generating training data and entity mappings.
+ ///
+ /// Note: Requires target-specific information from the first machine function
+ /// to determine the vocabulary layout (number of opcodes, register classes).
+ ///
+ /// FIXME: Use --target option to get target info directly, avoiding the need
+ /// to parse machine functions for pre-training operations.
+ bool initializeVocabularyForLayout(const Module &M) {
+ for (const Function &F : M) {
+ if (F.isDeclaration())
+ continue;
+
+ MachineFunction *MF = MMI.getMachineFunction(F);
+ if (!MF)
+ continue;
+
+ const TargetInstrInfo &TII = *MF->getSubtarget().getInstrInfo();
+ const TargetRegisterInfo &TRI = *MF->getSubtarget().getRegisterInfo();
+ const MachineRegisterInfo &MRI = MF->getRegInfo();
+
+ auto VocabOrErr =
+ MIRVocabulary::createDummyVocabForTest(TII, TRI, MRI, 1);
+ if (!VocabOrErr) {
+ WithColor::error(errs(), ToolName)
+ << "Failed to create dummy vocabulary - "
+ << toString(VocabOrErr.takeError()) << "\n";
+ return false;
+ }
+ Vocab = std::make_unique<MIRVocabulary>(std::move(*VocabOrErr));
+ return true;
+ }
+
+ WithColor::error(errs(), ToolName)
+ << "No machine functions found to initialize vocabulary\n";
+ return false;
+ }
+
+ /// Generate triplets for the module
+ /// Output format: MAX_RELATION=N header followed by relationships
+ void generateTriplets(const Module &M, raw_ostream &OS) const {
+ unsigned MaxRelation = MIRNextRelation; // Track maximum relation ID
+ std::string Relationships;
+ raw_string_ostream RelOS(Relationships);
+
+ for (const Function &F : M) {
+ if (F.isDeclaration())
+ continue;
+
+ MachineFunction *MF = MMI.getMachineFunction(F);
+ if (!MF) {
+ WithColor::warning(errs(), ToolName)
+ << "No MachineFunction for " << F.getName() << "\n";
+ continue;
+ }
+
+ unsigned FuncMaxRelation = generateTriplets(*MF, RelOS);
+ MaxRelation = std::max(MaxRelation, FuncMaxRelation);
+ }
+
+ RelOS.flush();
+
+ // Write metadata header followed by relationships
+ OS << "MAX_RELATION=" << MaxRelation << '\n';
+ OS << Relationships;
+ }
+
+ /// Generate triplets for a single machine function
+ /// Returns the maximum relation ID used in this function
+ unsigned generateTriplets(const MachineFunction &MF, raw_ostream &OS) const {
+ unsigned MaxRelation = MIRNextRelation;
+ unsigned PrevOpcode = 0;
+ bool HasPrevOpcode = false;
+
+ if (!Vocab) {
+ WithColor::error(errs(), ToolName)
+ << "MIR Vocabulary must be initialized for triplet generation.\n";
+ return MaxRelation;
+ }
+
+ for (const MachineBasicBlock &MBB : MF) {
+ for (const MachineInstr &MI : MBB) {
+ // Skip debug instructions
+ if (MI.isDebugInstr())
+ continue;
+
+ // Get opcode entity ID
+ unsigned OpcodeID = Vocab->getEntityIDForOpcode(MI.getOpcode());
+
+ // Add "Next" relationship with previous instruction
+ if (HasPrevOpcode) {
+ OS << PrevOpcode << '\t' << OpcodeID << '\t' << MIRNextRelation
+ << '\n';
+ LLVM_DEBUG(dbgs()
+ << Vocab->getStringKey(PrevOpcode) << '\t'
+ << Vocab->getStringKey(OpcodeID) << '\t' << "Next\n");
+ }
+
+ // Add "Arg" relationships for operands
+ unsigned ArgIndex = 0;
+ for (const MachineOperand &MO : MI.operands()) {
+ auto OperandID = Vocab->getEntityIDForMachineOperand(MO);
+ unsigned RelationID = MIRArgRelation + ArgIndex;
+ OS << OpcodeID << '\t' << OperandID << '\t' << RelationID << '\n';
+ LLVM_DEBUG({
+ std::string OperandStr = Vocab->getStringKey(OperandID);
+ dbgs() << Vocab->getStringKey(OpcodeID) << '\t' << OperandStr
+ << '\t' << "Arg" << ArgIndex << '\n';
+ });
+
+ ++ArgIndex;
+ }
+
+ // Update MaxRelation if there were operands
+ if (ArgIndex > 0)
+ MaxRelation = std::max(MaxRelation, MIRArgRelation + ArgIndex - 1);
+
+ PrevOpcode = OpcodeID;
+ HasPrevOpcode = true;
+ }
+ }
+
+ return MaxRelation;
+ }
+
+ /// Generate entity mappings with vocabulary
+ void generateEntityMappings(raw_ostream &OS) const {
+ if (!Vocab) {
+ WithColor::error(errs(), ToolName)
+ << "Vocabulary must be initialized for entity mappings.\n";
+ return;
+ }
+
+ const unsigned EntityCount = Vocab->getCanonicalSize();
+ OS << EntityCount << "\n";
+ for (unsigned EntityID = 0; EntityID < EntityCount; ++EntityID)
+ OS << Vocab->getStringKey(EntityID) << '\t' << EntityID << '\n';
+ }
+
/// Generate embeddings for all machine functions in the module
void generateEmbeddings(const Module &M, raw_ostream &OS) const {
if (!Vocab) {
@@ -538,38 +695,67 @@ int main(int argc, char **argv) {
return 1;
}
- // Create MIR2Vec tool and initialize vocabulary
+ // Create MIR2Vec tool
MIR2VecTool Tool(*MMI);
- if (!Tool.initializeVocabulary(*M))
- return 1;
+ // Initialize vocabulary. For triplet/entity generation, only layout is
+ // needed For embedding generation, the full vocabulary is needed.
+ //
+ // Note: Unlike IR2Vec, MIR2Vec vocabulary initialization requires
+ // target-specific information for generating the vocabulary layout. So, we
+ // always initialize the vocabulary in this case.
+ if (TripletsSubCmd || EntitiesSubCmd) {
+ if (!Tool.initializeVocabularyForLayout(*M)) {
+ WithColor::error(errs(), ToolName)
+ << "Failed to initialize MIR2Vec vocabulary for layout.\n";
+ return 1;
+ }
+ } else {
+ if (!Tool.initializeVocabulary(*M)) {
+ WithColor::error(errs(), ToolName)
+ << "Failed to initialize MIR2Vec vocabulary.\n";
+ return 1;
+ }
+ }
+ assert(Tool.getVocabulary() &&
+ "MIR2Vec vocabulary should be initialized at this point");
LLVM_DEBUG(dbgs() << "MIR2Vec vocabulary loaded successfully.\n"
<< "Vocabulary dimension: "
<< Tool.getVocabulary()->getDimension() << "\n"
<< "Vocabulary size: "
<< Tool.getVocabulary()->getCanonicalSize() << "\n");
- // Generate embeddings based on subcommand
- if (!FunctionName.empty()) {
- // Process single function
- Function *F = M->getFunction(FunctionName);
- if (!F) {
- WithColor::error(errs(), ToolName)
- << "Function '" << FunctionName << "' not found\n";
- return 1;
- }
+ // Handle subcommands
+ if (TripletsSubCmd) {
+ Tool.generateTriplets(*M, OS);
+ } else if (EntitiesSubCmd) {
+ Tool.generateEntityMappings(OS);
+ } else if (EmbeddingsSubCmd) {
+ if (!FunctionName.empty()) {
+ // Process single function
+ Function *F = M->getFunction(FunctionName);
+ if (!F) {
+ WithColor::error(errs(), ToolName)
+ << "Function '" << FunctionName << "' not found\n";
+ return 1;
+ }
- MachineFunction *MF = MMI->getMachineFunction(*F);
- if (!MF) {
- WithColor::error(errs(), ToolName)
- << "No MachineFunction for " << FunctionName << "\n";
- return 1;
- }
+ MachineFunction *MF = MMI->getMachineFunction(*F);
+ if (!MF) {
+ WithColor::error(errs(), ToolName)
+ << "No MachineFunction for " << FunctionName << "\n";
+ return 1;
+ }
- Tool.generateEmbeddings(*MF, OS);
+ Tool.generateEmbeddings(*MF, OS);
+ } else {
+ // Process all functions
+ Tool.generateEmbeddings(*M, OS);
+ }
} else {
- // Process all functions
- Tool.generateEmbeddings(*M, OS);
+ WithColor::error(errs(), ToolName)
+ << "Please specify a subcommand: triplets, entities, or embeddings\n";
+ return 1;
}
return 0;