aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target/SPIRV
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/SPIRV')
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVBuiltins.td7
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp17
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp5
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVISelLowering.h3
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVInstrInfo.td9
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp36
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp16
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp16
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp11
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp10
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td4
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVUtils.cpp19
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVUtils.h3
13 files changed, 127 insertions, 29 deletions
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
index e6e3560..28a63b9 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
@@ -619,7 +619,8 @@ class GroupBuiltin<string name, Op operation> {
!eq(operation, OpGroupNonUniformShuffleDown),
!eq(operation, OpGroupBroadcast),
!eq(operation, OpGroupNonUniformBroadcast),
- !eq(operation, OpGroupNonUniformBroadcastFirst));
+ !eq(operation, OpGroupNonUniformBroadcastFirst),
+ !eq(operation, OpGroupNonUniformRotateKHR));
bit HasBoolArg = !or(!and(IsAllOrAny, !eq(IsAllEqual, false)), IsBallot, IsLogical);
}
@@ -877,6 +878,10 @@ defm : DemangledGroupBuiltin<"group_non_uniform_scan_inclusive_logical_xors", Wo
defm : DemangledGroupBuiltin<"group_non_uniform_scan_exclusive_logical_xors", WorkOrSub, OpGroupNonUniformLogicalXor>;
defm : DemangledGroupBuiltin<"group_clustered_reduce_logical_xor", WorkOrSub, OpGroupNonUniformLogicalXor>;
+// cl_khr_subgroup_rotate / SPV_KHR_subgroup_rotate
+defm : DemangledGroupBuiltin<"group_rotate", OnlySub, OpGroupNonUniformRotateKHR>;
+defm : DemangledGroupBuiltin<"group_clustered_rotate", OnlySub, OpGroupNonUniformRotateKHR>;
+
// cl_khr_work_group_uniform_arithmetic / SPV_KHR_uniform_group_instructions
defm : DemangledGroupBuiltin<"group_reduce_imul", OnlyWork, OpGroupIMulKHR>;
defm : DemangledGroupBuiltin<"group_reduce_mulu", OnlyWork, OpGroupIMulKHR>;
diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index cc438b2..10569ef 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -150,7 +150,8 @@ getKernelArgTypeQual(const Function &F, unsigned ArgIdx) {
static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
SPIRVGlobalRegistry *GR,
- MachineIRBuilder &MIRBuilder) {
+ MachineIRBuilder &MIRBuilder,
+ const SPIRVSubtarget &ST) {
// Read argument's access qualifier from metadata or default.
SPIRV::AccessQualifier::AccessQualifier ArgAccessQual =
getArgAccessQual(F, ArgIdx);
@@ -169,8 +170,8 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
if (MDTypeStr.ends_with("*"))
ResArgType = GR->getOrCreateSPIRVTypeByName(
MDTypeStr, MIRBuilder,
- addressSpaceToStorageClass(
- OriginalArgType->getPointerAddressSpace()));
+ addressSpaceToStorageClass(OriginalArgType->getPointerAddressSpace(),
+ ST));
else if (MDTypeStr.ends_with("_t"))
ResArgType = GR->getOrCreateSPIRVTypeByName(
"opencl." + MDTypeStr.str(), MIRBuilder,
@@ -206,6 +207,10 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
assert(GR && "Must initialize the SPIRV type registry before lowering args.");
GR->setCurrentFunc(MIRBuilder.getMF());
+ // Get access to information about available extensions
+ const SPIRVSubtarget *ST =
+ static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget());
+
// Assign types and names to all args, and store their types for later.
FunctionType *FTy = getOriginalFunctionType(F);
SmallVector<SPIRVType *, 4> ArgTypeVRegs;
@@ -216,7 +221,7 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
// TODO: handle the case of multiple registers.
if (VRegs[i].size() > 1)
return false;
- auto *SpirvTy = getArgSPIRVType(F, i, GR, MIRBuilder);
+ auto *SpirvTy = getArgSPIRVType(F, i, GR, MIRBuilder, *ST);
GR->assignSPIRVTypeToVReg(SpirvTy, VRegs[i][0], MIRBuilder.getMF());
ArgTypeVRegs.push_back(SpirvTy);
@@ -318,10 +323,6 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
if (F.hasName())
buildOpName(FuncVReg, F.getName(), MIRBuilder);
- // Get access to information about available extensions
- const auto *ST =
- static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget());
-
// Handle entry points and function linkage.
if (isEntryPoint(F)) {
const auto &STI = MIRBuilder.getMF().getSubtarget<SPIRVSubtarget>();
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 47fec74..a1cb630 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -709,7 +709,10 @@ SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
// TODO: change the implementation once opaque pointers are supported
// in the SPIR-V specification.
SpvElementType = getOrCreateSPIRVIntegerType(8, MIRBuilder);
- auto SC = addressSpaceToStorageClass(PType->getAddressSpace());
+ // Get access to information about available extensions
+ const SPIRVSubtarget *ST =
+ static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget());
+ auto SC = addressSpaceToStorageClass(PType->getAddressSpace(), *ST);
// Null pointer means we have a loop in type definitions, make and
// return corresponding OpTypeForwardPointer.
if (SpvElementType == nullptr) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.h b/llvm/lib/Target/SPIRV/SPIRVISelLowering.h
index f317b26..d34f802 100644
--- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.h
+++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.h
@@ -31,6 +31,9 @@ public:
return true;
}
+ // prevent creation of jump tables
+ bool areJTsAllowed(const Function *) const override { return false; }
+
// This is to prevent sexts of non-i64 vector indices which are generated
// within general IRTranslator hence type generation for it is omitted.
MVT getVectorIdxTy(const DataLayout &DL) const override {
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
index 0f11bc3..7c5252e 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
@@ -430,6 +430,10 @@ def OpGenericCastToPtrExplicit : Op<123, (outs ID:$r), (ins TYPE:$t, ID:$p, Stor
"$r = OpGenericCastToPtrExplicit $t $p $s">;
def OpBitcast : UnOp<"OpBitcast", 124>;
+// SPV_INTEL_usm_storage_classes
+def OpPtrCastToCrossWorkgroupINTEL : UnOp<"OpPtrCastToCrossWorkgroupINTEL", 5934>;
+def OpCrossWorkgroupCastToPtrINTEL : UnOp<"OpCrossWorkgroupCastToPtrINTEL", 5938>;
+
// 3.42.12 Composite Instructions
def OpVectorExtractDynamic: Op<77, (outs ID:$res), (ins TYPE:$type, vID:$vec, ID:$idx),
@@ -765,6 +769,11 @@ def OpGroupNonUniformLogicalAnd: OpGroupNUGroup<"LogicalAnd", 362>;
def OpGroupNonUniformLogicalOr: OpGroupNUGroup<"LogicalOr", 363>;
def OpGroupNonUniformLogicalXor: OpGroupNUGroup<"LogicalXor", 364>;
+// SPV_KHR_subgroup_rotate
+def OpGroupNonUniformRotateKHR: Op<4431, (outs ID:$res),
+ (ins TYPE:$type, ID:$scope, ID:$value, ID:$delta, variable_ops),
+ "$res = OpGroupNonUniformRotateKHR $type $scope $value $delta">;
+
// 3.49.7, Constant-Creation Instructions
// - SPV_INTEL_function_pointers
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 53d19a1..7258d3b 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -828,8 +828,18 @@ static bool isGenericCastablePtr(SPIRV::StorageClass::StorageClass SC) {
}
}
+static bool isUSMStorageClass(SPIRV::StorageClass::StorageClass SC) {
+ switch (SC) {
+ case SPIRV::StorageClass::DeviceOnlyINTEL:
+ case SPIRV::StorageClass::HostOnlyINTEL:
+ return true;
+ default:
+ return false;
+ }
+}
+
// In SPIR-V address space casting can only happen to and from the Generic
-// storage class. We can also only case Workgroup, CrossWorkgroup, or Function
+// storage class. We can also only cast Workgroup, CrossWorkgroup, or Function
// pointers to and from Generic pointers. As such, we can convert e.g. from
// Workgroup to Function by going via a Generic pointer as an intermediary. All
// other combinations can only be done by a bitcast, and are probably not safe.
@@ -862,13 +872,17 @@ bool SPIRVInstructionSelector::selectAddrSpaceCast(Register ResVReg,
SPIRV::StorageClass::StorageClass SrcSC = GR.getPointerStorageClass(SrcPtr);
SPIRV::StorageClass::StorageClass DstSC = GR.getPointerStorageClass(ResVReg);
- // Casting from an eligable pointer to Generic.
+ // don't generate a cast between identical storage classes
+ if (SrcSC == DstSC)
+ return true;
+
+ // Casting from an eligible pointer to Generic.
if (DstSC == SPIRV::StorageClass::Generic && isGenericCastablePtr(SrcSC))
return selectUnOp(ResVReg, ResType, I, SPIRV::OpPtrCastToGeneric);
- // Casting from Generic to an eligable pointer.
+ // Casting from Generic to an eligible pointer.
if (SrcSC == SPIRV::StorageClass::Generic && isGenericCastablePtr(DstSC))
return selectUnOp(ResVReg, ResType, I, SPIRV::OpGenericCastToPtr);
- // Casting between 2 eligable pointers using Generic as an intermediary.
+ // Casting between 2 eligible pointers using Generic as an intermediary.
if (isGenericCastablePtr(SrcSC) && isGenericCastablePtr(DstSC)) {
Register Tmp = MRI->createVirtualRegister(&SPIRV::IDRegClass);
SPIRVType *GenericPtrTy = GR.getOrCreateSPIRVPointerType(
@@ -886,6 +900,16 @@ bool SPIRVInstructionSelector::selectAddrSpaceCast(Register ResVReg,
.addUse(Tmp)
.constrainAllUses(TII, TRI, RBI);
}
+
+ // Check if instructions from the SPV_INTEL_usm_storage_classes extension may
+ // be applied
+ if (isUSMStorageClass(SrcSC) && DstSC == SPIRV::StorageClass::CrossWorkgroup)
+ return selectUnOp(ResVReg, ResType, I,
+ SPIRV::OpPtrCastToCrossWorkgroupINTEL);
+ if (SrcSC == SPIRV::StorageClass::CrossWorkgroup && isUSMStorageClass(DstSC))
+ return selectUnOp(ResVReg, ResType, I,
+ SPIRV::OpCrossWorkgroupCastToPtrINTEL);
+
// TODO Should this case just be disallowed completely?
// We're casting 2 other arbitrary address spaces, so have to bitcast.
return selectUnOp(ResVReg, ResType, I, SPIRV::OpBitcast);
@@ -1545,7 +1569,7 @@ bool SPIRVInstructionSelector::selectGlobalValue(
}
SPIRVType *ResType = GR.getOrCreateSPIRVPointerType(
PointerBaseType, I, TII,
- addressSpaceToStorageClass(GV->getAddressSpace()));
+ addressSpaceToStorageClass(GV->getAddressSpace(), STI));
std::string GlobalIdent;
if (!GV->hasName()) {
@@ -1618,7 +1642,7 @@ bool SPIRVInstructionSelector::selectGlobalValue(
unsigned AddrSpace = GV->getAddressSpace();
SPIRV::StorageClass::StorageClass Storage =
- addressSpaceToStorageClass(AddrSpace);
+ addressSpaceToStorageClass(AddrSpace, STI);
bool HasLnkTy = GV->getLinkage() != GlobalValue::InternalLinkage &&
Storage != SPIRV::StorageClass::Function;
SPIRV::LinkageType::LinkageType LnkType =
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
index 011a550..4f2e7a2 100644
--- a/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
@@ -102,14 +102,16 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
const LLT p2 = LLT::pointer(2, PSize); // UniformConstant
const LLT p3 = LLT::pointer(3, PSize); // Workgroup
const LLT p4 = LLT::pointer(4, PSize); // Generic
- const LLT p5 = LLT::pointer(5, PSize); // Input
+ const LLT p5 =
+ LLT::pointer(5, PSize); // Input, SPV_INTEL_usm_storage_classes (Device)
+ const LLT p6 = LLT::pointer(6, PSize); // SPV_INTEL_usm_storage_classes (Host)
// TODO: remove copy-pasting here by using concatenation in some way.
auto allPtrsScalarsAndVectors = {
- p0, p1, p2, p3, p4, p5, s1, s8, s16,
- s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8,
- v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64, v8s1,
- v8s8, v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
+ p0, p1, p2, p3, p4, p5, p6, s1, s8, s16,
+ s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8, v3s16,
+ v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64, v8s1, v8s8, v8s16,
+ v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
auto allScalarsAndVectors = {
s1, s8, s16, s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64,
@@ -133,8 +135,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
auto allFloatAndIntScalars = allIntScalars;
- auto allPtrs = {p0, p1, p2, p3, p4, p5};
- auto allWritablePtrs = {p0, p1, p3, p4};
+ auto allPtrs = {p0, p1, p2, p3, p4, p5, p6};
+ auto allWritablePtrs = {p0, p1, p3, p4, p5, p6};
for (auto Opc : TypeFoldingSupportingOpcs)
getActionDefinitionsBuilder(Opc).custom();
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index dbda287..3be28c9 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1063,12 +1063,28 @@ void addInstrRequirements(const MachineInstr &MI,
Reqs.addCapability(SPIRV::Capability::ExpectAssumeKHR);
}
break;
+ case SPIRV::OpPtrCastToCrossWorkgroupINTEL:
+ case SPIRV::OpCrossWorkgroupCastToPtrINTEL:
+ if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes)) {
+ Reqs.addExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes);
+ Reqs.addCapability(SPIRV::Capability::USMStorageClassesINTEL);
+ }
+ break;
case SPIRV::OpConstantFunctionPointerINTEL:
if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers)) {
Reqs.addExtension(SPIRV::Extension::SPV_INTEL_function_pointers);
Reqs.addCapability(SPIRV::Capability::FunctionPointersINTEL);
}
break;
+ case SPIRV::OpGroupNonUniformRotateKHR:
+ if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_subgroup_rotate))
+ report_fatal_error("OpGroupNonUniformRotateKHR instruction requires the "
+ "following SPIR-V extension: SPV_KHR_subgroup_rotate",
+ false);
+ Reqs.addExtension(SPIRV::Extension::SPV_KHR_subgroup_rotate);
+ Reqs.addCapability(SPIRV::Capability::GroupNonUniformRotateKHR);
+ Reqs.addCapability(SPIRV::Capability::GroupNonUniform);
+ break;
case SPIRV::OpGroupIMulKHR:
case SPIRV::OpGroupFMulKHR:
case SPIRV::OpGroupBitwiseAndKHR:
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index cbc16fa..1442168 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -122,6 +122,9 @@ static void foldConstantsIntoIntrinsics(MachineFunction &MF) {
static void insertBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR,
MachineIRBuilder MIB) {
+ // Get access to information about available extensions
+ const SPIRVSubtarget *ST =
+ static_cast<const SPIRVSubtarget *>(&MIB.getMF().getSubtarget());
SmallVector<MachineInstr *, 10> ToErase;
for (MachineBasicBlock &MBB : MF) {
for (MachineInstr &MI : MBB) {
@@ -141,7 +144,7 @@ static void insertBitcasts(MachineFunction &MF, SPIRVGlobalRegistry *GR,
getMDOperandAsType(MI.getOperand(3).getMetadata(), 0), MIB);
SPIRVType *AssignedPtrType = GR->getOrCreateSPIRVPointerType(
BaseTy, MI, *MF.getSubtarget<SPIRVSubtarget>().getInstrInfo(),
- addressSpaceToStorageClass(MI.getOperand(4).getImm()));
+ addressSpaceToStorageClass(MI.getOperand(4).getImm(), *ST));
// If the bitcast would be redundant, replace all uses with the source
// register.
@@ -250,6 +253,10 @@ Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpirvTy,
static void generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
MachineIRBuilder MIB) {
+ // Get access to information about available extensions
+ const SPIRVSubtarget *ST =
+ static_cast<const SPIRVSubtarget *>(&MIB.getMF().getSubtarget());
+
MachineRegisterInfo &MRI = MF.getRegInfo();
SmallVector<MachineInstr *, 10> ToErase;
@@ -269,7 +276,7 @@ static void generateAssignInstrs(MachineFunction &MF, SPIRVGlobalRegistry *GR,
getMDOperandAsType(MI.getOperand(2).getMetadata(), 0), MIB);
SPIRVType *AssignedPtrType = GR->getOrCreateSPIRVPointerType(
BaseTy, MI, *MF.getSubtarget<SPIRVSubtarget>().getInstrInfo(),
- addressSpaceToStorageClass(MI.getOperand(3).getImm()));
+ addressSpaceToStorageClass(MI.getOperand(3).getImm(), *ST));
MachineInstr *Def = MRI.getVRegDef(Reg);
assert(Def && "Expecting an instruction that defines the register");
insertAssignInstr(Reg, nullptr, AssignedPtrType, GR, MIB,
diff --git a/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp b/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp
index e186154..79f1614 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp
@@ -49,6 +49,12 @@ cl::list<SPIRV::Extension::Extension> Extensions(
clEnumValN(SPIRV::Extension::SPV_INTEL_optnone, "SPV_INTEL_optnone",
"Adds OptNoneINTEL value for Function Control mask that "
"indicates a request to not optimize the function."),
+ clEnumValN(SPIRV::Extension::SPV_INTEL_usm_storage_classes,
+ "SPV_INTEL_usm_storage_classes",
+ "Introduces two new storage classes that are sub classes of "
+ "the CrossWorkgroup storage class "
+ "that provides additional information that can enable "
+ "optimization."),
clEnumValN(SPIRV::Extension::SPV_INTEL_subgroups, "SPV_INTEL_subgroups",
"Allows work items in a subgroup to share data without the "
"use of local memory and work group barriers, and to "
@@ -75,6 +81,10 @@ cl::list<SPIRV::Extension::Extension> Extensions(
"Allows to use the LinkOnceODR linkage type that is to let "
"a function or global variable to be merged with other functions "
"or global variables of the same name when linkage occurs."),
+ clEnumValN(SPIRV::Extension::SPV_KHR_subgroup_rotate,
+ "SPV_KHR_subgroup_rotate",
+ "Adds a new instruction that enables rotating values across "
+ "invocations within a subgroup."),
clEnumValN(SPIRV::Extension::SPV_INTEL_function_pointers,
"SPV_INTEL_function_pointers",
"Allows translation of function pointers.")));
diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index 4e5ac0d..b022b97 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -455,6 +455,7 @@ defm BitInstructions : CapabilityOperand<6025, 0, 0, [SPV_KHR_bit_instructions],
defm ExpectAssumeKHR : CapabilityOperand<5629, 0, 0, [SPV_KHR_expect_assume], []>;
defm FunctionPointersINTEL : CapabilityOperand<5603, 0, 0, [SPV_INTEL_function_pointers], []>;
defm IndirectReferencesINTEL : CapabilityOperand<5604, 0, 0, [SPV_INTEL_function_pointers], []>;
+defm GroupNonUniformRotateKHR : CapabilityOperand<6026, 0, 0, [SPV_KHR_subgroup_rotate], [GroupNonUniform]>;
defm AtomicFloat32AddEXT : CapabilityOperand<6033, 0, 0, [SPV_EXT_shader_atomic_float_add], []>;
defm AtomicFloat64AddEXT : CapabilityOperand<6034, 0, 0, [SPV_EXT_shader_atomic_float_add], []>;
defm AtomicFloat16AddEXT : CapabilityOperand<6095, 0, 0, [SPV_EXT_shader_atomic_float16_add], []>;
@@ -462,6 +463,7 @@ defm AtomicFloat16MinMaxEXT : CapabilityOperand<5616, 0, 0, [SPV_EXT_shader_atom
defm AtomicFloat32MinMaxEXT : CapabilityOperand<5612, 0, 0, [SPV_EXT_shader_atomic_float_min_max], []>;
defm AtomicFloat64MinMaxEXT : CapabilityOperand<5613, 0, 0, [SPV_EXT_shader_atomic_float_min_max], []>;
defm GroupUniformArithmeticKHR : CapabilityOperand<6400, 0, 0, [SPV_KHR_uniform_group_instructions], []>;
+defm USMStorageClassesINTEL : CapabilityOperand<5935, 0, 0, [SPV_INTEL_usm_storage_classes], [Kernel]>;
//===----------------------------------------------------------------------===//
// Multiclass used to define SourceLanguage enum values and at the same time
@@ -699,6 +701,8 @@ defm IncomingRayPayloadNV : StorageClassOperand<5342, [RayTracingNV]>;
defm ShaderRecordBufferNV : StorageClassOperand<5343, [RayTracingNV]>;
defm PhysicalStorageBufferEXT : StorageClassOperand<5349, [PhysicalStorageBufferAddressesEXT]>;
defm CodeSectionINTEL : StorageClassOperand<5605, [FunctionPointersINTEL]>;
+defm DeviceOnlyINTEL : StorageClassOperand<5936, [USMStorageClassesINTEL]>;
+defm HostOnlyINTEL : StorageClassOperand<5937, [USMStorageClassesINTEL]>;
//===----------------------------------------------------------------------===//
// Multiclass used to define Dim enum values and at the same time
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
index 05f766d..169d7cc 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
@@ -14,6 +14,7 @@
#include "MCTargetDesc/SPIRVBaseInfo.h"
#include "SPIRV.h"
#include "SPIRVInstrInfo.h"
+#include "SPIRVSubtarget.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
@@ -146,15 +147,19 @@ unsigned storageClassToAddressSpace(SPIRV::StorageClass::StorageClass SC) {
return 3;
case SPIRV::StorageClass::Generic:
return 4;
+ case SPIRV::StorageClass::DeviceOnlyINTEL:
+ return 5;
+ case SPIRV::StorageClass::HostOnlyINTEL:
+ return 6;
case SPIRV::StorageClass::Input:
return 7;
default:
- llvm_unreachable("Unable to get address space id");
+ report_fatal_error("Unable to get address space id");
}
}
SPIRV::StorageClass::StorageClass
-addressSpaceToStorageClass(unsigned AddrSpace) {
+addressSpaceToStorageClass(unsigned AddrSpace, const SPIRVSubtarget &STI) {
switch (AddrSpace) {
case 0:
return SPIRV::StorageClass::Function;
@@ -166,10 +171,18 @@ addressSpaceToStorageClass(unsigned AddrSpace) {
return SPIRV::StorageClass::Workgroup;
case 4:
return SPIRV::StorageClass::Generic;
+ case 5:
+ return STI.canUseExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes)
+ ? SPIRV::StorageClass::DeviceOnlyINTEL
+ : SPIRV::StorageClass::CrossWorkgroup;
+ case 6:
+ return STI.canUseExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes)
+ ? SPIRV::StorageClass::HostOnlyINTEL
+ : SPIRV::StorageClass::CrossWorkgroup;
case 7:
return SPIRV::StorageClass::Input;
default:
- llvm_unreachable("Unknown address space");
+ report_fatal_error("Unknown address space");
}
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h
index a33dc02..1af53dc 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.h
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h
@@ -27,6 +27,7 @@ class MachineRegisterInfo;
class Register;
class StringRef;
class SPIRVInstrInfo;
+class SPIRVSubtarget;
// Add the given string as a series of integer operand, inserting null
// terminators and padding to make sure the operands all have 32-bit
@@ -62,7 +63,7 @@ unsigned storageClassToAddressSpace(SPIRV::StorageClass::StorageClass SC);
// Convert an LLVM IR address space to a SPIR-V storage class.
SPIRV::StorageClass::StorageClass
-addressSpaceToStorageClass(unsigned AddrSpace);
+addressSpaceToStorageClass(unsigned AddrSpace, const SPIRVSubtarget &STI);
SPIRV::MemorySemantics::MemorySemantics
getMemSemanticsForStorageClass(SPIRV::StorageClass::StorageClass SC);