diff options
Diffstat (limited to 'llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp')
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp | 202 |
1 files changed, 194 insertions, 8 deletions
diff --git a/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp b/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp index c2a6e51..b765fec 100644 --- a/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp @@ -81,6 +81,7 @@ public: void outputExecutionMode(const Module &M); void outputAnnotations(const Module &M); void outputModuleSections(); + void outputFPFastMathDefaultInfo(); bool isHidden() { return MF->getFunction() .getFnAttribute(SPIRV_BACKEND_SERVICE_FUN_NAME) @@ -498,11 +499,27 @@ void SPIRVAsmPrinter::outputExecutionMode(const Module &M) { NamedMDNode *Node = M.getNamedMetadata("spirv.ExecutionMode"); if (Node) { for (unsigned i = 0; i < Node->getNumOperands(); i++) { + // If SPV_KHR_float_controls2 is enabled and we find any of + // FPFastMathDefault, ContractionOff or SignedZeroInfNanPreserve execution + // modes, skip it, it'll be done somewhere else. + if (ST->canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2)) { + const auto EM = + cast<ConstantInt>( + cast<ConstantAsMetadata>((Node->getOperand(i))->getOperand(1)) + ->getValue()) + ->getZExtValue(); + if (EM == SPIRV::ExecutionMode::FPFastMathDefault || + EM == SPIRV::ExecutionMode::ContractionOff || + EM == SPIRV::ExecutionMode::SignedZeroInfNanPreserve) + continue; + } + MCInst Inst; Inst.setOpcode(SPIRV::OpExecutionMode); addOpsFromMDNode(cast<MDNode>(Node->getOperand(i)), Inst, MAI); outputMCInst(Inst); } + outputFPFastMathDefaultInfo(); } for (auto FI = M.begin(), E = M.end(); FI != E; ++FI) { const Function &F = *FI; @@ -552,12 +569,84 @@ void SPIRVAsmPrinter::outputExecutionMode(const Module &M) { } if (ST->isKernel() && !M.getNamedMetadata("spirv.ExecutionMode") && !M.getNamedMetadata("opencl.enable.FP_CONTRACT")) { - MCInst Inst; - Inst.setOpcode(SPIRV::OpExecutionMode); - Inst.addOperand(MCOperand::createReg(FReg)); - unsigned EM = static_cast<unsigned>(SPIRV::ExecutionMode::ContractionOff); - Inst.addOperand(MCOperand::createImm(EM)); - outputMCInst(Inst); + if (ST->canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2)) { + // When SPV_KHR_float_controls2 is enabled, ContractionOff is + // deprecated. We need to use FPFastMathDefault with the appropriate + // flags instead. Since FPFastMathDefault takes a target type, we need + // to emit it for each floating-point type that exists in the module + // to match the effect of ContractionOff. As of now, there are 3 FP + // types: fp16, fp32 and fp64. + + // We only end up here because there is no "spirv.ExecutionMode" + // metadata, so that means no FPFastMathDefault. Therefore, we only + // need to make sure AllowContract is set to 0, as the rest of flags. + // We still need to emit the OpExecutionMode instruction, otherwise + // it's up to the client API to define the flags. Therefore, we need + // to find the constant with 0 value. + + // Collect the SPIRVTypes for fp16, fp32, and fp64 and the constant of + // type int32 with 0 value to represent the FP Fast Math Mode. + std::vector<const MachineInstr *> SPIRVFloatTypes; + const MachineInstr *ConstZero = nullptr; + for (const MachineInstr *MI : + MAI->getMSInstrs(SPIRV::MB_TypeConstVars)) { + // Skip if the instruction is not OpTypeFloat or OpConstant. + unsigned OpCode = MI->getOpcode(); + if (OpCode != SPIRV::OpTypeFloat && OpCode != SPIRV::OpConstantNull) + continue; + + // Collect the SPIRV type if it's a float. + if (OpCode == SPIRV::OpTypeFloat) { + // Skip if the target type is not fp16, fp32, fp64. + const unsigned OpTypeFloatSize = MI->getOperand(1).getImm(); + if (OpTypeFloatSize != 16 && OpTypeFloatSize != 32 && + OpTypeFloatSize != 64) { + continue; + } + SPIRVFloatTypes.push_back(MI); + } else { + // Check if the constant is int32, if not skip it. + const MachineRegisterInfo &MRI = MI->getMF()->getRegInfo(); + MachineInstr *TypeMI = MRI.getVRegDef(MI->getOperand(1).getReg()); + if (!TypeMI || TypeMI->getOperand(1).getImm() != 32) + continue; + + ConstZero = MI; + } + } + + // When SPV_KHR_float_controls2 is enabled, ContractionOff is + // deprecated. We need to use FPFastMathDefault with the appropriate + // flags instead. Since FPFastMathDefault takes a target type, we need + // to emit it for each floating-point type that exists in the module + // to match the effect of ContractionOff. As of now, there are 3 FP + // types: fp16, fp32 and fp64. + for (const MachineInstr *MI : SPIRVFloatTypes) { + MCInst Inst; + Inst.setOpcode(SPIRV::OpExecutionModeId); + Inst.addOperand(MCOperand::createReg(FReg)); + unsigned EM = + static_cast<unsigned>(SPIRV::ExecutionMode::FPFastMathDefault); + Inst.addOperand(MCOperand::createImm(EM)); + const MachineFunction *MF = MI->getMF(); + MCRegister TypeReg = + MAI->getRegisterAlias(MF, MI->getOperand(0).getReg()); + Inst.addOperand(MCOperand::createReg(TypeReg)); + assert(ConstZero && "There should be a constant zero."); + MCRegister ConstReg = MAI->getRegisterAlias( + ConstZero->getMF(), ConstZero->getOperand(0).getReg()); + Inst.addOperand(MCOperand::createReg(ConstReg)); + outputMCInst(Inst); + } + } else { + MCInst Inst; + Inst.setOpcode(SPIRV::OpExecutionMode); + Inst.addOperand(MCOperand::createReg(FReg)); + unsigned EM = + static_cast<unsigned>(SPIRV::ExecutionMode::ContractionOff); + Inst.addOperand(MCOperand::createImm(EM)); + outputMCInst(Inst); + } } } } @@ -606,6 +695,101 @@ void SPIRVAsmPrinter::outputAnnotations(const Module &M) { } } +void SPIRVAsmPrinter::outputFPFastMathDefaultInfo() { + // Collect the SPIRVTypes that are OpTypeFloat and the constants of type + // int32, that might be used as FP Fast Math Mode. + std::vector<const MachineInstr *> SPIRVFloatTypes; + // Hashtable to associate immediate values with the constant holding them. + std::unordered_map<int, const MachineInstr *> ConstMap; + for (const MachineInstr *MI : MAI->getMSInstrs(SPIRV::MB_TypeConstVars)) { + // Skip if the instruction is not OpTypeFloat or OpConstant. + unsigned OpCode = MI->getOpcode(); + if (OpCode != SPIRV::OpTypeFloat && OpCode != SPIRV::OpConstantI && + OpCode != SPIRV::OpConstantNull) + continue; + + // Collect the SPIRV type if it's a float. + if (OpCode == SPIRV::OpTypeFloat) { + SPIRVFloatTypes.push_back(MI); + } else { + // Check if the constant is int32, if not skip it. + const MachineRegisterInfo &MRI = MI->getMF()->getRegInfo(); + MachineInstr *TypeMI = MRI.getVRegDef(MI->getOperand(1).getReg()); + if (!TypeMI || TypeMI->getOpcode() != SPIRV::OpTypeInt || + TypeMI->getOperand(1).getImm() != 32) + continue; + + if (OpCode == SPIRV::OpConstantI) + ConstMap[MI->getOperand(2).getImm()] = MI; + else + ConstMap[0] = MI; + } + } + + for (const auto &[Func, FPFastMathDefaultInfoVec] : + MAI->FPFastMathDefaultInfoMap) { + if (FPFastMathDefaultInfoVec.empty()) + continue; + + for (const MachineInstr *MI : SPIRVFloatTypes) { + unsigned OpTypeFloatSize = MI->getOperand(1).getImm(); + unsigned Index = SPIRV::FPFastMathDefaultInfoVector:: + computeFPFastMathDefaultInfoVecIndex(OpTypeFloatSize); + assert(Index < FPFastMathDefaultInfoVec.size() && + "Index out of bounds for FPFastMathDefaultInfoVec"); + const auto &FPFastMathDefaultInfo = FPFastMathDefaultInfoVec[Index]; + assert(FPFastMathDefaultInfo.Ty && + "Expected target type for FPFastMathDefaultInfo"); + assert(FPFastMathDefaultInfo.Ty->getScalarSizeInBits() == + OpTypeFloatSize && + "Mismatched float type size"); + MCInst Inst; + Inst.setOpcode(SPIRV::OpExecutionModeId); + MCRegister FuncReg = MAI->getFuncReg(Func); + assert(FuncReg.isValid()); + Inst.addOperand(MCOperand::createReg(FuncReg)); + Inst.addOperand( + MCOperand::createImm(SPIRV::ExecutionMode::FPFastMathDefault)); + MCRegister TypeReg = + MAI->getRegisterAlias(MI->getMF(), MI->getOperand(0).getReg()); + Inst.addOperand(MCOperand::createReg(TypeReg)); + unsigned Flags = FPFastMathDefaultInfo.FastMathFlags; + if (FPFastMathDefaultInfo.ContractionOff && + (Flags & SPIRV::FPFastMathMode::AllowContract)) + report_fatal_error( + "Conflicting FPFastMathFlags: ContractionOff and AllowContract"); + + if (FPFastMathDefaultInfo.SignedZeroInfNanPreserve && + !(Flags & + (SPIRV::FPFastMathMode::NotNaN | SPIRV::FPFastMathMode::NotInf | + SPIRV::FPFastMathMode::NSZ))) { + if (FPFastMathDefaultInfo.FPFastMathDefault) + report_fatal_error("Conflicting FPFastMathFlags: " + "SignedZeroInfNanPreserve but at least one of " + "NotNaN/NotInf/NSZ is enabled."); + } + + // Don't emit if none of the execution modes was used. + if (Flags == SPIRV::FPFastMathMode::None && + !FPFastMathDefaultInfo.ContractionOff && + !FPFastMathDefaultInfo.SignedZeroInfNanPreserve && + !FPFastMathDefaultInfo.FPFastMathDefault) + continue; + + // Retrieve the constant instruction for the immediate value. + auto It = ConstMap.find(Flags); + if (It == ConstMap.end()) + report_fatal_error("Expected constant instruction for FP Fast Math " + "Mode operand of FPFastMathDefault execution mode."); + const MachineInstr *ConstMI = It->second; + MCRegister ConstReg = MAI->getRegisterAlias( + ConstMI->getMF(), ConstMI->getOperand(0).getReg()); + Inst.addOperand(MCOperand::createReg(ConstReg)); + outputMCInst(Inst); + } + } +} + void SPIRVAsmPrinter::outputModuleSections() { const Module *M = MMI->getModule(); // Get the global subtarget to output module-level info. @@ -614,7 +798,8 @@ void SPIRVAsmPrinter::outputModuleSections() { MAI = &SPIRVModuleAnalysis::MAI; assert(ST && TII && MAI && M && "Module analysis is required"); // Output instructions according to the Logical Layout of a Module: - // 1,2. All OpCapability instructions, then optional OpExtension instructions. + // 1,2. All OpCapability instructions, then optional OpExtension + // instructions. outputGlobalRequirements(); // 3. Optional OpExtInstImport instructions. outputOpExtInstImports(*M); @@ -622,7 +807,8 @@ void SPIRVAsmPrinter::outputModuleSections() { outputOpMemoryModel(); // 5. All entry point declarations, using OpEntryPoint. outputEntryPoints(); - // 6. Execution-mode declarations, using OpExecutionMode or OpExecutionModeId. + // 6. Execution-mode declarations, using OpExecutionMode or + // OpExecutionModeId. outputExecutionMode(*M); // 7a. Debug: all OpString, OpSourceExtension, OpSource, and // OpSourceContinued, without forward references. |