aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target/SPIRV/SPIRVLegalizeImplicitBinding.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/SPIRV/SPIRVLegalizeImplicitBinding.cpp')
-rw-r--r--llvm/lib/Target/SPIRV/SPIRVLegalizeImplicitBinding.cpp159
1 files changed, 159 insertions, 0 deletions
diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizeImplicitBinding.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizeImplicitBinding.cpp
new file mode 100644
index 0000000..0398e52
--- /dev/null
+++ b/llvm/lib/Target/SPIRV/SPIRVLegalizeImplicitBinding.cpp
@@ -0,0 +1,159 @@
+//===- SPIRVLegalizeImplicitBinding.cpp - Legalize implicit bindings ----*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass legalizes the @llvm.spv.resource.handlefromimplicitbinding
+// intrinsic by replacing it with a call to
+// @llvm.spv.resource.handlefrombinding.
+//
+//===----------------------------------------------------------------------===//
+
+#include "SPIRV.h"
+#include "llvm/ADT/BitVector.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/InstVisitor.h"
+#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/IntrinsicsSPIRV.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Pass.h"
+#include <algorithm>
+#include <vector>
+
+using namespace llvm;
+
+namespace {
+class SPIRVLegalizeImplicitBinding : public ModulePass {
+public:
+ static char ID;
+ SPIRVLegalizeImplicitBinding() : ModulePass(ID) {}
+
+ bool runOnModule(Module &M) override;
+
+private:
+ void collectBindingInfo(Module &M);
+ uint32_t getAndReserveFirstUnusedBinding(uint32_t DescSet);
+ void replaceImplicitBindingCalls(Module &M);
+
+ // A map from descriptor set to a bit vector of used binding numbers.
+ std::vector<BitVector> UsedBindings;
+ // A list of all implicit binding calls, to be sorted by order ID.
+ SmallVector<CallInst *, 16> ImplicitBindingCalls;
+};
+
+struct BindingInfoCollector : public InstVisitor<BindingInfoCollector> {
+ std::vector<BitVector> &UsedBindings;
+ SmallVector<CallInst *, 16> &ImplicitBindingCalls;
+
+ BindingInfoCollector(std::vector<BitVector> &UsedBindings,
+ SmallVector<CallInst *, 16> &ImplicitBindingCalls)
+ : UsedBindings(UsedBindings), ImplicitBindingCalls(ImplicitBindingCalls) {
+ }
+
+ void visitCallInst(CallInst &CI) {
+ if (CI.getIntrinsicID() == Intrinsic::spv_resource_handlefrombinding) {
+ const uint32_t DescSet =
+ cast<ConstantInt>(CI.getArgOperand(0))->getZExtValue();
+ const uint32_t Binding =
+ cast<ConstantInt>(CI.getArgOperand(1))->getZExtValue();
+
+ if (UsedBindings.size() <= DescSet) {
+ UsedBindings.resize(DescSet + 1);
+ UsedBindings[DescSet].resize(64);
+ }
+ if (UsedBindings[DescSet].size() <= Binding) {
+ UsedBindings[DescSet].resize(2 * Binding + 1);
+ }
+ UsedBindings[DescSet].set(Binding);
+ } else if (CI.getIntrinsicID() ==
+ Intrinsic::spv_resource_handlefromimplicitbinding) {
+ ImplicitBindingCalls.push_back(&CI);
+ }
+ }
+};
+
+void SPIRVLegalizeImplicitBinding::collectBindingInfo(Module &M) {
+ BindingInfoCollector InfoCollector(UsedBindings, ImplicitBindingCalls);
+ InfoCollector.visit(M);
+
+ // Sort the collected calls by their order ID.
+ std::sort(
+ ImplicitBindingCalls.begin(), ImplicitBindingCalls.end(),
+ [](const CallInst *A, const CallInst *B) {
+ const uint32_t OrderIdArgIdx = 0;
+ const uint32_t OrderA =
+ cast<ConstantInt>(A->getArgOperand(OrderIdArgIdx))->getZExtValue();
+ const uint32_t OrderB =
+ cast<ConstantInt>(B->getArgOperand(OrderIdArgIdx))->getZExtValue();
+ return OrderA < OrderB;
+ });
+}
+
+uint32_t SPIRVLegalizeImplicitBinding::getAndReserveFirstUnusedBinding(
+ uint32_t DescSet) {
+ if (UsedBindings.size() <= DescSet) {
+ UsedBindings.resize(DescSet + 1);
+ UsedBindings[DescSet].resize(64);
+ }
+
+ int NewBinding = UsedBindings[DescSet].find_first_unset();
+ if (NewBinding == -1) {
+ NewBinding = UsedBindings[DescSet].size();
+ UsedBindings[DescSet].resize(2 * NewBinding + 1);
+ }
+
+ UsedBindings[DescSet].set(NewBinding);
+ return NewBinding;
+}
+
+void SPIRVLegalizeImplicitBinding::replaceImplicitBindingCalls(Module &M) {
+ for (CallInst *OldCI : ImplicitBindingCalls) {
+ IRBuilder<> Builder(OldCI);
+ const uint32_t DescSet =
+ cast<ConstantInt>(OldCI->getArgOperand(1))->getZExtValue();
+ const uint32_t NewBinding = getAndReserveFirstUnusedBinding(DescSet);
+
+ SmallVector<Value *, 8> Args;
+ Args.push_back(Builder.getInt32(DescSet));
+ Args.push_back(Builder.getInt32(NewBinding));
+
+ // Copy the remaining arguments from the old call.
+ for (uint32_t i = 2; i < OldCI->arg_size(); ++i) {
+ Args.push_back(OldCI->getArgOperand(i));
+ }
+
+ Function *NewFunc = Intrinsic::getOrInsertDeclaration(
+ &M, Intrinsic::spv_resource_handlefrombinding, OldCI->getType());
+ CallInst *NewCI = Builder.CreateCall(NewFunc, Args);
+ NewCI->setCallingConv(OldCI->getCallingConv());
+
+ OldCI->replaceAllUsesWith(NewCI);
+ OldCI->eraseFromParent();
+ }
+}
+
+bool SPIRVLegalizeImplicitBinding::runOnModule(Module &M) {
+ collectBindingInfo(M);
+ if (ImplicitBindingCalls.empty()) {
+ return false;
+ }
+
+ replaceImplicitBindingCalls(M);
+ return true;
+}
+} // namespace
+
+char SPIRVLegalizeImplicitBinding::ID = 0;
+
+INITIALIZE_PASS(SPIRVLegalizeImplicitBinding, "legalize-spirv-implicit-binding",
+ "Legalize SPIR-V implicit bindings", false, false)
+
+ModulePass *llvm::createSPIRVLegalizeImplicitBindingPass() {
+ return new SPIRVLegalizeImplicitBinding();
+} \ No newline at end of file