aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp')
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp46
1 files changed, 33 insertions, 13 deletions
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index af76016..fbb127d 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1058,6 +1058,13 @@ static void addOpTypeImageReqs(const MachineInstr &MI,
}
}
+static bool isBFloat16Type(const SPIRVType *TypeDef) {
+ return TypeDef && TypeDef->getNumOperands() == 3 &&
+ TypeDef->getOpcode() == SPIRV::OpTypeFloat &&
+ TypeDef->getOperand(1).getImm() == 16 &&
+ TypeDef->getOperand(2).getImm() == SPIRV::FPEncoding::BFloat16KHR;
+}
+
// Add requirements for handling atomic float instructions
#define ATOM_FLT_REQ_EXT_MSG(ExtName) \
"The atomic float instruction requires the following SPIR-V " \
@@ -1081,11 +1088,21 @@ static void AddAtomicFloatRequirements(const MachineInstr &MI,
Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_add);
switch (BitWidth) {
case 16:
- if (!ST.canUseExtension(
- SPIRV::Extension::SPV_EXT_shader_atomic_float16_add))
- report_fatal_error(ATOM_FLT_REQ_EXT_MSG("16_add"), false);
- Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float16_add);
- Reqs.addCapability(SPIRV::Capability::AtomicFloat16AddEXT);
+ if (isBFloat16Type(TypeDef)) {
+ if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_16bit_atomics))
+ report_fatal_error(
+ "The atomic bfloat16 instruction requires the following SPIR-V "
+ "extension: SPV_INTEL_16bit_atomics",
+ false);
+ Reqs.addExtension(SPIRV::Extension::SPV_INTEL_16bit_atomics);
+ Reqs.addCapability(SPIRV::Capability::AtomicBFloat16AddINTEL);
+ } else {
+ if (!ST.canUseExtension(
+ SPIRV::Extension::SPV_EXT_shader_atomic_float16_add))
+ report_fatal_error(ATOM_FLT_REQ_EXT_MSG("16_add"), false);
+ Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float16_add);
+ Reqs.addCapability(SPIRV::Capability::AtomicFloat16AddEXT);
+ }
break;
case 32:
Reqs.addCapability(SPIRV::Capability::AtomicFloat32AddEXT);
@@ -1104,7 +1121,17 @@ static void AddAtomicFloatRequirements(const MachineInstr &MI,
Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_min_max);
switch (BitWidth) {
case 16:
- Reqs.addCapability(SPIRV::Capability::AtomicFloat16MinMaxEXT);
+ if (isBFloat16Type(TypeDef)) {
+ if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_16bit_atomics))
+ report_fatal_error(
+ "The atomic bfloat16 instruction requires the following SPIR-V "
+ "extension: SPV_INTEL_16bit_atomics",
+ false);
+ Reqs.addExtension(SPIRV::Extension::SPV_INTEL_16bit_atomics);
+ Reqs.addCapability(SPIRV::Capability::AtomicBFloat16MinMaxINTEL);
+ } else {
+ Reqs.addCapability(SPIRV::Capability::AtomicFloat16MinMaxEXT);
+ }
break;
case 32:
Reqs.addCapability(SPIRV::Capability::AtomicFloat32MinMaxEXT);
@@ -1328,13 +1355,6 @@ void addPrintfRequirements(const MachineInstr &MI,
}
}
-static bool isBFloat16Type(const SPIRVType *TypeDef) {
- return TypeDef && TypeDef->getNumOperands() == 3 &&
- TypeDef->getOpcode() == SPIRV::OpTypeFloat &&
- TypeDef->getOperand(1).getImm() == 16 &&
- TypeDef->getOperand(2).getImm() == SPIRV::FPEncoding::BFloat16KHR;
-}
-
void addInstrRequirements(const MachineInstr &MI,
SPIRV::ModuleAnalysisInfo &MAI,
const SPIRVSubtarget &ST) {