diff options
Diffstat (limited to 'llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp')
| -rw-r--r-- | llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp | 130 |
1 files changed, 96 insertions, 34 deletions
diff --git a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp index 9eebcc9..1e4797b 100644 --- a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp +++ b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp @@ -7,8 +7,10 @@ //===----------------------------------------------------------------------===// #include "DXILTranslateMetadata.h" +#include "DXILRootSignature.h" #include "DXILShaderFlags.h" #include "DirectX.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Twine.h" #include "llvm/Analysis/DXILMetadataAnalysis.h" @@ -204,9 +206,9 @@ getEntryPropAsMetadata(const EntryProperties &EP, uint64_t EntryShaderFlags, return MDNode::get(Ctx, MDVals); } -MDTuple *constructEntryMetadata(const Function *EntryFn, MDTuple *Signatures, - MDNode *Resources, MDTuple *Properties, - LLVMContext &Ctx) { +static MDTuple *constructEntryMetadata(const Function *EntryFn, + MDTuple *Signatures, MDNode *Resources, + MDTuple *Properties, LLVMContext &Ctx) { // Each entry point metadata record specifies: // * reference to the entry point function global symbol // * unmangled name @@ -290,42 +292,82 @@ static MDTuple *emitTopLevelLibraryNode(Module &M, MDNode *RMD, return constructEntryMetadata(nullptr, nullptr, RMD, Properties, Ctx); } -// TODO: We might need to refactor this to be more generic, -// in case we need more metadata to be replaced. -static void translateBranchMetadata(Module &M) { - for (Function &F : M) { - for (BasicBlock &BB : F) { - Instruction *BBTerminatorInst = BB.getTerminator(); +static void translateBranchMetadata(Module &M, Instruction *BBTerminatorInst) { + MDNode *HlslControlFlowMD = + BBTerminatorInst->getMetadata("hlsl.controlflow.hint"); + + if (!HlslControlFlowMD) + return; - MDNode *HlslControlFlowMD = - BBTerminatorInst->getMetadata("hlsl.controlflow.hint"); + assert(HlslControlFlowMD->getNumOperands() == 2 && + "invalid operands for hlsl.controlflow.hint"); - if (!HlslControlFlowMD) - continue; + MDBuilder MDHelper(M.getContext()); - assert(HlslControlFlowMD->getNumOperands() == 2 && - "invalid operands for hlsl.controlflow.hint"); + llvm::Metadata *HintsStr = MDHelper.createString("dx.controlflow.hints"); + llvm::Metadata *HintsValue = MDHelper.createConstant( + mdconst::extract<ConstantInt>(HlslControlFlowMD->getOperand(1))); - MDBuilder MDHelper(M.getContext()); - ConstantInt *Op1 = - mdconst::extract<ConstantInt>(HlslControlFlowMD->getOperand(1)); + MDNode *MDNode = llvm::MDNode::get(M.getContext(), {HintsStr, HintsValue}); - SmallVector<llvm::Metadata *, 2> Vals( - ArrayRef<Metadata *>{MDHelper.createString("dx.controlflow.hints"), - MDHelper.createConstant(Op1)}); + BBTerminatorInst->setMetadata("dx.controlflow.hints", MDNode); + BBTerminatorInst->setMetadata("hlsl.controlflow.hint", nullptr); +} + +static std::array<unsigned, 6> getCompatibleInstructionMDs(llvm::Module &M) { + return { + M.getMDKindID("dx.nonuniform"), M.getMDKindID("dx.controlflow.hints"), + M.getMDKindID("dx.precise"), llvm::LLVMContext::MD_range, + llvm::LLVMContext::MD_alias_scope, llvm::LLVMContext::MD_noalias}; +} - MDNode *MDNode = llvm::MDNode::get(M.getContext(), Vals); +static void translateInstructionMetadata(Module &M) { + // construct allowlist of valid metadata node kinds + std::array<unsigned, 6> DXILCompatibleMDs = getCompatibleInstructionMDs(M); - BBTerminatorInst->setMetadata("dx.controlflow.hints", MDNode); - BBTerminatorInst->setMetadata("hlsl.controlflow.hint", nullptr); + for (Function &F : M) { + for (BasicBlock &BB : F) { + // This needs to be done first so that "hlsl.controlflow.hints" isn't + // removed in the whitelist below + if (auto *I = BB.getTerminator()) + translateBranchMetadata(M, I); + + for (auto &I : make_early_inc_range(BB)) { + I.dropUnknownNonDebugMetadata(DXILCompatibleMDs); + } } } } -static void translateMetadata(Module &M, DXILResourceMap &DRM, - DXILResourceTypeMap &DRTM, - const ModuleShaderFlags &ShaderFlags, - const ModuleMetadataInfo &MMDI) { +static void cleanModuleFlags(Module &M) { + NamedMDNode *MDFlags = M.getModuleFlagsMetadata(); + if (!MDFlags) + return; + + SmallVector<llvm::Module::ModuleFlagEntry> FlagEntries; + M.getModuleFlagsMetadata(FlagEntries); + bool Updated = false; + for (auto &Flag : FlagEntries) { + // llvm 3.7 only supports behavior up to AppendUnique. + if (Flag.Behavior <= Module::ModFlagBehavior::AppendUnique) + continue; + Flag.Behavior = Module::ModFlagBehavior::Warning; + Updated = true; + } + + if (!Updated) + return; + + MDFlags->eraseFromParent(); + + for (auto &Flag : FlagEntries) + M.addModuleFlag(Flag.Behavior, Flag.Key->getString(), Flag.Val); +} + +static void translateGlobalMetadata(Module &M, DXILResourceMap &DRM, + DXILResourceTypeMap &DRTM, + const ModuleShaderFlags &ShaderFlags, + const ModuleMetadataInfo &MMDI) { LLVMContext &Ctx = M.getContext(); IRBuilder<> IRB(Ctx); SmallVector<MDNode *> EntryFnMDNodes; @@ -381,6 +423,22 @@ static void translateMetadata(Module &M, DXILResourceMap &DRM, M.getOrInsertNamedMetadata("dx.entryPoints"); for (auto *Entry : EntryFnMDNodes) EntryPointsNamedMD->addOperand(Entry); + + cleanModuleFlags(M); + + // dx.rootsignatures will have been parsed from its metadata form as its + // binary form as part of the RootSignatureAnalysisWrapper, so safely + // remove it as it is not recognized in DXIL + if (NamedMDNode *RootSignature = M.getNamedMetadata("dx.rootsignatures")) + RootSignature->eraseFromParent(); + + // llvm.errno.tbaa was recently added but is not supported in LLVM 3.7 and + // causes all tests using the DXIL Validator to fail. + // + // This is a temporary fix and should be replaced with a allowlist once + // we have determined all metadata that the DXIL Validator allows + if (NamedMDNode *ErrNo = M.getNamedMetadata("llvm.errno.tbaa")) + ErrNo->eraseFromParent(); } PreservedAnalyses DXILTranslateMetadata::run(Module &M, @@ -390,8 +448,8 @@ PreservedAnalyses DXILTranslateMetadata::run(Module &M, const ModuleShaderFlags &ShaderFlags = MAM.getResult<ShaderFlagsAnalysis>(M); const dxil::ModuleMetadataInfo MMDI = MAM.getResult<DXILMetadataAnalysis>(M); - translateMetadata(M, DRM, DRTM, ShaderFlags, MMDI); - translateBranchMetadata(M); + translateGlobalMetadata(M, DRM, DRTM, ShaderFlags, MMDI); + translateInstructionMetadata(M); return PreservedAnalyses::all(); } @@ -409,10 +467,13 @@ public: AU.addRequired<DXILResourceWrapperPass>(); AU.addRequired<ShaderFlagsAnalysisWrapper>(); AU.addRequired<DXILMetadataAnalysisWrapperPass>(); - AU.addPreserved<DXILResourceWrapperPass>(); + AU.addRequired<RootSignatureAnalysisWrapper>(); + AU.addPreserved<DXILMetadataAnalysisWrapperPass>(); - AU.addPreserved<ShaderFlagsAnalysisWrapper>(); AU.addPreserved<DXILResourceBindingWrapperPass>(); + AU.addPreserved<DXILResourceWrapperPass>(); + AU.addPreserved<RootSignatureAnalysisWrapper>(); + AU.addPreserved<ShaderFlagsAnalysisWrapper>(); } bool runOnModule(Module &M) override { @@ -425,8 +486,8 @@ public: dxil::ModuleMetadataInfo MMDI = getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata(); - translateMetadata(M, DRM, DRTM, ShaderFlags, MMDI); - translateBranchMetadata(M); + translateGlobalMetadata(M, DRM, DRTM, ShaderFlags, MMDI); + translateInstructionMetadata(M); return true; } }; @@ -443,6 +504,7 @@ INITIALIZE_PASS_BEGIN(DXILTranslateMetadataLegacy, "dxil-translate-metadata", "DXIL Translate Metadata", false, false) INITIALIZE_PASS_DEPENDENCY(DXILResourceWrapperPass) INITIALIZE_PASS_DEPENDENCY(ShaderFlagsAnalysisWrapper) +INITIALIZE_PASS_DEPENDENCY(RootSignatureAnalysisWrapper) INITIALIZE_PASS_DEPENDENCY(DXILMetadataAnalysisWrapperPass) INITIALIZE_PASS_END(DXILTranslateMetadataLegacy, "dxil-translate-metadata", "DXIL Translate Metadata", false, false) |
