aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp')
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp202
1 files changed, 194 insertions, 8 deletions
diff --git a/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp b/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp
index c2a6e51..b765fec 100644
--- a/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp
@@ -81,6 +81,7 @@ public:
void outputExecutionMode(const Module &M);
void outputAnnotations(const Module &M);
void outputModuleSections();
+ void outputFPFastMathDefaultInfo();
bool isHidden() {
return MF->getFunction()
.getFnAttribute(SPIRV_BACKEND_SERVICE_FUN_NAME)
@@ -498,11 +499,27 @@ void SPIRVAsmPrinter::outputExecutionMode(const Module &M) {
NamedMDNode *Node = M.getNamedMetadata("spirv.ExecutionMode");
if (Node) {
for (unsigned i = 0; i < Node->getNumOperands(); i++) {
+ // If SPV_KHR_float_controls2 is enabled and we find any of
+ // FPFastMathDefault, ContractionOff or SignedZeroInfNanPreserve execution
+ // modes, skip it, it'll be done somewhere else.
+ if (ST->canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2)) {
+ const auto EM =
+ cast<ConstantInt>(
+ cast<ConstantAsMetadata>((Node->getOperand(i))->getOperand(1))
+ ->getValue())
+ ->getZExtValue();
+ if (EM == SPIRV::ExecutionMode::FPFastMathDefault ||
+ EM == SPIRV::ExecutionMode::ContractionOff ||
+ EM == SPIRV::ExecutionMode::SignedZeroInfNanPreserve)
+ continue;
+ }
+
MCInst Inst;
Inst.setOpcode(SPIRV::OpExecutionMode);
addOpsFromMDNode(cast<MDNode>(Node->getOperand(i)), Inst, MAI);
outputMCInst(Inst);
}
+ outputFPFastMathDefaultInfo();
}
for (auto FI = M.begin(), E = M.end(); FI != E; ++FI) {
const Function &F = *FI;
@@ -552,12 +569,84 @@ void SPIRVAsmPrinter::outputExecutionMode(const Module &M) {
}
if (ST->isKernel() && !M.getNamedMetadata("spirv.ExecutionMode") &&
!M.getNamedMetadata("opencl.enable.FP_CONTRACT")) {
- MCInst Inst;
- Inst.setOpcode(SPIRV::OpExecutionMode);
- Inst.addOperand(MCOperand::createReg(FReg));
- unsigned EM = static_cast<unsigned>(SPIRV::ExecutionMode::ContractionOff);
- Inst.addOperand(MCOperand::createImm(EM));
- outputMCInst(Inst);
+ if (ST->canUseExtension(SPIRV::Extension::SPV_KHR_float_controls2)) {
+ // When SPV_KHR_float_controls2 is enabled, ContractionOff is
+ // deprecated. We need to use FPFastMathDefault with the appropriate
+ // flags instead. Since FPFastMathDefault takes a target type, we need
+ // to emit it for each floating-point type that exists in the module
+ // to match the effect of ContractionOff. As of now, there are 3 FP
+ // types: fp16, fp32 and fp64.
+
+ // We only end up here because there is no "spirv.ExecutionMode"
+ // metadata, so that means no FPFastMathDefault. Therefore, we only
+ // need to make sure AllowContract is set to 0, as the rest of flags.
+ // We still need to emit the OpExecutionMode instruction, otherwise
+ // it's up to the client API to define the flags. Therefore, we need
+ // to find the constant with 0 value.
+
+ // Collect the SPIRVTypes for fp16, fp32, and fp64 and the constant of
+ // type int32 with 0 value to represent the FP Fast Math Mode.
+ std::vector<const MachineInstr *> SPIRVFloatTypes;
+ const MachineInstr *ConstZero = nullptr;
+ for (const MachineInstr *MI :
+ MAI->getMSInstrs(SPIRV::MB_TypeConstVars)) {
+ // Skip if the instruction is not OpTypeFloat or OpConstant.
+ unsigned OpCode = MI->getOpcode();
+ if (OpCode != SPIRV::OpTypeFloat && OpCode != SPIRV::OpConstantNull)
+ continue;
+
+ // Collect the SPIRV type if it's a float.
+ if (OpCode == SPIRV::OpTypeFloat) {
+ // Skip if the target type is not fp16, fp32, fp64.
+ const unsigned OpTypeFloatSize = MI->getOperand(1).getImm();
+ if (OpTypeFloatSize != 16 && OpTypeFloatSize != 32 &&
+ OpTypeFloatSize != 64) {
+ continue;
+ }
+ SPIRVFloatTypes.push_back(MI);
+ } else {
+ // Check if the constant is int32, if not skip it.
+ const MachineRegisterInfo &MRI = MI->getMF()->getRegInfo();
+ MachineInstr *TypeMI = MRI.getVRegDef(MI->getOperand(1).getReg());
+ if (!TypeMI || TypeMI->getOperand(1).getImm() != 32)
+ continue;
+
+ ConstZero = MI;
+ }
+ }
+
+ // When SPV_KHR_float_controls2 is enabled, ContractionOff is
+ // deprecated. We need to use FPFastMathDefault with the appropriate
+ // flags instead. Since FPFastMathDefault takes a target type, we need
+ // to emit it for each floating-point type that exists in the module
+ // to match the effect of ContractionOff. As of now, there are 3 FP
+ // types: fp16, fp32 and fp64.
+ for (const MachineInstr *MI : SPIRVFloatTypes) {
+ MCInst Inst;
+ Inst.setOpcode(SPIRV::OpExecutionModeId);
+ Inst.addOperand(MCOperand::createReg(FReg));
+ unsigned EM =
+ static_cast<unsigned>(SPIRV::ExecutionMode::FPFastMathDefault);
+ Inst.addOperand(MCOperand::createImm(EM));
+ const MachineFunction *MF = MI->getMF();
+ MCRegister TypeReg =
+ MAI->getRegisterAlias(MF, MI->getOperand(0).getReg());
+ Inst.addOperand(MCOperand::createReg(TypeReg));
+ assert(ConstZero && "There should be a constant zero.");
+ MCRegister ConstReg = MAI->getRegisterAlias(
+ ConstZero->getMF(), ConstZero->getOperand(0).getReg());
+ Inst.addOperand(MCOperand::createReg(ConstReg));
+ outputMCInst(Inst);
+ }
+ } else {
+ MCInst Inst;
+ Inst.setOpcode(SPIRV::OpExecutionMode);
+ Inst.addOperand(MCOperand::createReg(FReg));
+ unsigned EM =
+ static_cast<unsigned>(SPIRV::ExecutionMode::ContractionOff);
+ Inst.addOperand(MCOperand::createImm(EM));
+ outputMCInst(Inst);
+ }
}
}
}
@@ -606,6 +695,101 @@ void SPIRVAsmPrinter::outputAnnotations(const Module &M) {
}
}
+void SPIRVAsmPrinter::outputFPFastMathDefaultInfo() {
+ // Collect the SPIRVTypes that are OpTypeFloat and the constants of type
+ // int32, that might be used as FP Fast Math Mode.
+ std::vector<const MachineInstr *> SPIRVFloatTypes;
+ // Hashtable to associate immediate values with the constant holding them.
+ std::unordered_map<int, const MachineInstr *> ConstMap;
+ for (const MachineInstr *MI : MAI->getMSInstrs(SPIRV::MB_TypeConstVars)) {
+ // Skip if the instruction is not OpTypeFloat or OpConstant.
+ unsigned OpCode = MI->getOpcode();
+ if (OpCode != SPIRV::OpTypeFloat && OpCode != SPIRV::OpConstantI &&
+ OpCode != SPIRV::OpConstantNull)
+ continue;
+
+ // Collect the SPIRV type if it's a float.
+ if (OpCode == SPIRV::OpTypeFloat) {
+ SPIRVFloatTypes.push_back(MI);
+ } else {
+ // Check if the constant is int32, if not skip it.
+ const MachineRegisterInfo &MRI = MI->getMF()->getRegInfo();
+ MachineInstr *TypeMI = MRI.getVRegDef(MI->getOperand(1).getReg());
+ if (!TypeMI || TypeMI->getOpcode() != SPIRV::OpTypeInt ||
+ TypeMI->getOperand(1).getImm() != 32)
+ continue;
+
+ if (OpCode == SPIRV::OpConstantI)
+ ConstMap[MI->getOperand(2).getImm()] = MI;
+ else
+ ConstMap[0] = MI;
+ }
+ }
+
+ for (const auto &[Func, FPFastMathDefaultInfoVec] :
+ MAI->FPFastMathDefaultInfoMap) {
+ if (FPFastMathDefaultInfoVec.empty())
+ continue;
+
+ for (const MachineInstr *MI : SPIRVFloatTypes) {
+ unsigned OpTypeFloatSize = MI->getOperand(1).getImm();
+ unsigned Index = SPIRV::FPFastMathDefaultInfoVector::
+ computeFPFastMathDefaultInfoVecIndex(OpTypeFloatSize);
+ assert(Index < FPFastMathDefaultInfoVec.size() &&
+ "Index out of bounds for FPFastMathDefaultInfoVec");
+ const auto &FPFastMathDefaultInfo = FPFastMathDefaultInfoVec[Index];
+ assert(FPFastMathDefaultInfo.Ty &&
+ "Expected target type for FPFastMathDefaultInfo");
+ assert(FPFastMathDefaultInfo.Ty->getScalarSizeInBits() ==
+ OpTypeFloatSize &&
+ "Mismatched float type size");
+ MCInst Inst;
+ Inst.setOpcode(SPIRV::OpExecutionModeId);
+ MCRegister FuncReg = MAI->getFuncReg(Func);
+ assert(FuncReg.isValid());
+ Inst.addOperand(MCOperand::createReg(FuncReg));
+ Inst.addOperand(
+ MCOperand::createImm(SPIRV::ExecutionMode::FPFastMathDefault));
+ MCRegister TypeReg =
+ MAI->getRegisterAlias(MI->getMF(), MI->getOperand(0).getReg());
+ Inst.addOperand(MCOperand::createReg(TypeReg));
+ unsigned Flags = FPFastMathDefaultInfo.FastMathFlags;
+ if (FPFastMathDefaultInfo.ContractionOff &&
+ (Flags & SPIRV::FPFastMathMode::AllowContract))
+ report_fatal_error(
+ "Conflicting FPFastMathFlags: ContractionOff and AllowContract");
+
+ if (FPFastMathDefaultInfo.SignedZeroInfNanPreserve &&
+ !(Flags &
+ (SPIRV::FPFastMathMode::NotNaN | SPIRV::FPFastMathMode::NotInf |
+ SPIRV::FPFastMathMode::NSZ))) {
+ if (FPFastMathDefaultInfo.FPFastMathDefault)
+ report_fatal_error("Conflicting FPFastMathFlags: "
+ "SignedZeroInfNanPreserve but at least one of "
+ "NotNaN/NotInf/NSZ is enabled.");
+ }
+
+ // Don't emit if none of the execution modes was used.
+ if (Flags == SPIRV::FPFastMathMode::None &&
+ !FPFastMathDefaultInfo.ContractionOff &&
+ !FPFastMathDefaultInfo.SignedZeroInfNanPreserve &&
+ !FPFastMathDefaultInfo.FPFastMathDefault)
+ continue;
+
+ // Retrieve the constant instruction for the immediate value.
+ auto It = ConstMap.find(Flags);
+ if (It == ConstMap.end())
+ report_fatal_error("Expected constant instruction for FP Fast Math "
+ "Mode operand of FPFastMathDefault execution mode.");
+ const MachineInstr *ConstMI = It->second;
+ MCRegister ConstReg = MAI->getRegisterAlias(
+ ConstMI->getMF(), ConstMI->getOperand(0).getReg());
+ Inst.addOperand(MCOperand::createReg(ConstReg));
+ outputMCInst(Inst);
+ }
+ }
+}
+
void SPIRVAsmPrinter::outputModuleSections() {
const Module *M = MMI->getModule();
// Get the global subtarget to output module-level info.
@@ -614,7 +798,8 @@ void SPIRVAsmPrinter::outputModuleSections() {
MAI = &SPIRVModuleAnalysis::MAI;
assert(ST && TII && MAI && M && "Module analysis is required");
// Output instructions according to the Logical Layout of a Module:
- // 1,2. All OpCapability instructions, then optional OpExtension instructions.
+ // 1,2. All OpCapability instructions, then optional OpExtension
+ // instructions.
outputGlobalRequirements();
// 3. Optional OpExtInstImport instructions.
outputOpExtInstImports(*M);
@@ -622,7 +807,8 @@ void SPIRVAsmPrinter::outputModuleSections() {
outputOpMemoryModel();
// 5. All entry point declarations, using OpEntryPoint.
outputEntryPoints();
- // 6. Execution-mode declarations, using OpExecutionMode or OpExecutionModeId.
+ // 6. Execution-mode declarations, using OpExecutionMode or
+ // OpExecutionModeId.
outputExecutionMode(*M);
// 7a. Debug: all OpString, OpSourceExtension, OpSource, and
// OpSourceContinued, without forward references.