diff options
Diffstat (limited to 'llvm/lib/Analysis/IR2Vec.cpp')
-rw-r--r-- | llvm/lib/Analysis/IR2Vec.cpp | 86 |
1 files changed, 43 insertions, 43 deletions
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp index af30422..295b6d3 100644 --- a/llvm/lib/Analysis/IR2Vec.cpp +++ b/llvm/lib/Analysis/IR2Vec.cpp @@ -330,6 +330,43 @@ bool VocabStorage::const_iterator::operator!=( return !(*this == Other); } +Error VocabStorage::parseVocabSection(StringRef Key, + const json::Value &ParsedVocabValue, + VocabMap &TargetVocab, unsigned &Dim) { + json::Path::Root Path(""); + const json::Object *RootObj = ParsedVocabValue.getAsObject(); + if (!RootObj) + return createStringError(errc::invalid_argument, + "JSON root is not an object"); + + const json::Value *SectionValue = RootObj->get(Key); + if (!SectionValue) + return createStringError(errc::invalid_argument, + "Missing '" + std::string(Key) + + "' section in vocabulary file"); + if (!json::fromJSON(*SectionValue, TargetVocab, Path)) + return createStringError(errc::illegal_byte_sequence, + "Unable to parse '" + std::string(Key) + + "' section from vocabulary"); + + Dim = TargetVocab.begin()->second.size(); + if (Dim == 0) + return createStringError(errc::illegal_byte_sequence, + "Dimension of '" + std::string(Key) + + "' section of the vocabulary is zero"); + + if (!std::all_of(TargetVocab.begin(), TargetVocab.end(), + [Dim](const std::pair<StringRef, Embedding> &Entry) { + return Entry.second.size() == Dim; + })) + return createStringError( + errc::illegal_byte_sequence, + "All vectors in the '" + std::string(Key) + + "' section of the vocabulary are not of the same dimension"); + + return Error::success(); +} + // ==----------------------------------------------------------------------===// // Vocabulary //===----------------------------------------------------------------------===// @@ -460,43 +497,6 @@ VocabStorage Vocabulary::createDummyVocabForTest(unsigned Dim) { // IR2VecVocabAnalysis //===----------------------------------------------------------------------===// -Error IR2VecVocabAnalysis::parseVocabSection( - StringRef Key, const json::Value &ParsedVocabValue, VocabMap &TargetVocab, - unsigned &Dim) { - json::Path::Root Path(""); - const json::Object *RootObj = ParsedVocabValue.getAsObject(); - if (!RootObj) - return createStringError(errc::invalid_argument, - "JSON root is not an object"); - - const json::Value *SectionValue = RootObj->get(Key); - if (!SectionValue) - return createStringError(errc::invalid_argument, - "Missing '" + std::string(Key) + - "' section in vocabulary file"); - if (!json::fromJSON(*SectionValue, TargetVocab, Path)) - return createStringError(errc::illegal_byte_sequence, - "Unable to parse '" + std::string(Key) + - "' section from vocabulary"); - - Dim = TargetVocab.begin()->second.size(); - if (Dim == 0) - return createStringError(errc::illegal_byte_sequence, - "Dimension of '" + std::string(Key) + - "' section of the vocabulary is zero"); - - if (!std::all_of(TargetVocab.begin(), TargetVocab.end(), - [Dim](const std::pair<StringRef, Embedding> &Entry) { - return Entry.second.size() == Dim; - })) - return createStringError( - errc::illegal_byte_sequence, - "All vectors in the '" + std::string(Key) + - "' section of the vocabulary are not of the same dimension"); - - return Error::success(); -} - // FIXME: Make this optional. We can avoid file reads // by auto-generating a default vocabulary during the build time. Error IR2VecVocabAnalysis::readVocabulary(VocabMap &OpcVocab, @@ -513,16 +513,16 @@ Error IR2VecVocabAnalysis::readVocabulary(VocabMap &OpcVocab, return ParsedVocabValue.takeError(); unsigned OpcodeDim = 0, TypeDim = 0, ArgDim = 0; - if (auto Err = - parseVocabSection("Opcodes", *ParsedVocabValue, OpcVocab, OpcodeDim)) + if (auto Err = VocabStorage::parseVocabSection("Opcodes", *ParsedVocabValue, + OpcVocab, OpcodeDim)) return Err; - if (auto Err = - parseVocabSection("Types", *ParsedVocabValue, TypeVocab, TypeDim)) + if (auto Err = VocabStorage::parseVocabSection("Types", *ParsedVocabValue, + TypeVocab, TypeDim)) return Err; - if (auto Err = - parseVocabSection("Arguments", *ParsedVocabValue, ArgVocab, ArgDim)) + if (auto Err = VocabStorage::parseVocabSection("Arguments", *ParsedVocabValue, + ArgVocab, ArgDim)) return Err; if (!(OpcodeDim == TypeDim && TypeDim == ArgDim)) |