diff options
Diffstat (limited to 'llvm/lib/Target/SPIRV')
| -rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp | 23 | ||||
| -rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp | 2 | ||||
| -rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp | 2 | ||||
| -rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp | 63 | ||||
| -rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td | 3 |
5 files changed, 60 insertions, 33 deletions
diff --git a/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp b/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp index 0175f2f..970b83d 100644 --- a/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp @@ -612,13 +612,10 @@ void SPIRVAsmPrinter::outputExecutionMode(const Module &M) { // Collect the SPIRVTypes for fp16, fp32, and fp64 and the constant of // type int32 with 0 value to represent the FP Fast Math Mode. std::vector<const MachineInstr *> SPIRVFloatTypes; - const MachineInstr *ConstZero = nullptr; + const MachineInstr *ConstZeroInt32 = nullptr; for (const MachineInstr *MI : MAI->getMSInstrs(SPIRV::MB_TypeConstVars)) { - // Skip if the instruction is not OpTypeFloat or OpConstant. unsigned OpCode = MI->getOpcode(); - if (OpCode != SPIRV::OpTypeFloat && OpCode != SPIRV::OpConstantNull) - continue; // Collect the SPIRV type if it's a float. if (OpCode == SPIRV::OpTypeFloat) { @@ -629,14 +626,18 @@ void SPIRVAsmPrinter::outputExecutionMode(const Module &M) { continue; } SPIRVFloatTypes.push_back(MI); - } else { + continue; + } + + if (OpCode == SPIRV::OpConstantNull) { // Check if the constant is int32, if not skip it. const MachineRegisterInfo &MRI = MI->getMF()->getRegInfo(); MachineInstr *TypeMI = MRI.getVRegDef(MI->getOperand(1).getReg()); - if (!TypeMI || TypeMI->getOperand(1).getImm() != 32) - continue; - - ConstZero = MI; + bool IsInt32Ty = TypeMI && + TypeMI->getOpcode() == SPIRV::OpTypeInt && + TypeMI->getOperand(1).getImm() == 32; + if (IsInt32Ty) + ConstZeroInt32 = MI; } } @@ -657,9 +658,9 @@ void SPIRVAsmPrinter::outputExecutionMode(const Module &M) { MCRegister TypeReg = MAI->getRegisterAlias(MF, MI->getOperand(0).getReg()); Inst.addOperand(MCOperand::createReg(TypeReg)); - assert(ConstZero && "There should be a constant zero."); + assert(ConstZeroInt32 && "There should be a constant zero."); MCRegister ConstReg = MAI->getRegisterAlias( - ConstZero->getMF(), ConstZero->getOperand(0).getReg()); + ConstZeroInt32->getMF(), ConstZeroInt32->getOperand(0).getReg()); Inst.addOperand(MCOperand::createReg(ConstReg)); outputMCInst(Inst); } diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp index f681b0d..ac09b93 100644 --- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp @@ -29,6 +29,8 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>> SPIRV::Extension::Extension::SPV_EXT_shader_atomic_float16_add}, {"SPV_EXT_shader_atomic_float_min_max", SPIRV::Extension::Extension::SPV_EXT_shader_atomic_float_min_max}, + {"SPV_INTEL_16bit_atomics", + SPIRV::Extension::Extension::SPV_INTEL_16bit_atomics}, {"SPV_EXT_arithmetic_fence", SPIRV::Extension::Extension::SPV_EXT_arithmetic_fence}, {"SPV_EXT_demote_to_helper_invocation", diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp index ba95ad8..4f8bf43 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp @@ -24,7 +24,7 @@ using namespace llvm; SPIRVInstrInfo::SPIRVInstrInfo(const SPIRVSubtarget &STI) - : SPIRVGenInstrInfo(STI) {} + : SPIRVGenInstrInfo(STI, RI) {} bool SPIRVInstrInfo::isConstantInstr(const MachineInstr &MI) const { switch (MI.getOpcode()) { diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp index af76016..b8cd9c1 100644 --- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp @@ -249,17 +249,18 @@ static InstrSignature instrToSignature(const MachineInstr &MI, InstrSignature Signature{MI.getOpcode()}; for (unsigned i = 0; i < MI.getNumOperands(); ++i) { // The only decorations that can be applied more than once to a given <id> - // or structure member are UserSemantic(5635), CacheControlLoadINTEL (6442), - // and CacheControlStoreINTEL (6443). For all the rest of decorations, we - // will only add to the signature the Opcode, the id to which it applies, - // and the decoration id, disregarding any decoration flags. This will - // ensure that any subsequent decoration with the same id will be deemed as - // a duplicate. Then, at the call site, we will be able to handle duplicates - // in the best way. + // or structure member are FuncParamAttr (38), UserSemantic (5635), + // CacheControlLoadINTEL (6442), and CacheControlStoreINTEL (6443). For all + // the rest of decorations, we will only add to the signature the Opcode, + // the id to which it applies, and the decoration id, disregarding any + // decoration flags. This will ensure that any subsequent decoration with + // the same id will be deemed as a duplicate. Then, at the call site, we + // will be able to handle duplicates in the best way. unsigned Opcode = MI.getOpcode(); if ((Opcode == SPIRV::OpDecorate) && i >= 2) { unsigned DecorationID = MI.getOperand(1).getImm(); - if (DecorationID != SPIRV::Decoration::UserSemantic && + if (DecorationID != SPIRV::Decoration::FuncParamAttr && + DecorationID != SPIRV::Decoration::UserSemantic && DecorationID != SPIRV::Decoration::CacheControlLoadINTEL && DecorationID != SPIRV::Decoration::CacheControlStoreINTEL) continue; @@ -1058,6 +1059,13 @@ static void addOpTypeImageReqs(const MachineInstr &MI, } } +static bool isBFloat16Type(const SPIRVType *TypeDef) { + return TypeDef && TypeDef->getNumOperands() == 3 && + TypeDef->getOpcode() == SPIRV::OpTypeFloat && + TypeDef->getOperand(1).getImm() == 16 && + TypeDef->getOperand(2).getImm() == SPIRV::FPEncoding::BFloat16KHR; +} + // Add requirements for handling atomic float instructions #define ATOM_FLT_REQ_EXT_MSG(ExtName) \ "The atomic float instruction requires the following SPIR-V " \ @@ -1081,11 +1089,21 @@ static void AddAtomicFloatRequirements(const MachineInstr &MI, Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_add); switch (BitWidth) { case 16: - if (!ST.canUseExtension( - SPIRV::Extension::SPV_EXT_shader_atomic_float16_add)) - report_fatal_error(ATOM_FLT_REQ_EXT_MSG("16_add"), false); - Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float16_add); - Reqs.addCapability(SPIRV::Capability::AtomicFloat16AddEXT); + if (isBFloat16Type(TypeDef)) { + if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_16bit_atomics)) + report_fatal_error( + "The atomic bfloat16 instruction requires the following SPIR-V " + "extension: SPV_INTEL_16bit_atomics", + false); + Reqs.addExtension(SPIRV::Extension::SPV_INTEL_16bit_atomics); + Reqs.addCapability(SPIRV::Capability::AtomicBFloat16AddINTEL); + } else { + if (!ST.canUseExtension( + SPIRV::Extension::SPV_EXT_shader_atomic_float16_add)) + report_fatal_error(ATOM_FLT_REQ_EXT_MSG("16_add"), false); + Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float16_add); + Reqs.addCapability(SPIRV::Capability::AtomicFloat16AddEXT); + } break; case 32: Reqs.addCapability(SPIRV::Capability::AtomicFloat32AddEXT); @@ -1104,7 +1122,17 @@ static void AddAtomicFloatRequirements(const MachineInstr &MI, Reqs.addExtension(SPIRV::Extension::SPV_EXT_shader_atomic_float_min_max); switch (BitWidth) { case 16: - Reqs.addCapability(SPIRV::Capability::AtomicFloat16MinMaxEXT); + if (isBFloat16Type(TypeDef)) { + if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_16bit_atomics)) + report_fatal_error( + "The atomic bfloat16 instruction requires the following SPIR-V " + "extension: SPV_INTEL_16bit_atomics", + false); + Reqs.addExtension(SPIRV::Extension::SPV_INTEL_16bit_atomics); + Reqs.addCapability(SPIRV::Capability::AtomicBFloat16MinMaxINTEL); + } else { + Reqs.addCapability(SPIRV::Capability::AtomicFloat16MinMaxEXT); + } break; case 32: Reqs.addCapability(SPIRV::Capability::AtomicFloat32MinMaxEXT); @@ -1328,13 +1356,6 @@ void addPrintfRequirements(const MachineInstr &MI, } } -static bool isBFloat16Type(const SPIRVType *TypeDef) { - return TypeDef && TypeDef->getNumOperands() == 3 && - TypeDef->getOpcode() == SPIRV::OpTypeFloat && - TypeDef->getOperand(1).getImm() == 16 && - TypeDef->getOperand(2).getImm() == SPIRV::FPEncoding::BFloat16KHR; -} - void addInstrRequirements(const MachineInstr &MI, SPIRV::ModuleAnalysisInfo &MAI, const SPIRVSubtarget &ST) { diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td index 65a8885..f02a587 100644 --- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td +++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td @@ -389,6 +389,7 @@ defm SPV_INTEL_predicated_io : ExtensionOperand<127, [EnvOpenCL]>; defm SPV_KHR_maximal_reconvergence : ExtensionOperand<128, [EnvVulkan]>; defm SPV_INTEL_bfloat16_arithmetic : ExtensionOperand<129, [EnvVulkan, EnvOpenCL]>; +defm SPV_INTEL_16bit_atomics : ExtensionOperand<130, [EnvVulkan, EnvOpenCL]>; //===----------------------------------------------------------------------===// // Multiclass used to define Capabilities enum values and at the same time @@ -566,9 +567,11 @@ defm FloatControls2 defm AtomicFloat32AddEXT : CapabilityOperand<6033, 0, 0, [SPV_EXT_shader_atomic_float_add], []>; defm AtomicFloat64AddEXT : CapabilityOperand<6034, 0, 0, [SPV_EXT_shader_atomic_float_add], []>; defm AtomicFloat16AddEXT : CapabilityOperand<6095, 0, 0, [SPV_EXT_shader_atomic_float16_add], []>; +defm AtomicBFloat16AddINTEL : CapabilityOperand<6255, 0, 0, [SPV_INTEL_16bit_atomics], []>; defm AtomicFloat16MinMaxEXT : CapabilityOperand<5616, 0, 0, [SPV_EXT_shader_atomic_float_min_max], []>; defm AtomicFloat32MinMaxEXT : CapabilityOperand<5612, 0, 0, [SPV_EXT_shader_atomic_float_min_max], []>; defm AtomicFloat64MinMaxEXT : CapabilityOperand<5613, 0, 0, [SPV_EXT_shader_atomic_float_min_max], []>; +defm AtomicBFloat16MinMaxINTEL : CapabilityOperand<6256, 0, 0, [SPV_INTEL_16bit_atomics], []>; defm VariableLengthArrayINTEL : CapabilityOperand<5817, 0, 0, [SPV_INTEL_variable_length_array], []>; defm GroupUniformArithmeticKHR : CapabilityOperand<6400, 0, 0, [SPV_KHR_uniform_group_instructions], []>; defm USMStorageClassesINTEL : CapabilityOperand<5935, 0, 0, [SPV_INTEL_usm_storage_classes], [Kernel]>; |
