//===- DXContainerGlobals.cpp - DXContainer global generator pass ---------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // DXContainerGlobalsPass implementation. // //===----------------------------------------------------------------------===// #include "DXILRootSignature.h" #include "DXILShaderFlags.h" #include "DirectX.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/Analysis/DXILMetadataAnalysis.h" #include "llvm/Analysis/DXILResource.h" #include "llvm/BinaryFormat/DXContainer.h" #include "llvm/CodeGen/Passes.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Module.h" #include "llvm/InitializePasses.h" #include "llvm/MC/DXContainerPSVInfo.h" #include "llvm/Pass.h" #include "llvm/Support/MD5.h" #include "llvm/TargetParser/Triple.h" #include "llvm/Transforms/Utils/ModuleUtils.h" #include using namespace llvm; using namespace llvm::dxil; using namespace llvm::mcdxbc; namespace { class DXContainerGlobals : public llvm::ModulePass { GlobalVariable *buildContainerGlobal(Module &M, Constant *Content, StringRef Name, StringRef SectionName); GlobalVariable *getFeatureFlags(Module &M); GlobalVariable *computeShaderHash(Module &M); GlobalVariable *buildSignature(Module &M, Signature &Sig, StringRef Name, StringRef SectionName); void addSignature(Module &M, SmallVector &Globals); void addRootSignature(Module &M, SmallVector &Globals); void addResourcesForPSV(Module &M, PSVRuntimeInfo &PSV); void addPipelineStateValidationInfo(Module &M, SmallVector &Globals); public: static char ID; // Pass identification, replacement for typeid DXContainerGlobals() : ModulePass(ID) {} StringRef getPassName() const override { return "DXContainer Global Emitter"; } bool runOnModule(Module &M) override; void getAnalysisUsage(AnalysisUsage &AU) const override { AU.setPreservesAll(); AU.addRequired(); AU.addRequired(); AU.addRequired(); AU.addRequired(); AU.addRequired(); } }; } // namespace bool DXContainerGlobals::runOnModule(Module &M) { llvm::SmallVector Globals; Globals.push_back(getFeatureFlags(M)); Globals.push_back(computeShaderHash(M)); addSignature(M, Globals); addRootSignature(M, Globals); addPipelineStateValidationInfo(M, Globals); appendToCompilerUsed(M, Globals); return true; } GlobalVariable *DXContainerGlobals::getFeatureFlags(Module &M) { uint64_t CombinedFeatureFlags = getAnalysis() .getShaderFlags() .getCombinedFlags() .getFeatureFlags(); Constant *FeatureFlagsConstant = ConstantInt::get(M.getContext(), APInt(64, CombinedFeatureFlags)); return buildContainerGlobal(M, FeatureFlagsConstant, "dx.sfi0", "SFI0"); } GlobalVariable *DXContainerGlobals::computeShaderHash(Module &M) { auto *DXILConstant = cast(M.getNamedGlobal("dx.dxil")->getInitializer()); MD5 Digest; Digest.update(DXILConstant->getRawDataValues()); MD5::MD5Result Result = Digest.final(); dxbc::ShaderHash HashData = {0, {0}}; // The Hash's IncludesSource flag gets set whenever the hashed shader includes // debug information. if (M.debug_compile_units_begin() != M.debug_compile_units_end()) HashData.Flags = static_cast(dxbc::HashFlags::IncludesSource); memcpy(reinterpret_cast(&HashData.Digest), Result.data(), 16); if (sys::IsBigEndianHost) HashData.swapBytes(); StringRef Data(reinterpret_cast(&HashData), sizeof(dxbc::ShaderHash)); Constant *ModuleConstant = ConstantDataArray::get(M.getContext(), arrayRefFromStringRef(Data)); return buildContainerGlobal(M, ModuleConstant, "dx.hash", "HASH"); } GlobalVariable *DXContainerGlobals::buildContainerGlobal( Module &M, Constant *Content, StringRef Name, StringRef SectionName) { auto *GV = new llvm::GlobalVariable( M, Content->getType(), true, GlobalValue::PrivateLinkage, Content, Name); GV->setSection(SectionName); GV->setAlignment(Align(4)); return GV; } GlobalVariable *DXContainerGlobals::buildSignature(Module &M, Signature &Sig, StringRef Name, StringRef SectionName) { SmallString<256> Data; raw_svector_ostream OS(Data); Sig.write(OS); Constant *Constant = ConstantDataArray::getString(M.getContext(), Data, /*AddNull*/ false); return buildContainerGlobal(M, Constant, Name, SectionName); } void DXContainerGlobals::addSignature(Module &M, SmallVector &Globals) { // FIXME: support graphics shader. // see issue https://github.com/llvm/llvm-project/issues/90504. Signature InputSig; Globals.emplace_back(buildSignature(M, InputSig, "dx.isg1", "ISG1")); Signature OutputSig; Globals.emplace_back(buildSignature(M, OutputSig, "dx.osg1", "OSG1")); } void DXContainerGlobals::addRootSignature(Module &M, SmallVector &Globals) { dxil::ModuleMetadataInfo &MMI = getAnalysis().getModuleMetadata(); // Root Signature in Library don't compile to DXContainer. if (MMI.ShaderProfile == llvm::Triple::Library) return; assert(MMI.EntryPropertyVec.size() == 1); auto &RSA = getAnalysis().getRSInfo(); const Function *EntryFunction = MMI.EntryPropertyVec[0].Entry; const std::optional &RS = RSA.getDescForFunction(EntryFunction); if (!RS) return; SmallString<256> Data; raw_svector_ostream OS(Data); RS->write(OS); Constant *Constant = ConstantDataArray::getString(M.getContext(), Data, /*AddNull*/ false); Globals.emplace_back(buildContainerGlobal(M, Constant, "dx.rts0", "RTS0")); } void DXContainerGlobals::addResourcesForPSV(Module &M, PSVRuntimeInfo &PSV) { const DXILResourceMap &DRM = getAnalysis().getResourceMap(); DXILResourceTypeMap &DRTM = getAnalysis().getResourceTypeMap(); auto MakeBinding = [](const dxil::ResourceInfo::ResourceBinding &Binding, const dxbc::PSV::ResourceType Type, const dxil::ResourceKind Kind, const dxbc::PSV::ResourceFlags Flags = dxbc::PSV::ResourceFlags()) { dxbc::PSV::v2::ResourceBindInfo BindInfo; BindInfo.Type = Type; BindInfo.LowerBound = Binding.LowerBound; BindInfo.UpperBound = Binding.LowerBound + Binding.Size - 1; BindInfo.Space = Binding.Space; BindInfo.Kind = static_cast(Kind); BindInfo.Flags = Flags; return BindInfo; }; for (const dxil::ResourceInfo &RI : DRM.cbuffers()) { const dxil::ResourceInfo::ResourceBinding &Binding = RI.getBinding(); PSV.Resources.push_back(MakeBinding(Binding, dxbc::PSV::ResourceType::CBV, dxil::ResourceKind::CBuffer)); } for (const dxil::ResourceInfo &RI : DRM.samplers()) { const dxil::ResourceInfo::ResourceBinding &Binding = RI.getBinding(); PSV.Resources.push_back(MakeBinding(Binding, dxbc::PSV::ResourceType::Sampler, dxil::ResourceKind::Sampler)); } for (const dxil::ResourceInfo &RI : DRM.srvs()) { const dxil::ResourceInfo::ResourceBinding &Binding = RI.getBinding(); dxil::ResourceTypeInfo &TypeInfo = DRTM[RI.getHandleTy()]; dxbc::PSV::ResourceType ResType; if (TypeInfo.isStruct()) ResType = dxbc::PSV::ResourceType::SRVStructured; else if (TypeInfo.isTyped()) ResType = dxbc::PSV::ResourceType::SRVTyped; else ResType = dxbc::PSV::ResourceType::SRVRaw; PSV.Resources.push_back( MakeBinding(Binding, ResType, TypeInfo.getResourceKind())); } for (const dxil::ResourceInfo &RI : DRM.uavs()) { const dxil::ResourceInfo::ResourceBinding &Binding = RI.getBinding(); dxil::ResourceTypeInfo &TypeInfo = DRTM[RI.getHandleTy()]; dxbc::PSV::ResourceType ResType; if (RI.hasCounter()) ResType = dxbc::PSV::ResourceType::UAVStructuredWithCounter; else if (TypeInfo.isStruct()) ResType = dxbc::PSV::ResourceType::UAVStructured; else if (TypeInfo.isTyped()) ResType = dxbc::PSV::ResourceType::UAVTyped; else ResType = dxbc::PSV::ResourceType::UAVRaw; dxbc::PSV::ResourceFlags Flags; // TODO: Add support for dxbc::PSV::ResourceFlag::UsedByAtomic64, tracking // with https://github.com/llvm/llvm-project/issues/104392 Flags.Flags = 0u; PSV.Resources.push_back( MakeBinding(Binding, ResType, TypeInfo.getResourceKind(), Flags)); } } void DXContainerGlobals::addPipelineStateValidationInfo( Module &M, SmallVector &Globals) { SmallString<256> Data; raw_svector_ostream OS(Data); PSVRuntimeInfo PSV; PSV.BaseData.MinimumWaveLaneCount = 0; PSV.BaseData.MaximumWaveLaneCount = std::numeric_limits::max(); dxil::ModuleMetadataInfo &MMI = getAnalysis().getModuleMetadata(); assert(MMI.EntryPropertyVec.size() == 1 || MMI.ShaderProfile == Triple::Library); PSV.BaseData.ShaderStage = static_cast(MMI.ShaderProfile - Triple::Pixel); addResourcesForPSV(M, PSV); // Hardcoded values here to unblock loading the shader into D3D. // // TODO: Lots more stuff to do here! // // See issue https://github.com/llvm/llvm-project/issues/96674. switch (MMI.ShaderProfile) { case Triple::Compute: PSV.BaseData.NumThreadsX = MMI.EntryPropertyVec[0].NumThreadsX; PSV.BaseData.NumThreadsY = MMI.EntryPropertyVec[0].NumThreadsY; PSV.BaseData.NumThreadsZ = MMI.EntryPropertyVec[0].NumThreadsZ; break; default: break; } if (MMI.ShaderProfile != Triple::Library) PSV.EntryName = MMI.EntryPropertyVec[0].Entry->getName(); PSV.finalize(MMI.ShaderProfile); PSV.write(OS); Constant *Constant = ConstantDataArray::getString(M.getContext(), Data, /*AddNull*/ false); Globals.emplace_back(buildContainerGlobal(M, Constant, "dx.psv0", "PSV0")); } char DXContainerGlobals::ID = 0; INITIALIZE_PASS_BEGIN(DXContainerGlobals, "dxil-globals", "DXContainer Global Emitter", false, true) INITIALIZE_PASS_DEPENDENCY(ShaderFlagsAnalysisWrapper) INITIALIZE_PASS_DEPENDENCY(DXILMetadataAnalysisWrapperPass) INITIALIZE_PASS_DEPENDENCY(DXILResourceTypeWrapperPass) INITIALIZE_PASS_DEPENDENCY(DXILResourceWrapperPass) INITIALIZE_PASS_END(DXContainerGlobals, "dxil-globals", "DXContainer Global Emitter", false, true) ModulePass *llvm::createDXContainerGlobalsPass() { return new DXContainerGlobals(); }