aboutsummaryrefslogtreecommitdiff
path: root/offload/liboffload/src/OffloadImpl.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'offload/liboffload/src/OffloadImpl.cpp')
-rw-r--r--offload/liboffload/src/OffloadImpl.cpp89
1 files changed, 45 insertions, 44 deletions
diff --git a/offload/liboffload/src/OffloadImpl.cpp b/offload/liboffload/src/OffloadImpl.cpp
index 08a2e25..051882d 100644
--- a/offload/liboffload/src/OffloadImpl.cpp
+++ b/offload/liboffload/src/OffloadImpl.cpp
@@ -39,12 +39,28 @@ using namespace llvm::omp::target;
using namespace llvm::omp::target::plugin;
using namespace error;
+struct ol_platform_impl_t {
+ ol_platform_impl_t(std::unique_ptr<GenericPluginTy> Plugin,
+ ol_platform_backend_t BackendType)
+ : Plugin(std::move(Plugin)), BackendType(BackendType) {}
+ std::unique_ptr<GenericPluginTy> Plugin;
+ llvm::SmallVector<std::unique_ptr<ol_device_impl_t>> Devices;
+ ol_platform_backend_t BackendType;
+
+ /// Complete all pending work for this platform and perform any needed
+ /// cleanup.
+ ///
+ /// After calling this function, no liboffload functions should be called with
+ /// this platform handle.
+ llvm::Error destroy();
+};
+
// Handle type definitions. Ideally these would be 1:1 with the plugins, but
// we add some additional data here for now to avoid churn in the plugin
// interface.
struct ol_device_impl_t {
ol_device_impl_t(int DeviceNum, GenericDeviceTy *Device,
- ol_platform_handle_t Platform, InfoTreeNode &&DevInfo)
+ ol_platform_impl_t &Platform, InfoTreeNode &&DevInfo)
: DeviceNum(DeviceNum), Device(Device), Platform(Platform),
Info(std::forward<InfoTreeNode>(DevInfo)) {}
@@ -55,7 +71,7 @@ struct ol_device_impl_t {
int DeviceNum;
GenericDeviceTy *Device;
- ol_platform_handle_t Platform;
+ ol_platform_impl_t &Platform;
InfoTreeNode Info;
llvm::SmallVector<__tgt_async_info *> OutstandingQueues;
@@ -102,31 +118,17 @@ struct ol_device_impl_t {
}
};
-struct ol_platform_impl_t {
- ol_platform_impl_t(std::unique_ptr<GenericPluginTy> Plugin,
- ol_platform_backend_t BackendType)
- : Plugin(std::move(Plugin)), BackendType(BackendType) {}
- std::unique_ptr<GenericPluginTy> Plugin;
- llvm::SmallVector<std::unique_ptr<ol_device_impl_t>> Devices;
- ol_platform_backend_t BackendType;
-
- /// Complete all pending work for this platform and perform any needed
- /// cleanup.
- ///
- /// After calling this function, no liboffload functions should be called with
- /// this platform handle.
- llvm::Error destroy() {
- llvm::Error Result = Plugin::success();
- for (auto &D : Devices)
- if (auto Err = D->destroy())
- Result = llvm::joinErrors(std::move(Result), std::move(Err));
+llvm::Error ol_platform_impl_t::destroy() {
+ llvm::Error Result = Plugin::success();
+ for (auto &D : Devices)
+ if (auto Err = D->destroy())
+ Result = llvm::joinErrors(std::move(Result), std::move(Err));
- if (auto Res = Plugin->deinit())
- Result = llvm::joinErrors(std::move(Result), std::move(Res));
+ if (auto Res = Plugin->deinit())
+ Result = llvm::joinErrors(std::move(Result), std::move(Res));
- return Result;
- }
-};
+ return Result;
+}
struct ol_queue_impl_t {
ol_queue_impl_t(__tgt_async_info *AsyncInfo, ol_device_handle_t Device)
@@ -206,12 +208,12 @@ struct OffloadContext {
// Partitioned list of memory base addresses. Each element in this list is a
// key in AllocInfoMap
llvm::SmallVector<void *> AllocBases{};
- SmallVector<ol_platform_impl_t, 4> Platforms{};
+ SmallVector<std::unique_ptr<ol_platform_impl_t>, 4> Platforms{};
size_t RefCount;
ol_device_handle_t HostDevice() {
// The host platform is always inserted last
- return Platforms.back().Devices[0].get();
+ return Platforms.back()->Devices[0].get();
}
static OffloadContext &get() {
@@ -251,35 +253,34 @@ Error initPlugins(OffloadContext &Context) {
#define PLUGIN_TARGET(Name) \
do { \
if (StringRef(#Name) != "host") \
- Context.Platforms.emplace_back(ol_platform_impl_t{ \
+ Context.Platforms.emplace_back(std::make_unique<ol_platform_impl_t>( \
std::unique_ptr<GenericPluginTy>(createPlugin_##Name()), \
- pluginNameToBackend(#Name)}); \
+ pluginNameToBackend(#Name))); \
} while (false);
#include "Shared/Targets.def"
// Preemptively initialize all devices in the plugin
for (auto &Platform : Context.Platforms) {
- auto Err = Platform.Plugin->init();
+ auto Err = Platform->Plugin->init();
[[maybe_unused]] std::string InfoMsg = toString(std::move(Err));
- for (auto DevNum = 0; DevNum < Platform.Plugin->number_of_devices();
+ for (auto DevNum = 0; DevNum < Platform->Plugin->number_of_devices();
DevNum++) {
- if (Platform.Plugin->init_device(DevNum) == OFFLOAD_SUCCESS) {
- auto Device = &Platform.Plugin->getDevice(DevNum);
+ if (Platform->Plugin->init_device(DevNum) == OFFLOAD_SUCCESS) {
+ auto Device = &Platform->Plugin->getDevice(DevNum);
auto Info = Device->obtainInfoImpl();
if (auto Err = Info.takeError())
return Err;
- Platform.Devices.emplace_back(std::make_unique<ol_device_impl_t>(
- DevNum, Device, &Platform, std::move(*Info)));
+ Platform->Devices.emplace_back(std::make_unique<ol_device_impl_t>(
+ DevNum, Device, *Platform, std::move(*Info)));
}
}
}
// Add the special host device
auto &HostPlatform = Context.Platforms.emplace_back(
- ol_platform_impl_t{nullptr, OL_PLATFORM_BACKEND_HOST});
- HostPlatform.Devices.emplace_back(
- std::make_unique<ol_device_impl_t>(-1, nullptr, nullptr, InfoTreeNode{}));
- Context.HostDevice()->Platform = &HostPlatform;
+ std::make_unique<ol_platform_impl_t>(nullptr, OL_PLATFORM_BACKEND_HOST));
+ HostPlatform->Devices.emplace_back(std::make_unique<ol_device_impl_t>(
+ -1, nullptr, *HostPlatform, InfoTreeNode{}));
Context.TracingEnabled = std::getenv("OFFLOAD_TRACE");
Context.ValidationEnabled = !std::getenv("OFFLOAD_DISABLE_VALIDATION");
@@ -316,10 +317,10 @@ Error olShutDown_impl() {
for (auto &P : OldContext->Platforms) {
// Host plugin is nullptr and has no deinit
- if (!P.Plugin || !P.Plugin->is_initialized())
+ if (!P->Plugin || !P->Plugin->is_initialized())
continue;
- if (auto Res = P.destroy())
+ if (auto Res = P->destroy())
Result = llvm::joinErrors(std::move(Result), std::move(Res));
}
@@ -384,7 +385,7 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device,
// These are not implemented by the plugin interface
switch (PropName) {
case OL_DEVICE_INFO_PLATFORM:
- return Info.write<void *>(Device->Platform);
+ return Info.write<void *>(&Device->Platform);
case OL_DEVICE_INFO_TYPE:
return Info.write<ol_device_type_t>(OL_DEVICE_TYPE_GPU);
@@ -517,7 +518,7 @@ Error olGetDeviceInfoImplDetailHost(ol_device_handle_t Device,
switch (PropName) {
case OL_DEVICE_INFO_PLATFORM:
- return Info.write<void *>(Device->Platform);
+ return Info.write<void *>(&Device->Platform);
case OL_DEVICE_INFO_TYPE:
return Info.write<ol_device_type_t>(OL_DEVICE_TYPE_HOST);
case OL_DEVICE_INFO_NAME:
@@ -595,7 +596,7 @@ Error olGetDeviceInfoSize_impl(ol_device_handle_t Device,
Error olIterateDevices_impl(ol_device_iterate_cb_t Callback, void *UserData) {
for (auto &Platform : OffloadContext::get().Platforms) {
- for (auto &Device : Platform.Devices) {
+ for (auto &Device : Platform->Devices) {
if (!Callback(Device.get(), UserData)) {
break;
}