diff options
Diffstat (limited to 'llvm/lib/Target/SPIRV')
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp | 11 | ||||
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVCBufferAccess.cpp | 14 | ||||
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp | 14 | ||||
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 14 | ||||
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h | 14 | ||||
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp | 71 | ||||
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp | 6 | ||||
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.h | 4 | ||||
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVUtils.cpp | 15 | ||||
-rw-r--r-- | llvm/lib/Target/SPIRV/SPIRVUtils.h | 3 |
10 files changed, 87 insertions, 79 deletions
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp index dbe8e18..d91923b 100644 --- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp @@ -507,7 +507,9 @@ static Register buildLoadInst(SPIRVType *BaseType, Register PtrRegister, static Register buildBuiltinVariableLoad( MachineIRBuilder &MIRBuilder, SPIRVType *VariableType, SPIRVGlobalRegistry *GR, SPIRV::BuiltIn::BuiltIn BuiltinValue, LLT LLType, - Register Reg = Register(0), bool isConst = true, bool hasLinkageTy = true) { + Register Reg = Register(0), bool isConst = true, + const std::optional<SPIRV::LinkageType::LinkageType> &LinkageTy = { + SPIRV::LinkageType::Import}) { Register NewRegister = MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::pIDRegClass); MIRBuilder.getMRI()->setType( @@ -521,9 +523,8 @@ static Register buildBuiltinVariableLoad( // Set up the global OpVariable with the necessary builtin decorations. Register Variable = GR->buildGlobalVariable( NewRegister, PtrType, getLinkStringForBuiltIn(BuiltinValue), nullptr, - SPIRV::StorageClass::Input, nullptr, /* isConst= */ isConst, - /* HasLinkageTy */ hasLinkageTy, SPIRV::LinkageType::Import, MIRBuilder, - false); + SPIRV::StorageClass::Input, nullptr, /* isConst= */ isConst, LinkageTy, + MIRBuilder, false); // Load the value from the global variable. Register LoadedRegister = @@ -1851,7 +1852,7 @@ static bool generateWaveInst(const SPIRV::IncomingCall *Call, return buildBuiltinVariableLoad( MIRBuilder, Call->ReturnType, GR, Value, LLType, Call->ReturnRegister, - /* isConst= */ false, /* hasLinkageTy= */ false); + /* isConst= */ false, /* LinkageType= */ std::nullopt); } // We expect a builtin diff --git a/llvm/lib/Target/SPIRV/SPIRVCBufferAccess.cpp b/llvm/lib/Target/SPIRV/SPIRVCBufferAccess.cpp index f7fb886..3ca0b40 100644 --- a/llvm/lib/Target/SPIRV/SPIRVCBufferAccess.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVCBufferAccess.cpp @@ -35,6 +35,7 @@ #include "llvm/IR/IRBuilder.h" #include "llvm/IR/IntrinsicsSPIRV.h" #include "llvm/IR/Module.h" +#include "llvm/IR/ReplaceConstant.h" #define DEBUG_TYPE "spirv-cbuffer-access" using namespace llvm; @@ -57,6 +58,12 @@ static bool replaceCBufferAccesses(Module &M) { if (!CBufMD) return false; + SmallVector<Constant *> CBufferGlobals; + for (const hlsl::CBufferMapping &Mapping : *CBufMD) + for (const hlsl::CBufferMember &Member : Mapping.Members) + CBufferGlobals.push_back(Member.GV); + convertUsersOfConstantsToInstructions(CBufferGlobals); + for (const hlsl::CBufferMapping &Mapping : *CBufMD) { Instruction *HandleDef = findHandleDef(Mapping.Handle); if (!HandleDef) { @@ -80,12 +87,7 @@ static bool replaceCBufferAccesses(Module &M) { Value *GetPointerCall = Builder.CreateIntrinsic( PtrType, Intrinsic::spv_resource_getpointer, {HandleDef, IndexVal}); - // We cannot use replaceAllUsesWith here because some uses may be - // ConstantExprs, which cannot be replaced with non-constants. - SmallVector<User *, 4> Users(MemberGV->users()); - for (User *U : Users) { - U->replaceUsesOfWith(MemberGV, GetPointerCall); - } + MemberGV->replaceAllUsesWith(GetPointerCall); } } diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp index 1a7c02c..9e11c3a 100644 --- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp @@ -479,19 +479,9 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder, .addImm(static_cast<uint32_t>(getExecutionModel(*ST, F))) .addUse(FuncVReg); addStringImm(F.getName(), MIB); - } else if (F.getLinkage() != GlobalValue::InternalLinkage && - F.getLinkage() != GlobalValue::PrivateLinkage && - F.getVisibility() != GlobalValue::HiddenVisibility) { - SPIRV::LinkageType::LinkageType LnkTy = - F.isDeclaration() - ? SPIRV::LinkageType::Import - : (F.getLinkage() == GlobalValue::LinkOnceODRLinkage && - ST->canUseExtension( - SPIRV::Extension::SPV_KHR_linkonce_odr) - ? SPIRV::LinkageType::LinkOnceODR - : SPIRV::LinkageType::Export); + } else if (const auto LnkTy = getSpirvLinkageTypeFor(*ST, F)) { buildOpDecorate(FuncVReg, MIRBuilder, SPIRV::Decoration::LinkageAttributes, - {static_cast<uint32_t>(LnkTy)}, F.getName()); + {static_cast<uint32_t>(*LnkTy)}, F.getName()); } // Handle function pointers decoration diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp index 6fd1c7e..6181abb 100644 --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp @@ -712,9 +712,9 @@ SPIRVGlobalRegistry::buildConstantSampler(Register ResReg, unsigned AddrMode, Register SPIRVGlobalRegistry::buildGlobalVariable( Register ResVReg, SPIRVType *BaseType, StringRef Name, const GlobalValue *GV, SPIRV::StorageClass::StorageClass Storage, - const MachineInstr *Init, bool IsConst, bool HasLinkageTy, - SPIRV::LinkageType::LinkageType LinkageType, MachineIRBuilder &MIRBuilder, - bool IsInstSelector) { + const MachineInstr *Init, bool IsConst, + const std::optional<SPIRV::LinkageType::LinkageType> &LinkageType, + MachineIRBuilder &MIRBuilder, bool IsInstSelector) { const GlobalVariable *GVar = nullptr; if (GV) { GVar = cast<const GlobalVariable>(GV); @@ -792,9 +792,9 @@ Register SPIRVGlobalRegistry::buildGlobalVariable( buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::Alignment, {Alignment}); } - if (HasLinkageTy) + if (LinkageType) buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::LinkageAttributes, - {static_cast<uint32_t>(LinkageType)}, Name); + {static_cast<uint32_t>(*LinkageType)}, Name); SPIRV::BuiltIn::BuiltIn BuiltInId; if (getSpirvBuiltInIdByName(Name, BuiltInId)) @@ -821,8 +821,8 @@ Register SPIRVGlobalRegistry::getOrCreateGlobalVariableWithBinding( MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::iIDRegClass); buildGlobalVariable(VarReg, VarType, Name, nullptr, - getPointerStorageClass(VarType), nullptr, false, false, - SPIRV::LinkageType::Import, MIRBuilder, false); + getPointerStorageClass(VarType), nullptr, false, + std::nullopt, MIRBuilder, false); buildOpDecorate(VarReg, MIRBuilder, SPIRV::Decoration::DescriptorSet, {Set}); buildOpDecorate(VarReg, MIRBuilder, SPIRV::Decoration::Binding, {Binding}); diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h index a648def..c230e62 100644 --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h @@ -548,14 +548,12 @@ public: MachineIRBuilder &MIRBuilder); Register getOrCreateUndef(MachineInstr &I, SPIRVType *SpvType, const SPIRVInstrInfo &TII); - Register buildGlobalVariable(Register Reg, SPIRVType *BaseType, - StringRef Name, const GlobalValue *GV, - SPIRV::StorageClass::StorageClass Storage, - const MachineInstr *Init, bool IsConst, - bool HasLinkageTy, - SPIRV::LinkageType::LinkageType LinkageType, - MachineIRBuilder &MIRBuilder, - bool IsInstSelector); + Register buildGlobalVariable( + Register Reg, SPIRVType *BaseType, StringRef Name, const GlobalValue *GV, + SPIRV::StorageClass::StorageClass Storage, const MachineInstr *Init, + bool IsConst, + const std::optional<SPIRV::LinkageType::LinkageType> &LinkageType, + MachineIRBuilder &MIRBuilder, bool IsInstSelector); Register getOrCreateGlobalVariableWithBinding(const SPIRVType *VarType, uint32_t Set, uint32_t Binding, StringRef Name, diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index a0cff4d..021353a 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -355,9 +355,9 @@ private: SPIRVType *widenTypeToVec4(const SPIRVType *Type, MachineInstr &I) const; bool extractSubvector(Register &ResVReg, const SPIRVType *ResType, Register &ReadReg, MachineInstr &InsertionPoint) const; - bool generateImageRead(Register &ResVReg, const SPIRVType *ResType, - Register ImageReg, Register IdxReg, DebugLoc Loc, - MachineInstr &Pos) const; + bool generateImageReadOrFetch(Register &ResVReg, const SPIRVType *ResType, + Register ImageReg, Register IdxReg, + DebugLoc Loc, MachineInstr &Pos) const; bool BuildCOPY(Register DestReg, Register SrcReg, MachineInstr &I) const; bool loadVec3BuiltinInputID(SPIRV::BuiltIn::BuiltIn BuiltInValue, Register ResVReg, const SPIRVType *ResType, @@ -1321,8 +1321,8 @@ bool SPIRVInstructionSelector::selectLoad(Register ResVReg, } Register IdxReg = IntPtrDef->getOperand(3).getReg(); - return generateImageRead(ResVReg, ResType, NewHandleReg, IdxReg, - I.getDebugLoc(), I); + return generateImageReadOrFetch(ResVReg, ResType, NewHandleReg, IdxReg, + I.getDebugLoc(), I); } } @@ -3639,27 +3639,33 @@ bool SPIRVInstructionSelector::selectReadImageIntrinsic( DebugLoc Loc = I.getDebugLoc(); MachineInstr &Pos = I; - return generateImageRead(ResVReg, ResType, NewImageReg, IdxReg, Loc, Pos); + return generateImageReadOrFetch(ResVReg, ResType, NewImageReg, IdxReg, Loc, + Pos); } -bool SPIRVInstructionSelector::generateImageRead(Register &ResVReg, - const SPIRVType *ResType, - Register ImageReg, - Register IdxReg, DebugLoc Loc, - MachineInstr &Pos) const { +bool SPIRVInstructionSelector::generateImageReadOrFetch( + Register &ResVReg, const SPIRVType *ResType, Register ImageReg, + Register IdxReg, DebugLoc Loc, MachineInstr &Pos) const { SPIRVType *ImageType = GR.getSPIRVTypeForVReg(ImageReg); assert(ImageType && ImageType->getOpcode() == SPIRV::OpTypeImage && "ImageReg is not an image type."); + bool IsSignedInteger = sampledTypeIsSignedInteger(GR.getTypeForSPIRVType(ImageType)); + // Check if the "sampled" operand of the image type is 1. + // https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpImageFetch + auto SampledOp = ImageType->getOperand(6); + bool IsFetch = (SampledOp.getImm() == 1); uint64_t ResultSize = GR.getScalarOrVectorComponentCount(ResType); if (ResultSize == 4) { - auto BMI = BuildMI(*Pos.getParent(), Pos, Loc, TII.get(SPIRV::OpImageRead)) - .addDef(ResVReg) - .addUse(GR.getSPIRVTypeID(ResType)) - .addUse(ImageReg) - .addUse(IdxReg); + auto BMI = + BuildMI(*Pos.getParent(), Pos, Loc, + TII.get(IsFetch ? SPIRV::OpImageFetch : SPIRV::OpImageRead)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)) + .addUse(ImageReg) + .addUse(IdxReg); if (IsSignedInteger) BMI.addImm(0x1000); // SignExtend @@ -3668,11 +3674,13 @@ bool SPIRVInstructionSelector::generateImageRead(Register &ResVReg, SPIRVType *ReadType = widenTypeToVec4(ResType, Pos); Register ReadReg = MRI->createVirtualRegister(GR.getRegClass(ReadType)); - auto BMI = BuildMI(*Pos.getParent(), Pos, Loc, TII.get(SPIRV::OpImageRead)) - .addDef(ReadReg) - .addUse(GR.getSPIRVTypeID(ReadType)) - .addUse(ImageReg) - .addUse(IdxReg); + auto BMI = + BuildMI(*Pos.getParent(), Pos, Loc, + TII.get(IsFetch ? SPIRV::OpImageFetch : SPIRV::OpImageRead)) + .addDef(ReadReg) + .addUse(GR.getSPIRVTypeID(ReadType)) + .addUse(ImageReg) + .addUse(IdxReg); if (IsSignedInteger) BMI.addImm(0x1000); // SignExtend bool Succeed = BMI.constrainAllUses(TII, TRI, RBI); @@ -4350,15 +4358,8 @@ bool SPIRVInstructionSelector::selectGlobalValue( if (hasInitializer(GlobalVar) && !Init) return true; - bool HasLnkTy = !GV->hasInternalLinkage() && !GV->hasPrivateLinkage() && - !GV->hasHiddenVisibility(); - SPIRV::LinkageType::LinkageType LnkType = - GV->isDeclarationForLinker() - ? SPIRV::LinkageType::Import - : (GV->hasLinkOnceODRLinkage() && - STI.canUseExtension(SPIRV::Extension::SPV_KHR_linkonce_odr) - ? SPIRV::LinkageType::LinkOnceODR - : SPIRV::LinkageType::Export); + const std::optional<SPIRV::LinkageType::LinkageType> LnkType = + getSpirvLinkageTypeFor(STI, *GV); const unsigned AddrSpace = GV->getAddressSpace(); SPIRV::StorageClass::StorageClass StorageClass = @@ -4366,7 +4367,7 @@ bool SPIRVInstructionSelector::selectGlobalValue( SPIRVType *ResType = GR.getOrCreateSPIRVPointerType(GVType, I, StorageClass); Register Reg = GR.buildGlobalVariable( ResVReg, ResType, GlobalIdent, GV, StorageClass, Init, - GlobalVar->isConstant(), HasLnkTy, LnkType, MIRBuilder, true); + GlobalVar->isConstant(), LnkType, MIRBuilder, true); return Reg.isValid(); } @@ -4517,8 +4518,8 @@ bool SPIRVInstructionSelector::loadVec3BuiltinInputID( // builtin variable. Register Variable = GR.buildGlobalVariable( NewRegister, PtrType, getLinkStringForBuiltIn(BuiltInValue), nullptr, - SPIRV::StorageClass::Input, nullptr, true, false, - SPIRV::LinkageType::Import, MIRBuilder, false); + SPIRV::StorageClass::Input, nullptr, true, std::nullopt, MIRBuilder, + false); // Create new register for loading value. MachineRegisterInfo *MRI = MIRBuilder.getMRI(); @@ -4570,8 +4571,8 @@ bool SPIRVInstructionSelector::loadBuiltinInputID( // builtin variable. Register Variable = GR.buildGlobalVariable( NewRegister, PtrType, getLinkStringForBuiltIn(BuiltInValue), nullptr, - SPIRV::StorageClass::Input, nullptr, true, false, - SPIRV::LinkageType::Import, MIRBuilder, false); + SPIRV::StorageClass::Input, nullptr, true, std::nullopt, MIRBuilder, + false); // Load uint value from the global variable. auto MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpLoad)) diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp index 61a0bbe..f7cdfcb 100644 --- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp @@ -547,9 +547,9 @@ void SPIRVModuleAnalysis::collectFuncNames(MachineInstr &MI, if (MI.getOpcode() == SPIRV::OpDecorate) { // If it's got Import linkage. auto Dec = MI.getOperand(1).getImm(); - if (Dec == static_cast<unsigned>(SPIRV::Decoration::LinkageAttributes)) { + if (Dec == SPIRV::Decoration::LinkageAttributes) { auto Lnk = MI.getOperand(MI.getNumOperands() - 1).getImm(); - if (Lnk == static_cast<unsigned>(SPIRV::LinkageType::Import)) { + if (Lnk == SPIRV::LinkageType::Import) { // Map imported function name to function ID register. const Function *ImportedFunc = F->getParent()->getFunction(getStringImm(MI, 2)); @@ -635,7 +635,7 @@ static void collectOtherInstr(MachineInstr &MI, SPIRV::ModuleAnalysisInfo &MAI, void SPIRVModuleAnalysis::processOtherInstrs(const Module &M) { InstrTraces IS; for (auto F = M.begin(), E = M.end(); F != E; ++F) { - if ((*F).isDeclaration()) + if (F->isDeclaration()) continue; MachineFunction *MF = MMI->getMachineFunction(*F); assert(MF); diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.h b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.h index d8376cd..2d19f6de 100644 --- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.h +++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.h @@ -169,9 +169,7 @@ struct ModuleAnalysisInfo { MCRegister getFuncReg(const Function *F) { assert(F && "Function is null"); - auto FuncPtrRegPair = FuncMap.find(F); - return FuncPtrRegPair == FuncMap.end() ? MCRegister() - : FuncPtrRegPair->second; + return FuncMap.lookup(F); } MCRegister getExtInstSetReg(unsigned SetNum) { return ExtInstSetMap[SetNum]; } InstrList &getMSInstrs(unsigned MSType) { return MS[MSType]; } diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp index 1d47c89..4e2cc88 100644 --- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp @@ -1040,4 +1040,19 @@ getFirstValidInstructionInsertPoint(MachineBasicBlock &BB) { : VarPos; } +std::optional<SPIRV::LinkageType::LinkageType> +getSpirvLinkageTypeFor(const SPIRVSubtarget &ST, const GlobalValue &GV) { + if (GV.hasLocalLinkage() || GV.hasHiddenVisibility()) + return std::nullopt; + + if (GV.isDeclarationForLinker()) + return SPIRV::LinkageType::Import; + + if (GV.hasLinkOnceODRLinkage() && + ST.canUseExtension(SPIRV::Extension::SPV_KHR_linkonce_odr)) + return SPIRV::LinkageType::LinkOnceODR; + + return SPIRV::LinkageType::Export; +} + } // namespace llvm diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h index 5777a24..99d9d40 100644 --- a/llvm/lib/Target/SPIRV/SPIRVUtils.h +++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h @@ -559,5 +559,8 @@ unsigned getArrayComponentCount(const MachineRegisterInfo *MRI, const MachineInstr *ResType); MachineBasicBlock::iterator getFirstValidInstructionInsertPoint(MachineBasicBlock &BB); + +std::optional<SPIRV::LinkageType::LinkageType> +getSpirvLinkageTypeFor(const SPIRVSubtarget &ST, const GlobalValue &GV); } // namespace llvm #endif // LLVM_LIB_TARGET_SPIRV_SPIRVUTILS_H |