diff options
Diffstat (limited to 'llvm/lib/Analysis/IR2Vec.cpp')
-rw-r--r-- | llvm/lib/Analysis/IR2Vec.cpp | 274 |
1 files changed, 188 insertions, 86 deletions
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp index 99afc06..271f004 100644 --- a/llvm/lib/Analysis/IR2Vec.cpp +++ b/llvm/lib/Analysis/IR2Vec.cpp @@ -15,6 +15,7 @@ #include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/Sequence.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/IR/CFG.h" #include "llvm/IR/Module.h" @@ -216,6 +217,8 @@ void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const { ArgEmb += Vocab[*Op]; auto InstVector = Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb; + if (const auto *IC = dyn_cast<CmpInst>(&I)) + InstVector += Vocab[IC->getPredicate()]; InstVecMap[&I] = InstVector; BBVector += InstVector; } @@ -250,6 +253,9 @@ void FlowAwareEmbedder::computeEmbeddings(const BasicBlock &BB) const { // embeddings auto InstVector = Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb; + // Add compare predicate embedding as an additional operand if applicable + if (const auto *IC = dyn_cast<CmpInst>(&I)) + InstVector += Vocab[IC->getPredicate()]; InstVecMap[&I] = InstVector; BBVector += InstVector; } @@ -257,41 +263,75 @@ void FlowAwareEmbedder::computeEmbeddings(const BasicBlock &BB) const { } // ==----------------------------------------------------------------------===// -// Vocabulary +// VocabStorage //===----------------------------------------------------------------------===// -unsigned Vocabulary::getDimension() const { - assert(isValid() && "IR2Vec Vocabulary is invalid"); - return Vocab[0].size(); -} - -unsigned Vocabulary::getSlotIndex(unsigned Opcode) { - assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode"); - return Opcode - 1; // Convert to zero-based index -} - -unsigned Vocabulary::getSlotIndex(Type::TypeID TypeID) { - assert(static_cast<unsigned>(TypeID) < MaxTypeIDs && "Invalid type ID"); - return MaxOpcodes + static_cast<unsigned>(getCanonicalTypeID(TypeID)); -} - -unsigned Vocabulary::getSlotIndex(const Value &Op) { - unsigned Index = static_cast<unsigned>(getOperandKind(&Op)); - assert(Index < MaxOperandKinds && "Invalid OperandKind"); - return MaxOpcodes + MaxCanonicalTypeIDs + Index; +VocabStorage::VocabStorage(std::vector<std::vector<Embedding>> &&SectionData) + : Sections(std::move(SectionData)), TotalSize([&] { + assert(!Sections.empty() && "Vocabulary has no sections"); + // Compute total size across all sections + size_t Size = 0; + for (const auto &Section : Sections) { + assert(!Section.empty() && "Vocabulary section is empty"); + Size += Section.size(); + } + return Size; + }()), + Dimension([&] { + // Get dimension from the first embedding in the first section - all + // embeddings must have the same dimension + assert(!Sections.empty() && "Vocabulary has no sections"); + assert(!Sections[0].empty() && "First section of vocabulary is empty"); + unsigned ExpectedDim = static_cast<unsigned>(Sections[0][0].size()); + + // Verify that all embeddings across all sections have the same + // dimension + auto allSameDim = [ExpectedDim](const std::vector<Embedding> &Section) { + return std::all_of(Section.begin(), Section.end(), + [ExpectedDim](const Embedding &Emb) { + return Emb.size() == ExpectedDim; + }); + }; + assert(std::all_of(Sections.begin(), Sections.end(), allSameDim) && + "All embeddings must have the same dimension"); + + return ExpectedDim; + }()) {} + +const Embedding &VocabStorage::const_iterator::operator*() const { + assert(SectionId < Storage->Sections.size() && "Invalid section ID"); + assert(LocalIndex < Storage->Sections[SectionId].size() && + "Local index out of range"); + return Storage->Sections[SectionId][LocalIndex]; +} + +VocabStorage::const_iterator &VocabStorage::const_iterator::operator++() { + ++LocalIndex; + // Check if we need to move to the next section + if (SectionId < Storage->getNumSections() && + LocalIndex >= Storage->Sections[SectionId].size()) { + assert(LocalIndex == Storage->Sections[SectionId].size() && + "Local index should be at the end of the current section"); + LocalIndex = 0; + ++SectionId; + } + return *this; } -const Embedding &Vocabulary::operator[](unsigned Opcode) const { - return Vocab[getSlotIndex(Opcode)]; +bool VocabStorage::const_iterator::operator==( + const const_iterator &Other) const { + return Storage == Other.Storage && SectionId == Other.SectionId && + LocalIndex == Other.LocalIndex; } -const Embedding &Vocabulary::operator[](Type::TypeID TypeID) const { - return Vocab[getSlotIndex(TypeID)]; +bool VocabStorage::const_iterator::operator!=( + const const_iterator &Other) const { + return !(*this == Other); } -const ir2vec::Embedding &Vocabulary::operator[](const Value &Arg) const { - return Vocab[getSlotIndex(Arg)]; -} +// ==----------------------------------------------------------------------===// +// Vocabulary +//===----------------------------------------------------------------------===// StringRef Vocabulary::getVocabKeyForOpcode(unsigned Opcode) { assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode"); @@ -304,29 +344,6 @@ StringRef Vocabulary::getVocabKeyForOpcode(unsigned Opcode) { return "UnknownOpcode"; } -StringRef Vocabulary::getVocabKeyForCanonicalTypeID(CanonicalTypeID CType) { - unsigned Index = static_cast<unsigned>(CType); - assert(Index < MaxCanonicalTypeIDs && "Invalid CanonicalTypeID"); - return CanonicalTypeNames[Index]; -} - -Vocabulary::CanonicalTypeID -Vocabulary::getCanonicalTypeID(Type::TypeID TypeID) { - unsigned Index = static_cast<unsigned>(TypeID); - assert(Index < MaxTypeIDs && "Invalid TypeID"); - return TypeIDMapping[Index]; -} - -StringRef Vocabulary::getVocabKeyForTypeID(Type::TypeID TypeID) { - return getVocabKeyForCanonicalTypeID(getCanonicalTypeID(TypeID)); -} - -StringRef Vocabulary::getVocabKeyForOperandKind(Vocabulary::OperandKind Kind) { - unsigned Index = static_cast<unsigned>(Kind); - assert(Index < MaxOperandKinds && "Invalid OperandKind"); - return OperandKindNames[Index]; -} - // Helper function to classify an operand into OperandKind Vocabulary::OperandKind Vocabulary::getOperandKind(const Value *Op) { if (isa<Function>(Op)) @@ -338,18 +355,50 @@ Vocabulary::OperandKind Vocabulary::getOperandKind(const Value *Op) { return OperandKind::VariableID; } +unsigned Vocabulary::getPredicateLocalIndex(CmpInst::Predicate P) { + if (P >= CmpInst::FIRST_FCMP_PREDICATE && P <= CmpInst::LAST_FCMP_PREDICATE) + return P - CmpInst::FIRST_FCMP_PREDICATE; + else + return P - CmpInst::FIRST_ICMP_PREDICATE + + (CmpInst::LAST_FCMP_PREDICATE - CmpInst::FIRST_FCMP_PREDICATE + 1); +} + +CmpInst::Predicate Vocabulary::getPredicateFromLocalIndex(unsigned LocalIndex) { + unsigned fcmpRange = + CmpInst::LAST_FCMP_PREDICATE - CmpInst::FIRST_FCMP_PREDICATE + 1; + if (LocalIndex < fcmpRange) + return static_cast<CmpInst::Predicate>(CmpInst::FIRST_FCMP_PREDICATE + + LocalIndex); + else + return static_cast<CmpInst::Predicate>(CmpInst::FIRST_ICMP_PREDICATE + + LocalIndex - fcmpRange); +} + +StringRef Vocabulary::getVocabKeyForPredicate(CmpInst::Predicate Pred) { + static SmallString<16> PredNameBuffer; + if (Pred < CmpInst::FIRST_ICMP_PREDICATE) + PredNameBuffer = "FCMP_"; + else + PredNameBuffer = "ICMP_"; + PredNameBuffer += CmpInst::getPredicateName(Pred); + return PredNameBuffer; +} + StringRef Vocabulary::getStringKey(unsigned Pos) { assert(Pos < NumCanonicalEntries && "Position out of bounds in vocabulary"); // Opcode if (Pos < MaxOpcodes) return getVocabKeyForOpcode(Pos + 1); // Type - if (Pos < MaxOpcodes + MaxCanonicalTypeIDs) + if (Pos < OperandBaseOffset) return getVocabKeyForCanonicalTypeID( static_cast<CanonicalTypeID>(Pos - MaxOpcodes)); // Operand - return getVocabKeyForOperandKind( - static_cast<OperandKind>(Pos - MaxOpcodes - MaxCanonicalTypeIDs)); + if (Pos < PredicateBaseOffset) + return getVocabKeyForOperandKind( + static_cast<OperandKind>(Pos - OperandBaseOffset)); + // Predicates + return getVocabKeyForPredicate(getPredicate(Pos - PredicateBaseOffset)); } // For now, assume vocabulary is stable unless explicitly invalidated. @@ -359,19 +408,51 @@ bool Vocabulary::invalidate(Module &M, const PreservedAnalyses &PA, return !(PAC.preservedWhenStateless()); } -Vocabulary::VocabVector Vocabulary::createDummyVocabForTest(unsigned Dim) { - VocabVector DummyVocab; - DummyVocab.reserve(NumCanonicalEntries); +VocabStorage Vocabulary::createDummyVocabForTest(unsigned Dim) { float DummyVal = 0.1f; - // Create a dummy vocabulary with entries for all opcodes, types, and - // operands - for ([[maybe_unused]] unsigned _ : - seq(0u, Vocabulary::MaxOpcodes + Vocabulary::MaxCanonicalTypeIDs + - Vocabulary::MaxOperandKinds)) { - DummyVocab.push_back(Embedding(Dim, DummyVal)); + + // Create sections for opcodes, types, operands, and predicates + // Order must match Vocabulary::Section enum + std::vector<std::vector<Embedding>> Sections; + Sections.reserve(4); + + // Opcodes section + std::vector<Embedding> OpcodeSec; + OpcodeSec.reserve(MaxOpcodes); + for (unsigned I = 0; I < MaxOpcodes; ++I) { + OpcodeSec.emplace_back(Dim, DummyVal); DummyVal += 0.1f; } - return DummyVocab; + Sections.push_back(std::move(OpcodeSec)); + + // Types section + std::vector<Embedding> TypeSec; + TypeSec.reserve(MaxCanonicalTypeIDs); + for (unsigned I = 0; I < MaxCanonicalTypeIDs; ++I) { + TypeSec.emplace_back(Dim, DummyVal); + DummyVal += 0.1f; + } + Sections.push_back(std::move(TypeSec)); + + // Operands section + std::vector<Embedding> OperandSec; + OperandSec.reserve(MaxOperandKinds); + for (unsigned I = 0; I < MaxOperandKinds; ++I) { + OperandSec.emplace_back(Dim, DummyVal); + DummyVal += 0.1f; + } + Sections.push_back(std::move(OperandSec)); + + // Predicates section + std::vector<Embedding> PredicateSec; + PredicateSec.reserve(MaxPredicateKinds); + for (unsigned I = 0; I < MaxPredicateKinds; ++I) { + PredicateSec.emplace_back(Dim, DummyVal); + DummyVal += 0.1f; + } + Sections.push_back(std::move(PredicateSec)); + + return VocabStorage(std::move(Sections)); } // ==----------------------------------------------------------------------===// @@ -417,7 +498,9 @@ Error IR2VecVocabAnalysis::parseVocabSection( // FIXME: Make this optional. We can avoid file reads // by auto-generating a default vocabulary during the build time. -Error IR2VecVocabAnalysis::readVocabulary() { +Error IR2VecVocabAnalysis::readVocabulary(VocabMap &OpcVocab, + VocabMap &TypeVocab, + VocabMap &ArgVocab) { auto BufOrError = MemoryBuffer::getFileOrSTDIN(VocabFile, /*IsText=*/true); if (!BufOrError) return createFileError(VocabFile, BufOrError.getError()); @@ -448,7 +531,9 @@ Error IR2VecVocabAnalysis::readVocabulary() { return Error::success(); } -void IR2VecVocabAnalysis::generateNumMappedVocab() { +void IR2VecVocabAnalysis::generateVocabStorage(VocabMap &OpcVocab, + VocabMap &TypeVocab, + VocabMap &ArgVocab) { // Helper for handling missing entities in the vocabulary. // Currently, we use a zero vector. In the future, we will throw an error to @@ -466,7 +551,6 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() { // Handle Opcodes std::vector<Embedding> NumericOpcodeEmbeddings(Vocabulary::MaxOpcodes, Embedding(Dim)); - NumericOpcodeEmbeddings.reserve(Vocabulary::MaxOpcodes); for (unsigned Opcode : seq(0u, Vocabulary::MaxOpcodes)) { StringRef VocabKey = Vocabulary::getVocabKeyForOpcode(Opcode + 1); auto It = OpcVocab.find(VocabKey.str()); @@ -475,13 +559,10 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() { else handleMissingEntity(VocabKey.str()); } - Vocab.insert(Vocab.end(), NumericOpcodeEmbeddings.begin(), - NumericOpcodeEmbeddings.end()); // Handle Types - only canonical types are present in vocabulary std::vector<Embedding> NumericTypeEmbeddings(Vocabulary::MaxCanonicalTypeIDs, Embedding(Dim)); - NumericTypeEmbeddings.reserve(Vocabulary::MaxCanonicalTypeIDs); for (unsigned CTypeID : seq(0u, Vocabulary::MaxCanonicalTypeIDs)) { StringRef VocabKey = Vocabulary::getVocabKeyForCanonicalTypeID( static_cast<Vocabulary::CanonicalTypeID>(CTypeID)); @@ -491,13 +572,10 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() { } handleMissingEntity(VocabKey.str()); } - Vocab.insert(Vocab.end(), NumericTypeEmbeddings.begin(), - NumericTypeEmbeddings.end()); // Handle Arguments/Operands std::vector<Embedding> NumericArgEmbeddings(Vocabulary::MaxOperandKinds, Embedding(Dim)); - NumericArgEmbeddings.reserve(Vocabulary::MaxOperandKinds); for (unsigned OpKind : seq(0u, Vocabulary::MaxOperandKinds)) { Vocabulary::OperandKind Kind = static_cast<Vocabulary::OperandKind>(OpKind); StringRef VocabKey = Vocabulary::getVocabKeyForOperandKind(Kind); @@ -508,15 +586,37 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() { } handleMissingEntity(VocabKey.str()); } - Vocab.insert(Vocab.end(), NumericArgEmbeddings.begin(), - NumericArgEmbeddings.end()); -} -IR2VecVocabAnalysis::IR2VecVocabAnalysis(const VocabVector &Vocab) - : Vocab(Vocab) {} + // Handle Predicates: part of Operands section. We look up predicate keys + // in ArgVocab. + std::vector<Embedding> NumericPredEmbeddings(Vocabulary::MaxPredicateKinds, + Embedding(Dim, 0)); + for (unsigned PK : seq(0u, Vocabulary::MaxPredicateKinds)) { + StringRef VocabKey = + Vocabulary::getVocabKeyForPredicate(Vocabulary::getPredicate(PK)); + auto It = ArgVocab.find(VocabKey.str()); + if (It != ArgVocab.end()) { + NumericPredEmbeddings[PK] = It->second; + continue; + } + handleMissingEntity(VocabKey.str()); + } + + // Create section-based storage instead of flat vocabulary + // Order must match Vocabulary::Section enum + std::vector<std::vector<Embedding>> Sections(4); + Sections[static_cast<unsigned>(Vocabulary::Section::Opcodes)] = + std::move(NumericOpcodeEmbeddings); // Section::Opcodes + Sections[static_cast<unsigned>(Vocabulary::Section::CanonicalTypes)] = + std::move(NumericTypeEmbeddings); // Section::CanonicalTypes + Sections[static_cast<unsigned>(Vocabulary::Section::Operands)] = + std::move(NumericArgEmbeddings); // Section::Operands + Sections[static_cast<unsigned>(Vocabulary::Section::Predicates)] = + std::move(NumericPredEmbeddings); // Section::Predicates -IR2VecVocabAnalysis::IR2VecVocabAnalysis(VocabVector &&Vocab) - : Vocab(std::move(Vocab)) {} + // Create VocabStorage from organized sections + Vocab.emplace(std::move(Sections)); +} void IR2VecVocabAnalysis::emitError(Error Err, LLVMContext &Ctx) { handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) { @@ -528,8 +628,8 @@ IR2VecVocabAnalysis::Result IR2VecVocabAnalysis::run(Module &M, ModuleAnalysisManager &AM) { auto Ctx = &M.getContext(); // If vocabulary is already populated by the constructor, use it. - if (!Vocab.empty()) - return Vocabulary(std::move(Vocab)); + if (Vocab.has_value()) + return Vocabulary(std::move(Vocab.value())); // Otherwise, try to read from the vocabulary file. if (VocabFile.empty()) { @@ -538,7 +638,9 @@ IR2VecVocabAnalysis::run(Module &M, ModuleAnalysisManager &AM) { "set it using --ir2vec-vocab-path"); return Vocabulary(); // Return invalid result } - if (auto Err = readVocabulary()) { + + VocabMap OpcVocab, TypeVocab, ArgVocab; + if (auto Err = readVocabulary(OpcVocab, TypeVocab, ArgVocab)) { emitError(std::move(Err), *Ctx); return Vocabulary(); } @@ -553,9 +655,9 @@ IR2VecVocabAnalysis::run(Module &M, ModuleAnalysisManager &AM) { scaleVocabSection(ArgVocab, ArgWeight); // Generate the numeric lookup vocabulary - generateNumMappedVocab(); + generateVocabStorage(OpcVocab, TypeVocab, ArgVocab); - return Vocabulary(std::move(Vocab)); + return Vocabulary(std::move(Vocab.value())); } // ==----------------------------------------------------------------------===// @@ -564,7 +666,7 @@ IR2VecVocabAnalysis::run(Module &M, ModuleAnalysisManager &AM) { PreservedAnalyses IR2VecPrinterPass::run(Module &M, ModuleAnalysisManager &MAM) { - auto Vocabulary = MAM.getResult<IR2VecVocabAnalysis>(M); + auto &Vocabulary = MAM.getResult<IR2VecVocabAnalysis>(M); assert(Vocabulary.isValid() && "IR2Vec Vocabulary is invalid"); for (Function &F : M) { @@ -606,7 +708,7 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M, PreservedAnalyses IR2VecVocabPrinterPass::run(Module &M, ModuleAnalysisManager &MAM) { - auto IR2VecVocabulary = MAM.getResult<IR2VecVocabAnalysis>(M); + auto &IR2VecVocabulary = MAM.getResult<IR2VecVocabAnalysis>(M); assert(IR2VecVocabulary.isValid() && "IR2Vec Vocabulary is invalid"); // Print each entry |