aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib
diff options
context:
space:
mode:
authorS. VenkataKeerthy <31350914+svkeerthy@users.noreply.github.com>2025-06-13 10:43:22 -0700
committerGitHub <noreply@github.com>2025-06-13 10:43:22 -0700
commit09c54c2e9e044fa0857831e6ce1bf77c8ce16ecc (patch)
tree152e52509f30a78bfd40ee0dea1732198479a948 /llvm/lib
parentecdb549e6de60b3211cfa860eec498270e3980f1 (diff)
downloadllvm-09c54c2e9e044fa0857831e6ce1bf77c8ce16ecc.zip
llvm-09c54c2e9e044fa0857831e6ce1bf77c8ce16ecc.tar.gz
llvm-09c54c2e9e044fa0857831e6ce1bf77c8ce16ecc.tar.bz2
[IR2Vec] Minor vocab changes and exposing weights (#143200)
This PR changes some asserts in Vocab to hard checks that emit error and exposes flags and constructor to help in unit tests. (Tracking issue - #141817)
Diffstat (limited to 'llvm/lib')
-rw-r--r--llvm/lib/Analysis/IR2Vec.cpp82
1 files changed, 51 insertions, 31 deletions
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index 25ce35d..0f7303c 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -16,13 +16,11 @@
#include "llvm/ADT/Statistic.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/PassManager.h"
-#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/Errc.h"
#include "llvm/Support/Error.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/Format.h"
-#include "llvm/Support/JSON.h"
#include "llvm/Support/MemoryBuffer.h"
using namespace llvm;
@@ -33,6 +31,8 @@ using namespace ir2vec;
STATISTIC(VocabMissCounter,
"Number of lookups to entites not present in the vocabulary");
+namespace llvm {
+namespace ir2vec {
static cl::OptionCategory IR2VecCategory("IR2Vec Options");
// FIXME: Use a default vocab when not specified
@@ -40,18 +40,17 @@ static cl::opt<std::string>
VocabFile("ir2vec-vocab-path", cl::Optional,
cl::desc("Path to the vocabulary file for IR2Vec"), cl::init(""),
cl::cat(IR2VecCategory));
-static cl::opt<float> OpcWeight("ir2vec-opc-weight", cl::Optional,
- cl::init(1.0),
- cl::desc("Weight for opcode embeddings"),
- cl::cat(IR2VecCategory));
-static cl::opt<float> TypeWeight("ir2vec-type-weight", cl::Optional,
- cl::init(0.5),
- cl::desc("Weight for type embeddings"),
- cl::cat(IR2VecCategory));
-static cl::opt<float> ArgWeight("ir2vec-arg-weight", cl::Optional,
- cl::init(0.2),
- cl::desc("Weight for argument embeddings"),
- cl::cat(IR2VecCategory));
+cl::opt<float> OpcWeight("ir2vec-opc-weight", cl::Optional, cl::init(1.0),
+ cl::desc("Weight for opcode embeddings"),
+ cl::cat(IR2VecCategory));
+cl::opt<float> TypeWeight("ir2vec-type-weight", cl::Optional, cl::init(0.5),
+ cl::desc("Weight for type embeddings"),
+ cl::cat(IR2VecCategory));
+cl::opt<float> ArgWeight("ir2vec-arg-weight", cl::Optional, cl::init(0.2),
+ cl::desc("Weight for argument embeddings"),
+ cl::cat(IR2VecCategory));
+} // namespace ir2vec
+} // namespace llvm
AnalysisKey IR2VecVocabAnalysis::Key;
@@ -251,9 +250,9 @@ bool IR2VecVocabResult::invalidate(
// by auto-generating a default vocabulary during the build time.
Error IR2VecVocabAnalysis::readVocabulary() {
auto BufOrError = MemoryBuffer::getFileOrSTDIN(VocabFile, /*IsText=*/true);
- if (!BufOrError) {
+ if (!BufOrError)
return createFileError(VocabFile, BufOrError.getError());
- }
+
auto Content = BufOrError.get()->getBuffer();
json::Path::Root Path("");
Expected<json::Value> ParsedVocabValue = json::parse(Content);
@@ -261,39 +260,60 @@ Error IR2VecVocabAnalysis::readVocabulary() {
return ParsedVocabValue.takeError();
bool Res = json::fromJSON(*ParsedVocabValue, Vocabulary, Path);
- if (!Res) {
+ if (!Res)
return createStringError(errc::illegal_byte_sequence,
"Unable to parse the vocabulary");
- }
- assert(Vocabulary.size() > 0 && "Vocabulary is empty");
+
+ if (Vocabulary.empty())
+ return createStringError(errc::illegal_byte_sequence,
+ "Vocabulary is empty");
unsigned Dim = Vocabulary.begin()->second.size();
- assert(Dim > 0 && "Dimension of vocabulary is zero");
- (void)Dim;
- assert(std::all_of(Vocabulary.begin(), Vocabulary.end(),
- [Dim](const std::pair<StringRef, Embedding> &Entry) {
- return Entry.second.size() == Dim;
- }) &&
- "All vectors in the vocabulary are not of the same dimension");
+ if (Dim == 0)
+ return createStringError(errc::illegal_byte_sequence,
+ "Dimension of vocabulary is zero");
+
+ if (!std::all_of(Vocabulary.begin(), Vocabulary.end(),
+ [Dim](const std::pair<StringRef, Embedding> &Entry) {
+ return Entry.second.size() == Dim;
+ }))
+ return createStringError(
+ errc::illegal_byte_sequence,
+ "All vectors in the vocabulary are not of the same dimension");
+
return Error::success();
}
+IR2VecVocabAnalysis::IR2VecVocabAnalysis(const Vocab &Vocabulary)
+ : Vocabulary(Vocabulary) {}
+
+IR2VecVocabAnalysis::IR2VecVocabAnalysis(Vocab &&Vocabulary)
+ : Vocabulary(std::move(Vocabulary)) {}
+
+void IR2VecVocabAnalysis::emitError(Error Err, LLVMContext &Ctx) {
+ handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
+ Ctx.emitError("Error reading vocabulary: " + EI.message());
+ });
+}
+
IR2VecVocabAnalysis::Result
IR2VecVocabAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
auto Ctx = &M.getContext();
+ // FIXME: Scale the vocabulary once. This would avoid scaling per use later.
+ // If vocabulary is already populated by the constructor, use it.
+ if (!Vocabulary.empty())
+ return IR2VecVocabResult(std::move(Vocabulary));
+
+ // Otherwise, try to read from the vocabulary file.
if (VocabFile.empty()) {
// FIXME: Use default vocabulary
Ctx->emitError("IR2Vec vocabulary file path not specified");
return IR2VecVocabResult(); // Return invalid result
}
if (auto Err = readVocabulary()) {
- handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
- Ctx->emitError("Error reading vocabulary: " + EI.message());
- });
+ emitError(std::move(Err), *Ctx);
return IR2VecVocabResult();
}
- // FIXME: Scale the vocabulary here once. This would avoid scaling per use
- // later.
return IR2VecVocabResult(std::move(Vocabulary));
}