aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/ExecutionEngine/VulkanRuntimeWrappers.cpp
blob: ddea230addd40bf66171ef7f54c64f615a4e67c7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
//===- VulkanRuntimeWrappers.cpp - MLIR Vulkan runner wrapper library -----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Implements C runtime wrappers around the VulkanRuntime.
//
//===----------------------------------------------------------------------===//

#include <iostream>
#include <mutex>
#include <numeric>
#include <string>
#include <vector>

#include "VulkanRuntime.h"

// Explicitly export entry points to the vulkan-runtime-wrapper.

#ifdef _WIN32
#define VULKAN_WRAPPER_SYMBOL_EXPORT __declspec(dllexport)
#else
#define VULKAN_WRAPPER_SYMBOL_EXPORT __attribute__((visibility("default")))
#endif // _WIN32

namespace {

class VulkanModule;

// Class to be a thing that can be returned from `mgpuModuleGetFunction`.
struct VulkanFunction {
  VulkanModule *module;
  std::string name;

  VulkanFunction(VulkanModule *module, const char *name)
      : module(module), name(name) {}
};

// Class to own a copy of the SPIR-V provided to `mgpuModuleLoad` and to manage
// allocation of pointers returned from `mgpuModuleGetFunction`.
class VulkanModule {
public:
  VulkanModule(const uint8_t *ptr, size_t sizeInBytes)
      : blob(ptr, ptr + sizeInBytes) {}
  ~VulkanModule() = default;

  VulkanFunction *getFunction(const char *name) {
    return functions.emplace_back(std::make_unique<VulkanFunction>(this, name))
        .get();
  }

  uint8_t *blobData() { return blob.data(); }
  size_t blobSizeInBytes() const { return blob.size(); }

private:
  std::vector<uint8_t> blob;
  std::vector<std::unique_ptr<VulkanFunction>> functions;
};

class VulkanRuntimeManager {
public:
  VulkanRuntimeManager() = default;
  VulkanRuntimeManager(const VulkanRuntimeManager &) = delete;
  VulkanRuntimeManager operator=(const VulkanRuntimeManager &) = delete;
  ~VulkanRuntimeManager() = default;

  void setResourceData(DescriptorSetIndex setIndex, BindingIndex bindIndex,
                       const VulkanHostMemoryBuffer &memBuffer) {
    std::lock_guard<std::mutex> lock(mutex);
    vulkanRuntime.setResourceData(setIndex, bindIndex, memBuffer);
  }

  void setEntryPoint(const char *entryPoint) {
    std::lock_guard<std::mutex> lock(mutex);
    vulkanRuntime.setEntryPoint(entryPoint);
  }

  void setNumWorkGroups(NumWorkGroups numWorkGroups) {
    std::lock_guard<std::mutex> lock(mutex);
    vulkanRuntime.setNumWorkGroups(numWorkGroups);
  }

  void setShaderModule(uint8_t *shader, uint32_t size) {
    std::lock_guard<std::mutex> lock(mutex);
    vulkanRuntime.setShaderModule(shader, size);
  }

  void runOnVulkan() {
    std::lock_guard<std::mutex> lock(mutex);
    if (failed(vulkanRuntime.initRuntime()) || failed(vulkanRuntime.run()) ||
        failed(vulkanRuntime.updateHostMemoryBuffers()) ||
        failed(vulkanRuntime.destroy())) {
      std::cerr << "runOnVulkan failed";
    }
  }

private:
  VulkanRuntime vulkanRuntime;
  std::mutex mutex;
};

} // namespace

template <typename T, int N>
struct MemRefDescriptor {
  T *allocated;
  T *aligned;
  int64_t offset;
  int64_t sizes[N];
  int64_t strides[N];
};

extern "C" {

//===----------------------------------------------------------------------===//
//
// Wrappers intended for mlir-runner. Uses of GPU dialect operations get
// lowered to calls to these functions by GPUToLLVMConversionPass.
//
//===----------------------------------------------------------------------===//

VULKAN_WRAPPER_SYMBOL_EXPORT void *mgpuStreamCreate() {
  return new VulkanRuntimeManager();
}

VULKAN_WRAPPER_SYMBOL_EXPORT void mgpuStreamDestroy(void *vkRuntimeManager) {
  delete static_cast<VulkanRuntimeManager *>(vkRuntimeManager);
}

VULKAN_WRAPPER_SYMBOL_EXPORT void mgpuStreamSynchronize(void *) {
  // Currently a no-op as the other operations are synchronous.
}

VULKAN_WRAPPER_SYMBOL_EXPORT void *mgpuModuleLoad(const void *data,
                                                  size_t gpuBlobSize) {
  // gpuBlobSize is the size of the data in bytes.
  return new VulkanModule(static_cast<const uint8_t *>(data), gpuBlobSize);
}

VULKAN_WRAPPER_SYMBOL_EXPORT void mgpuModuleUnload(void *vkModule) {
  delete static_cast<VulkanModule *>(vkModule);
}

VULKAN_WRAPPER_SYMBOL_EXPORT void *mgpuModuleGetFunction(void *vkModule,
                                                         const char *name) {
  if (!vkModule)
    abort();
  return static_cast<VulkanModule *>(vkModule)->getFunction(name);
}

VULKAN_WRAPPER_SYMBOL_EXPORT void
mgpuLaunchKernel(void *vkKernel, size_t gridX, size_t gridY, size_t gridZ,
                 size_t /*blockX*/, size_t /*blockY*/, size_t /*blockZ*/,
                 size_t /*smem*/, void *vkRuntimeManager, void **params,
                 void ** /*extra*/, size_t paramsCount) {
  auto manager = static_cast<VulkanRuntimeManager *>(vkRuntimeManager);

  // GpuToLLVMConversionPass with the kernelBarePtrCallConv and
  // kernelIntersperseSizeCallConv options will set up the params array like:
  // { &memref_ptr0, &memref_size0, &memref_ptr1, &memref_size1, ... }
  const size_t paramsPerMemRef = 2;
  if (paramsCount % paramsPerMemRef != 0) {
    abort(); // This would indicate a serious calling convention mismatch.
  }
  const DescriptorSetIndex setIndex = 0;
  BindingIndex bindIndex = 0;
  for (size_t i = 0; i < paramsCount; i += paramsPerMemRef) {
    void *memrefBufferBasePtr = *static_cast<void **>(params[i + 0]);
    size_t memrefBufferSize = *static_cast<size_t *>(params[i + 1]);
    VulkanHostMemoryBuffer memBuffer{memrefBufferBasePtr,
                                     static_cast<uint32_t>(memrefBufferSize)};
    manager->setResourceData(setIndex, bindIndex, memBuffer);
    ++bindIndex;
  }

  manager->setNumWorkGroups(NumWorkGroups{static_cast<uint32_t>(gridX),
                                          static_cast<uint32_t>(gridY),
                                          static_cast<uint32_t>(gridZ)});

  auto function = static_cast<VulkanFunction *>(vkKernel);
  // Expected size should be in bytes.
  manager->setShaderModule(
      function->module->blobData(),
      static_cast<uint32_t>(function->module->blobSizeInBytes()));
  manager->setEntryPoint(function->name.c_str());

  manager->runOnVulkan();
}

//===----------------------------------------------------------------------===//
//
// Miscellaneous utility functions that can be directly used by tests.
//
//===----------------------------------------------------------------------===//

/// Fills the given 1D float memref with the given float value.
VULKAN_WRAPPER_SYMBOL_EXPORT void
_mlir_ciface_fillResource1DFloat(MemRefDescriptor<float, 1> *ptr, // NOLINT
                                 float value) {
  std::fill_n(ptr->allocated, ptr->sizes[0], value);
}

/// Fills the given 2D float memref with the given float value.
VULKAN_WRAPPER_SYMBOL_EXPORT void
_mlir_ciface_fillResource2DFloat(MemRefDescriptor<float, 2> *ptr, // NOLINT
                                 float value) {
  std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1], value);
}

/// Fills the given 3D float memref with the given float value.
VULKAN_WRAPPER_SYMBOL_EXPORT void
_mlir_ciface_fillResource3DFloat(MemRefDescriptor<float, 3> *ptr, // NOLINT
                                 float value) {
  std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2],
              value);
}

/// Fills the given 1D int memref with the given int value.
VULKAN_WRAPPER_SYMBOL_EXPORT void
_mlir_ciface_fillResource1DInt(MemRefDescriptor<int32_t, 1> *ptr, // NOLINT
                               int32_t value) {
  std::fill_n(ptr->allocated, ptr->sizes[0], value);
}

/// Fills the given 2D int memref with the given int value.
VULKAN_WRAPPER_SYMBOL_EXPORT void
_mlir_ciface_fillResource2DInt(MemRefDescriptor<int32_t, 2> *ptr, // NOLINT
                               int32_t value) {
  std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1], value);
}

/// Fills the given 3D int memref with the given int value.
VULKAN_WRAPPER_SYMBOL_EXPORT void
_mlir_ciface_fillResource3DInt(MemRefDescriptor<int32_t, 3> *ptr, // NOLINT
                               int32_t value) {
  std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2],
              value);
}

/// Fills the given 1D int memref with the given int8 value.
VULKAN_WRAPPER_SYMBOL_EXPORT void
_mlir_ciface_fillResource1DInt8(MemRefDescriptor<int8_t, 1> *ptr, // NOLINT
                                int8_t value) {
  std::fill_n(ptr->allocated, ptr->sizes[0], value);
}

/// Fills the given 2D int memref with the given int8 value.
VULKAN_WRAPPER_SYMBOL_EXPORT void
_mlir_ciface_fillResource2DInt8(MemRefDescriptor<int8_t, 2> *ptr, // NOLINT
                                int8_t value) {
  std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1], value);
}

/// Fills the given 3D int memref with the given int8 value.
VULKAN_WRAPPER_SYMBOL_EXPORT void
_mlir_ciface_fillResource3DInt8(MemRefDescriptor<int8_t, 3> *ptr, // NOLINT
                                int8_t value) {
  std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2],
              value);
}
}