aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target/SPIRV/SPIRVPushConstantAccess.cpp
blob: 8ec05271fc8fbfb2cb15d352f1b20d9f05df8605 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
//===- SPIRVPushConstantAccess.cpp - Translate CBuffer Loads ----*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This pass changes the types of all the globals in the PushConstant
// address space into a target extension type, and makes all references
// to this global go though a custom SPIR-V intrinsic.
//
// This allows the backend to properly lower the push constant struct type
// to a fully laid out type, and generate the proper OpAccessChain.
//
//===----------------------------------------------------------------------===//

#include "SPIRVPushConstantAccess.h"
#include "SPIRV.h"
#include "SPIRVSubtarget.h"
#include "SPIRVTargetMachine.h"
#include "SPIRVUtils.h"
#include "llvm/Frontend/HLSL/CBuffer.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/IntrinsicsSPIRV.h"
#include "llvm/IR/Module.h"

#define DEBUG_TYPE "spirv-pushconstant-access"
using namespace llvm;

static bool replacePushConstantAccesses(Module &M, SPIRVGlobalRegistry *GR) {
  bool Changed = false;
  for (GlobalVariable &GV : make_early_inc_range(M.globals())) {
    if (GV.getAddressSpace() !=
        storageClassToAddressSpace(SPIRV::StorageClass::PushConstant))
      continue;

    GV.removeDeadConstantUsers();

    Type *PCType = llvm::TargetExtType::get(
        M.getContext(), "spirv.PushConstant", {GV.getValueType()});
    GlobalVariable *NewGV =
        new GlobalVariable(M, PCType, GV.isConstant(), GV.getLinkage(),
                           /* initializer= */ nullptr, GV.getName(),
                           /* InsertBefore= */ &GV, GV.getThreadLocalMode(),
                           GV.getAddressSpace(), GV.isExternallyInitialized());
    NewGV->setVisibility(GV.getVisibility());

    for (User *U : make_early_inc_range(GV.users())) {
      Instruction *I = cast<Instruction>(U);
      IRBuilder<> Builder(I);
      Value *GetPointerCall = Builder.CreateIntrinsic(
          NewGV->getType(), Intrinsic::spv_pushconstant_getpointer, {NewGV});
      GR->buildAssignPtr(Builder, GV.getValueType(), GetPointerCall);

      I->replaceUsesOfWith(&GV, GetPointerCall);
    }

    GV.eraseFromParent();
    Changed = true;
  }

  return Changed;
}

PreservedAnalyses SPIRVPushConstantAccess::run(Module &M,
                                               ModuleAnalysisManager &AM) {
  const SPIRVSubtarget *ST = TM.getSubtargetImpl();
  SPIRVGlobalRegistry *GR = ST->getSPIRVGlobalRegistry();
  return replacePushConstantAccesses(M, GR) ? PreservedAnalyses::none()
                                            : PreservedAnalyses::all();
}

namespace {
class SPIRVPushConstantAccessLegacy : public ModulePass {
  SPIRVTargetMachine *TM = nullptr;

public:
  bool runOnModule(Module &M) override {
    const SPIRVSubtarget *ST = TM->getSubtargetImpl();
    SPIRVGlobalRegistry *GR = ST->getSPIRVGlobalRegistry();
    return replacePushConstantAccesses(M, GR);
  }
  StringRef getPassName() const override {
    return "SPIRV push constant Access";
  }
  SPIRVPushConstantAccessLegacy(SPIRVTargetMachine *TM)
      : ModulePass(ID), TM(TM) {}

  static char ID; // Pass identification.
};
char SPIRVPushConstantAccessLegacy::ID = 0;
} // end anonymous namespace

INITIALIZE_PASS(SPIRVPushConstantAccessLegacy, DEBUG_TYPE,
                "SPIRV push constant Access", false, false)

ModulePass *
llvm::createSPIRVPushConstantAccessLegacyPass(SPIRVTargetMachine *TM) {
  return new SPIRVPushConstantAccessLegacy(TM);
}