diff options
author | S. VenkataKeerthy <31350914+svkeerthy@users.noreply.github.com> | 2025-06-13 10:43:22 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-06-13 10:43:22 -0700 |
commit | 09c54c2e9e044fa0857831e6ce1bf77c8ce16ecc (patch) | |
tree | 152e52509f30a78bfd40ee0dea1732198479a948 /llvm/lib | |
parent | ecdb549e6de60b3211cfa860eec498270e3980f1 (diff) | |
download | llvm-09c54c2e9e044fa0857831e6ce1bf77c8ce16ecc.zip llvm-09c54c2e9e044fa0857831e6ce1bf77c8ce16ecc.tar.gz llvm-09c54c2e9e044fa0857831e6ce1bf77c8ce16ecc.tar.bz2 |
[IR2Vec] Minor vocab changes and exposing weights (#143200)
This PR changes some asserts in Vocab to hard checks that emit error and exposes flags and constructor to help in unit tests.
(Tracking issue - #141817)
Diffstat (limited to 'llvm/lib')
-rw-r--r-- | llvm/lib/Analysis/IR2Vec.cpp | 82 |
1 files changed, 51 insertions, 31 deletions
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp index 25ce35d..0f7303c 100644 --- a/llvm/lib/Analysis/IR2Vec.cpp +++ b/llvm/lib/Analysis/IR2Vec.cpp @@ -16,13 +16,11 @@ #include "llvm/ADT/Statistic.h" #include "llvm/IR/Module.h" #include "llvm/IR/PassManager.h" -#include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/Errc.h" #include "llvm/Support/Error.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/Format.h" -#include "llvm/Support/JSON.h" #include "llvm/Support/MemoryBuffer.h" using namespace llvm; @@ -33,6 +31,8 @@ using namespace ir2vec; STATISTIC(VocabMissCounter, "Number of lookups to entites not present in the vocabulary"); +namespace llvm { +namespace ir2vec { static cl::OptionCategory IR2VecCategory("IR2Vec Options"); // FIXME: Use a default vocab when not specified @@ -40,18 +40,17 @@ static cl::opt<std::string> VocabFile("ir2vec-vocab-path", cl::Optional, cl::desc("Path to the vocabulary file for IR2Vec"), cl::init(""), cl::cat(IR2VecCategory)); -static cl::opt<float> OpcWeight("ir2vec-opc-weight", cl::Optional, - cl::init(1.0), - cl::desc("Weight for opcode embeddings"), - cl::cat(IR2VecCategory)); -static cl::opt<float> TypeWeight("ir2vec-type-weight", cl::Optional, - cl::init(0.5), - cl::desc("Weight for type embeddings"), - cl::cat(IR2VecCategory)); -static cl::opt<float> ArgWeight("ir2vec-arg-weight", cl::Optional, - cl::init(0.2), - cl::desc("Weight for argument embeddings"), - cl::cat(IR2VecCategory)); +cl::opt<float> OpcWeight("ir2vec-opc-weight", cl::Optional, cl::init(1.0), + cl::desc("Weight for opcode embeddings"), + cl::cat(IR2VecCategory)); +cl::opt<float> TypeWeight("ir2vec-type-weight", cl::Optional, cl::init(0.5), + cl::desc("Weight for type embeddings"), + cl::cat(IR2VecCategory)); +cl::opt<float> ArgWeight("ir2vec-arg-weight", cl::Optional, cl::init(0.2), + cl::desc("Weight for argument embeddings"), + cl::cat(IR2VecCategory)); +} // namespace ir2vec +} // namespace llvm AnalysisKey IR2VecVocabAnalysis::Key; @@ -251,9 +250,9 @@ bool IR2VecVocabResult::invalidate( // by auto-generating a default vocabulary during the build time. Error IR2VecVocabAnalysis::readVocabulary() { auto BufOrError = MemoryBuffer::getFileOrSTDIN(VocabFile, /*IsText=*/true); - if (!BufOrError) { + if (!BufOrError) return createFileError(VocabFile, BufOrError.getError()); - } + auto Content = BufOrError.get()->getBuffer(); json::Path::Root Path(""); Expected<json::Value> ParsedVocabValue = json::parse(Content); @@ -261,39 +260,60 @@ Error IR2VecVocabAnalysis::readVocabulary() { return ParsedVocabValue.takeError(); bool Res = json::fromJSON(*ParsedVocabValue, Vocabulary, Path); - if (!Res) { + if (!Res) return createStringError(errc::illegal_byte_sequence, "Unable to parse the vocabulary"); - } - assert(Vocabulary.size() > 0 && "Vocabulary is empty"); + + if (Vocabulary.empty()) + return createStringError(errc::illegal_byte_sequence, + "Vocabulary is empty"); unsigned Dim = Vocabulary.begin()->second.size(); - assert(Dim > 0 && "Dimension of vocabulary is zero"); - (void)Dim; - assert(std::all_of(Vocabulary.begin(), Vocabulary.end(), - [Dim](const std::pair<StringRef, Embedding> &Entry) { - return Entry.second.size() == Dim; - }) && - "All vectors in the vocabulary are not of the same dimension"); + if (Dim == 0) + return createStringError(errc::illegal_byte_sequence, + "Dimension of vocabulary is zero"); + + if (!std::all_of(Vocabulary.begin(), Vocabulary.end(), + [Dim](const std::pair<StringRef, Embedding> &Entry) { + return Entry.second.size() == Dim; + })) + return createStringError( + errc::illegal_byte_sequence, + "All vectors in the vocabulary are not of the same dimension"); + return Error::success(); } +IR2VecVocabAnalysis::IR2VecVocabAnalysis(const Vocab &Vocabulary) + : Vocabulary(Vocabulary) {} + +IR2VecVocabAnalysis::IR2VecVocabAnalysis(Vocab &&Vocabulary) + : Vocabulary(std::move(Vocabulary)) {} + +void IR2VecVocabAnalysis::emitError(Error Err, LLVMContext &Ctx) { + handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) { + Ctx.emitError("Error reading vocabulary: " + EI.message()); + }); +} + IR2VecVocabAnalysis::Result IR2VecVocabAnalysis::run(Module &M, ModuleAnalysisManager &AM) { auto Ctx = &M.getContext(); + // FIXME: Scale the vocabulary once. This would avoid scaling per use later. + // If vocabulary is already populated by the constructor, use it. + if (!Vocabulary.empty()) + return IR2VecVocabResult(std::move(Vocabulary)); + + // Otherwise, try to read from the vocabulary file. if (VocabFile.empty()) { // FIXME: Use default vocabulary Ctx->emitError("IR2Vec vocabulary file path not specified"); return IR2VecVocabResult(); // Return invalid result } if (auto Err = readVocabulary()) { - handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) { - Ctx->emitError("Error reading vocabulary: " + EI.message()); - }); + emitError(std::move(Err), *Ctx); return IR2VecVocabResult(); } - // FIXME: Scale the vocabulary here once. This would avoid scaling per use - // later. return IR2VecVocabResult(std::move(Vocabulary)); } |