diff options
Diffstat (limited to 'llvm/lib/Analysis')
-rw-r--r-- | llvm/lib/Analysis/IR2Vec.cpp | 180 |
1 files changed, 71 insertions, 109 deletions
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp index 1794a60..85b5372 100644 --- a/llvm/lib/Analysis/IR2Vec.cpp +++ b/llvm/lib/Analysis/IR2Vec.cpp @@ -153,11 +153,6 @@ void Embedding::print(raw_ostream &OS) const { // Embedder and its subclasses //===----------------------------------------------------------------------===// -Embedder::Embedder(const Function &F, const Vocabulary &Vocab) - : F(F), Vocab(Vocab), Dimension(Vocab.getDimension()), - OpcWeight(::OpcWeight), TypeWeight(::TypeWeight), ArgWeight(::ArgWeight), - FuncVector(Embedding(Dimension)) {} - std::unique_ptr<Embedder> Embedder::create(IR2VecKind Mode, const Function &F, const Vocabulary &Vocab) { switch (Mode) { @@ -169,110 +164,85 @@ std::unique_ptr<Embedder> Embedder::create(IR2VecKind Mode, const Function &F, return nullptr; } -const InstEmbeddingsMap &Embedder::getInstVecMap() const { - if (InstVecMap.empty()) - computeEmbeddings(); - return InstVecMap; -} - -const BBEmbeddingsMap &Embedder::getBBVecMap() const { - if (BBVecMap.empty()) - computeEmbeddings(); - return BBVecMap; -} - -const Embedding &Embedder::getBBVector(const BasicBlock &BB) const { - auto It = BBVecMap.find(&BB); - if (It != BBVecMap.end()) - return It->second; - computeEmbeddings(BB); - return BBVecMap[&BB]; -} +Embedding Embedder::computeEmbeddings() const { + Embedding FuncVector(Dimension, 0.0); -const Embedding &Embedder::getFunctionVector() const { - // Currently, we always (re)compute the embeddings for the function. - // This is cheaper than caching the vector. - computeEmbeddings(); - return FuncVector; -} - -void Embedder::computeEmbeddings() const { if (F.isDeclaration()) - return; - - FuncVector = Embedding(Dimension, 0.0); + return FuncVector; // Consider only the basic blocks that are reachable from entry - for (const BasicBlock *BB : depth_first(&F)) { - computeEmbeddings(*BB); - FuncVector += BBVecMap[BB]; - } + for (const BasicBlock *BB : depth_first(&F)) + FuncVector += computeEmbeddings(*BB); + return FuncVector; } -void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const { +Embedding Embedder::computeEmbeddings(const BasicBlock &BB) const { Embedding BBVector(Dimension, 0); // We consider only the non-debug and non-pseudo instructions - for (const auto &I : BB.instructionsWithoutDebug()) { - Embedding ArgEmb(Dimension, 0); - for (const auto &Op : I.operands()) - ArgEmb += Vocab[*Op]; - auto InstVector = - Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb; - if (const auto *IC = dyn_cast<CmpInst>(&I)) - InstVector += Vocab[IC->getPredicate()]; - InstVecMap[&I] = InstVector; - BBVector += InstVector; - } - BBVecMap[&BB] = BBVector; -} - -void FlowAwareEmbedder::computeEmbeddings(const BasicBlock &BB) const { - Embedding BBVector(Dimension, 0); + for (const auto &I : BB.instructionsWithoutDebug()) + BBVector += computeEmbeddings(I); + return BBVector; +} + +Embedding SymbolicEmbedder::computeEmbeddings(const Instruction &I) const { + // Currently, we always (re)compute the embeddings for symbolic embedder. + // This is cheaper than caching the vectors. + Embedding ArgEmb(Dimension, 0); + for (const auto &Op : I.operands()) + ArgEmb += Vocab[*Op]; + auto InstVector = + Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb; + if (const auto *IC = dyn_cast<CmpInst>(&I)) + InstVector += Vocab[IC->getPredicate()]; + return InstVector; +} + +Embedding FlowAwareEmbedder::computeEmbeddings(const Instruction &I) const { + // If we have already computed the embedding for this instruction, return it + auto It = InstVecMap.find(&I); + if (It != InstVecMap.end()) + return It->second; - // We consider only the non-debug and non-pseudo instructions - for (const auto &I : BB.instructionsWithoutDebug()) { - // TODO: Handle call instructions differently. - // For now, we treat them like other instructions - Embedding ArgEmb(Dimension, 0); - for (const auto &Op : I.operands()) { - // If the operand is defined elsewhere, we use its embedding - if (const auto *DefInst = dyn_cast<Instruction>(Op)) { - auto DefIt = InstVecMap.find(DefInst); - // Fixme (#159171): Ideally we should never miss an instruction - // embedding here. - // But when we have cyclic dependencies (e.g., phi - // nodes), we might miss the embedding. In such cases, we fall back to - // using the vocabulary embedding. This can be fixed by iterating to a - // fixed-point, or by using a simple solver for the set of simultaneous - // equations. - // Another case when we might miss an instruction embedding is when - // the operand instruction is in a different basic block that has not - // been processed yet. This can be fixed by processing the basic blocks - // in a topological order. - if (DefIt != InstVecMap.end()) - ArgEmb += DefIt->second; - else - ArgEmb += Vocab[*Op]; - } - // If the operand is not defined by an instruction, we use the vocabulary - else { - LLVM_DEBUG(errs() << "Using embedding from vocabulary for operand: " - << *Op << "=" << Vocab[*Op][0] << "\n"); + // TODO: Handle call instructions differently. + // For now, we treat them like other instructions + Embedding ArgEmb(Dimension, 0); + for (const auto &Op : I.operands()) { + // If the operand is defined elsewhere, we use its embedding + if (const auto *DefInst = dyn_cast<Instruction>(Op)) { + auto DefIt = InstVecMap.find(DefInst); + // Fixme (#159171): Ideally we should never miss an instruction + // embedding here. + // But when we have cyclic dependencies (e.g., phi + // nodes), we might miss the embedding. In such cases, we fall back to + // using the vocabulary embedding. This can be fixed by iterating to a + // fixed-point, or by using a simple solver for the set of simultaneous + // equations. + // Another case when we might miss an instruction embedding is when + // the operand instruction is in a different basic block that has not + // been processed yet. This can be fixed by processing the basic blocks + // in a topological order. + if (DefIt != InstVecMap.end()) + ArgEmb += DefIt->second; + else ArgEmb += Vocab[*Op]; - } } - // Create the instruction vector by combining opcode, type, and arguments - // embeddings - auto InstVector = - Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb; - // Add compare predicate embedding as an additional operand if applicable - if (const auto *IC = dyn_cast<CmpInst>(&I)) - InstVector += Vocab[IC->getPredicate()]; - InstVecMap[&I] = InstVector; - BBVector += InstVector; + // If the operand is not defined by an instruction, we use the + // vocabulary + else { + LLVM_DEBUG(errs() << "Using embedding from vocabulary for operand: " + << *Op << "=" << Vocab[*Op][0] << "\n"); + ArgEmb += Vocab[*Op]; + } } - BBVecMap[&BB] = BBVector; + // Create the instruction vector by combining opcode, type, and arguments + // embeddings + auto InstVector = + Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb; + if (const auto *IC = dyn_cast<CmpInst>(&I)) + InstVector += Vocab[IC->getPredicate()]; + InstVecMap[&I] = InstVector; + return InstVector; } // ==----------------------------------------------------------------------===// @@ -695,25 +665,17 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M, Emb->getFunctionVector().print(OS); OS << "Basic block vectors:\n"; - const auto &BBMap = Emb->getBBVecMap(); for (const BasicBlock &BB : F) { - auto It = BBMap.find(&BB); - if (It != BBMap.end()) { - OS << "Basic block: " << BB.getName() << ":\n"; - It->second.print(OS); - } + OS << "Basic block: " << BB.getName() << ":\n"; + Emb->getBBVector(BB).print(OS); } OS << "Instruction vectors:\n"; - const auto &InstMap = Emb->getInstVecMap(); for (const BasicBlock &BB : F) { for (const Instruction &I : BB) { - auto It = InstMap.find(&I); - if (It != InstMap.end()) { - OS << "Instruction: "; - I.print(OS); - It->second.print(OS); - } + OS << "Instruction: "; + I.print(OS); + Emb->getInstVector(I).print(OS); } } } |