diff options
Diffstat (limited to 'llvm/lib/CodeGen/MIR2Vec.cpp')
-rw-r--r-- | llvm/lib/CodeGen/MIR2Vec.cpp | 366 |
1 files changed, 295 insertions, 71 deletions
diff --git a/llvm/lib/CodeGen/MIR2Vec.cpp b/llvm/lib/CodeGen/MIR2Vec.cpp index 99be1fc0..00b37e7 100644 --- a/llvm/lib/CodeGen/MIR2Vec.cpp +++ b/llvm/lib/CodeGen/MIR2Vec.cpp @@ -42,6 +42,13 @@ static cl::opt<std::string> cl::opt<float> OpcWeight("mir2vec-opc-weight", cl::Optional, cl::init(1.0), cl::desc("Weight for machine opcode embeddings"), cl::cat(MIR2VecCategory)); +cl::opt<float> CommonOperandWeight( + "mir2vec-common-operand-weight", cl::Optional, cl::init(1.0), + cl::desc("Weight for common operand embeddings"), cl::cat(MIR2VecCategory)); +cl::opt<float> + RegOperandWeight("mir2vec-reg-operand-weight", cl::Optional, cl::init(1.0), + cl::desc("Weight for register operand embeddings"), + cl::cat(MIR2VecCategory)); cl::opt<MIR2VecKind> MIR2VecEmbeddingKind( "mir2vec-kind", cl::Optional, cl::values(clEnumValN(MIR2VecKind::Symbolic, "symbolic", @@ -56,26 +63,52 @@ cl::opt<MIR2VecKind> MIR2VecEmbeddingKind( // Vocabulary //===----------------------------------------------------------------------===// -MIRVocabulary::MIRVocabulary(VocabMap &&OpcodeEntries, - const TargetInstrInfo &TII) - : TII(TII) { +MIRVocabulary::MIRVocabulary(VocabMap &&OpcodeMap, VocabMap &&CommonOperandMap, + VocabMap &&PhysicalRegisterMap, + VocabMap &&VirtualRegisterMap, + const TargetInstrInfo &TII, + const TargetRegisterInfo &TRI, + const MachineRegisterInfo &MRI) + : TII(TII), TRI(TRI), MRI(MRI) { buildCanonicalOpcodeMapping(); - unsigned CanonicalOpcodeCount = UniqueBaseOpcodeNames.size(); assert(CanonicalOpcodeCount > 0 && "No canonical opcodes found for target - invalid vocabulary"); - Layout.OperandBase = CanonicalOpcodeCount; - generateStorage(OpcodeEntries); + + buildRegisterOperandMapping(); + + // Define layout of vocabulary sections + Layout.OpcodeBase = 0; + Layout.CommonOperandBase = CanonicalOpcodeCount; + // We expect same classes for physical and virtual registers + Layout.PhyRegBase = Layout.CommonOperandBase + std::size(CommonOperandNames); + Layout.VirtRegBase = Layout.PhyRegBase + RegisterOperandNames.size(); + + generateStorage(OpcodeMap, CommonOperandMap, PhysicalRegisterMap, + VirtualRegisterMap); Layout.TotalEntries = Storage.size(); } -Expected<MIRVocabulary> MIRVocabulary::create(VocabMap &&Entries, - const TargetInstrInfo &TII) { - if (Entries.empty()) +Expected<MIRVocabulary> +MIRVocabulary::create(VocabMap &&OpcodeMap, VocabMap &&CommonOperandMap, + VocabMap &&PhyRegMap, VocabMap &&VirtRegMap, + const TargetInstrInfo &TII, const TargetRegisterInfo &TRI, + const MachineRegisterInfo &MRI) { + if (OpcodeMap.empty() || CommonOperandMap.empty() || PhyRegMap.empty() || + VirtRegMap.empty()) return createStringError(errc::invalid_argument, "Empty vocabulary entries provided"); - return MIRVocabulary(std::move(Entries), TII); + MIRVocabulary Vocab(std::move(OpcodeMap), std::move(CommonOperandMap), + std::move(PhyRegMap), std::move(VirtRegMap), TII, TRI, + MRI); + + // Validate Storage after construction + if (!Vocab.Storage.isValid()) + return createStringError(errc::invalid_argument, + "Failed to create valid vocabulary storage"); + Vocab.ZeroEmbedding = Embedding(Vocab.Storage.getDimension(), 0.0); + return std::move(Vocab); } std::string MIRVocabulary::extractBaseOpcodeName(StringRef InstrName) { @@ -122,22 +155,74 @@ unsigned MIRVocabulary::getCanonicalOpcodeIndex(unsigned Opcode) const { return getCanonicalIndexForBaseName(BaseOpcode); } +unsigned +MIRVocabulary::getCanonicalIndexForOperandName(StringRef OperandName) const { + auto It = std::find(std::begin(CommonOperandNames), + std::end(CommonOperandNames), OperandName); + assert(It != std::end(CommonOperandNames) && + "Operand name not found in common operands"); + return Layout.CommonOperandBase + + std::distance(std::begin(CommonOperandNames), It); +} + +unsigned +MIRVocabulary::getCanonicalIndexForRegisterClass(StringRef RegName, + bool IsPhysical) const { + auto It = std::find(RegisterOperandNames.begin(), RegisterOperandNames.end(), + RegName); + assert(It != RegisterOperandNames.end() && + "Register name not found in register operands"); + unsigned LocalIndex = std::distance(RegisterOperandNames.begin(), It); + return (IsPhysical ? Layout.PhyRegBase : Layout.VirtRegBase) + LocalIndex; +} + std::string MIRVocabulary::getStringKey(unsigned Pos) const { assert(Pos < Layout.TotalEntries && "Position out of bounds in vocabulary"); - // For now, all entries are opcodes since we only have one section - if (Pos < Layout.OperandBase && Pos < UniqueBaseOpcodeNames.size()) { + // Handle opcodes section + if (Pos < Layout.CommonOperandBase) { // Convert canonical index back to base opcode name auto It = UniqueBaseOpcodeNames.begin(); std::advance(It, Pos); + assert(It != UniqueBaseOpcodeNames.end() && + "Canonical index out of bounds in opcode section"); return *It; } - llvm_unreachable("Invalid position in vocabulary"); - return ""; + auto getLocalIndex = [](unsigned Pos, size_t BaseOffset, size_t Bound, + const char *Msg) { + unsigned LocalIndex = Pos - BaseOffset; + assert(LocalIndex < Bound && Msg); + return LocalIndex; + }; + + // Handle common operands section + if (Pos < Layout.PhyRegBase) { + unsigned LocalIndex = getLocalIndex( + Pos, Layout.CommonOperandBase, std::size(CommonOperandNames), + "Local index out of bounds in common operands"); + return CommonOperandNames[LocalIndex].str(); + } + + // Handle physical registers section + if (Pos < Layout.VirtRegBase) { + unsigned LocalIndex = + getLocalIndex(Pos, Layout.PhyRegBase, RegisterOperandNames.size(), + "Local index out of bounds in physical registers"); + return "PhyReg_" + RegisterOperandNames[LocalIndex]; + } + + // Handle virtual registers section + unsigned LocalIndex = + getLocalIndex(Pos, Layout.VirtRegBase, RegisterOperandNames.size(), + "Local index out of bounds in virtual registers"); + return "VirtReg_" + RegisterOperandNames[LocalIndex]; } -void MIRVocabulary::generateStorage(const VocabMap &OpcodeMap) { +void MIRVocabulary::generateStorage(const VocabMap &OpcodeMap, + const VocabMap &CommonOperandsMap, + const VocabMap &PhyRegMap, + const VocabMap &VirtRegMap) { // Helper for handling missing entities in the vocabulary. // Currently, we use a zero vector. In the future, we will throw an error to @@ -151,14 +236,14 @@ void MIRVocabulary::generateStorage(const VocabMap &OpcodeMap) { // Initialize opcode embeddings section unsigned EmbeddingDim = OpcodeMap.begin()->second.size(); - std::vector<Embedding> OpcodeEmbeddings(Layout.OperandBase, + std::vector<Embedding> OpcodeEmbeddings(Layout.CommonOperandBase, Embedding(EmbeddingDim)); // Populate opcode embeddings using canonical mapping for (auto COpcodeName : UniqueBaseOpcodeNames) { if (auto It = OpcodeMap.find(COpcodeName); It != OpcodeMap.end()) { auto COpcodeIndex = getCanonicalIndexForBaseName(COpcodeName); - assert(COpcodeIndex < Layout.OperandBase && + assert(COpcodeIndex < Layout.CommonOperandBase && "Canonical index out of bounds"); OpcodeEmbeddings[COpcodeIndex] = It->second; } else { @@ -166,8 +251,39 @@ void MIRVocabulary::generateStorage(const VocabMap &OpcodeMap) { } } - // TODO: Add operand/argument embeddings as additional sections - // This will require extending the vocabulary format and layout + // Initialize common operand embeddings section + std::vector<Embedding> CommonOperandEmbeddings(std::size(CommonOperandNames), + Embedding(EmbeddingDim)); + unsigned OperandIndex = 0; + for (const auto &CommonOperandName : CommonOperandNames) { + if (auto It = CommonOperandsMap.find(CommonOperandName.str()); + It != CommonOperandsMap.end()) { + CommonOperandEmbeddings[OperandIndex] = It->second; + } else { + handleMissingEntity(CommonOperandName); + } + ++OperandIndex; + } + + // Helper lambda for creating register operand embeddings + auto createRegisterEmbeddings = [&](const VocabMap &RegMap) { + std::vector<Embedding> RegEmbeddings(TRI.getNumRegClasses(), + Embedding(EmbeddingDim)); + unsigned RegOperandIndex = 0; + for (const auto &RegOperandName : RegisterOperandNames) { + if (auto It = RegMap.find(RegOperandName); It != RegMap.end()) + RegEmbeddings[RegOperandIndex] = It->second; + else + handleMissingEntity(RegOperandName); + ++RegOperandIndex; + } + return RegEmbeddings; + }; + + // Initialize register operand embeddings sections + std::vector<Embedding> PhyRegEmbeddings = createRegisterEmbeddings(PhyRegMap); + std::vector<Embedding> VirtRegEmbeddings = + createRegisterEmbeddings(VirtRegMap); // Scale the vocabulary sections based on the provided weights auto scaleVocabSection = [](std::vector<Embedding> &Embeddings, @@ -176,9 +292,20 @@ void MIRVocabulary::generateStorage(const VocabMap &OpcodeMap) { Embedding *= Weight; }; scaleVocabSection(OpcodeEmbeddings, OpcWeight); - - std::vector<std::vector<Embedding>> Sections(1); - Sections[0] = std::move(OpcodeEmbeddings); + scaleVocabSection(CommonOperandEmbeddings, CommonOperandWeight); + scaleVocabSection(PhyRegEmbeddings, RegOperandWeight); + scaleVocabSection(VirtRegEmbeddings, RegOperandWeight); + + std::vector<std::vector<Embedding>> Sections( + static_cast<unsigned>(Section::MaxSections)); + Sections[static_cast<unsigned>(Section::Opcodes)] = + std::move(OpcodeEmbeddings); + Sections[static_cast<unsigned>(Section::CommonOperands)] = + std::move(CommonOperandEmbeddings); + Sections[static_cast<unsigned>(Section::PhyRegisters)] = + std::move(PhyRegEmbeddings); + Sections[static_cast<unsigned>(Section::VirtRegisters)] = + std::move(VirtRegEmbeddings); Storage = ir2vec::VocabStorage(std::move(Sections)); } @@ -199,46 +326,130 @@ void MIRVocabulary::buildCanonicalOpcodeMapping() { << " unique base opcodes\n"); } -Expected<MIRVocabulary> -MIRVocabulary::createDummyVocabForTest(const TargetInstrInfo &TII, - unsigned Dim) { +void MIRVocabulary::buildRegisterOperandMapping() { + // Check if already built + if (!RegisterOperandNames.empty()) + return; + + for (unsigned RC = 0; RC < TRI.getNumRegClasses(); ++RC) { + const TargetRegisterClass *RegClass = TRI.getRegClass(RC); + if (!RegClass) + continue; + + // Get the register class name + StringRef ClassName = TRI.getRegClassName(RegClass); + RegisterOperandNames.push_back(ClassName.str()); + } +} + +unsigned MIRVocabulary::getCommonOperandIndex( + MachineOperand::MachineOperandType OperandType) const { + assert(OperandType != MachineOperand::MO_Register && + "Expected non-register operand type"); + assert(OperandType > MachineOperand::MO_Register && + OperandType < MachineOperand::MO_Last && "Operand type out of bounds"); + return static_cast<unsigned>(OperandType) - 1; +} + +unsigned MIRVocabulary::getRegisterOperandIndex(Register Reg) const { + assert(!RegisterOperandNames.empty() && "Register operand mapping not built"); + assert(Reg.isValid() && "Invalid register; not expected here"); + assert((Reg.isPhysical() || Reg.isVirtual()) && + "Expected a physical or virtual register"); + + const TargetRegisterClass *RegClass = nullptr; + + // For physical registers, use TRI to get minimal register class as a + // physical register can belong to multiple classes. For virtual + // registers, use MRI to uniquely identify the assigned register class. + if (Reg.isPhysical()) + RegClass = TRI.getMinimalPhysRegClass(Reg); + else + RegClass = MRI.getRegClass(Reg); + + if (RegClass) + return RegClass->getID(); + // Fallback for registers without a class (shouldn't happen) + llvm_unreachable("Register operand without a valid register class"); + return 0; +} + +Expected<MIRVocabulary> MIRVocabulary::createDummyVocabForTest( + const TargetInstrInfo &TII, const TargetRegisterInfo &TRI, + const MachineRegisterInfo &MRI, unsigned Dim) { assert(Dim > 0 && "Dimension must be greater than zero"); float DummyVal = 0.1f; - // Create dummy embeddings for all canonical opcode names - VocabMap DummyVocabMap; + VocabMap DummyOpcMap, DummyOperandMap, DummyPhyRegMap, DummyVirtRegMap; + + // Process opcodes directly without creating temporary vocabulary for (unsigned Opcode = 0; Opcode < TII.getNumOpcodes(); ++Opcode) { std::string BaseOpcode = extractBaseOpcodeName(TII.getName(Opcode)); - if (DummyVocabMap.count(BaseOpcode) == 0) { - // Only add if not already present - DummyVocabMap[BaseOpcode] = Embedding(Dim, DummyVal); + if (DummyOpcMap.count(BaseOpcode) == 0) { // Only add if not already present + DummyOpcMap[BaseOpcode] = Embedding(Dim, DummyVal); DummyVal += 0.1f; } } - // Create and return vocabulary with dummy embeddings - return MIRVocabulary::create(std::move(DummyVocabMap), TII); + // Add common operands + for (const auto &CommonOperandName : CommonOperandNames) { + DummyOperandMap[CommonOperandName.str()] = Embedding(Dim, DummyVal); + DummyVal += 0.1f; + } + + // Process register classes directly + for (unsigned RC = 0; RC < TRI.getNumRegClasses(); ++RC) { + const TargetRegisterClass *RegClass = TRI.getRegClass(RC); + if (!RegClass) + continue; + + std::string ClassName = TRI.getRegClassName(RegClass); + DummyPhyRegMap[ClassName] = Embedding(Dim, DummyVal); + DummyVirtRegMap[ClassName] = Embedding(Dim, DummyVal); + DummyVal += 0.1f; + } + + // Create vocabulary directly without temporary instance + return MIRVocabulary::create( + std::move(DummyOpcMap), std::move(DummyOperandMap), + std::move(DummyPhyRegMap), std::move(DummyVirtRegMap), TII, TRI, MRI); } //===----------------------------------------------------------------------===// -// MIR2VecVocabLegacyAnalysis Implementation +// MIR2VecVocabProvider and MIR2VecVocabLegacyAnalysis //===----------------------------------------------------------------------===// -char MIR2VecVocabLegacyAnalysis::ID = 0; -INITIALIZE_PASS_BEGIN(MIR2VecVocabLegacyAnalysis, "mir2vec-vocab-analysis", - "MIR2Vec Vocabulary Analysis", false, true) -INITIALIZE_PASS_DEPENDENCY(MachineModuleInfoWrapperPass) -INITIALIZE_PASS_END(MIR2VecVocabLegacyAnalysis, "mir2vec-vocab-analysis", - "MIR2Vec Vocabulary Analysis", false, true) +Expected<mir2vec::MIRVocabulary> +MIR2VecVocabProvider::getVocabulary(const Module &M) { + VocabMap OpcVocab, CommonOperandVocab, PhyRegVocabMap, VirtRegVocabMap; -StringRef MIR2VecVocabLegacyAnalysis::getPassName() const { - return "MIR2Vec Vocabulary Analysis"; + if (Error Err = readVocabulary(OpcVocab, CommonOperandVocab, PhyRegVocabMap, + VirtRegVocabMap)) + return std::move(Err); + + for (const auto &F : M) { + if (F.isDeclaration()) + continue; + + if (auto *MF = MMI.getMachineFunction(F)) { + auto &Subtarget = MF->getSubtarget(); + if (const auto *TII = Subtarget.getInstrInfo()) + if (const auto *TRI = Subtarget.getRegisterInfo()) + return mir2vec::MIRVocabulary::create( + std::move(OpcVocab), std::move(CommonOperandVocab), + std::move(PhyRegVocabMap), std::move(VirtRegVocabMap), *TII, *TRI, + MF->getRegInfo()); + } + } + return createStringError(errc::invalid_argument, + "No machine functions found in module"); } -Error MIR2VecVocabLegacyAnalysis::readVocabulary() { - // TODO: Extend vocabulary format to support multiple sections - // (opcodes, operands, etc.) similar to IR2Vec structure +Error MIR2VecVocabProvider::readVocabulary(VocabMap &OpcodeVocab, + VocabMap &CommonOperandVocab, + VocabMap &PhyRegVocabMap, + VocabMap &VirtRegVocabMap) { if (VocabFile.empty()) return createStringError( errc::invalid_argument, @@ -255,39 +466,47 @@ Error MIR2VecVocabLegacyAnalysis::readVocabulary() { if (!ParsedVocabValue) return ParsedVocabValue.takeError(); - unsigned Dim = 0; + unsigned OpcodeDim = 0, CommonOperandDim = 0, PhyRegOperandDim = 0, + VirtRegOperandDim = 0; if (auto Err = ir2vec::VocabStorage::parseVocabSection( - "entities", *ParsedVocabValue, StrVocabMap, Dim)) + "Opcodes", *ParsedVocabValue, OpcodeVocab, OpcodeDim)) return Err; - return Error::success(); -} - -Expected<mir2vec::MIRVocabulary> -MIR2VecVocabLegacyAnalysis::getMIR2VecVocabulary(const Module &M) { - if (StrVocabMap.empty()) { - if (Error Err = readVocabulary()) { - return std::move(Err); - } - } + if (auto Err = ir2vec::VocabStorage::parseVocabSection( + "CommonOperands", *ParsedVocabValue, CommonOperandVocab, + CommonOperandDim)) + return Err; - // Get machine module info to access machine functions and target info - MachineModuleInfo &MMI = getAnalysis<MachineModuleInfoWrapperPass>().getMMI(); + if (auto Err = ir2vec::VocabStorage::parseVocabSection( + "PhysicalRegisters", *ParsedVocabValue, PhyRegVocabMap, + PhyRegOperandDim)) + return Err; - // Find first available machine function to get target instruction info - for (const auto &F : M) { - if (F.isDeclaration()) - continue; + if (auto Err = ir2vec::VocabStorage::parseVocabSection( + "VirtualRegisters", *ParsedVocabValue, VirtRegVocabMap, + VirtRegOperandDim)) + return Err; - if (auto *MF = MMI.getMachineFunction(F)) { - const TargetInstrInfo *TII = MF->getSubtarget().getInstrInfo(); - return mir2vec::MIRVocabulary::create(std::move(StrVocabMap), *TII); - } + // All sections must have the same embedding dimension + if (!(OpcodeDim == CommonOperandDim && CommonOperandDim == PhyRegOperandDim && + PhyRegOperandDim == VirtRegOperandDim)) { + return createStringError( + errc::illegal_byte_sequence, + "MIR2Vec vocabulary sections have different dimensions"); } - // No machine functions available - return error - return createStringError(errc::invalid_argument, - "No machine functions found in module"); + return Error::success(); +} + +char MIR2VecVocabLegacyAnalysis::ID = 0; +INITIALIZE_PASS_BEGIN(MIR2VecVocabLegacyAnalysis, "mir2vec-vocab-analysis", + "MIR2Vec Vocabulary Analysis", false, true) +INITIALIZE_PASS_DEPENDENCY(MachineModuleInfoWrapperPass) +INITIALIZE_PASS_END(MIR2VecVocabLegacyAnalysis, "mir2vec-vocab-analysis", + "MIR2Vec Vocabulary Analysis", false, true) + +StringRef MIR2VecVocabLegacyAnalysis::getPassName() const { + return "MIR2Vec Vocabulary Analysis"; } //===----------------------------------------------------------------------===// @@ -351,9 +570,14 @@ Embedding SymbolicMIREmbedder::computeEmbeddings(const MachineInstr &MI) const { if (MI.isDebugInstr()) return Embedding(Dimension, 0); - // Todo: Add operand/argument contributions + // Opcode embedding + Embedding InstructionEmbedding = Vocab[MI.getOpcode()]; + + // Add operand contributions + for (const MachineOperand &MO : MI.operands()) + InstructionEmbedding += Vocab[MO]; - return Vocab[MI.getOpcode()]; + return InstructionEmbedding; } //===----------------------------------------------------------------------===// |