//===- DXILTranslateMetadata.cpp - Pass to emit DXIL metadata -------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "DXILTranslateMetadata.h" #include "DXILShaderFlags.h" #include "DirectX.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Twine.h" #include "llvm/Analysis/DXILMetadataAnalysis.h" #include "llvm/Analysis/DXILResource.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/DiagnosticPrinter.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/VersionTuple.h" #include "llvm/TargetParser/Triple.h" #include using namespace llvm; using namespace llvm::dxil; namespace { /// A simple Wrapper DiagnosticInfo that generates Module-level diagnostic /// for TranslateMetadata pass class DiagnosticInfoTranslateMD : public DiagnosticInfo { private: const Twine &Msg; const Module &Mod; public: /// \p M is the module for which the diagnostic is being emitted. \p Msg is /// the message to show. Note that this class does not copy this message, so /// this reference must be valid for the whole life time of the diagnostic. DiagnosticInfoTranslateMD(const Module &M, const Twine &Msg LLVM_LIFETIME_BOUND, DiagnosticSeverity Severity = DS_Error) : DiagnosticInfo(DK_Unsupported, Severity), Msg(Msg), Mod(M) {} void print(DiagnosticPrinter &DP) const override { DP << Mod.getName() << ": " << Msg << '\n'; } }; enum class EntryPropsTag { ShaderFlags = 0, GSState, DSState, HSState, NumThreads, AutoBindingSpace, RayPayloadSize, RayAttribSize, ShaderKind, MSState, ASStateTag, WaveSize, EntryRootSig, }; } // namespace static NamedMDNode *emitResourceMetadata(Module &M, DXILResourceMap &DRM, DXILResourceTypeMap &DRTM) { LLVMContext &Context = M.getContext(); for (ResourceInfo &RI : DRM) if (!RI.hasSymbol()) RI.createSymbol(M, DRTM[RI.getHandleTy()].createElementStruct(RI.getName())); SmallVector SRVs, UAVs, CBufs, Smps; for (const ResourceInfo &RI : DRM.srvs()) SRVs.push_back(RI.getAsMetadata(M, DRTM[RI.getHandleTy()])); for (const ResourceInfo &RI : DRM.uavs()) UAVs.push_back(RI.getAsMetadata(M, DRTM[RI.getHandleTy()])); for (const ResourceInfo &RI : DRM.cbuffers()) CBufs.push_back(RI.getAsMetadata(M, DRTM[RI.getHandleTy()])); for (const ResourceInfo &RI : DRM.samplers()) Smps.push_back(RI.getAsMetadata(M, DRTM[RI.getHandleTy()])); Metadata *SRVMD = SRVs.empty() ? nullptr : MDNode::get(Context, SRVs); Metadata *UAVMD = UAVs.empty() ? nullptr : MDNode::get(Context, UAVs); Metadata *CBufMD = CBufs.empty() ? nullptr : MDNode::get(Context, CBufs); Metadata *SmpMD = Smps.empty() ? nullptr : MDNode::get(Context, Smps); if (DRM.empty()) return nullptr; NamedMDNode *ResourceMD = M.getOrInsertNamedMetadata("dx.resources"); ResourceMD->addOperand( MDNode::get(M.getContext(), {SRVMD, UAVMD, CBufMD, SmpMD})); return ResourceMD; } static StringRef getShortShaderStage(Triple::EnvironmentType Env) { switch (Env) { case Triple::Pixel: return "ps"; case Triple::Vertex: return "vs"; case Triple::Geometry: return "gs"; case Triple::Hull: return "hs"; case Triple::Domain: return "ds"; case Triple::Compute: return "cs"; case Triple::Library: return "lib"; case Triple::Mesh: return "ms"; case Triple::Amplification: return "as"; default: break; } llvm_unreachable("Unsupported environment for DXIL generation."); } static uint32_t getShaderStage(Triple::EnvironmentType Env) { return (uint32_t)Env - (uint32_t)llvm::Triple::Pixel; } static SmallVector getTagValueAsMetadata(EntryPropsTag Tag, uint64_t Value, LLVMContext &Ctx) { SmallVector MDVals; MDVals.emplace_back(ConstantAsMetadata::get( ConstantInt::get(Type::getInt32Ty(Ctx), static_cast(Tag)))); switch (Tag) { case EntryPropsTag::ShaderFlags: MDVals.emplace_back(ConstantAsMetadata::get( ConstantInt::get(Type::getInt64Ty(Ctx), Value))); break; case EntryPropsTag::ShaderKind: MDVals.emplace_back(ConstantAsMetadata::get( ConstantInt::get(Type::getInt32Ty(Ctx), Value))); break; case EntryPropsTag::GSState: case EntryPropsTag::DSState: case EntryPropsTag::HSState: case EntryPropsTag::NumThreads: case EntryPropsTag::AutoBindingSpace: case EntryPropsTag::RayPayloadSize: case EntryPropsTag::RayAttribSize: case EntryPropsTag::MSState: case EntryPropsTag::ASStateTag: case EntryPropsTag::WaveSize: case EntryPropsTag::EntryRootSig: llvm_unreachable("NYI: Unhandled entry property tag"); } return MDVals; } static MDTuple * getEntryPropAsMetadata(const EntryProperties &EP, uint64_t EntryShaderFlags, const Triple::EnvironmentType ShaderProfile) { SmallVector MDVals; LLVMContext &Ctx = EP.Entry->getContext(); if (EntryShaderFlags != 0) MDVals.append(getTagValueAsMetadata(EntryPropsTag::ShaderFlags, EntryShaderFlags, Ctx)); if (EP.Entry != nullptr) { // FIXME: support more props. // See https://github.com/llvm/llvm-project/issues/57948. // Add shader kind for lib entries. if (ShaderProfile == Triple::EnvironmentType::Library && EP.ShaderStage != Triple::EnvironmentType::Library) MDVals.append(getTagValueAsMetadata(EntryPropsTag::ShaderKind, getShaderStage(EP.ShaderStage), Ctx)); if (EP.ShaderStage == Triple::EnvironmentType::Compute) { MDVals.emplace_back(ConstantAsMetadata::get(ConstantInt::get( Type::getInt32Ty(Ctx), static_cast(EntryPropsTag::NumThreads)))); Metadata *NumThreadVals[] = {ConstantAsMetadata::get(ConstantInt::get( Type::getInt32Ty(Ctx), EP.NumThreadsX)), ConstantAsMetadata::get(ConstantInt::get( Type::getInt32Ty(Ctx), EP.NumThreadsY)), ConstantAsMetadata::get(ConstantInt::get( Type::getInt32Ty(Ctx), EP.NumThreadsZ))}; MDVals.emplace_back(MDNode::get(Ctx, NumThreadVals)); } } if (MDVals.empty()) return nullptr; return MDNode::get(Ctx, MDVals); } MDTuple *constructEntryMetadata(const Function *EntryFn, MDTuple *Signatures, MDNode *Resources, MDTuple *Properties, LLVMContext &Ctx) { // Each entry point metadata record specifies: // * reference to the entry point function global symbol // * unmangled name // * list of signatures // * list of resources // * list of tag-value pairs of shader capabilities and other properties Metadata *MDVals[5]; MDVals[0] = EntryFn ? ValueAsMetadata::get(const_cast(EntryFn)) : nullptr; MDVals[1] = MDString::get(Ctx, EntryFn ? EntryFn->getName() : ""); MDVals[2] = Signatures; MDVals[3] = Resources; MDVals[4] = Properties; return MDNode::get(Ctx, MDVals); } static MDTuple *emitEntryMD(const EntryProperties &EP, MDTuple *Signatures, MDNode *MDResources, const uint64_t EntryShaderFlags, const Triple::EnvironmentType ShaderProfile) { MDTuple *Properties = getEntryPropAsMetadata(EP, EntryShaderFlags, ShaderProfile); return constructEntryMetadata(EP.Entry, Signatures, MDResources, Properties, EP.Entry->getContext()); } static void emitValidatorVersionMD(Module &M, const ModuleMetadataInfo &MMDI) { if (MMDI.ValidatorVersion.empty()) return; LLVMContext &Ctx = M.getContext(); IRBuilder<> IRB(Ctx); Metadata *MDVals[2]; MDVals[0] = ConstantAsMetadata::get(IRB.getInt32(MMDI.ValidatorVersion.getMajor())); MDVals[1] = ConstantAsMetadata::get( IRB.getInt32(MMDI.ValidatorVersion.getMinor().value_or(0))); NamedMDNode *ValVerNode = M.getOrInsertNamedMetadata("dx.valver"); // Set validator version obtained from DXIL Metadata Analysis pass ValVerNode->clearOperands(); ValVerNode->addOperand(MDNode::get(Ctx, MDVals)); } static void emitShaderModelVersionMD(Module &M, const ModuleMetadataInfo &MMDI) { LLVMContext &Ctx = M.getContext(); IRBuilder<> IRB(Ctx); Metadata *SMVals[3]; VersionTuple SM = MMDI.ShaderModelVersion; SMVals[0] = MDString::get(Ctx, getShortShaderStage(MMDI.ShaderProfile)); SMVals[1] = ConstantAsMetadata::get(IRB.getInt32(SM.getMajor())); SMVals[2] = ConstantAsMetadata::get(IRB.getInt32(SM.getMinor().value_or(0))); NamedMDNode *SMMDNode = M.getOrInsertNamedMetadata("dx.shaderModel"); SMMDNode->addOperand(MDNode::get(Ctx, SMVals)); } static void emitDXILVersionTupleMD(Module &M, const ModuleMetadataInfo &MMDI) { LLVMContext &Ctx = M.getContext(); IRBuilder<> IRB(Ctx); VersionTuple DXILVer = MMDI.DXILVersion; Metadata *DXILVals[2]; DXILVals[0] = ConstantAsMetadata::get(IRB.getInt32(DXILVer.getMajor())); DXILVals[1] = ConstantAsMetadata::get(IRB.getInt32(DXILVer.getMinor().value_or(0))); NamedMDNode *DXILVerMDNode = M.getOrInsertNamedMetadata("dx.version"); DXILVerMDNode->addOperand(MDNode::get(Ctx, DXILVals)); } static MDTuple *emitTopLevelLibraryNode(Module &M, MDNode *RMD, uint64_t ShaderFlags) { LLVMContext &Ctx = M.getContext(); MDTuple *Properties = nullptr; if (ShaderFlags != 0) { SmallVector MDVals; MDVals.append( getTagValueAsMetadata(EntryPropsTag::ShaderFlags, ShaderFlags, Ctx)); Properties = MDNode::get(Ctx, MDVals); } // Library has an entry metadata with resource table metadata and all other // MDNodes as null. return constructEntryMetadata(nullptr, nullptr, RMD, Properties, Ctx); } // TODO: We might need to refactor this to be more generic, // in case we need more metadata to be replaced. static void translateBranchMetadata(Module &M) { for (Function &F : M) { for (BasicBlock &BB : F) { Instruction *BBTerminatorInst = BB.getTerminator(); MDNode *HlslControlFlowMD = BBTerminatorInst->getMetadata("hlsl.controlflow.hint"); if (!HlslControlFlowMD) continue; assert(HlslControlFlowMD->getNumOperands() == 2 && "invalid operands for hlsl.controlflow.hint"); MDBuilder MDHelper(M.getContext()); ConstantInt *Op1 = mdconst::extract(HlslControlFlowMD->getOperand(1)); SmallVector Vals( ArrayRef{MDHelper.createString("dx.controlflow.hints"), MDHelper.createConstant(Op1)}); MDNode *MDNode = llvm::MDNode::get(M.getContext(), Vals); BBTerminatorInst->setMetadata("dx.controlflow.hints", MDNode); BBTerminatorInst->setMetadata("hlsl.controlflow.hint", nullptr); } } } static void translateMetadata(Module &M, DXILResourceMap &DRM, DXILResourceTypeMap &DRTM, const ModuleShaderFlags &ShaderFlags, const ModuleMetadataInfo &MMDI) { LLVMContext &Ctx = M.getContext(); IRBuilder<> IRB(Ctx); SmallVector EntryFnMDNodes; emitValidatorVersionMD(M, MMDI); emitShaderModelVersionMD(M, MMDI); emitDXILVersionTupleMD(M, MMDI); NamedMDNode *NamedResourceMD = emitResourceMetadata(M, DRM, DRTM); auto *ResourceMD = (NamedResourceMD != nullptr) ? NamedResourceMD->getOperand(0) : nullptr; // FIXME: Add support to construct Signatures // See https://github.com/llvm/llvm-project/issues/57928 MDTuple *Signatures = nullptr; if (MMDI.ShaderProfile == Triple::EnvironmentType::Library) { // Get the combined shader flag mask of all functions in the library to be // used as shader flags mask value associated with top-level library entry // metadata. uint64_t CombinedMask = ShaderFlags.getCombinedFlags(); EntryFnMDNodes.emplace_back( emitTopLevelLibraryNode(M, ResourceMD, CombinedMask)); } else if (MMDI.EntryPropertyVec.size() > 1) { M.getContext().diagnose(DiagnosticInfoTranslateMD( M, "Non-library shader: One and only one entry expected")); } for (const EntryProperties &EntryProp : MMDI.EntryPropertyVec) { const ComputedShaderFlags &EntrySFMask = ShaderFlags.getFunctionFlags(EntryProp.Entry); // If ShaderProfile is Library, mask is already consolidated in the // top-level library node. Hence it is not emitted. uint64_t EntryShaderFlags = 0; if (MMDI.ShaderProfile != Triple::EnvironmentType::Library) { EntryShaderFlags = EntrySFMask; if (EntryProp.ShaderStage != MMDI.ShaderProfile) { M.getContext().diagnose(DiagnosticInfoTranslateMD( M, "Shader stage '" + Twine(getShortShaderStage(EntryProp.ShaderStage) + "' for entry '" + Twine(EntryProp.Entry->getName()) + "' different from specified target profile '" + Twine(Triple::getEnvironmentTypeName(MMDI.ShaderProfile) + "'")))); } } EntryFnMDNodes.emplace_back(emitEntryMD(EntryProp, Signatures, ResourceMD, EntryShaderFlags, MMDI.ShaderProfile)); } NamedMDNode *EntryPointsNamedMD = M.getOrInsertNamedMetadata("dx.entryPoints"); for (auto *Entry : EntryFnMDNodes) EntryPointsNamedMD->addOperand(Entry); } PreservedAnalyses DXILTranslateMetadata::run(Module &M, ModuleAnalysisManager &MAM) { DXILResourceMap &DRM = MAM.getResult(M); DXILResourceTypeMap &DRTM = MAM.getResult(M); const ModuleShaderFlags &ShaderFlags = MAM.getResult(M); const dxil::ModuleMetadataInfo MMDI = MAM.getResult(M); translateMetadata(M, DRM, DRTM, ShaderFlags, MMDI); translateBranchMetadata(M); return PreservedAnalyses::all(); } namespace { class DXILTranslateMetadataLegacy : public ModulePass { public: static char ID; // Pass identification, replacement for typeid explicit DXILTranslateMetadataLegacy() : ModulePass(ID) {} StringRef getPassName() const override { return "DXIL Translate Metadata"; } void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired(); AU.addRequired(); AU.addRequired(); AU.addRequired(); AU.addPreserved(); AU.addPreserved(); AU.addPreserved(); AU.addPreserved(); } bool runOnModule(Module &M) override { DXILResourceMap &DRM = getAnalysis().getResourceMap(); DXILResourceTypeMap &DRTM = getAnalysis().getResourceTypeMap(); const ModuleShaderFlags &ShaderFlags = getAnalysis().getShaderFlags(); dxil::ModuleMetadataInfo MMDI = getAnalysis().getModuleMetadata(); translateMetadata(M, DRM, DRTM, ShaderFlags, MMDI); translateBranchMetadata(M); return true; } }; } // namespace char DXILTranslateMetadataLegacy::ID = 0; ModulePass *llvm::createDXILTranslateMetadataLegacyPass() { return new DXILTranslateMetadataLegacy(); } INITIALIZE_PASS_BEGIN(DXILTranslateMetadataLegacy, "dxil-translate-metadata", "DXIL Translate Metadata", false, false) INITIALIZE_PASS_DEPENDENCY(DXILResourceWrapperPass) INITIALIZE_PASS_DEPENDENCY(ShaderFlagsAnalysisWrapper) INITIALIZE_PASS_DEPENDENCY(DXILMetadataAnalysisWrapperPass) INITIALIZE_PASS_END(DXILTranslateMetadataLegacy, "dxil-translate-metadata", "DXIL Translate Metadata", false, false)