//===-- SPIRVAsmPrinter.cpp - SPIR-V LLVM assembly writer ------*- C++ -*--===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file contains a printer that converts from our internal representation // of machine-dependent LLVM code to the SPIR-V assembly language. // //===----------------------------------------------------------------------===// #include "MCTargetDesc/SPIRVInstPrinter.h" #include "SPIRV.h" #include "SPIRVInstrInfo.h" #include "SPIRVMCInstLower.h" #include "SPIRVModuleAnalysis.h" #include "SPIRVSubtarget.h" #include "SPIRVTargetMachine.h" #include "SPIRVUtils.h" #include "TargetInfo/SPIRVTargetInfo.h" #include "llvm/ADT/DenseMap.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/CodeGen/AsmPrinter.h" #include "llvm/CodeGen/MachineConstantPool.h" #include "llvm/CodeGen/MachineInstr.h" #include "llvm/CodeGen/MachineModuleInfo.h" #include "llvm/CodeGen/TargetLoweringObjectFileImpl.h" #include "llvm/MC/MCAsmInfo.h" #include "llvm/MC/MCAssembler.h" #include "llvm/MC/MCInst.h" #include "llvm/MC/MCObjectStreamer.h" #include "llvm/MC/MCSPIRVObjectWriter.h" #include "llvm/MC/MCStreamer.h" #include "llvm/MC/MCSymbol.h" #include "llvm/MC/TargetRegistry.h" #include "llvm/Support/Compiler.h" #include "llvm/Support/raw_ostream.h" using namespace llvm; #define DEBUG_TYPE "asm-printer" namespace { class SPIRVAsmPrinter : public AsmPrinter { unsigned NLabels = 0; SmallPtrSet LabeledMBB; public: explicit SPIRVAsmPrinter(TargetMachine &TM, std::unique_ptr Streamer) : AsmPrinter(TM, std::move(Streamer), ID), ModuleSectionsEmitted(false), ST(nullptr), TII(nullptr), MAI(nullptr) {} static char ID; bool ModuleSectionsEmitted; const SPIRVSubtarget *ST; const SPIRVInstrInfo *TII; StringRef getPassName() const override { return "SPIRV Assembly Printer"; } void printOperand(const MachineInstr *MI, int OpNum, raw_ostream &O); bool PrintAsmOperand(const MachineInstr *MI, unsigned OpNo, const char *ExtraCode, raw_ostream &O) override; void outputMCInst(MCInst &Inst); void outputInstruction(const MachineInstr *MI); void outputModuleSection(SPIRV::ModuleSectionType MSType); void outputGlobalRequirements(); void outputEntryPoints(); void outputDebugSourceAndStrings(const Module &M); void outputOpExtInstImports(const Module &M); void outputOpMemoryModel(); void outputOpFunctionEnd(); void outputExtFuncDecls(); void outputExecutionModeFromMDNode(MCRegister Reg, MDNode *Node, SPIRV::ExecutionMode::ExecutionMode EM, unsigned ExpectMDOps, int64_t DefVal); void outputExecutionModeFromNumthreadsAttribute( const MCRegister &Reg, const Attribute &Attr, SPIRV::ExecutionMode::ExecutionMode EM); void outputExecutionModeFromEnableMaximalReconvergenceAttr( const MCRegister &Reg, const SPIRVSubtarget &ST); void outputExecutionMode(const Module &M); void outputAnnotations(const Module &M); void outputModuleSections(); void outputFPFastMathDefaultInfo(); bool isHidden() { return MF->getFunction() .getFnAttribute(SPIRV_BACKEND_SERVICE_FUN_NAME) .isValid(); } void emitInstruction(const MachineInstr *MI) override; void emitFunctionEntryLabel() override {} void emitFunctionHeader() override; void emitFunctionBodyStart() override {} void emitFunctionBodyEnd() override; void emitBasicBlockStart(const MachineBasicBlock &MBB) override; void emitBasicBlockEnd(const MachineBasicBlock &MBB) override {} void emitGlobalVariable(const GlobalVariable *GV) override {} void emitOpLabel(const MachineBasicBlock &MBB); void emitEndOfAsmFile(Module &M) override; bool doInitialization(Module &M) override; void getAnalysisUsage(AnalysisUsage &AU) const override; SPIRV::ModuleAnalysisInfo *MAI; protected: void cleanUp(Module &M); }; } // namespace void SPIRVAsmPrinter::getAnalysisUsage(AnalysisUsage &AU) const { AU.addRequired(); AU.addPreserved(); AsmPrinter::getAnalysisUsage(AU); } // If the module has no functions, we need output global info anyway. void SPIRVAsmPrinter::emitEndOfAsmFile(Module &M) { if (ModuleSectionsEmitted == false) { outputModuleSections(); ModuleSectionsEmitted = true; } ST = static_cast(TM).getSubtargetImpl(); VersionTuple SPIRVVersion = ST->getSPIRVVersion(); uint32_t Major = SPIRVVersion.getMajor(); uint32_t Minor = SPIRVVersion.getMinor().value_or(0); // Bound is an approximation that accounts for the maximum used register // number and number of generated OpLabels unsigned Bound = 2 * (ST->getBound() + 1) + NLabels; if (MCAssembler *Asm = OutStreamer->getAssemblerPtr()) static_cast(Asm->getWriter()) .setBuildVersion(Major, Minor, Bound); cleanUp(M); } // Any cleanup actions with the Module after we don't care about its content // anymore. void SPIRVAsmPrinter::cleanUp(Module &M) { // Verifier disallows uses of intrinsic global variables. for (StringRef GVName : {"llvm.global_ctors", "llvm.global_dtors", "llvm.used"}) { if (GlobalVariable *GV = M.getNamedGlobal(GVName)) GV->setName(""); } } void SPIRVAsmPrinter::emitFunctionHeader() { if (ModuleSectionsEmitted == false) { outputModuleSections(); ModuleSectionsEmitted = true; } // Get the subtarget from the current MachineFunction. ST = &MF->getSubtarget(); TII = ST->getInstrInfo(); const Function &F = MF->getFunction(); if (isVerbose() && !isHidden()) { OutStreamer->getCommentOS() << "-- Begin function " << GlobalValue::dropLLVMManglingEscape(F.getName()) << '\n'; } auto Section = getObjFileLowering().SectionForGlobal(&F, TM); MF->setSection(Section); } void SPIRVAsmPrinter::outputOpFunctionEnd() { MCInst FunctionEndInst; FunctionEndInst.setOpcode(SPIRV::OpFunctionEnd); outputMCInst(FunctionEndInst); } void SPIRVAsmPrinter::emitFunctionBodyEnd() { if (!isHidden()) outputOpFunctionEnd(); } void SPIRVAsmPrinter::emitOpLabel(const MachineBasicBlock &MBB) { // Do not emit anything if it's an internal service function. if (isHidden()) return; MCInst LabelInst; LabelInst.setOpcode(SPIRV::OpLabel); LabelInst.addOperand(MCOperand::createReg(MAI->getOrCreateMBBRegister(MBB))); outputMCInst(LabelInst); ++NLabels; LabeledMBB.insert(&MBB); } void SPIRVAsmPrinter::emitBasicBlockStart(const MachineBasicBlock &MBB) { // Do not emit anything if it's an internal service function. if (MBB.empty()) return; // If it's the first MBB in MF, it has OpFunction and OpFunctionParameter, so // OpLabel should be output after them. if (MBB.getNumber() == MF->front().getNumber()) { for (const MachineInstr &MI : MBB) if (MI.getOpcode() == SPIRV::OpFunction) return; // TODO: this case should be checked by the verifier. report_fatal_error("OpFunction is expected in the front MBB of MF"); } emitOpLabel(MBB); } void SPIRVAsmPrinter::printOperand(const MachineInstr *MI, int OpNum, raw_ostream &O) { const MachineOperand &MO = MI->getOperand(OpNum); switch (MO.getType()) { case MachineOperand::MO_Register: O << SPIRVInstPrinter::getRegisterName(MO.getReg()); break; case MachineOperand::MO_Immediate: O << MO.getImm(); break; case MachineOperand::MO_FPImmediate: O << MO.getFPImm(); break; case MachineOperand::MO_MachineBasicBlock: O << *MO.getMBB()->getSymbol(); break; case MachineOperand::MO_GlobalAddress: O << *getSymbol(MO.getGlobal()); break; case MachineOperand::MO_BlockAddress: { MCSymbol *BA = GetBlockAddressSymbol(MO.getBlockAddress()); O << BA->getName(); break; } case MachineOperand::MO_ExternalSymbol: O << *GetExternalSymbolSymbol(MO.getSymbolName()); break; case MachineOperand::MO_JumpTableIndex: case MachineOperand::MO_ConstantPoolIndex: default: llvm_unreachable(""); } } bool SPIRVAsmPrinter::PrintAsmOperand(const MachineInstr *MI, unsigned OpNo, const char *ExtraCode, raw_ostream &O) { if (ExtraCode && ExtraCode[0]) return true; // Invalid instruction - SPIR-V does not have special modifiers printOperand(MI, OpNo, O); return false; } static bool isFuncOrHeaderInstr(const MachineInstr *MI, const SPIRVInstrInfo *TII) { return TII->isHeaderInstr(*MI) || MI->getOpcode() == SPIRV::OpFunction || MI->getOpcode() == SPIRV::OpFunctionParameter; } void SPIRVAsmPrinter::outputMCInst(MCInst &Inst) { OutStreamer->emitInstruction(Inst, *OutContext.getSubtargetInfo()); } void SPIRVAsmPrinter::outputInstruction(const MachineInstr *MI) { SPIRVMCInstLower MCInstLowering; MCInst TmpInst; MCInstLowering.lower(MI, TmpInst, MAI); outputMCInst(TmpInst); } void SPIRVAsmPrinter::emitInstruction(const MachineInstr *MI) { SPIRV_MC::verifyInstructionPredicates(MI->getOpcode(), getSubtargetInfo().getFeatureBits()); if (!MAI->getSkipEmission(MI)) outputInstruction(MI); // Output OpLabel after OpFunction and OpFunctionParameter in the first MBB. const MachineInstr *NextMI = MI->getNextNode(); if (!LabeledMBB.contains(MI->getParent()) && isFuncOrHeaderInstr(MI, TII) && (!NextMI || !isFuncOrHeaderInstr(NextMI, TII))) { assert(MI->getParent()->getNumber() == MF->front().getNumber() && "OpFunction is not in the front MBB of MF"); emitOpLabel(*MI->getParent()); } } void SPIRVAsmPrinter::outputModuleSection(SPIRV::ModuleSectionType MSType) { for (const MachineInstr *MI : MAI->getMSInstrs(MSType)) outputInstruction(MI); } void SPIRVAsmPrinter::outputDebugSourceAndStrings(const Module &M) { // Output OpSourceExtensions. for (auto &Str : MAI->SrcExt) { MCInst Inst; Inst.setOpcode(SPIRV::OpSourceExtension); addStringImm(Str.first(), Inst); outputMCInst(Inst); } // Output OpString. outputModuleSection(SPIRV::MB_DebugStrings); // Output OpSource. MCInst Inst; Inst.setOpcode(SPIRV::OpSource); Inst.addOperand(MCOperand::createImm(static_cast(MAI->SrcLang))); Inst.addOperand( MCOperand::createImm(static_cast(MAI->SrcLangVersion))); outputMCInst(Inst); } void SPIRVAsmPrinter::outputOpExtInstImports(const Module &M) { for (auto &CU : MAI->ExtInstSetMap) { unsigned Set = CU.first; MCRegister Reg = CU.second; MCInst Inst; Inst.setOpcode(SPIRV::OpExtInstImport); Inst.addOperand(MCOperand::createReg(Reg)); addStringImm(getExtInstSetName( static_cast(Set)), Inst); outputMCInst(Inst); } } void SPIRVAsmPrinter::outputOpMemoryModel() { MCInst Inst; Inst.setOpcode(SPIRV::OpMemoryModel); Inst.addOperand(MCOperand::createImm(static_cast(MAI->Addr))); Inst.addOperand(MCOperand::createImm(static_cast(MAI->Mem))); outputMCInst(Inst); } // Before the OpEntryPoints' output, we need to add the entry point's // interfaces. The interface is a list of IDs of global OpVariable instructions. // These declare the set of global variables from a module that form // the interface of this entry point. void SPIRVAsmPrinter::outputEntryPoints() { // Find all OpVariable IDs with required StorageClass. DenseSet InterfaceIDs; for (const MachineInstr *MI : MAI->GlobalVarList) { assert(MI->getOpcode() == SPIRV::OpVariable); auto SC = static_cast( MI->getOperand(2).getImm()); // Before version 1.4, the interface's storage classes are limited to // the Input and Output storage classes. Starting with version 1.4, // the interface's storage classes are all storage classes used in // declaring all global variables referenced by the entry point call tree. if (ST->isAtLeastSPIRVVer(VersionTuple(1, 4)) || SC == SPIRV::StorageClass::Input || SC == SPIRV::StorageClass::Output) { const MachineFunction *MF = MI->getMF(); MCRegister Reg = MAI->getRegisterAlias(MF, MI->getOperand(0).getReg()); InterfaceIDs.insert(Reg); } } // Output OpEntryPoints adding interface args to all of them. for (const MachineInstr *MI : MAI->getMSInstrs(SPIRV::MB_EntryPoints)) { SPIRVMCInstLower MCInstLowering; MCInst TmpInst; MCInstLowering.lower(MI, TmpInst, MAI); for (MCRegister Reg : InterfaceIDs) { assert(Reg.isValid()); TmpInst.addOperand(MCOperand::createReg(Reg)); } outputMCInst(TmpInst); } } // Create global OpCapability instructions for the required capabilities. void SPIRVAsmPrinter::outputGlobalRequirements() { // Abort here if not all requirements can be satisfied. MAI->Reqs.checkSatisfiable(*ST); for (const auto &Cap : MAI->Reqs.getMinimalCapabilities()) { MCInst Inst; Inst.setOpcode(SPIRV::OpCapability); Inst.addOperand(MCOperand::createImm(Cap)); outputMCInst(Inst); } // Generate the final OpExtensions with strings instead of enums. for (const auto &Ext : MAI->Reqs.getExtensions()) { MCInst Inst; Inst.setOpcode(SPIRV::OpExtension); addStringImm(getSymbolicOperandMnemonic( SPIRV::OperandCategory::ExtensionOperand, Ext), Inst); outputMCInst(Inst); } // TODO add a pseudo instr for version number. } void SPIRVAsmPrinter::outputExtFuncDecls() { // Insert OpFunctionEnd after each declaration. auto I = MAI->getMSInstrs(SPIRV::MB_ExtFuncDecls).begin(), E = MAI->getMSInstrs(SPIRV::MB_ExtFuncDecls).end(); for (; I != E; ++I) { outputInstruction(*I); if ((I + 1) == E || (*(I + 1))->getOpcode() == SPIRV::OpFunction) outputOpFunctionEnd(); } } // Encode LLVM type by SPIR-V execution mode VecTypeHint. static unsigned encodeVecTypeHint(Type *Ty) { if (Ty->isHalfTy()) return 4; if (Ty->isFloatTy()) return 5; if (Ty->isDoubleTy()) return 6; if (IntegerType *IntTy = dyn_cast(Ty)) { switch (IntTy->getIntegerBitWidth()) { case 8: return 0; case 16: return 1; case 32: return 2; case 64: return 3; default: llvm_unreachable("invalid integer type"); } } if (FixedVectorType *VecTy = dyn_cast(Ty)) { Type *EleTy = VecTy->getElementType(); unsigned Size = VecTy->getNumElements(); return Size << 16 | encodeVecTypeHint(EleTy); } llvm_unreachable("invalid type"); } static void addOpsFromMDNode(MDNode *MDN, MCInst &Inst, SPIRV::ModuleAnalysisInfo *MAI) { for (const MDOperand &MDOp : MDN->operands()) { if (auto *CMeta = dyn_cast(MDOp)) { Constant *C = CMeta->getValue(); if (ConstantInt *Const = dyn_cast(C)) { Inst.addOperand(MCOperand::createImm(Const->getZExtValue())); } else if (auto *CE = dyn_cast(C)) { MCRegister FuncReg = MAI->getFuncReg(CE); assert(FuncReg.isValid()); Inst.addOperand(MCOperand::createReg(FuncReg)); } } } } void SPIRVAsmPrinter::outputExecutionModeFromMDNode( MCRegister Reg, MDNode *Node, SPIRV::ExecutionMode::ExecutionMode EM, unsigned ExpectMDOps, int64_t DefVal) { MCInst Inst; Inst.setOpcode(SPIRV::OpExecutionMode); Inst.addOperand(MCOperand::createReg(Reg)); Inst.addOperand(MCOperand::createImm(static_cast(EM))); addOpsFromMDNode(Node, Inst, MAI); // reqd_work_group_size and work_group_size_hint require 3 operands, // if metadata contains less operands, just add a default value unsigned NodeSz = Node->getNumOperands(); if (ExpectMDOps > 0 && NodeSz < ExpectMDOps) for (unsigned i = NodeSz; i < ExpectMDOps; ++i) Inst.addOperand(MCOperand::createImm(DefVal)); outputMCInst(Inst); } void SPIRVAsmPrinter::outputExecutionModeFromNumthreadsAttribute( const MCRegister &Reg, const Attribute &Attr, SPIRV::ExecutionMode::ExecutionMode EM) { assert(Attr.isValid() && "Function called with an invalid attribute."); MCInst Inst; Inst.setOpcode(SPIRV::OpExecutionMode); Inst.addOperand(MCOperand::createReg(Reg)); Inst.addOperand(MCOperand::createImm(static_cast(EM))); SmallVector NumThreads; Attr.getValueAsString().split(NumThreads, ','); assert(NumThreads.size() == 3 && "invalid numthreads"); for (uint32_t i = 0; i < 3; ++i) { uint32_t V; [[maybe_unused]] bool Result = NumThreads[i].getAsInteger(10, V); assert(!Result && "Failed to parse numthreads"); Inst.addOperand(MCOperand::createImm(V)); } outputMCInst(Inst); } void SPIRVAsmPrinter::outputExecutionModeFromEnableMaximalReconvergenceAttr( const MCRegister &Reg, const SPIRVSubtarget &ST) { assert(ST.canUseExtension(SPIRV::Extension::SPV_KHR_maximal_reconvergence) && "Function called when SPV_KHR_maximal_reconvergence is not enabled."); MCInst Inst; Inst.setOpcode(SPIRV::OpExecutionMode); Inst.addOperand(MCOperand::createReg(Reg)); unsigned EM = static_cast(SPIRV::ExecutionMode::MaximallyReconvergesKHR); Inst.addOperand(MCOperand::createImm(EM)); outputMCInst(Inst); } void SPIRVAsmPrinter::outputExecutionMode(const Module &M) { NamedMDNode *Node = M.getNamedMetadata("spirv.ExecutionMode"); if (Node) { for (unsigned i = 0; i < Node->getNumOperands(); i++) { // If SPV_KHR_float_controls2 is enabled and we find any of // FPFastMathDefault, ContractionOff or SignedZeroInfNanPreserve execution // modes, skip it, it'll be done somewhere else. if (ST->canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2)) { const auto EM = cast( cast((Node->getOperand(i))->getOperand(1)) ->getValue()) ->getZExtValue(); if (EM == SPIRV::ExecutionMode::FPFastMathDefault || EM == SPIRV::ExecutionMode::ContractionOff || EM == SPIRV::ExecutionMode::SignedZeroInfNanPreserve) continue; } MCInst Inst; Inst.setOpcode(SPIRV::OpExecutionMode); addOpsFromMDNode(cast(Node->getOperand(i)), Inst, MAI); outputMCInst(Inst); } outputFPFastMathDefaultInfo(); } for (auto FI = M.begin(), E = M.end(); FI != E; ++FI) { const Function &F = *FI; // Only operands of OpEntryPoint instructions are allowed to be // operands of OpExecutionMode if (F.isDeclaration() || !isEntryPoint(F)) continue; MCRegister FReg = MAI->getFuncReg(&F); assert(FReg.isValid()); if (Attribute Attr = F.getFnAttribute("hlsl.shader"); Attr.isValid()) { // SPIR-V common validation: Fragment requires OriginUpperLeft or // OriginLowerLeft. // VUID-StandaloneSpirv-OriginLowerLeft-04653: Fragment must declare // OriginUpperLeft. if (Attr.getValueAsString() == "pixel") { MCInst Inst; Inst.setOpcode(SPIRV::OpExecutionMode); Inst.addOperand(MCOperand::createReg(FReg)); unsigned EM = static_cast(SPIRV::ExecutionMode::OriginUpperLeft); Inst.addOperand(MCOperand::createImm(EM)); outputMCInst(Inst); } } if (MDNode *Node = F.getMetadata("reqd_work_group_size")) outputExecutionModeFromMDNode(FReg, Node, SPIRV::ExecutionMode::LocalSize, 3, 1); if (Attribute Attr = F.getFnAttribute("hlsl.numthreads"); Attr.isValid()) outputExecutionModeFromNumthreadsAttribute( FReg, Attr, SPIRV::ExecutionMode::LocalSize); if (Attribute Attr = F.getFnAttribute("enable-maximal-reconvergence"); Attr.getValueAsBool()) { outputExecutionModeFromEnableMaximalReconvergenceAttr(FReg, *ST); } if (MDNode *Node = F.getMetadata("work_group_size_hint")) outputExecutionModeFromMDNode(FReg, Node, SPIRV::ExecutionMode::LocalSizeHint, 3, 1); if (MDNode *Node = F.getMetadata("intel_reqd_sub_group_size")) outputExecutionModeFromMDNode(FReg, Node, SPIRV::ExecutionMode::SubgroupSize, 0, 0); if (MDNode *Node = F.getMetadata("max_work_group_size")) { if (ST->canUseExtension(SPIRV::Extension::SPV_INTEL_kernel_attributes)) outputExecutionModeFromMDNode( FReg, Node, SPIRV::ExecutionMode::MaxWorkgroupSizeINTEL, 3, 1); } if (MDNode *Node = F.getMetadata("vec_type_hint")) { MCInst Inst; Inst.setOpcode(SPIRV::OpExecutionMode); Inst.addOperand(MCOperand::createReg(FReg)); unsigned EM = static_cast(SPIRV::ExecutionMode::VecTypeHint); Inst.addOperand(MCOperand::createImm(EM)); unsigned TypeCode = encodeVecTypeHint(getMDOperandAsType(Node, 0)); Inst.addOperand(MCOperand::createImm(TypeCode)); outputMCInst(Inst); } if (ST->isKernel() && !M.getNamedMetadata("spirv.ExecutionMode") && !M.getNamedMetadata("opencl.enable.FP_CONTRACT")) { if (ST->canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2)) { // When SPV_KHR_float_controls2 is enabled, ContractionOff is // deprecated. We need to use FPFastMathDefault with the appropriate // flags instead. Since FPFastMathDefault takes a target type, we need // to emit it for each floating-point type that exists in the module // to match the effect of ContractionOff. As of now, there are 3 FP // types: fp16, fp32 and fp64. // We only end up here because there is no "spirv.ExecutionMode" // metadata, so that means no FPFastMathDefault. Therefore, we only // need to make sure AllowContract is set to 0, as the rest of flags. // We still need to emit the OpExecutionMode instruction, otherwise // it's up to the client API to define the flags. Therefore, we need // to find the constant with 0 value. // Collect the SPIRVTypes for fp16, fp32, and fp64 and the constant of // type int32 with 0 value to represent the FP Fast Math Mode. std::vector SPIRVFloatTypes; const MachineInstr *ConstZeroInt32 = nullptr; for (const MachineInstr *MI : MAI->getMSInstrs(SPIRV::MB_TypeConstVars)) { unsigned OpCode = MI->getOpcode(); // Collect the SPIRV type if it's a float. if (OpCode == SPIRV::OpTypeFloat) { // Skip if the target type is not fp16, fp32, fp64. const unsigned OpTypeFloatSize = MI->getOperand(1).getImm(); if (OpTypeFloatSize != 16 && OpTypeFloatSize != 32 && OpTypeFloatSize != 64) { continue; } SPIRVFloatTypes.push_back(MI); continue; } if (OpCode == SPIRV::OpConstantNull) { // Check if the constant is int32, if not skip it. const MachineRegisterInfo &MRI = MI->getMF()->getRegInfo(); MachineInstr *TypeMI = MRI.getVRegDef(MI->getOperand(1).getReg()); bool IsInt32Ty = TypeMI && TypeMI->getOpcode() == SPIRV::OpTypeInt && TypeMI->getOperand(1).getImm() == 32; if (IsInt32Ty) ConstZeroInt32 = MI; } } // When SPV_KHR_float_controls2 is enabled, ContractionOff is // deprecated. We need to use FPFastMathDefault with the appropriate // flags instead. Since FPFastMathDefault takes a target type, we need // to emit it for each floating-point type that exists in the module // to match the effect of ContractionOff. As of now, there are 3 FP // types: fp16, fp32 and fp64. for (const MachineInstr *MI : SPIRVFloatTypes) { MCInst Inst; Inst.setOpcode(SPIRV::OpExecutionModeId); Inst.addOperand(MCOperand::createReg(FReg)); unsigned EM = static_cast(SPIRV::ExecutionMode::FPFastMathDefault); Inst.addOperand(MCOperand::createImm(EM)); const MachineFunction *MF = MI->getMF(); MCRegister TypeReg = MAI->getRegisterAlias(MF, MI->getOperand(0).getReg()); Inst.addOperand(MCOperand::createReg(TypeReg)); assert(ConstZeroInt32 && "There should be a constant zero."); MCRegister ConstReg = MAI->getRegisterAlias( ConstZeroInt32->getMF(), ConstZeroInt32->getOperand(0).getReg()); Inst.addOperand(MCOperand::createReg(ConstReg)); outputMCInst(Inst); } } else { MCInst Inst; Inst.setOpcode(SPIRV::OpExecutionMode); Inst.addOperand(MCOperand::createReg(FReg)); unsigned EM = static_cast(SPIRV::ExecutionMode::ContractionOff); Inst.addOperand(MCOperand::createImm(EM)); outputMCInst(Inst); } } } } void SPIRVAsmPrinter::outputAnnotations(const Module &M) { outputModuleSection(SPIRV::MB_Annotations); // Process llvm.global.annotations special global variable. for (auto F = M.global_begin(), E = M.global_end(); F != E; ++F) { if ((*F).getName() != "llvm.global.annotations") continue; const GlobalVariable *V = &(*F); const ConstantArray *CA = cast(V->getOperand(0)); for (Value *Op : CA->operands()) { ConstantStruct *CS = cast(Op); // The first field of the struct contains a pointer to // the annotated variable. Value *AnnotatedVar = CS->getOperand(0)->stripPointerCasts(); if (!isa(AnnotatedVar)) report_fatal_error("Unsupported value in llvm.global.annotations"); Function *Func = cast(AnnotatedVar); MCRegister Reg = MAI->getFuncReg(Func); if (!Reg.isValid()) { std::string DiagMsg; raw_string_ostream OS(DiagMsg); AnnotatedVar->print(OS); DiagMsg = "Unknown function in llvm.global.annotations: " + DiagMsg; report_fatal_error(DiagMsg.c_str()); } // The second field contains a pointer to a global annotation string. GlobalVariable *GV = cast(CS->getOperand(1)->stripPointerCasts()); StringRef AnnotationString; [[maybe_unused]] bool Success = getConstantStringInfo(GV, AnnotationString); assert(Success && "Failed to get annotation string"); MCInst Inst; Inst.setOpcode(SPIRV::OpDecorate); Inst.addOperand(MCOperand::createReg(Reg)); unsigned Dec = static_cast(SPIRV::Decoration::UserSemantic); Inst.addOperand(MCOperand::createImm(Dec)); addStringImm(AnnotationString, Inst); outputMCInst(Inst); } } } void SPIRVAsmPrinter::outputFPFastMathDefaultInfo() { // Collect the SPIRVTypes that are OpTypeFloat and the constants of type // int32, that might be used as FP Fast Math Mode. std::vector SPIRVFloatTypes; // Hashtable to associate immediate values with the constant holding them. std::unordered_map ConstMap; for (const MachineInstr *MI : MAI->getMSInstrs(SPIRV::MB_TypeConstVars)) { // Skip if the instruction is not OpTypeFloat or OpConstant. unsigned OpCode = MI->getOpcode(); if (OpCode != SPIRV::OpTypeFloat && OpCode != SPIRV::OpConstantI && OpCode != SPIRV::OpConstantNull) continue; // Collect the SPIRV type if it's a float. if (OpCode == SPIRV::OpTypeFloat) { SPIRVFloatTypes.push_back(MI); } else { // Check if the constant is int32, if not skip it. const MachineRegisterInfo &MRI = MI->getMF()->getRegInfo(); MachineInstr *TypeMI = MRI.getVRegDef(MI->getOperand(1).getReg()); if (!TypeMI || TypeMI->getOpcode() != SPIRV::OpTypeInt || TypeMI->getOperand(1).getImm() != 32) continue; if (OpCode == SPIRV::OpConstantI) ConstMap[MI->getOperand(2).getImm()] = MI; else ConstMap[0] = MI; } } for (const auto &[Func, FPFastMathDefaultInfoVec] : MAI->FPFastMathDefaultInfoMap) { if (FPFastMathDefaultInfoVec.empty()) continue; for (const MachineInstr *MI : SPIRVFloatTypes) { unsigned OpTypeFloatSize = MI->getOperand(1).getImm(); unsigned Index = SPIRV::FPFastMathDefaultInfoVector:: computeFPFastMathDefaultInfoVecIndex(OpTypeFloatSize); assert(Index < FPFastMathDefaultInfoVec.size() && "Index out of bounds for FPFastMathDefaultInfoVec"); const auto &FPFastMathDefaultInfo = FPFastMathDefaultInfoVec[Index]; assert(FPFastMathDefaultInfo.Ty && "Expected target type for FPFastMathDefaultInfo"); assert(FPFastMathDefaultInfo.Ty->getScalarSizeInBits() == OpTypeFloatSize && "Mismatched float type size"); MCInst Inst; Inst.setOpcode(SPIRV::OpExecutionModeId); MCRegister FuncReg = MAI->getFuncReg(Func); assert(FuncReg.isValid()); Inst.addOperand(MCOperand::createReg(FuncReg)); Inst.addOperand( MCOperand::createImm(SPIRV::ExecutionMode::FPFastMathDefault)); MCRegister TypeReg = MAI->getRegisterAlias(MI->getMF(), MI->getOperand(0).getReg()); Inst.addOperand(MCOperand::createReg(TypeReg)); unsigned Flags = FPFastMathDefaultInfo.FastMathFlags; if (FPFastMathDefaultInfo.ContractionOff && (Flags & SPIRV::FPFastMathMode::AllowContract)) report_fatal_error( "Conflicting FPFastMathFlags: ContractionOff and AllowContract"); if (FPFastMathDefaultInfo.SignedZeroInfNanPreserve && !(Flags & (SPIRV::FPFastMathMode::NotNaN | SPIRV::FPFastMathMode::NotInf | SPIRV::FPFastMathMode::NSZ))) { if (FPFastMathDefaultInfo.FPFastMathDefault) report_fatal_error("Conflicting FPFastMathFlags: " "SignedZeroInfNanPreserve but at least one of " "NotNaN/NotInf/NSZ is enabled."); } // Don't emit if none of the execution modes was used. if (Flags == SPIRV::FPFastMathMode::None && !FPFastMathDefaultInfo.ContractionOff && !FPFastMathDefaultInfo.SignedZeroInfNanPreserve && !FPFastMathDefaultInfo.FPFastMathDefault) continue; // Retrieve the constant instruction for the immediate value. auto It = ConstMap.find(Flags); if (It == ConstMap.end()) report_fatal_error("Expected constant instruction for FP Fast Math " "Mode operand of FPFastMathDefault execution mode."); const MachineInstr *ConstMI = It->second; MCRegister ConstReg = MAI->getRegisterAlias( ConstMI->getMF(), ConstMI->getOperand(0).getReg()); Inst.addOperand(MCOperand::createReg(ConstReg)); outputMCInst(Inst); } } } void SPIRVAsmPrinter::outputModuleSections() { const Module *M = MMI->getModule(); // Get the global subtarget to output module-level info. ST = static_cast(TM).getSubtargetImpl(); TII = ST->getInstrInfo(); MAI = &SPIRVModuleAnalysis::MAI; assert(ST && TII && MAI && M && "Module analysis is required"); // Output instructions according to the Logical Layout of a Module: // 1,2. All OpCapability instructions, then optional OpExtension // instructions. outputGlobalRequirements(); // 3. Optional OpExtInstImport instructions. outputOpExtInstImports(*M); // 4. The single required OpMemoryModel instruction. outputOpMemoryModel(); // 5. All entry point declarations, using OpEntryPoint. outputEntryPoints(); // 6. Execution-mode declarations, using OpExecutionMode or // OpExecutionModeId. outputExecutionMode(*M); // 7a. Debug: all OpString, OpSourceExtension, OpSource, and // OpSourceContinued, without forward references. outputDebugSourceAndStrings(*M); // 7b. Debug: all OpName and all OpMemberName. outputModuleSection(SPIRV::MB_DebugNames); // 7c. Debug: all OpModuleProcessed instructions. outputModuleSection(SPIRV::MB_DebugModuleProcessed); // xxx. SPV_INTEL_memory_access_aliasing instructions go before 8. // "All annotation instructions" outputModuleSection(SPIRV::MB_AliasingInsts); // 8. All annotation instructions (all decorations). outputAnnotations(*M); // 9. All type declarations (OpTypeXXX instructions), all constant // instructions, and all global variable declarations. This section is // the first section to allow use of: OpLine and OpNoLine debug information; // non-semantic instructions with OpExtInst. outputModuleSection(SPIRV::MB_TypeConstVars); // 10. All global NonSemantic.Shader.DebugInfo.100 instructions. outputModuleSection(SPIRV::MB_NonSemanticGlobalDI); // 11. All function declarations (functions without a body). outputExtFuncDecls(); // 12. All function definitions (functions with a body). // This is done in regular function output. } bool SPIRVAsmPrinter::doInitialization(Module &M) { ModuleSectionsEmitted = false; // We need to call the parent's one explicitly. return AsmPrinter::doInitialization(M); } char SPIRVAsmPrinter::ID = 0; INITIALIZE_PASS(SPIRVAsmPrinter, "spirv-asm-printer", "SPIRV Assembly Printer", false, false) // Force static initialization. extern "C" LLVM_ABI LLVM_EXTERNAL_VISIBILITY void LLVMInitializeSPIRVAsmPrinter() { RegisterAsmPrinter X(getTheSPIRV32Target()); RegisterAsmPrinter Y(getTheSPIRV64Target()); RegisterAsmPrinter Z(getTheSPIRVLogicalTarget()); }