aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp')
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp17
1 files changed, 9 insertions, 8 deletions
diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index cc438b2..10569ef 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -150,7 +150,8 @@ getKernelArgTypeQual(const Function &F, unsigned ArgIdx) {
static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
SPIRVGlobalRegistry *GR,
- MachineIRBuilder &MIRBuilder) {
+ MachineIRBuilder &MIRBuilder,
+ const SPIRVSubtarget &ST) {
// Read argument's access qualifier from metadata or default.
SPIRV::AccessQualifier::AccessQualifier ArgAccessQual =
getArgAccessQual(F, ArgIdx);
@@ -169,8 +170,8 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
if (MDTypeStr.ends_with("*"))
ResArgType = GR->getOrCreateSPIRVTypeByName(
MDTypeStr, MIRBuilder,
- addressSpaceToStorageClass(
- OriginalArgType->getPointerAddressSpace()));
+ addressSpaceToStorageClass(OriginalArgType->getPointerAddressSpace(),
+ ST));
else if (MDTypeStr.ends_with("_t"))
ResArgType = GR->getOrCreateSPIRVTypeByName(
"opencl." + MDTypeStr.str(), MIRBuilder,
@@ -206,6 +207,10 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
assert(GR && "Must initialize the SPIRV type registry before lowering args.");
GR->setCurrentFunc(MIRBuilder.getMF());
+ // Get access to information about available extensions
+ const SPIRVSubtarget *ST =
+ static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget());
+
// Assign types and names to all args, and store their types for later.
FunctionType *FTy = getOriginalFunctionType(F);
SmallVector<SPIRVType *, 4> ArgTypeVRegs;
@@ -216,7 +221,7 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
// TODO: handle the case of multiple registers.
if (VRegs[i].size() > 1)
return false;
- auto *SpirvTy = getArgSPIRVType(F, i, GR, MIRBuilder);
+ auto *SpirvTy = getArgSPIRVType(F, i, GR, MIRBuilder, *ST);
GR->assignSPIRVTypeToVReg(SpirvTy, VRegs[i][0], MIRBuilder.getMF());
ArgTypeVRegs.push_back(SpirvTy);
@@ -318,10 +323,6 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
if (F.hasName())
buildOpName(FuncVReg, F.getName(), MIRBuilder);
- // Get access to information about available extensions
- const auto *ST =
- static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget());
-
// Handle entry points and function linkage.
if (isEntryPoint(F)) {
const auto &STI = MIRBuilder.getMF().getSubtarget<SPIRVSubtarget>();