aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/CodeGen/MIR2Vec.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/CodeGen/MIR2Vec.cpp')
-rw-r--r--llvm/lib/CodeGen/MIR2Vec.cpp166
1 files changed, 159 insertions, 7 deletions
diff --git a/llvm/lib/CodeGen/MIR2Vec.cpp b/llvm/lib/CodeGen/MIR2Vec.cpp
index 5c78d98..99be1fc0 100644
--- a/llvm/lib/CodeGen/MIR2Vec.cpp
+++ b/llvm/lib/CodeGen/MIR2Vec.cpp
@@ -12,6 +12,7 @@
//===----------------------------------------------------------------------===//
#include "llvm/CodeGen/MIR2Vec.h"
+#include "llvm/ADT/DepthFirstIterator.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/CodeGen/TargetInstrInfo.h"
#include "llvm/IR/Module.h"
@@ -29,20 +30,30 @@ using namespace mir2vec;
STATISTIC(MIRVocabMissCounter,
"Number of lookups to MIR entities not present in the vocabulary");
-cl::OptionCategory llvm::mir2vec::MIR2VecCategory("MIR2Vec Options");
+namespace llvm {
+namespace mir2vec {
+cl::OptionCategory MIR2VecCategory("MIR2Vec Options");
// FIXME: Use a default vocab when not specified
static cl::opt<std::string>
VocabFile("mir2vec-vocab-path", cl::Optional,
cl::desc("Path to the vocabulary file for MIR2Vec"), cl::init(""),
cl::cat(MIR2VecCategory));
-cl::opt<float>
- llvm::mir2vec::OpcWeight("mir2vec-opc-weight", cl::Optional, cl::init(1.0),
- cl::desc("Weight for machine opcode embeddings"),
- cl::cat(MIR2VecCategory));
+cl::opt<float> OpcWeight("mir2vec-opc-weight", cl::Optional, cl::init(1.0),
+ cl::desc("Weight for machine opcode embeddings"),
+ cl::cat(MIR2VecCategory));
+cl::opt<MIR2VecKind> MIR2VecEmbeddingKind(
+ "mir2vec-kind", cl::Optional,
+ cl::values(clEnumValN(MIR2VecKind::Symbolic, "symbolic",
+ "Generate symbolic embeddings for MIR")),
+ cl::init(MIR2VecKind::Symbolic), cl::desc("MIR2Vec embedding kind"),
+ cl::cat(MIR2VecCategory));
+
+} // namespace mir2vec
+} // namespace llvm
//===----------------------------------------------------------------------===//
-// Vocabulary Implementation
+// Vocabulary
//===----------------------------------------------------------------------===//
MIRVocabulary::MIRVocabulary(VocabMap &&OpcodeEntries,
@@ -188,6 +199,28 @@ void MIRVocabulary::buildCanonicalOpcodeMapping() {
<< " unique base opcodes\n");
}
+Expected<MIRVocabulary>
+MIRVocabulary::createDummyVocabForTest(const TargetInstrInfo &TII,
+ unsigned Dim) {
+ assert(Dim > 0 && "Dimension must be greater than zero");
+
+ float DummyVal = 0.1f;
+
+ // Create dummy embeddings for all canonical opcode names
+ VocabMap DummyVocabMap;
+ for (unsigned Opcode = 0; Opcode < TII.getNumOpcodes(); ++Opcode) {
+ std::string BaseOpcode = extractBaseOpcodeName(TII.getName(Opcode));
+ if (DummyVocabMap.count(BaseOpcode) == 0) {
+ // Only add if not already present
+ DummyVocabMap[BaseOpcode] = Embedding(Dim, DummyVal);
+ DummyVal += 0.1f;
+ }
+ }
+
+ // Create and return vocabulary with dummy embeddings
+ return MIRVocabulary::create(std::move(DummyVocabMap), TII);
+}
+
//===----------------------------------------------------------------------===//
// MIR2VecVocabLegacyAnalysis Implementation
//===----------------------------------------------------------------------===//
@@ -258,7 +291,73 @@ MIR2VecVocabLegacyAnalysis::getMIR2VecVocabulary(const Module &M) {
}
//===----------------------------------------------------------------------===//
-// Printer Passes Implementation
+// MIREmbedder and its subclasses
+//===----------------------------------------------------------------------===//
+
+std::unique_ptr<MIREmbedder> MIREmbedder::create(MIR2VecKind Mode,
+ const MachineFunction &MF,
+ const MIRVocabulary &Vocab) {
+ switch (Mode) {
+ case MIR2VecKind::Symbolic:
+ return std::make_unique<SymbolicMIREmbedder>(MF, Vocab);
+ }
+ return nullptr;
+}
+
+Embedding MIREmbedder::computeEmbeddings(const MachineBasicBlock &MBB) const {
+ Embedding MBBVector(Dimension, 0);
+
+ // Get instruction info for opcode name resolution
+ const auto &Subtarget = MF.getSubtarget();
+ const auto *TII = Subtarget.getInstrInfo();
+ if (!TII) {
+ MF.getFunction().getContext().emitError(
+ "MIR2Vec: No TargetInstrInfo available; cannot compute embeddings");
+ return MBBVector;
+ }
+
+ // Process each machine instruction in the basic block
+ for (const auto &MI : MBB) {
+ // Skip debug instructions and other metadata
+ if (MI.isDebugInstr())
+ continue;
+ MBBVector += computeEmbeddings(MI);
+ }
+
+ return MBBVector;
+}
+
+Embedding MIREmbedder::computeEmbeddings() const {
+ Embedding MFuncVector(Dimension, 0);
+
+ // Consider all reachable machine basic blocks in the function
+ for (const auto *MBB : depth_first(&MF))
+ MFuncVector += computeEmbeddings(*MBB);
+ return MFuncVector;
+}
+
+SymbolicMIREmbedder::SymbolicMIREmbedder(const MachineFunction &MF,
+ const MIRVocabulary &Vocab)
+ : MIREmbedder(MF, Vocab) {}
+
+std::unique_ptr<SymbolicMIREmbedder>
+SymbolicMIREmbedder::create(const MachineFunction &MF,
+ const MIRVocabulary &Vocab) {
+ return std::make_unique<SymbolicMIREmbedder>(MF, Vocab);
+}
+
+Embedding SymbolicMIREmbedder::computeEmbeddings(const MachineInstr &MI) const {
+ // Skip debug instructions and other metadata
+ if (MI.isDebugInstr())
+ return Embedding(Dimension, 0);
+
+ // Todo: Add operand/argument contributions
+
+ return Vocab[MI.getOpcode()];
+}
+
+//===----------------------------------------------------------------------===//
+// Printer Passes
//===----------------------------------------------------------------------===//
char MIR2VecVocabPrinterLegacyPass::ID = 0;
@@ -297,3 +396,56 @@ MachineFunctionPass *
llvm::createMIR2VecVocabPrinterLegacyPass(raw_ostream &OS) {
return new MIR2VecVocabPrinterLegacyPass(OS);
}
+
+char MIR2VecPrinterLegacyPass::ID = 0;
+INITIALIZE_PASS_BEGIN(MIR2VecPrinterLegacyPass, "print-mir2vec",
+ "MIR2Vec Embedder Printer Pass", false, true)
+INITIALIZE_PASS_DEPENDENCY(MIR2VecVocabLegacyAnalysis)
+INITIALIZE_PASS_DEPENDENCY(MachineModuleInfoWrapperPass)
+INITIALIZE_PASS_END(MIR2VecPrinterLegacyPass, "print-mir2vec",
+ "MIR2Vec Embedder Printer Pass", false, true)
+
+bool MIR2VecPrinterLegacyPass::runOnMachineFunction(MachineFunction &MF) {
+ auto &Analysis = getAnalysis<MIR2VecVocabLegacyAnalysis>();
+ auto VocabOrErr =
+ Analysis.getMIR2VecVocabulary(*MF.getFunction().getParent());
+ assert(VocabOrErr && "Failed to get MIR2Vec vocabulary");
+ auto &MIRVocab = *VocabOrErr;
+
+ auto Emb = mir2vec::MIREmbedder::create(MIR2VecEmbeddingKind, MF, MIRVocab);
+ if (!Emb) {
+ OS << "Error creating MIR2Vec embeddings for function " << MF.getName()
+ << "\n";
+ return false;
+ }
+
+ OS << "MIR2Vec embeddings for machine function " << MF.getName() << ":\n";
+ OS << "Machine Function vector: ";
+ Emb->getMFunctionVector().print(OS);
+
+ OS << "Machine basic block vectors:\n";
+ for (const MachineBasicBlock &MBB : MF) {
+ OS << "Machine basic block: " << MBB.getFullName() << ":\n";
+ Emb->getMBBVector(MBB).print(OS);
+ }
+
+ OS << "Machine instruction vectors:\n";
+ for (const MachineBasicBlock &MBB : MF) {
+ for (const MachineInstr &MI : MBB) {
+ // Skip debug instructions as they are not
+ // embedded
+ if (MI.isDebugInstr())
+ continue;
+
+ OS << "Machine instruction: ";
+ MI.print(OS);
+ Emb->getMInstVector(MI).print(OS);
+ }
+ }
+
+ return false;
+}
+
+MachineFunctionPass *llvm::createMIR2VecPrinterLegacyPass(raw_ostream &OS) {
+ return new MIR2VecPrinterLegacyPass(OS);
+}