diff options
Diffstat (limited to 'llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp')
| -rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp | 46 |
1 files changed, 33 insertions, 13 deletions
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp index af76016..fbb127d 100644 --- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp @@ -1058,6 +1058,13 @@ static void addOpTypeImageReqs(const MachineInstr &MI, } } +static bool isBFloat16Type(const SPIRVType *TypeDef) { + return TypeDef && TypeDef->getNumOperands() == 3 && + TypeDef->getOpcode() == SPIRV::OpTypeFloat && + TypeDef->getOperand(1).getImm() == 16 && + TypeDef->getOperand(2).getImm() == SPIRV::FPEncoding::BFloat16KHR; +} + // Add requirements for handling atomic float instructions #define ATOM_FLT_REQ_EXT_MSG(ExtName) \ "The atomic float instruction requires the following SPIR-V " \ @@ -1081,11 +1088,21 @@ static void AddAtomicFloatRequirements(const MachineInstr &MI, Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_add); switch (BitWidth) { case 16: - if (!ST.canUseExtension( - SPIRV::Extension::SPV_EXT_shader_atomic_float16_add)) - report_fatal_error(ATOM_FLT_REQ_EXT_MSG("16_add"), false); - Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float16_add); - Reqs.addCapability(SPIRV::Capability::AtomicFloat16AddEXT); + if (isBFloat16Type(TypeDef)) { + if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_16bit_atomics)) + report_fatal_error( + "The atomic bfloat16 instruction requires the following SPIR-V " + "extension: SPV_INTEL_16bit_atomics", + false); + Reqs.addExtension(SPIRV::Extension::SPV_INTEL_16bit_atomics); + Reqs.addCapability(SPIRV::Capability::AtomicBFloat16AddINTEL); + } else { + if (!ST.canUseExtension( + SPIRV::Extension::SPV_EXT_shader_atomic_float16_add)) + report_fatal_error(ATOM_FLT_REQ_EXT_MSG("16_add"), false); + Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float16_add); + Reqs.addCapability(SPIRV::Capability::AtomicFloat16AddEXT); + } break; case 32: Reqs.addCapability(SPIRV::Capability::AtomicFloat32AddEXT); @@ -1104,7 +1121,17 @@ static void AddAtomicFloatRequirements(const MachineInstr &MI, Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_min_max); switch (BitWidth) { case 16: - Reqs.addCapability(SPIRV::Capability::AtomicFloat16MinMaxEXT); + if (isBFloat16Type(TypeDef)) { + if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_16bit_atomics)) + report_fatal_error( + "The atomic bfloat16 instruction requires the following SPIR-V " + "extension: SPV_INTEL_16bit_atomics", + false); + Reqs.addExtension(SPIRV::Extension::SPV_INTEL_16bit_atomics); + Reqs.addCapability(SPIRV::Capability::AtomicBFloat16MinMaxINTEL); + } else { + Reqs.addCapability(SPIRV::Capability::AtomicFloat16MinMaxEXT); + } break; case 32: Reqs.addCapability(SPIRV::Capability::AtomicFloat32MinMaxEXT); @@ -1328,13 +1355,6 @@ void addPrintfRequirements(const MachineInstr &MI, } } -static bool isBFloat16Type(const SPIRVType *TypeDef) { - return TypeDef && TypeDef->getNumOperands() == 3 && - TypeDef->getOpcode() == SPIRV::OpTypeFloat && - TypeDef->getOperand(1).getImm() == 16 && - TypeDef->getOperand(2).getImm() == SPIRV::FPEncoding::BFloat16KHR; -} - void addInstrRequirements(const MachineInstr &MI, SPIRV::ModuleAnalysisInfo &MAI, const SPIRVSubtarget &ST) { |
