aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target/SPIRV/SPIRVCBufferAccess.cpp
blob: f7fb886e7391db36fe5c2700d3afa8ca29b1b740 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
//===- SPIRVCBufferAccess.cpp - Translate CBuffer Loads ---------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This pass replaces all accesses to constant buffer global variables with
// accesses to the proper SPIR-V resource.
//
// The pass operates as follows:
// 1. It finds all constant buffers by looking for the `!hlsl.cbs` metadata.
// 2. For each cbuffer, it finds the global variable holding the resource handle
//    and the global variables for each of the cbuffer's members.
// 3. For each member variable, it creates a call to the
//    `llvm.spv.resource.getpointer` intrinsic. This intrinsic takes the
//    resource handle and the member's index within the cbuffer as arguments.
//    The result is a pointer to that member within the SPIR-V resource.
// 4. It then replaces all uses of the original member global variable with the
//    pointer returned by the `getpointer` intrinsic. This effectively retargets
//    all loads and GEPs to the new resource pointer.
// 5. Finally, it cleans up by deleting the original global variables and the
//    `!hlsl.cbs` metadata.
//
// This approach allows subsequent passes, like SPIRVEmitIntrinsics, to
// correctly handle GEPs that operate on the result of the `getpointer` call,
// folding them into a single OpAccessChain instruction.
//
//===----------------------------------------------------------------------===//

#include "SPIRVCBufferAccess.h"
#include "SPIRV.h"
#include "llvm/Frontend/HLSL/CBuffer.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/IntrinsicsSPIRV.h"
#include "llvm/IR/Module.h"

#define DEBUG_TYPE "spirv-cbuffer-access"
using namespace llvm;

// Finds the single instruction that defines the resource handle. This is
// typically a call to `llvm.spv.resource.handlefrombinding`.
static Instruction *findHandleDef(GlobalVariable *HandleVar) {
  for (User *U : HandleVar->users()) {
    if (auto *SI = dyn_cast<StoreInst>(U)) {
      if (auto *I = dyn_cast<Instruction>(SI->getValueOperand())) {
        return I;
      }
    }
  }
  return nullptr;
}

static bool replaceCBufferAccesses(Module &M) {
  std::optional<hlsl::CBufferMetadata> CBufMD = hlsl::CBufferMetadata::get(M);
  if (!CBufMD)
    return false;

  for (const hlsl::CBufferMapping &Mapping : *CBufMD) {
    Instruction *HandleDef = findHandleDef(Mapping.Handle);
    if (!HandleDef) {
      report_fatal_error("Could not find handle definition for cbuffer: " +
                         Mapping.Handle->getName());
    }

    // The handle definition should dominate all uses of the cbuffer members.
    // We'll insert our getpointer calls right after it.
    IRBuilder<> Builder(HandleDef->getNextNode());

    for (uint32_t Index = 0; Index < Mapping.Members.size(); ++Index) {
      GlobalVariable *MemberGV = Mapping.Members[Index].GV;
      if (MemberGV->use_empty()) {
        continue;
      }

      // Create the getpointer intrinsic call.
      Value *IndexVal = Builder.getInt32(Index);
      Type *PtrType = MemberGV->getType();
      Value *GetPointerCall = Builder.CreateIntrinsic(
          PtrType, Intrinsic::spv_resource_getpointer, {HandleDef, IndexVal});

      // We cannot use replaceAllUsesWith here because some uses may be
      // ConstantExprs, which cannot be replaced with non-constants.
      SmallVector<User *, 4> Users(MemberGV->users());
      for (User *U : Users) {
        U->replaceUsesOfWith(MemberGV, GetPointerCall);
      }
    }
  }

  // Now that all uses are replaced, clean up the globals and metadata.
  for (const hlsl::CBufferMapping &Mapping : *CBufMD) {
    for (const auto &Member : Mapping.Members) {
      Member.GV->eraseFromParent();
    }
    // Erase the stores to the handle variable before erasing the handle itself.
    SmallVector<Instruction *, 4> HandleStores;
    for (User *U : Mapping.Handle->users()) {
      if (auto *SI = dyn_cast<StoreInst>(U)) {
        HandleStores.push_back(SI);
      }
    }
    for (Instruction *I : HandleStores) {
      I->eraseFromParent();
    }
    Mapping.Handle->eraseFromParent();
  }

  CBufMD->eraseFromModule();
  return true;
}

PreservedAnalyses SPIRVCBufferAccess::run(Module &M,
                                          ModuleAnalysisManager &AM) {
  if (replaceCBufferAccesses(M)) {
    return PreservedAnalyses::none();
  }
  return PreservedAnalyses::all();
}

namespace {
class SPIRVCBufferAccessLegacy : public ModulePass {
public:
  bool runOnModule(Module &M) override { return replaceCBufferAccesses(M); }
  StringRef getPassName() const override { return "SPIRV CBuffer Access"; }
  SPIRVCBufferAccessLegacy() : ModulePass(ID) {}

  static char ID; // Pass identification.
};
char SPIRVCBufferAccessLegacy::ID = 0;
} // end anonymous namespace

INITIALIZE_PASS(SPIRVCBufferAccessLegacy, DEBUG_TYPE, "SPIRV CBuffer Access",
                false, false)

ModulePass *llvm::createSPIRVCBufferAccessLegacyPass() {
  return new SPIRVCBufferAccessLegacy();
}