diff options
Diffstat (limited to 'llvm/tools')
| -rw-r--r-- | llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp | 244 | ||||
| -rw-r--r-- | llvm/tools/llvm-lto2/llvm-lto2.cpp | 11 | ||||
| -rw-r--r-- | llvm/tools/llvm-profdata/llvm-profdata.cpp | 3 | 
3 files changed, 227 insertions, 31 deletions
| diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp index a723d37..7402782 100644 --- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp +++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp @@ -19,12 +19,22 @@  ///    Generates numeric triplets (head, tail, relation) for vocabulary  ///    training. Output format: MAX_RELATION=N header followed by  ///    head\ttail\trelation lines. Relations: 0=Type, 1=Next, 2+=Arg0,Arg1,... -///    Usage: llvm-ir2vec triplets input.bc -o train2id.txt +/// +///    For LLVM IR: +///      llvm-ir2vec triplets input.bc -o train2id.txt +/// +///    For Machine IR: +///      llvm-ir2vec triplets -mode=mir input.mir -o train2id.txt  ///  /// 2. Entity Mappings (entities):  ///    Generates entity mappings for vocabulary training.  ///    Output format: <total_entities> header followed by entity\tid lines. -///    Usage: llvm-ir2vec entities input.bc -o entity2id.txt +/// +///    For LLVM IR: +///      llvm-ir2vec entities input.bc -o entity2id.txt +/// +///    For Machine IR: +///      llvm-ir2vec entities -mode=mir input.mir -o entity2id.txt  ///  /// 3. Embedding Generation (embeddings):  ///    Generates IR2Vec/MIR2Vec embeddings using a trained vocabulary. @@ -67,6 +77,8 @@  #include "llvm/CodeGen/MIRParser/MIRParser.h"  #include "llvm/CodeGen/MachineFunction.h"  #include "llvm/CodeGen/MachineModuleInfo.h" +#include "llvm/CodeGen/TargetInstrInfo.h" +#include "llvm/CodeGen/TargetRegisterInfo.h"  #include "llvm/MC/TargetRegistry.h"  #include "llvm/Support/TargetSelect.h"  #include "llvm/Support/WithColor.h" @@ -106,11 +118,10 @@ static cl::SubCommand                       "Generate embeddings using trained vocabulary");  // Common options -static cl::opt<std::string> -    InputFilename(cl::Positional, -                  cl::desc("<input bitcode file or '-' for stdin>"), -                  cl::init("-"), cl::sub(TripletsSubCmd), -                  cl::sub(EmbeddingsSubCmd), cl::cat(CommonCategory)); +static cl::opt<std::string> InputFilename( +    cl::Positional, cl::desc("<input bitcode/MIR file or '-' for stdin>"), +    cl::init("-"), cl::sub(TripletsSubCmd), cl::sub(EntitiesSubCmd), +    cl::sub(EmbeddingsSubCmd), cl::cat(CommonCategory));  static cl::opt<std::string> OutputFilename("o", cl::desc("Output filename"),                                             cl::value_desc("filename"), @@ -345,6 +356,12 @@ Error processModule(Module &M, raw_ostream &OS) {  namespace mir2vec { +/// Relation types for MIR2Vec triplet generation +enum MIRRelationType { +  MIRNextRelation = 0, ///< Sequential instruction relationship +  MIRArgRelation = 1 ///< Instruction to operand relationship (ArgRelation + N) +}; +  /// Helper class for MIR2Vec embedding generation  class MIR2VecTool {  private: @@ -354,7 +371,7 @@ private:  public:    explicit MIR2VecTool(MachineModuleInfo &MMI) : MMI(MMI) {} -  /// Initialize MIR2Vec vocabulary +  /// Initialize MIR2Vec vocabulary from file (for embeddings generation)    bool initializeVocabulary(const Module &M) {      MIR2VecVocabProvider Provider(MMI);      auto VocabOrErr = Provider.getVocabulary(M); @@ -368,6 +385,146 @@ public:      return true;    } +  /// Initialize vocabulary with layout information only. +  /// This creates a minimal vocabulary with correct layout but no actual +  /// embeddings. Sufficient for generating training data and entity mappings. +  /// +  /// Note: Requires target-specific information from the first machine function +  /// to determine the vocabulary layout (number of opcodes, register classes). +  /// +  /// FIXME: Use --target option to get target info directly, avoiding the need +  /// to parse machine functions for pre-training operations. +  bool initializeVocabularyForLayout(const Module &M) { +    for (const Function &F : M) { +      if (F.isDeclaration()) +        continue; + +      MachineFunction *MF = MMI.getMachineFunction(F); +      if (!MF) +        continue; + +      const TargetInstrInfo &TII = *MF->getSubtarget().getInstrInfo(); +      const TargetRegisterInfo &TRI = *MF->getSubtarget().getRegisterInfo(); +      const MachineRegisterInfo &MRI = MF->getRegInfo(); + +      auto VocabOrErr = +          MIRVocabulary::createDummyVocabForTest(TII, TRI, MRI, 1); +      if (!VocabOrErr) { +        WithColor::error(errs(), ToolName) +            << "Failed to create dummy vocabulary - " +            << toString(VocabOrErr.takeError()) << "\n"; +        return false; +      } +      Vocab = std::make_unique<MIRVocabulary>(std::move(*VocabOrErr)); +      return true; +    } + +    WithColor::error(errs(), ToolName) +        << "No machine functions found to initialize vocabulary\n"; +    return false; +  } + +  /// Generate triplets for the module +  /// Output format: MAX_RELATION=N header followed by relationships +  void generateTriplets(const Module &M, raw_ostream &OS) const { +    unsigned MaxRelation = MIRNextRelation; // Track maximum relation ID +    std::string Relationships; +    raw_string_ostream RelOS(Relationships); + +    for (const Function &F : M) { +      if (F.isDeclaration()) +        continue; + +      MachineFunction *MF = MMI.getMachineFunction(F); +      if (!MF) { +        WithColor::warning(errs(), ToolName) +            << "No MachineFunction for " << F.getName() << "\n"; +        continue; +      } + +      unsigned FuncMaxRelation = generateTriplets(*MF, RelOS); +      MaxRelation = std::max(MaxRelation, FuncMaxRelation); +    } + +    RelOS.flush(); + +    // Write metadata header followed by relationships +    OS << "MAX_RELATION=" << MaxRelation << '\n'; +    OS << Relationships; +  } + +  /// Generate triplets for a single machine function +  /// Returns the maximum relation ID used in this function +  unsigned generateTriplets(const MachineFunction &MF, raw_ostream &OS) const { +    unsigned MaxRelation = MIRNextRelation; +    unsigned PrevOpcode = 0; +    bool HasPrevOpcode = false; + +    if (!Vocab) { +      WithColor::error(errs(), ToolName) +          << "MIR Vocabulary must be initialized for triplet generation.\n"; +      return MaxRelation; +    } + +    for (const MachineBasicBlock &MBB : MF) { +      for (const MachineInstr &MI : MBB) { +        // Skip debug instructions +        if (MI.isDebugInstr()) +          continue; + +        // Get opcode entity ID +        unsigned OpcodeID = Vocab->getEntityIDForOpcode(MI.getOpcode()); + +        // Add "Next" relationship with previous instruction +        if (HasPrevOpcode) { +          OS << PrevOpcode << '\t' << OpcodeID << '\t' << MIRNextRelation +             << '\n'; +          LLVM_DEBUG(dbgs() +                     << Vocab->getStringKey(PrevOpcode) << '\t' +                     << Vocab->getStringKey(OpcodeID) << '\t' << "Next\n"); +        } + +        // Add "Arg" relationships for operands +        unsigned ArgIndex = 0; +        for (const MachineOperand &MO : MI.operands()) { +          auto OperandID = Vocab->getEntityIDForMachineOperand(MO); +          unsigned RelationID = MIRArgRelation + ArgIndex; +          OS << OpcodeID << '\t' << OperandID << '\t' << RelationID << '\n'; +          LLVM_DEBUG({ +            std::string OperandStr = Vocab->getStringKey(OperandID); +            dbgs() << Vocab->getStringKey(OpcodeID) << '\t' << OperandStr +                   << '\t' << "Arg" << ArgIndex << '\n'; +          }); + +          ++ArgIndex; +        } + +        // Update MaxRelation if there were operands +        if (ArgIndex > 0) +          MaxRelation = std::max(MaxRelation, MIRArgRelation + ArgIndex - 1); + +        PrevOpcode = OpcodeID; +        HasPrevOpcode = true; +      } +    } + +    return MaxRelation; +  } + +  /// Generate entity mappings with vocabulary +  void generateEntityMappings(raw_ostream &OS) const { +    if (!Vocab) { +      WithColor::error(errs(), ToolName) +          << "Vocabulary must be initialized for entity mappings.\n"; +      return; +    } + +    const unsigned EntityCount = Vocab->getCanonicalSize(); +    OS << EntityCount << "\n"; +    for (unsigned EntityID = 0; EntityID < EntityCount; ++EntityID) +      OS << Vocab->getStringKey(EntityID) << '\t' << EntityID << '\n'; +  } +    /// Generate embeddings for all machine functions in the module    void generateEmbeddings(const Module &M, raw_ostream &OS) const {      if (!Vocab) { @@ -538,38 +695,67 @@ int main(int argc, char **argv) {        return 1;      } -    // Create MIR2Vec tool and initialize vocabulary +    // Create MIR2Vec tool      MIR2VecTool Tool(*MMI); -    if (!Tool.initializeVocabulary(*M)) -      return 1; +    // Initialize vocabulary. For triplet/entity generation, only layout is +    // needed For embedding generation, the full vocabulary is needed. +    // +    // Note: Unlike IR2Vec, MIR2Vec vocabulary initialization requires +    // target-specific information for generating the vocabulary layout. So, we +    // always initialize the vocabulary in this case. +    if (TripletsSubCmd || EntitiesSubCmd) { +      if (!Tool.initializeVocabularyForLayout(*M)) { +        WithColor::error(errs(), ToolName) +            << "Failed to initialize MIR2Vec vocabulary for layout.\n"; +        return 1; +      } +    } else { +      if (!Tool.initializeVocabulary(*M)) { +        WithColor::error(errs(), ToolName) +            << "Failed to initialize MIR2Vec vocabulary.\n"; +        return 1; +      } +    } +    assert(Tool.getVocabulary() && +           "MIR2Vec vocabulary should be initialized at this point");      LLVM_DEBUG(dbgs() << "MIR2Vec vocabulary loaded successfully.\n"                        << "Vocabulary dimension: "                        << Tool.getVocabulary()->getDimension() << "\n"                        << "Vocabulary size: "                        << Tool.getVocabulary()->getCanonicalSize() << "\n"); -    // Generate embeddings based on subcommand -    if (!FunctionName.empty()) { -      // Process single function -      Function *F = M->getFunction(FunctionName); -      if (!F) { -        WithColor::error(errs(), ToolName) -            << "Function '" << FunctionName << "' not found\n"; -        return 1; -      } +    // Handle subcommands +    if (TripletsSubCmd) { +      Tool.generateTriplets(*M, OS); +    } else if (EntitiesSubCmd) { +      Tool.generateEntityMappings(OS); +    } else if (EmbeddingsSubCmd) { +      if (!FunctionName.empty()) { +        // Process single function +        Function *F = M->getFunction(FunctionName); +        if (!F) { +          WithColor::error(errs(), ToolName) +              << "Function '" << FunctionName << "' not found\n"; +          return 1; +        } -      MachineFunction *MF = MMI->getMachineFunction(*F); -      if (!MF) { -        WithColor::error(errs(), ToolName) -            << "No MachineFunction for " << FunctionName << "\n"; -        return 1; -      } +        MachineFunction *MF = MMI->getMachineFunction(*F); +        if (!MF) { +          WithColor::error(errs(), ToolName) +              << "No MachineFunction for " << FunctionName << "\n"; +          return 1; +        } -      Tool.generateEmbeddings(*MF, OS); +        Tool.generateEmbeddings(*MF, OS); +      } else { +        // Process all functions +        Tool.generateEmbeddings(*M, OS); +      }      } else { -      // Process all functions -      Tool.generateEmbeddings(*M, OS); +      WithColor::error(errs(), ToolName) +          << "Please specify a subcommand: triplets, entities, or embeddings\n"; +      return 1;      }      return 0; diff --git a/llvm/tools/llvm-lto2/llvm-lto2.cpp b/llvm/tools/llvm-lto2/llvm-lto2.cpp index ff75bb5..399306f 100644 --- a/llvm/tools/llvm-lto2/llvm-lto2.cpp +++ b/llvm/tools/llvm-lto2/llvm-lto2.cpp @@ -111,6 +111,12 @@ static cl::opt<std::string> DTLTOCompiler(      "dtlto-compiler",      cl::desc("Compiler to use for DTLTO ThinLTO backend compilations.")); +static cl::list<std::string> DTLTOCompilerPrependArgs( +    "dtlto-compiler-prepend-arg", cl::CommaSeparated, +    cl::desc("Prepend arguments to pass to the remote compiler for backend " +             "compilations."), +    cl::value_desc("arg")); +  static cl::list<std::string> DTLTOCompilerArgs(      "dtlto-compiler-arg", cl::CommaSeparated,      cl::desc("Arguments to pass to the remote compiler for backend " @@ -371,6 +377,9 @@ static int run(int argc, char **argv) {                      "with -dtlto-distributor\n";    auto DTLTODistributorArgsSV = llvm::to_vector<0>(llvm::map_range(        DTLTODistributorArgs, [](const std::string &S) { return StringRef(S); })); +  auto DTLTOCompilerPrependArgsSV = llvm::to_vector<0>( +      llvm::map_range(DTLTOCompilerPrependArgs, +                      [](const std::string &S) { return StringRef(S); }));    auto DTLTOCompilerArgsSV = llvm::to_vector<0>(llvm::map_range(        DTLTOCompilerArgs, [](const std::string &S) { return StringRef(S); })); @@ -388,7 +397,7 @@ static int run(int argc, char **argv) {          llvm::heavyweight_hardware_concurrency(Threads),          /*OnWrite=*/{}, ThinLTOEmitIndexes, ThinLTOEmitImports, OutputFilename,          DTLTODistributor, DTLTODistributorArgsSV, DTLTOCompiler, -        DTLTOCompilerArgsSV, SaveTemps); +        DTLTOCompilerPrependArgsSV, DTLTOCompilerArgsSV, SaveTemps);    } else      Backend = createInProcessThinBackend(          llvm::heavyweight_hardware_concurrency(Threads), diff --git a/llvm/tools/llvm-profdata/llvm-profdata.cpp b/llvm/tools/llvm-profdata/llvm-profdata.cpp index 15ddb05..a356bcd 100644 --- a/llvm/tools/llvm-profdata/llvm-profdata.cpp +++ b/llvm/tools/llvm-profdata/llvm-profdata.cpp @@ -34,7 +34,7 @@  #include "llvm/Support/FileSystem.h"  #include "llvm/Support/Format.h"  #include "llvm/Support/FormattedStream.h" -#include "llvm/Support/LLVMDriver.h" +#include "llvm/Support/InitLLVM.h"  #include "llvm/Support/MD5.h"  #include "llvm/Support/MemoryBuffer.h"  #include "llvm/Support/Path.h" @@ -3465,6 +3465,7 @@ static int order_main() {  }  int main(int argc, const char *argv[]) { +  InitLLVM X(argc, argv);    StringRef ProgName(sys::path::filename(argv[0]));    if (argc < 2) { | 
