//===- DXILShaderFlags.cpp - DXIL Shader Flags helper objects -------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// /// /// \file This file contains helper objects and APIs for working with DXIL /// Shader Flags. /// //===----------------------------------------------------------------------===// #include "DXILShaderFlags.h" #include "DirectX.h" #include "llvm/ADT/SCCIterator.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/CallGraph.h" #include "llvm/Analysis/DXILResource.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/IntrinsicsDirectX.h" #include "llvm/IR/Module.h" #include "llvm/InitializePasses.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/raw_ostream.h" using namespace llvm; using namespace llvm::dxil; static bool hasUAVsAtEveryStage(const DXILResourceMap &DRM, const ModuleMetadataInfo &MMDI) { if (DRM.uavs().empty()) return false; switch (MMDI.ShaderProfile) { default: return false; case Triple::EnvironmentType::Compute: case Triple::EnvironmentType::Pixel: return false; case Triple::EnvironmentType::Vertex: case Triple::EnvironmentType::Geometry: case Triple::EnvironmentType::Hull: case Triple::EnvironmentType::Domain: return true; case Triple::EnvironmentType::Library: case Triple::EnvironmentType::RayGeneration: case Triple::EnvironmentType::Intersection: case Triple::EnvironmentType::AnyHit: case Triple::EnvironmentType::ClosestHit: case Triple::EnvironmentType::Miss: case Triple::EnvironmentType::Callable: case Triple::EnvironmentType::Mesh: case Triple::EnvironmentType::Amplification: return MMDI.ValidatorVersion < VersionTuple(1, 8); } } static bool checkWaveOps(Intrinsic::ID IID) { // Currently unsupported intrinsics // case Intrinsic::dx_wave_getlanecount: // case Intrinsic::dx_wave_allequal: // case Intrinsic::dx_wave_ballot: // case Intrinsic::dx_wave_readfirst: // case Intrinsic::dx_wave_reduce.and: // case Intrinsic::dx_wave_reduce.or: // case Intrinsic::dx_wave_reduce.xor: // case Intrinsic::dx_wave_prefixop: // case Intrinsic::dx_quad.readat: // case Intrinsic::dx_quad.readacrossx: // case Intrinsic::dx_quad.readacrossy: // case Intrinsic::dx_quad.readacrossdiagonal: // case Intrinsic::dx_wave_prefixballot: // case Intrinsic::dx_wave_match: // case Intrinsic::dx_wavemulti.*: // case Intrinsic::dx_wavemulti.ballot: // case Intrinsic::dx_quad.vote: switch (IID) { default: return false; case Intrinsic::dx_wave_is_first_lane: case Intrinsic::dx_wave_getlaneindex: case Intrinsic::dx_wave_any: case Intrinsic::dx_wave_all: case Intrinsic::dx_wave_readlane: case Intrinsic::dx_wave_active_countbits: // Wave Active Op Variants case Intrinsic::dx_wave_reduce_sum: case Intrinsic::dx_wave_reduce_usum: case Intrinsic::dx_wave_reduce_max: case Intrinsic::dx_wave_reduce_umax: return true; } } /// Update the shader flags mask based on the given instruction. /// \param CSF Shader flags mask to update. /// \param I Instruction to check. void ModuleShaderFlags::updateFunctionFlags(ComputedShaderFlags &CSF, const Instruction &I, DXILResourceTypeMap &DRTM, const ModuleMetadataInfo &MMDI) { if (!CSF.Doubles) CSF.Doubles = I.getType()->isDoubleTy(); if (!CSF.Doubles) { for (const Value *Op : I.operands()) { if (Op->getType()->isDoubleTy()) { CSF.Doubles = true; break; } } } if (CSF.Doubles) { switch (I.getOpcode()) { case Instruction::FDiv: case Instruction::UIToFP: case Instruction::SIToFP: case Instruction::FPToUI: case Instruction::FPToSI: CSF.DX11_1_DoubleExtensions = true; break; } } if (!CSF.LowPrecisionPresent) CSF.LowPrecisionPresent = I.getType()->isIntegerTy(16) || I.getType()->isHalfTy(); if (!CSF.LowPrecisionPresent) { for (const Value *Op : I.operands()) { if (Op->getType()->isIntegerTy(16) || Op->getType()->isHalfTy()) { CSF.LowPrecisionPresent = true; break; } } } if (CSF.LowPrecisionPresent) { if (CSF.NativeLowPrecisionMode) CSF.NativeLowPrecision = true; else CSF.MinimumPrecision = true; } if (!CSF.Int64Ops) CSF.Int64Ops = I.getType()->isIntegerTy(64); if (!CSF.Int64Ops && !isa(&I)) { for (const Value *Op : I.operands()) { if (Op->getType()->isIntegerTy(64)) { CSF.Int64Ops = true; break; } } } if (auto *II = dyn_cast(&I)) { switch (II->getIntrinsicID()) { default: break; case Intrinsic::dx_resource_handlefrombinding: { dxil::ResourceTypeInfo &RTI = DRTM[cast(II->getType())]; // Set ResMayNotAlias if DXIL validator version >= 1.8 and the function // uses UAVs if (!CSF.ResMayNotAlias && CanSetResMayNotAlias && MMDI.ValidatorVersion >= VersionTuple(1, 8) && RTI.isUAV()) CSF.ResMayNotAlias = true; switch (RTI.getResourceKind()) { case dxil::ResourceKind::StructuredBuffer: case dxil::ResourceKind::RawBuffer: CSF.EnableRawAndStructuredBuffers = true; break; default: break; } break; } case Intrinsic::dx_resource_load_typedbuffer: { dxil::ResourceTypeInfo &RTI = DRTM[cast(II->getArgOperand(0)->getType())]; if (RTI.isTyped()) CSF.TypedUAVLoadAdditionalFormats |= RTI.getTyped().ElementCount > 1; break; } } } // Handle call instructions if (auto *CI = dyn_cast(&I)) { const Function *CF = CI->getCalledFunction(); // Merge-in shader flags mask of the called function in the current module if (FunctionFlags.contains(CF)) CSF.merge(FunctionFlags[CF]); // TODO: Set DX11_1_DoubleExtensions if I is a call to DXIL intrinsic // DXIL::Opcode::Fma https://github.com/llvm/llvm-project/issues/114554 CSF.WaveOps |= checkWaveOps(CI->getIntrinsicID()); } } /// Set shader flags that apply to all functions within the module ComputedShaderFlags ModuleShaderFlags::gatherGlobalModuleFlags(const Module &M, const DXILResourceMap &DRM, const ModuleMetadataInfo &MMDI) { ComputedShaderFlags CSF; // Set DisableOptimizations flag based on the presence of OptimizeNone // attribute of entry functions. if (MMDI.EntryPropertyVec.size() > 0) { CSF.DisableOptimizations = MMDI.EntryPropertyVec[0].Entry->hasFnAttribute( llvm::Attribute::OptimizeNone); // Ensure all entry functions have the same optimization attribute for (const auto &EntryFunProps : MMDI.EntryPropertyVec) if (CSF.DisableOptimizations != EntryFunProps.Entry->hasFnAttribute(llvm::Attribute::OptimizeNone)) EntryFunProps.Entry->getContext().diagnose(DiagnosticInfoUnsupported( *(EntryFunProps.Entry), "Inconsistent optnone attribute ")); } CSF.UAVsAtEveryStage = hasUAVsAtEveryStage(DRM, MMDI); // Set the Max64UAVs flag if the number of UAVs is > 8 uint32_t NumUAVs = 0; for (auto &UAV : DRM.uavs()) if (MMDI.ValidatorVersion < VersionTuple(1, 6)) NumUAVs++; else // MMDI.ValidatorVersion >= VersionTuple(1, 6) NumUAVs += UAV.getBinding().Size; if (NumUAVs > 8) CSF.Max64UAVs = true; // Set the module flag that enables native low-precision execution mode. // NativeLowPrecisionMode can only be set when the command line option // -enable-16bit-types is provided. This is indicated by the dx.nativelowprec // module flag being set // This flag is needed even if the module does not use 16-bit types because a // corresponding debug module may include 16-bit types, and tools that use the // debug module may expect it to have the same flags as the original if (auto *NativeLowPrec = mdconst::extract_or_null( M.getModuleFlag("dx.nativelowprec"))) if (MMDI.ShaderModelVersion >= VersionTuple(6, 2)) CSF.NativeLowPrecisionMode = NativeLowPrec->getValue().getBoolValue(); // Set ResMayNotAlias to true if DXIL validator version < 1.8 and there // are UAVs present globally. if (CanSetResMayNotAlias && MMDI.ValidatorVersion < VersionTuple(1, 8)) CSF.ResMayNotAlias = !DRM.uavs().empty(); return CSF; } /// Construct ModuleShaderFlags for module Module M void ModuleShaderFlags::initialize(Module &M, DXILResourceTypeMap &DRTM, const DXILResourceMap &DRM, const ModuleMetadataInfo &MMDI) { CanSetResMayNotAlias = MMDI.DXILVersion >= VersionTuple(1, 7); // The command line option -res-may-alias will set the dx.resmayalias module // flag to 1, thereby disabling the ability to set the ResMayNotAlias flag if (auto *ResMayAlias = mdconst::extract_or_null( M.getModuleFlag("dx.resmayalias"))) if (ResMayAlias->getValue().getBoolValue()) CanSetResMayNotAlias = false; ComputedShaderFlags GlobalSFMask = gatherGlobalModuleFlags(M, DRM, MMDI); CallGraph CG(M); // Compute Shader Flags Mask for all functions using post-order visit of SCC // of the call graph. for (scc_iterator SCCI = scc_begin(&CG); !SCCI.isAtEnd(); ++SCCI) { const std::vector &CurSCC = *SCCI; // Union of shader masks of all functions in CurSCC ComputedShaderFlags SCCSF; // List of functions in CurSCC that are neither external nor declarations // and hence whose flags are collected SmallVector CurSCCFuncs; for (CallGraphNode *CGN : CurSCC) { Function *F = CGN->getFunction(); if (!F) continue; if (F->isDeclaration()) { assert(!F->getName().starts_with("dx.op.") && "DXIL Shader Flag analysis should not be run post-lowering."); continue; } ComputedShaderFlags CSF = GlobalSFMask; for (const auto &BB : *F) for (const auto &I : BB) updateFunctionFlags(CSF, I, DRTM, MMDI); // Update combined shader flags mask for all functions in this SCC SCCSF.merge(CSF); CurSCCFuncs.push_back(F); } // Update combined shader flags mask for all functions of the module CombinedSFMask.merge(SCCSF); // Shader flags mask of each of the functions in an SCC of the call graph is // the union of all functions in the SCC. Update shader flags masks of // functions in CurSCC accordingly. This is trivially true if SCC contains // one function. for (Function *F : CurSCCFuncs) // Merge SCCSF with that of F FunctionFlags[F].merge(SCCSF); } } void ComputedShaderFlags::print(raw_ostream &OS) const { uint64_t FlagVal = (uint64_t) * this; OS << formatv("; Shader Flags Value: {0:x8}\n;\n", FlagVal); if (FlagVal == 0) return; OS << "; Note: shader requires additional functionality:\n"; #define SHADER_FEATURE_FLAG(FeatureBit, DxilModuleNum, FlagName, Str) \ if (FlagName) \ (OS << ";").indent(7) << Str << "\n"; #include "llvm/BinaryFormat/DXContainerConstants.def" OS << "; Note: extra DXIL module flags:\n"; #define DXIL_MODULE_FLAG(DxilModuleBit, FlagName, Str) \ if (FlagName) \ (OS << ";").indent(7) << Str << "\n"; #include "llvm/BinaryFormat/DXContainerConstants.def" OS << ";\n"; } /// Return the shader flags mask of the specified function Func. const ComputedShaderFlags & ModuleShaderFlags::getFunctionFlags(const Function *Func) const { auto Iter = FunctionFlags.find(Func); assert((Iter != FunctionFlags.end() && Iter->first == Func) && "Get Shader Flags : No Shader Flags Mask exists for function"); return Iter->second; } //===----------------------------------------------------------------------===// // ShaderFlagsAnalysis and ShaderFlagsAnalysisPrinterPass // Provide an explicit template instantiation for the static ID. AnalysisKey ShaderFlagsAnalysis::Key; ModuleShaderFlags ShaderFlagsAnalysis::run(Module &M, ModuleAnalysisManager &AM) { DXILResourceTypeMap &DRTM = AM.getResult(M); DXILResourceMap &DRM = AM.getResult(M); const ModuleMetadataInfo MMDI = AM.getResult(M); ModuleShaderFlags MSFI; MSFI.initialize(M, DRTM, DRM, MMDI); return MSFI; } PreservedAnalyses ShaderFlagsAnalysisPrinter::run(Module &M, ModuleAnalysisManager &AM) { const ModuleShaderFlags &FlagsInfo = AM.getResult(M); // Print description of combined shader flags for all module functions OS << "; Combined Shader Flags for Module\n"; FlagsInfo.getCombinedFlags().print(OS); // Print shader flags mask for each of the module functions OS << "; Shader Flags for Module Functions\n"; for (const auto &F : M.getFunctionList()) { if (F.isDeclaration()) continue; const ComputedShaderFlags &SFMask = FlagsInfo.getFunctionFlags(&F); OS << formatv("; Function {0} : {1:x8}\n;\n", F.getName(), (uint64_t)(SFMask)); } return PreservedAnalyses::all(); } //===----------------------------------------------------------------------===// // ShaderFlagsAnalysis and ShaderFlagsAnalysisPrinterPass bool ShaderFlagsAnalysisWrapper::runOnModule(Module &M) { DXILResourceTypeMap &DRTM = getAnalysis().getResourceTypeMap(); DXILResourceMap &DRM = getAnalysis().getResourceMap(); const ModuleMetadataInfo MMDI = getAnalysis().getModuleMetadata(); MSFI.initialize(M, DRTM, DRM, MMDI); return false; } void ShaderFlagsAnalysisWrapper::getAnalysisUsage(AnalysisUsage &AU) const { AU.setPreservesAll(); AU.addRequiredTransitive(); AU.addRequiredTransitive(); AU.addRequired(); } char ShaderFlagsAnalysisWrapper::ID = 0; INITIALIZE_PASS_BEGIN(ShaderFlagsAnalysisWrapper, "dx-shader-flag-analysis", "DXIL Shader Flag Analysis", true, true) INITIALIZE_PASS_DEPENDENCY(DXILResourceTypeWrapperPass) INITIALIZE_PASS_DEPENDENCY(DXILMetadataAnalysisWrapperPass) INITIALIZE_PASS_END(ShaderFlagsAnalysisWrapper, "dx-shader-flag-analysis", "DXIL Shader Flag Analysis", true, true)