diff options
Diffstat (limited to 'llvm/lib/CodeGen')
-rw-r--r-- | llvm/lib/CodeGen/CMakeLists.txt | 1 | ||||
-rw-r--r-- | llvm/lib/CodeGen/CodeGen.cpp | 2 | ||||
-rw-r--r-- | llvm/lib/CodeGen/MIR2Vec.cpp | 306 | ||||
-rw-r--r-- | llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp | 8 | ||||
-rw-r--r-- | llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h | 2 |
5 files changed, 316 insertions, 3 deletions
diff --git a/llvm/lib/CodeGen/CMakeLists.txt b/llvm/lib/CodeGen/CMakeLists.txt index f8f9bbb..b6872605 100644 --- a/llvm/lib/CodeGen/CMakeLists.txt +++ b/llvm/lib/CodeGen/CMakeLists.txt @@ -155,6 +155,7 @@ add_llvm_component_library(LLVMCodeGen MIRFSDiscriminator.cpp MIRSampleProfile.cpp MIRYamlMapping.cpp + MIR2Vec.cpp MLRegAllocEvictAdvisor.cpp MLRegAllocPriorityAdvisor.cpp ModuloSchedule.cpp diff --git a/llvm/lib/CodeGen/CodeGen.cpp b/llvm/lib/CodeGen/CodeGen.cpp index 9e0cb3b..c438eae 100644 --- a/llvm/lib/CodeGen/CodeGen.cpp +++ b/llvm/lib/CodeGen/CodeGen.cpp @@ -96,6 +96,8 @@ void llvm::initializeCodeGen(PassRegistry &Registry) { initializeMachineSchedulerLegacyPass(Registry); initializeMachineSinkingLegacyPass(Registry); initializeMachineUniformityAnalysisPassPass(Registry); + initializeMIR2VecVocabLegacyAnalysisPass(Registry); + initializeMIR2VecVocabPrinterLegacyPassPass(Registry); initializeMachineUniformityInfoPrinterPassPass(Registry); initializeMachineVerifierLegacyPassPass(Registry); initializeObjCARCContractLegacyPassPass(Registry); diff --git a/llvm/lib/CodeGen/MIR2Vec.cpp b/llvm/lib/CodeGen/MIR2Vec.cpp new file mode 100644 index 0000000..87565c0 --- /dev/null +++ b/llvm/lib/CodeGen/MIR2Vec.cpp @@ -0,0 +1,306 @@ +//===- MIR2Vec.cpp - Implementation of MIR2Vec ---------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. See the LICENSE file for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file implements the MIR2Vec algorithm for Machine IR embeddings. +/// +//===----------------------------------------------------------------------===// + +#include "llvm/CodeGen/MIR2Vec.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/CodeGen/TargetInstrInfo.h" +#include "llvm/IR/Module.h" +#include "llvm/InitializePasses.h" +#include "llvm/Pass.h" +#include "llvm/Support/Errc.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/Regex.h" + +using namespace llvm; +using namespace mir2vec; + +#define DEBUG_TYPE "mir2vec" + +STATISTIC(MIRVocabMissCounter, + "Number of lookups to MIR entities not present in the vocabulary"); + +namespace llvm { +namespace mir2vec { +cl::OptionCategory MIR2VecCategory("MIR2Vec Options"); + +// FIXME: Use a default vocab when not specified +static cl::opt<std::string> + VocabFile("mir2vec-vocab-path", cl::Optional, + cl::desc("Path to the vocabulary file for MIR2Vec"), cl::init(""), + cl::cat(MIR2VecCategory)); +cl::opt<float> OpcWeight("mir2vec-opc-weight", cl::Optional, cl::init(1.0), + cl::desc("Weight for machine opcode embeddings"), + cl::cat(MIR2VecCategory)); +} // namespace mir2vec +} // namespace llvm + +//===----------------------------------------------------------------------===// +// Vocabulary Implementation +//===----------------------------------------------------------------------===// + +MIRVocabulary::MIRVocabulary(VocabMap &&OpcodeEntries, + const TargetInstrInfo *TII) + : TII(*TII) { + // Fixme: Use static factory methods for creating vocabularies instead of + // public constructors + // Early return for invalid inputs - creates empty/invalid vocabulary + if (!TII || OpcodeEntries.empty()) + return; + + buildCanonicalOpcodeMapping(); + + unsigned CanonicalOpcodeCount = UniqueBaseOpcodeNames.size(); + assert(CanonicalOpcodeCount > 0 && + "No canonical opcodes found for target - invalid vocabulary"); + Layout.OperandBase = CanonicalOpcodeCount; + generateStorage(OpcodeEntries); + Layout.TotalEntries = Storage.size(); +} + +std::string MIRVocabulary::extractBaseOpcodeName(StringRef InstrName) { + // Extract base instruction name using regex to capture letters and + // underscores Examples: "ADD32rr" -> "ADD", "ARITH_FENCE" -> "ARITH_FENCE" + // + // TODO: Consider more sophisticated extraction: + // - Handle complex prefixes like "AVX1_SETALLONES" correctly (Currently, it + // would naively map to "AVX") + // - Extract width suffixes (8,16,32,64) as separate features + // - Capture addressing mode suffixes (r,i,m,ri,etc.) for better analysis + // (Currently, instances like "MOV32mi" map to "MOV", but "ADDPDrr" would map + // to "ADDPDrr") + + assert(!InstrName.empty() && "Instruction name should not be empty"); + + // Use regex to extract initial sequence of letters and underscores + static const Regex BaseOpcodeRegex("([a-zA-Z_]+)"); + SmallVector<StringRef, 2> Matches; + + if (BaseOpcodeRegex.match(InstrName, &Matches) && Matches.size() > 1) { + StringRef Match = Matches[1]; + // Trim trailing underscores + while (!Match.empty() && Match.back() == '_') + Match = Match.drop_back(); + return Match.str(); + } + + // Fallback to original name if no pattern matches + return InstrName.str(); +} + +unsigned MIRVocabulary::getCanonicalIndexForBaseName(StringRef BaseName) const { + assert(!UniqueBaseOpcodeNames.empty() && "Canonical mapping not built"); + auto It = std::find(UniqueBaseOpcodeNames.begin(), + UniqueBaseOpcodeNames.end(), BaseName.str()); + assert(It != UniqueBaseOpcodeNames.end() && + "Base name not found in unique opcodes"); + return std::distance(UniqueBaseOpcodeNames.begin(), It); +} + +unsigned MIRVocabulary::getCanonicalOpcodeIndex(unsigned Opcode) const { + assert(isValid() && "MIR2Vec Vocabulary is invalid"); + auto BaseOpcode = extractBaseOpcodeName(TII.getName(Opcode)); + return getCanonicalIndexForBaseName(BaseOpcode); +} + +std::string MIRVocabulary::getStringKey(unsigned Pos) const { + assert(isValid() && "MIR2Vec Vocabulary is invalid"); + assert(Pos < Layout.TotalEntries && "Position out of bounds in vocabulary"); + + // For now, all entries are opcodes since we only have one section + if (Pos < Layout.OperandBase && Pos < UniqueBaseOpcodeNames.size()) { + // Convert canonical index back to base opcode name + auto It = UniqueBaseOpcodeNames.begin(); + std::advance(It, Pos); + return *It; + } + + llvm_unreachable("Invalid position in vocabulary"); + return ""; +} + +void MIRVocabulary::generateStorage(const VocabMap &OpcodeMap) { + + // Helper for handling missing entities in the vocabulary. + // Currently, we use a zero vector. In the future, we will throw an error to + // ensure that *all* known entities are present in the vocabulary. + auto handleMissingEntity = [](StringRef Key) { + LLVM_DEBUG(errs() << "MIR2Vec: Missing vocabulary entry for " << Key + << "; using zero vector. This will result in an error " + "in the future.\n"); + ++MIRVocabMissCounter; + }; + + // Initialize opcode embeddings section + unsigned EmbeddingDim = OpcodeMap.begin()->second.size(); + std::vector<Embedding> OpcodeEmbeddings(Layout.OperandBase, + Embedding(EmbeddingDim)); + + // Populate opcode embeddings using canonical mapping + for (auto COpcodeName : UniqueBaseOpcodeNames) { + if (auto It = OpcodeMap.find(COpcodeName); It != OpcodeMap.end()) { + auto COpcodeIndex = getCanonicalIndexForBaseName(COpcodeName); + assert(COpcodeIndex < Layout.OperandBase && + "Canonical index out of bounds"); + OpcodeEmbeddings[COpcodeIndex] = It->second; + } else { + handleMissingEntity(COpcodeName); + } + } + + // TODO: Add operand/argument embeddings as additional sections + // This will require extending the vocabulary format and layout + + // Scale the vocabulary sections based on the provided weights + auto scaleVocabSection = [](std::vector<Embedding> &Embeddings, + double Weight) { + for (auto &Embedding : Embeddings) + Embedding *= Weight; + }; + scaleVocabSection(OpcodeEmbeddings, OpcWeight); + + std::vector<std::vector<Embedding>> Sections(1); + Sections[0] = std::move(OpcodeEmbeddings); + + Storage = ir2vec::VocabStorage(std::move(Sections)); +} + +void MIRVocabulary::buildCanonicalOpcodeMapping() { + // Check if already built + if (!UniqueBaseOpcodeNames.empty()) + return; + + // Build mapping from opcodes to canonical base opcode indices + for (unsigned Opcode = 0; Opcode < TII.getNumOpcodes(); ++Opcode) { + std::string BaseOpcode = extractBaseOpcodeName(TII.getName(Opcode)); + UniqueBaseOpcodeNames.insert(BaseOpcode); + } + + LLVM_DEBUG(dbgs() << "MIR2Vec: Built canonical mapping for target with " + << UniqueBaseOpcodeNames.size() + << " unique base opcodes\n"); +} + +//===----------------------------------------------------------------------===// +// MIR2VecVocabLegacyAnalysis Implementation +//===----------------------------------------------------------------------===// + +char MIR2VecVocabLegacyAnalysis::ID = 0; +INITIALIZE_PASS_BEGIN(MIR2VecVocabLegacyAnalysis, "mir2vec-vocab-analysis", + "MIR2Vec Vocabulary Analysis", false, true) +INITIALIZE_PASS_DEPENDENCY(MachineModuleInfoWrapperPass) +INITIALIZE_PASS_END(MIR2VecVocabLegacyAnalysis, "mir2vec-vocab-analysis", + "MIR2Vec Vocabulary Analysis", false, true) + +StringRef MIR2VecVocabLegacyAnalysis::getPassName() const { + return "MIR2Vec Vocabulary Analysis"; +} + +Error MIR2VecVocabLegacyAnalysis::readVocabulary() { + // TODO: Extend vocabulary format to support multiple sections + // (opcodes, operands, etc.) similar to IR2Vec structure + if (VocabFile.empty()) + return createStringError( + errc::invalid_argument, + "MIR2Vec vocabulary file path not specified; set it " + "using --mir2vec-vocab-path"); + + auto BufOrError = MemoryBuffer::getFileOrSTDIN(VocabFile, /*IsText=*/true); + if (!BufOrError) + return createFileError(VocabFile, BufOrError.getError()); + + auto Content = BufOrError.get()->getBuffer(); + + Expected<json::Value> ParsedVocabValue = json::parse(Content); + if (!ParsedVocabValue) + return ParsedVocabValue.takeError(); + + unsigned Dim = 0; + if (auto Err = ir2vec::VocabStorage::parseVocabSection( + "entities", *ParsedVocabValue, StrVocabMap, Dim)) + return Err; + + return Error::success(); +} + +void MIR2VecVocabLegacyAnalysis::emitError(Error Err, LLVMContext &Ctx) { + Ctx.emitError(toString(std::move(Err))); +} + +mir2vec::MIRVocabulary +MIR2VecVocabLegacyAnalysis::getMIR2VecVocabulary(const Module &M) { + if (StrVocabMap.empty()) { + if (Error Err = readVocabulary()) { + emitError(std::move(Err), M.getContext()); + return mir2vec::MIRVocabulary(std::move(StrVocabMap), nullptr); + } + } + + // Get machine module info to access machine functions and target info + MachineModuleInfo &MMI = getAnalysis<MachineModuleInfoWrapperPass>().getMMI(); + + // Find first available machine function to get target instruction info + for (const auto &F : M) { + if (F.isDeclaration()) + continue; + + if (auto *MF = MMI.getMachineFunction(F)) { + const TargetInstrInfo *TII = MF->getSubtarget().getInstrInfo(); + return mir2vec::MIRVocabulary(std::move(StrVocabMap), TII); + } + } + + // No machine functions available - return invalid vocabulary + emitError(make_error<StringError>("No machine functions found in module", + inconvertibleErrorCode()), + M.getContext()); + return mir2vec::MIRVocabulary(std::move(StrVocabMap), nullptr); +} + +//===----------------------------------------------------------------------===// +// Printer Passes Implementation +//===----------------------------------------------------------------------===// + +char MIR2VecVocabPrinterLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(MIR2VecVocabPrinterLegacyPass, "print-mir2vec-vocab", + "MIR2Vec Vocabulary Printer Pass", false, true) +INITIALIZE_PASS_DEPENDENCY(MIR2VecVocabLegacyAnalysis) +INITIALIZE_PASS_DEPENDENCY(MachineModuleInfoWrapperPass) +INITIALIZE_PASS_END(MIR2VecVocabPrinterLegacyPass, "print-mir2vec-vocab", + "MIR2Vec Vocabulary Printer Pass", false, true) + +bool MIR2VecVocabPrinterLegacyPass::runOnMachineFunction(MachineFunction &MF) { + return false; +} + +bool MIR2VecVocabPrinterLegacyPass::doFinalization(Module &M) { + auto &Analysis = getAnalysis<MIR2VecVocabLegacyAnalysis>(); + auto MIR2VecVocab = Analysis.getMIR2VecVocabulary(M); + + if (!MIR2VecVocab.isValid()) { + OS << "MIR2Vec Vocabulary Printer: Invalid vocabulary\n"; + return false; + } + + unsigned Pos = 0; + for (const auto &Entry : MIR2VecVocab) { + OS << "Key: " << MIR2VecVocab.getStringKey(Pos++) << ": "; + Entry.print(OS); + } + + return false; +} + +MachineFunctionPass * +llvm::createMIR2VecVocabPrinterLegacyPass(raw_ostream &OS) { + return new MIR2VecVocabPrinterLegacyPass(OS); +} diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp index 83bb1df..b5f8a61 100644 --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp @@ -3740,7 +3740,11 @@ bool DAGTypeLegalizer::SoftPromoteHalfOperand(SDNode *N, unsigned OpNo) { case ISD::STRICT_FP_TO_SINT: case ISD::STRICT_FP_TO_UINT: case ISD::FP_TO_SINT: - case ISD::FP_TO_UINT: Res = SoftPromoteHalfOp_FP_TO_XINT(N); break; + case ISD::FP_TO_UINT: + case ISD::LRINT: + case ISD::LLRINT: + Res = SoftPromoteHalfOp_Op0WithStrict(N); + break; case ISD::FP_TO_SINT_SAT: case ISD::FP_TO_UINT_SAT: Res = SoftPromoteHalfOp_FP_TO_XINT_SAT(N); break; @@ -3819,7 +3823,7 @@ SDValue DAGTypeLegalizer::SoftPromoteHalfOp_FP_EXTEND(SDNode *N) { return DAG.getNode(GetPromotionOpcode(SVT, RVT), SDLoc(N), RVT, Op); } -SDValue DAGTypeLegalizer::SoftPromoteHalfOp_FP_TO_XINT(SDNode *N) { +SDValue DAGTypeLegalizer::SoftPromoteHalfOp_Op0WithStrict(SDNode *N) { EVT RVT = N->getValueType(0); bool IsStrict = N->isStrictFPOpcode(); SDValue Op = N->getOperand(IsStrict ? 1 : 0); diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h index 586c341..d580ce0 100644 --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h @@ -843,7 +843,7 @@ private: SDValue SoftPromoteHalfOp_FAKE_USE(SDNode *N, unsigned OpNo); SDValue SoftPromoteHalfOp_FCOPYSIGN(SDNode *N, unsigned OpNo); SDValue SoftPromoteHalfOp_FP_EXTEND(SDNode *N); - SDValue SoftPromoteHalfOp_FP_TO_XINT(SDNode *N); + SDValue SoftPromoteHalfOp_Op0WithStrict(SDNode *N); SDValue SoftPromoteHalfOp_FP_TO_XINT_SAT(SDNode *N); SDValue SoftPromoteHalfOp_SETCC(SDNode *N); SDValue SoftPromoteHalfOp_SELECT_CC(SDNode *N, unsigned OpNo); |