aboutsummaryrefslogtreecommitdiff
path: root/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp')
-rw-r--r--llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp65
1 files changed, 28 insertions, 37 deletions
diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
index f6ed94b..8e17a4a 100644
--- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
+++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
@@ -9,22 +9,22 @@
/// \file
/// This file implements the IR2Vec embedding generation tool.
///
-/// This tool provides three main modes:
+/// This tool provides three main subcommands:
///
-/// 1. Triplet Generation Mode (--mode=triplets):
+/// 1. Triplet Generation (triplets):
/// Generates numeric triplets (head, tail, relation) for vocabulary
/// training. Output format: MAX_RELATION=N header followed by
/// head\ttail\trelation lines. Relations: 0=Type, 1=Next, 2+=Arg0,Arg1,...
-/// Usage: llvm-ir2vec --mode=triplets input.bc -o train2id.txt
+/// Usage: llvm-ir2vec triplets input.bc -o train2id.txt
///
-/// 2. Entities Generation Mode (--mode=entities):
+/// 2. Entity Mappings (entities):
/// Generates entity mappings for vocabulary training.
/// Output format: <total_entities> header followed by entity\tid lines.
-/// Usage: llvm-ir2vec --mode=entities input.bc -o entity2id.txt
+/// Usage: llvm-ir2vec entities input.bc -o entity2id.txt
///
-/// 3. Embedding Generation Mode (--mode=embeddings):
+/// 3. Embedding Generation (embeddings):
/// Generates IR2Vec embeddings using a trained vocabulary.
-/// Usage: llvm-ir2vec --mode=embeddings --ir2vec-vocab-path=vocab.json
+/// Usage: llvm-ir2vec embeddings --ir2vec-vocab-path=vocab.json
/// --level=func input.bc -o embeddings.txt Levels: --level=inst
/// (instructions), --level=bb (basic blocks), --level=func (functions)
/// (See IR2Vec.cpp for more embedding generation options)
@@ -55,36 +55,33 @@ namespace ir2vec {
static cl::OptionCategory IR2VecToolCategory("IR2Vec Tool Options");
+// Subcommands
+static cl::SubCommand
+ TripletsSubCmd("triplets", "Generate triplets for vocabulary training");
+static cl::SubCommand
+ EntitiesSubCmd("entities",
+ "Generate entity mappings for vocabulary training");
+static cl::SubCommand
+ EmbeddingsSubCmd("embeddings",
+ "Generate embeddings using trained vocabulary");
+
+// Common options
static cl::opt<std::string>
InputFilename(cl::Positional,
cl::desc("<input bitcode file or '-' for stdin>"),
- cl::init("-"), cl::cat(IR2VecToolCategory));
+ cl::init("-"), cl::sub(TripletsSubCmd),
+ cl::sub(EmbeddingsSubCmd), cl::cat(IR2VecToolCategory));
static cl::opt<std::string> OutputFilename("o", cl::desc("Output filename"),
cl::value_desc("filename"),
cl::init("-"),
cl::cat(IR2VecToolCategory));
-enum ToolMode {
- TripletMode, // Generate triplets for vocabulary training
- EntityMode, // Generate entity mappings for vocabulary training
- EmbeddingMode // Generate embeddings using trained vocabulary
-};
-
-static cl::opt<ToolMode> Mode(
- "mode", cl::desc("Tool operation mode:"),
- cl::values(clEnumValN(TripletMode, "triplets",
- "Generate triplets for vocabulary training"),
- clEnumValN(EntityMode, "entities",
- "Generate entity mappings for vocabulary training"),
- clEnumValN(EmbeddingMode, "embeddings",
- "Generate embeddings using trained vocabulary")),
- cl::init(EmbeddingMode), cl::cat(IR2VecToolCategory));
-
+// Embedding-specific options
static cl::opt<std::string>
FunctionName("function", cl::desc("Process specific function only"),
cl::value_desc("name"), cl::Optional, cl::init(""),
- cl::cat(IR2VecToolCategory));
+ cl::sub(EmbeddingsSubCmd), cl::cat(IR2VecToolCategory));
enum EmbeddingLevel {
InstructionLevel, // Generate instruction-level embeddings
@@ -93,14 +90,15 @@ enum EmbeddingLevel {
};
static cl::opt<EmbeddingLevel>
- Level("level", cl::desc("Embedding generation level (for embedding mode):"),
+ Level("level", cl::desc("Embedding generation level:"),
cl::values(clEnumValN(InstructionLevel, "inst",
"Generate instruction-level embeddings"),
clEnumValN(BasicBlockLevel, "bb",
"Generate basic block-level embeddings"),
clEnumValN(FunctionLevel, "func",
"Generate function-level embeddings")),
- cl::init(FunctionLevel), cl::cat(IR2VecToolCategory));
+ cl::init(FunctionLevel), cl::sub(EmbeddingsSubCmd),
+ cl::cat(IR2VecToolCategory));
namespace {
@@ -291,7 +289,7 @@ public:
Error processModule(Module &M, raw_ostream &OS) {
IR2VecTool Tool(M);
- if (Mode == EmbeddingMode) {
+ if (EmbeddingsSubCmd) {
// Initialize vocabulary for embedding generation
// Note: Requires --ir2vec-vocab-path option to be set
auto VocabStatus = Tool.initializeVocabulary();
@@ -311,6 +309,7 @@ Error processModule(Module &M, raw_ostream &OS) {
Tool.generateEmbeddings(OS);
}
} else {
+ // Both triplets and entities use triplet generation
Tool.generateTriplets(OS);
}
return Error::success();
@@ -334,14 +333,6 @@ int main(int argc, char **argv) {
"See https://llvm.org/docs/CommandGuide/llvm-ir2vec.html for more "
"information.\n");
- // Validate command line options
- if (Mode != EmbeddingMode) {
- if (Level.getNumOccurrences() > 0)
- errs() << "Warning: --level option is ignored\n";
- if (FunctionName.getNumOccurrences() > 0)
- errs() << "Warning: --function option is ignored\n";
- }
-
std::error_code EC;
raw_fd_ostream OS(OutputFilename, EC);
if (EC) {
@@ -349,7 +340,7 @@ int main(int argc, char **argv) {
return 1;
}
- if (Mode == EntityMode) {
+ if (EntitiesSubCmd) {
// Just dump entity mappings without processing any IR
IR2VecTool::generateEntityMappings(OS);
return 0;