aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target/SPIRV
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/SPIRV')
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp23
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp2
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp2
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp63
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td3
5 files changed, 60 insertions, 33 deletions
diff --git a/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp b/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp
index 0175f2f..970b83d 100644
--- a/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp
@@ -612,13 +612,10 @@ void SPIRVAsmPrinter::outputExecutionMode(const Module &M) {
// Collect the SPIRVTypes for fp16, fp32, and fp64 and the constant of
// type int32 with 0 value to represent the FP Fast Math Mode.
std::vector<const MachineInstr *> SPIRVFloatTypes;
- const MachineInstr *ConstZero = nullptr;
+ const MachineInstr *ConstZeroInt32 = nullptr;
for (const MachineInstr *MI :
MAI->getMSInstrs(SPIRV::MB_TypeConstVars)) {
- // Skip if the instruction is not OpTypeFloat or OpConstant.
unsigned OpCode = MI->getOpcode();
- if (OpCode != SPIRV::OpTypeFloat && OpCode != SPIRV::OpConstantNull)
- continue;
// Collect the SPIRV type if it's a float.
if (OpCode == SPIRV::OpTypeFloat) {
@@ -629,14 +626,18 @@ void SPIRVAsmPrinter::outputExecutionMode(const Module &M) {
continue;
}
SPIRVFloatTypes.push_back(MI);
- } else {
+ continue;
+ }
+
+ if (OpCode == SPIRV::OpConstantNull) {
// Check if the constant is int32, if not skip it.
const MachineRegisterInfo &MRI = MI->getMF()->getRegInfo();
MachineInstr *TypeMI = MRI.getVRegDef(MI->getOperand(1).getReg());
- if (!TypeMI || TypeMI->getOperand(1).getImm() != 32)
- continue;
-
- ConstZero = MI;
+ bool IsInt32Ty = TypeMI &&
+ TypeMI->getOpcode() == SPIRV::OpTypeInt &&
+ TypeMI->getOperand(1).getImm() == 32;
+ if (IsInt32Ty)
+ ConstZeroInt32 = MI;
}
}
@@ -657,9 +658,9 @@ void SPIRVAsmPrinter::outputExecutionMode(const Module &M) {
MCRegister TypeReg =
MAI->getRegisterAlias(MF, MI->getOperand(0).getReg());
Inst.addOperand(MCOperand::createReg(TypeReg));
- assert(ConstZero && "There should be a constant zero.");
+ assert(ConstZeroInt32 && "There should be a constant zero.");
MCRegister ConstReg = MAI->getRegisterAlias(
- ConstZero->getMF(), ConstZero->getOperand(0).getReg());
+ ConstZeroInt32->getMF(), ConstZeroInt32->getOperand(0).getReg());
Inst.addOperand(MCOperand::createReg(ConstReg));
outputMCInst(Inst);
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
index f681b0d..ac09b93 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
@@ -29,6 +29,8 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
SPIRV::Extension::Extension::SPV_EXT_shader_atomic_float16_add},
{"SPV_EXT_shader_atomic_float_min_max",
SPIRV::Extension::Extension::SPV_EXT_shader_atomic_float_min_max},
+ {"SPV_INTEL_16bit_atomics",
+ SPIRV::Extension::Extension::SPV_INTEL_16bit_atomics},
{"SPV_EXT_arithmetic_fence",
SPIRV::Extension::Extension::SPV_EXT_arithmetic_fence},
{"SPV_EXT_demote_to_helper_invocation",
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp
index ba95ad8..4f8bf43 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp
@@ -24,7 +24,7 @@
using namespace llvm;
SPIRVInstrInfo::SPIRVInstrInfo(const SPIRVSubtarget &STI)
- : SPIRVGenInstrInfo(STI) {}
+ : SPIRVGenInstrInfo(STI, RI) {}
bool SPIRVInstrInfo::isConstantInstr(const MachineInstr &MI) const {
switch (MI.getOpcode()) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index af76016..b8cd9c1 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -249,17 +249,18 @@ static InstrSignature instrToSignature(const MachineInstr &MI,
InstrSignature Signature{MI.getOpcode()};
for (unsigned i = 0; i < MI.getNumOperands(); ++i) {
// The only decorations that can be applied more than once to a given <id>
- // or structure member are UserSemantic(5635), CacheControlLoadINTEL (6442),
- // and CacheControlStoreINTEL (6443). For all the rest of decorations, we
- // will only add to the signature the Opcode, the id to which it applies,
- // and the decoration id, disregarding any decoration flags. This will
- // ensure that any subsequent decoration with the same id will be deemed as
- // a duplicate. Then, at the call site, we will be able to handle duplicates
- // in the best way.
+ // or structure member are FuncParamAttr (38), UserSemantic (5635),
+ // CacheControlLoadINTEL (6442), and CacheControlStoreINTEL (6443). For all
+ // the rest of decorations, we will only add to the signature the Opcode,
+ // the id to which it applies, and the decoration id, disregarding any
+ // decoration flags. This will ensure that any subsequent decoration with
+ // the same id will be deemed as a duplicate. Then, at the call site, we
+ // will be able to handle duplicates in the best way.
unsigned Opcode = MI.getOpcode();
if ((Opcode == SPIRV::OpDecorate) && i >= 2) {
unsigned DecorationID = MI.getOperand(1).getImm();
- if (DecorationID != SPIRV::Decoration::UserSemantic &&
+ if (DecorationID != SPIRV::Decoration::FuncParamAttr &&
+ DecorationID != SPIRV::Decoration::UserSemantic &&
DecorationID != SPIRV::Decoration::CacheControlLoadINTEL &&
DecorationID != SPIRV::Decoration::CacheControlStoreINTEL)
continue;
@@ -1058,6 +1059,13 @@ static void addOpTypeImageReqs(const MachineInstr &MI,
}
}
+static bool isBFloat16Type(const SPIRVType *TypeDef) {
+ return TypeDef && TypeDef->getNumOperands() == 3 &&
+ TypeDef->getOpcode() == SPIRV::OpTypeFloat &&
+ TypeDef->getOperand(1).getImm() == 16 &&
+ TypeDef->getOperand(2).getImm() == SPIRV::FPEncoding::BFloat16KHR;
+}
+
// Add requirements for handling atomic float instructions
#define ATOM_FLT_REQ_EXT_MSG(ExtName) \
"The atomic float instruction requires the following SPIR-V " \
@@ -1081,11 +1089,21 @@ static void AddAtomicFloatRequirements(const MachineInstr &MI,
Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_add);
switch (BitWidth) {
case 16:
- if (!ST.canUseExtension(
- SPIRV::Extension::SPV_EXT_shader_atomic_float16_add))
- report_fatal_error(ATOM_FLT_REQ_EXT_MSG("16_add"), false);
- Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float16_add);
- Reqs.addCapability(SPIRV::Capability::AtomicFloat16AddEXT);
+ if (isBFloat16Type(TypeDef)) {
+ if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_16bit_atomics))
+ report_fatal_error(
+ "The atomic bfloat16 instruction requires the following SPIR-V "
+ "extension: SPV_INTEL_16bit_atomics",
+ false);
+ Reqs.addExtension(SPIRV::Extension::SPV_INTEL_16bit_atomics);
+ Reqs.addCapability(SPIRV::Capability::AtomicBFloat16AddINTEL);
+ } else {
+ if (!ST.canUseExtension(
+ SPIRV::Extension::SPV_EXT_shader_atomic_float16_add))
+ report_fatal_error(ATOM_FLT_REQ_EXT_MSG("16_add"), false);
+ Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float16_add);
+ Reqs.addCapability(SPIRV::Capability::AtomicFloat16AddEXT);
+ }
break;
case 32:
Reqs.addCapability(SPIRV::Capability::AtomicFloat32AddEXT);
@@ -1104,7 +1122,17 @@ static void AddAtomicFloatRequirements(const MachineInstr &MI,
Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_min_max);
switch (BitWidth) {
case 16:
- Reqs.addCapability(SPIRV::Capability::AtomicFloat16MinMaxEXT);
+ if (isBFloat16Type(TypeDef)) {
+ if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_16bit_atomics))
+ report_fatal_error(
+ "The atomic bfloat16 instruction requires the following SPIR-V "
+ "extension: SPV_INTEL_16bit_atomics",
+ false);
+ Reqs.addExtension(SPIRV::Extension::SPV_INTEL_16bit_atomics);
+ Reqs.addCapability(SPIRV::Capability::AtomicBFloat16MinMaxINTEL);
+ } else {
+ Reqs.addCapability(SPIRV::Capability::AtomicFloat16MinMaxEXT);
+ }
break;
case 32:
Reqs.addCapability(SPIRV::Capability::AtomicFloat32MinMaxEXT);
@@ -1328,13 +1356,6 @@ void addPrintfRequirements(const MachineInstr &MI,
}
}
-static bool isBFloat16Type(const SPIRVType *TypeDef) {
- return TypeDef && TypeDef->getNumOperands() == 3 &&
- TypeDef->getOpcode() == SPIRV::OpTypeFloat &&
- TypeDef->getOperand(1).getImm() == 16 &&
- TypeDef->getOperand(2).getImm() == SPIRV::FPEncoding::BFloat16KHR;
-}
-
void addInstrRequirements(const MachineInstr &MI,
SPIRV::ModuleAnalysisInfo &MAI,
const SPIRVSubtarget &ST) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index 65a8885..f02a587 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -389,6 +389,7 @@ defm SPV_INTEL_predicated_io : ExtensionOperand<127, [EnvOpenCL]>;
defm SPV_KHR_maximal_reconvergence : ExtensionOperand<128, [EnvVulkan]>;
defm SPV_INTEL_bfloat16_arithmetic
: ExtensionOperand<129, [EnvVulkan, EnvOpenCL]>;
+defm SPV_INTEL_16bit_atomics : ExtensionOperand<130, [EnvVulkan, EnvOpenCL]>;
//===----------------------------------------------------------------------===//
// Multiclass used to define Capabilities enum values and at the same time
@@ -566,9 +567,11 @@ defm FloatControls2
defm AtomicFloat32AddEXT : CapabilityOperand<6033, 0, 0, [SPV_EXT_shader_atomic_float_add], []>;
defm AtomicFloat64AddEXT : CapabilityOperand<6034, 0, 0, [SPV_EXT_shader_atomic_float_add], []>;
defm AtomicFloat16AddEXT : CapabilityOperand<6095, 0, 0, [SPV_EXT_shader_atomic_float16_add], []>;
+defm AtomicBFloat16AddINTEL : CapabilityOperand<6255, 0, 0, [SPV_INTEL_16bit_atomics], []>;
defm AtomicFloat16MinMaxEXT : CapabilityOperand<5616, 0, 0, [SPV_EXT_shader_atomic_float_min_max], []>;
defm AtomicFloat32MinMaxEXT : CapabilityOperand<5612, 0, 0, [SPV_EXT_shader_atomic_float_min_max], []>;
defm AtomicFloat64MinMaxEXT : CapabilityOperand<5613, 0, 0, [SPV_EXT_shader_atomic_float_min_max], []>;
+defm AtomicBFloat16MinMaxINTEL : CapabilityOperand<6256, 0, 0, [SPV_INTEL_16bit_atomics], []>;
defm VariableLengthArrayINTEL : CapabilityOperand<5817, 0, 0, [SPV_INTEL_variable_length_array], []>;
defm GroupUniformArithmeticKHR : CapabilityOperand<6400, 0, 0, [SPV_KHR_uniform_group_instructions], []>;
defm USMStorageClassesINTEL : CapabilityOperand<5935, 0, 0, [SPV_INTEL_usm_storage_classes], [Kernel]>;