diff options
Diffstat (limited to 'llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp')
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp | 24 |
1 files changed, 22 insertions, 2 deletions
diff --git a/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp b/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp index b765fec..640b014 100644 --- a/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp @@ -78,6 +78,8 @@ public: void outputExecutionModeFromNumthreadsAttribute( const MCRegister &Reg, const Attribute &Attr, SPIRV::ExecutionMode::ExecutionMode EM); + void outputExecutionModeFromEnableMaximalReconvergenceAttr( + const MCRegister &Reg, const SPIRVSubtarget &ST); void outputExecutionMode(const Module &M); void outputAnnotations(const Module &M); void outputModuleSections(); @@ -139,8 +141,8 @@ void SPIRVAsmPrinter::emitEndOfAsmFile(Module &M) { // anymore. void SPIRVAsmPrinter::cleanUp(Module &M) { // Verifier disallows uses of intrinsic global variables. - for (StringRef GVName : {"llvm.global_ctors", "llvm.global_dtors", - "llvm.used", "llvm.compiler.used"}) { + for (StringRef GVName : + {"llvm.global_ctors", "llvm.global_dtors", "llvm.used"}) { if (GlobalVariable *GV = M.getNamedGlobal(GVName)) GV->setName(""); } @@ -495,6 +497,20 @@ void SPIRVAsmPrinter::outputExecutionModeFromNumthreadsAttribute( outputMCInst(Inst); } +void SPIRVAsmPrinter::outputExecutionModeFromEnableMaximalReconvergenceAttr( + const MCRegister &Reg, const SPIRVSubtarget &ST) { + assert(ST.canUseExtension(SPIRV::Extension::SPV_KHR_maximal_reconvergence) && + "Function called when SPV_KHR_maximal_reconvergence is not enabled."); + + MCInst Inst; + Inst.setOpcode(SPIRV::OpExecutionMode); + Inst.addOperand(MCOperand::createReg(Reg)); + unsigned EM = + static_cast<unsigned>(SPIRV::ExecutionMode::MaximallyReconvergesKHR); + Inst.addOperand(MCOperand::createImm(EM)); + outputMCInst(Inst); +} + void SPIRVAsmPrinter::outputExecutionMode(const Module &M) { NamedMDNode *Node = M.getNamedMetadata("spirv.ExecutionMode"); if (Node) { @@ -551,6 +567,10 @@ void SPIRVAsmPrinter::outputExecutionMode(const Module &M) { if (Attribute Attr = F.getFnAttribute("hlsl.numthreads"); Attr.isValid()) outputExecutionModeFromNumthreadsAttribute( FReg, Attr, SPIRV::ExecutionMode::LocalSize); + if (Attribute Attr = F.getFnAttribute("enable-maximal-reconvergence"); + Attr.getValueAsBool()) { + outputExecutionModeFromEnableMaximalReconvergenceAttr(FReg, *ST); + } if (MDNode *Node = F.getMetadata("work_group_size_hint")) outputExecutionModeFromMDNode(FReg, Node, SPIRV::ExecutionMode::LocalSizeHint, 3, 1); |