aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Analysis
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Analysis')
-rw-r--r--llvm/lib/Analysis/IR2Vec.cpp180
1 files changed, 71 insertions, 109 deletions
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index 1794a60..85b5372 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -153,11 +153,6 @@ void Embedding::print(raw_ostream &OS) const {
// Embedder and its subclasses
//===----------------------------------------------------------------------===//
-Embedder::Embedder(const Function &F, const Vocabulary &Vocab)
- : F(F), Vocab(Vocab), Dimension(Vocab.getDimension()),
- OpcWeight(::OpcWeight), TypeWeight(::TypeWeight), ArgWeight(::ArgWeight),
- FuncVector(Embedding(Dimension)) {}
-
std::unique_ptr<Embedder> Embedder::create(IR2VecKind Mode, const Function &F,
const Vocabulary &Vocab) {
switch (Mode) {
@@ -169,110 +164,85 @@ std::unique_ptr<Embedder> Embedder::create(IR2VecKind Mode, const Function &F,
return nullptr;
}
-const InstEmbeddingsMap &Embedder::getInstVecMap() const {
- if (InstVecMap.empty())
- computeEmbeddings();
- return InstVecMap;
-}
-
-const BBEmbeddingsMap &Embedder::getBBVecMap() const {
- if (BBVecMap.empty())
- computeEmbeddings();
- return BBVecMap;
-}
-
-const Embedding &Embedder::getBBVector(const BasicBlock &BB) const {
- auto It = BBVecMap.find(&BB);
- if (It != BBVecMap.end())
- return It->second;
- computeEmbeddings(BB);
- return BBVecMap[&BB];
-}
+Embedding Embedder::computeEmbeddings() const {
+ Embedding FuncVector(Dimension, 0.0);
-const Embedding &Embedder::getFunctionVector() const {
- // Currently, we always (re)compute the embeddings for the function.
- // This is cheaper than caching the vector.
- computeEmbeddings();
- return FuncVector;
-}
-
-void Embedder::computeEmbeddings() const {
if (F.isDeclaration())
- return;
-
- FuncVector = Embedding(Dimension, 0.0);
+ return FuncVector;
// Consider only the basic blocks that are reachable from entry
- for (const BasicBlock *BB : depth_first(&F)) {
- computeEmbeddings(*BB);
- FuncVector += BBVecMap[BB];
- }
+ for (const BasicBlock *BB : depth_first(&F))
+ FuncVector += computeEmbeddings(*BB);
+ return FuncVector;
}
-void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
+Embedding Embedder::computeEmbeddings(const BasicBlock &BB) const {
Embedding BBVector(Dimension, 0);
// We consider only the non-debug and non-pseudo instructions
- for (const auto &I : BB.instructionsWithoutDebug()) {
- Embedding ArgEmb(Dimension, 0);
- for (const auto &Op : I.operands())
- ArgEmb += Vocab[*Op];
- auto InstVector =
- Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
- if (const auto *IC = dyn_cast<CmpInst>(&I))
- InstVector += Vocab[IC->getPredicate()];
- InstVecMap[&I] = InstVector;
- BBVector += InstVector;
- }
- BBVecMap[&BB] = BBVector;
-}
-
-void FlowAwareEmbedder::computeEmbeddings(const BasicBlock &BB) const {
- Embedding BBVector(Dimension, 0);
+ for (const auto &I : BB.instructionsWithoutDebug())
+ BBVector += computeEmbeddings(I);
+ return BBVector;
+}
+
+Embedding SymbolicEmbedder::computeEmbeddings(const Instruction &I) const {
+ // Currently, we always (re)compute the embeddings for symbolic embedder.
+ // This is cheaper than caching the vectors.
+ Embedding ArgEmb(Dimension, 0);
+ for (const auto &Op : I.operands())
+ ArgEmb += Vocab[*Op];
+ auto InstVector =
+ Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
+ if (const auto *IC = dyn_cast<CmpInst>(&I))
+ InstVector += Vocab[IC->getPredicate()];
+ return InstVector;
+}
+
+Embedding FlowAwareEmbedder::computeEmbeddings(const Instruction &I) const {
+ // If we have already computed the embedding for this instruction, return it
+ auto It = InstVecMap.find(&I);
+ if (It != InstVecMap.end())
+ return It->second;
- // We consider only the non-debug and non-pseudo instructions
- for (const auto &I : BB.instructionsWithoutDebug()) {
- // TODO: Handle call instructions differently.
- // For now, we treat them like other instructions
- Embedding ArgEmb(Dimension, 0);
- for (const auto &Op : I.operands()) {
- // If the operand is defined elsewhere, we use its embedding
- if (const auto *DefInst = dyn_cast<Instruction>(Op)) {
- auto DefIt = InstVecMap.find(DefInst);
- // Fixme (#159171): Ideally we should never miss an instruction
- // embedding here.
- // But when we have cyclic dependencies (e.g., phi
- // nodes), we might miss the embedding. In such cases, we fall back to
- // using the vocabulary embedding. This can be fixed by iterating to a
- // fixed-point, or by using a simple solver for the set of simultaneous
- // equations.
- // Another case when we might miss an instruction embedding is when
- // the operand instruction is in a different basic block that has not
- // been processed yet. This can be fixed by processing the basic blocks
- // in a topological order.
- if (DefIt != InstVecMap.end())
- ArgEmb += DefIt->second;
- else
- ArgEmb += Vocab[*Op];
- }
- // If the operand is not defined by an instruction, we use the vocabulary
- else {
- LLVM_DEBUG(errs() << "Using embedding from vocabulary for operand: "
- << *Op << "=" << Vocab[*Op][0] << "\n");
+ // TODO: Handle call instructions differently.
+ // For now, we treat them like other instructions
+ Embedding ArgEmb(Dimension, 0);
+ for (const auto &Op : I.operands()) {
+ // If the operand is defined elsewhere, we use its embedding
+ if (const auto *DefInst = dyn_cast<Instruction>(Op)) {
+ auto DefIt = InstVecMap.find(DefInst);
+ // Fixme (#159171): Ideally we should never miss an instruction
+ // embedding here.
+ // But when we have cyclic dependencies (e.g., phi
+ // nodes), we might miss the embedding. In such cases, we fall back to
+ // using the vocabulary embedding. This can be fixed by iterating to a
+ // fixed-point, or by using a simple solver for the set of simultaneous
+ // equations.
+ // Another case when we might miss an instruction embedding is when
+ // the operand instruction is in a different basic block that has not
+ // been processed yet. This can be fixed by processing the basic blocks
+ // in a topological order.
+ if (DefIt != InstVecMap.end())
+ ArgEmb += DefIt->second;
+ else
ArgEmb += Vocab[*Op];
- }
}
- // Create the instruction vector by combining opcode, type, and arguments
- // embeddings
- auto InstVector =
- Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
- // Add compare predicate embedding as an additional operand if applicable
- if (const auto *IC = dyn_cast<CmpInst>(&I))
- InstVector += Vocab[IC->getPredicate()];
- InstVecMap[&I] = InstVector;
- BBVector += InstVector;
+ // If the operand is not defined by an instruction, we use the
+ // vocabulary
+ else {
+ LLVM_DEBUG(errs() << "Using embedding from vocabulary for operand: "
+ << *Op << "=" << Vocab[*Op][0] << "\n");
+ ArgEmb += Vocab[*Op];
+ }
}
- BBVecMap[&BB] = BBVector;
+ // Create the instruction vector by combining opcode, type, and arguments
+ // embeddings
+ auto InstVector =
+ Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
+ if (const auto *IC = dyn_cast<CmpInst>(&I))
+ InstVector += Vocab[IC->getPredicate()];
+ InstVecMap[&I] = InstVector;
+ return InstVector;
}
// ==----------------------------------------------------------------------===//
@@ -695,25 +665,17 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M,
Emb->getFunctionVector().print(OS);
OS << "Basic block vectors:\n";
- const auto &BBMap = Emb->getBBVecMap();
for (const BasicBlock &BB : F) {
- auto It = BBMap.find(&BB);
- if (It != BBMap.end()) {
- OS << "Basic block: " << BB.getName() << ":\n";
- It->second.print(OS);
- }
+ OS << "Basic block: " << BB.getName() << ":\n";
+ Emb->getBBVector(BB).print(OS);
}
OS << "Instruction vectors:\n";
- const auto &InstMap = Emb->getInstVecMap();
for (const BasicBlock &BB : F) {
for (const Instruction &I : BB) {
- auto It = InstMap.find(&I);
- if (It != InstMap.end()) {
- OS << "Instruction: ";
- I.print(OS);
- It->second.print(OS);
- }
+ OS << "Instruction: ";
+ I.print(OS);
+ Emb->getInstVector(I).print(OS);
}
}
}