//===- SPIRVCBufferAccess.cpp - Translate CBuffer Loads ---------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This pass replaces all accesses to constant buffer global variables with // accesses to the proper SPIR-V resource. // // The pass operates as follows: // 1. It finds all constant buffers by looking for the `!hlsl.cbs` metadata. // 2. For each cbuffer, it finds the global variable holding the resource handle // and the global variables for each of the cbuffer's members. // 3. For each member variable, it creates a call to the // `llvm.spv.resource.getpointer` intrinsic. This intrinsic takes the // resource handle and the member's index within the cbuffer as arguments. // The result is a pointer to that member within the SPIR-V resource. // 4. It then replaces all uses of the original member global variable with the // pointer returned by the `getpointer` intrinsic. This effectively retargets // all loads and GEPs to the new resource pointer. // 5. Finally, it cleans up by deleting the original global variables and the // `!hlsl.cbs` metadata. // // This approach allows subsequent passes, like SPIRVEmitIntrinsics, to // correctly handle GEPs that operate on the result of the `getpointer` call, // folding them into a single OpAccessChain instruction. // //===----------------------------------------------------------------------===// #include "SPIRVCBufferAccess.h" #include "SPIRV.h" #include "llvm/Frontend/HLSL/CBuffer.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/IntrinsicsSPIRV.h" #include "llvm/IR/Module.h" #define DEBUG_TYPE "spirv-cbuffer-access" using namespace llvm; // Finds the single instruction that defines the resource handle. This is // typically a call to `llvm.spv.resource.handlefrombinding`. static Instruction *findHandleDef(GlobalVariable *HandleVar) { for (User *U : HandleVar->users()) { if (auto *SI = dyn_cast(U)) { if (auto *I = dyn_cast(SI->getValueOperand())) { return I; } } } return nullptr; } static bool replaceCBufferAccesses(Module &M) { std::optional CBufMD = hlsl::CBufferMetadata::get(M); if (!CBufMD) return false; for (const hlsl::CBufferMapping &Mapping : *CBufMD) { Instruction *HandleDef = findHandleDef(Mapping.Handle); if (!HandleDef) { report_fatal_error("Could not find handle definition for cbuffer: " + Mapping.Handle->getName()); } // The handle definition should dominate all uses of the cbuffer members. // We'll insert our getpointer calls right after it. IRBuilder<> Builder(HandleDef->getNextNode()); for (uint32_t Index = 0; Index < Mapping.Members.size(); ++Index) { GlobalVariable *MemberGV = Mapping.Members[Index].GV; if (MemberGV->use_empty()) { continue; } // Create the getpointer intrinsic call. Value *IndexVal = Builder.getInt32(Index); Type *PtrType = MemberGV->getType(); Value *GetPointerCall = Builder.CreateIntrinsic( PtrType, Intrinsic::spv_resource_getpointer, {HandleDef, IndexVal}); // We cannot use replaceAllUsesWith here because some uses may be // ConstantExprs, which cannot be replaced with non-constants. SmallVector Users(MemberGV->users()); for (User *U : Users) { U->replaceUsesOfWith(MemberGV, GetPointerCall); } } } // Now that all uses are replaced, clean up the globals and metadata. for (const hlsl::CBufferMapping &Mapping : *CBufMD) { for (const auto &Member : Mapping.Members) { Member.GV->eraseFromParent(); } // Erase the stores to the handle variable before erasing the handle itself. SmallVector HandleStores; for (User *U : Mapping.Handle->users()) { if (auto *SI = dyn_cast(U)) { HandleStores.push_back(SI); } } for (Instruction *I : HandleStores) { I->eraseFromParent(); } Mapping.Handle->eraseFromParent(); } CBufMD->eraseFromModule(); return true; } PreservedAnalyses SPIRVCBufferAccess::run(Module &M, ModuleAnalysisManager &AM) { if (replaceCBufferAccesses(M)) { return PreservedAnalyses::none(); } return PreservedAnalyses::all(); } namespace { class SPIRVCBufferAccessLegacy : public ModulePass { public: bool runOnModule(Module &M) override { return replaceCBufferAccesses(M); } StringRef getPassName() const override { return "SPIRV CBuffer Access"; } SPIRVCBufferAccessLegacy() : ModulePass(ID) {} static char ID; // Pass identification. }; char SPIRVCBufferAccessLegacy::ID = 0; } // end anonymous namespace INITIALIZE_PASS(SPIRVCBufferAccessLegacy, DEBUG_TYPE, "SPIRV CBuffer Access", false, false) ModulePass *llvm::createSPIRVCBufferAccessLegacyPass() { return new SPIRVCBufferAccessLegacy(); }