diff options
Diffstat (limited to 'offload/liboffload/src/OffloadImpl.cpp')
-rw-r--r-- | offload/liboffload/src/OffloadImpl.cpp | 89 |
1 files changed, 45 insertions, 44 deletions
diff --git a/offload/liboffload/src/OffloadImpl.cpp b/offload/liboffload/src/OffloadImpl.cpp index 08a2e25..051882d 100644 --- a/offload/liboffload/src/OffloadImpl.cpp +++ b/offload/liboffload/src/OffloadImpl.cpp @@ -39,12 +39,28 @@ using namespace llvm::omp::target; using namespace llvm::omp::target::plugin; using namespace error; +struct ol_platform_impl_t { + ol_platform_impl_t(std::unique_ptr<GenericPluginTy> Plugin, + ol_platform_backend_t BackendType) + : Plugin(std::move(Plugin)), BackendType(BackendType) {} + std::unique_ptr<GenericPluginTy> Plugin; + llvm::SmallVector<std::unique_ptr<ol_device_impl_t>> Devices; + ol_platform_backend_t BackendType; + + /// Complete all pending work for this platform and perform any needed + /// cleanup. + /// + /// After calling this function, no liboffload functions should be called with + /// this platform handle. + llvm::Error destroy(); +}; + // Handle type definitions. Ideally these would be 1:1 with the plugins, but // we add some additional data here for now to avoid churn in the plugin // interface. struct ol_device_impl_t { ol_device_impl_t(int DeviceNum, GenericDeviceTy *Device, - ol_platform_handle_t Platform, InfoTreeNode &&DevInfo) + ol_platform_impl_t &Platform, InfoTreeNode &&DevInfo) : DeviceNum(DeviceNum), Device(Device), Platform(Platform), Info(std::forward<InfoTreeNode>(DevInfo)) {} @@ -55,7 +71,7 @@ struct ol_device_impl_t { int DeviceNum; GenericDeviceTy *Device; - ol_platform_handle_t Platform; + ol_platform_impl_t &Platform; InfoTreeNode Info; llvm::SmallVector<__tgt_async_info *> OutstandingQueues; @@ -102,31 +118,17 @@ struct ol_device_impl_t { } }; -struct ol_platform_impl_t { - ol_platform_impl_t(std::unique_ptr<GenericPluginTy> Plugin, - ol_platform_backend_t BackendType) - : Plugin(std::move(Plugin)), BackendType(BackendType) {} - std::unique_ptr<GenericPluginTy> Plugin; - llvm::SmallVector<std::unique_ptr<ol_device_impl_t>> Devices; - ol_platform_backend_t BackendType; - - /// Complete all pending work for this platform and perform any needed - /// cleanup. - /// - /// After calling this function, no liboffload functions should be called with - /// this platform handle. - llvm::Error destroy() { - llvm::Error Result = Plugin::success(); - for (auto &D : Devices) - if (auto Err = D->destroy()) - Result = llvm::joinErrors(std::move(Result), std::move(Err)); +llvm::Error ol_platform_impl_t::destroy() { + llvm::Error Result = Plugin::success(); + for (auto &D : Devices) + if (auto Err = D->destroy()) + Result = llvm::joinErrors(std::move(Result), std::move(Err)); - if (auto Res = Plugin->deinit()) - Result = llvm::joinErrors(std::move(Result), std::move(Res)); + if (auto Res = Plugin->deinit()) + Result = llvm::joinErrors(std::move(Result), std::move(Res)); - return Result; - } -}; + return Result; +} struct ol_queue_impl_t { ol_queue_impl_t(__tgt_async_info *AsyncInfo, ol_device_handle_t Device) @@ -206,12 +208,12 @@ struct OffloadContext { // Partitioned list of memory base addresses. Each element in this list is a // key in AllocInfoMap llvm::SmallVector<void *> AllocBases{}; - SmallVector<ol_platform_impl_t, 4> Platforms{}; + SmallVector<std::unique_ptr<ol_platform_impl_t>, 4> Platforms{}; size_t RefCount; ol_device_handle_t HostDevice() { // The host platform is always inserted last - return Platforms.back().Devices[0].get(); + return Platforms.back()->Devices[0].get(); } static OffloadContext &get() { @@ -251,35 +253,34 @@ Error initPlugins(OffloadContext &Context) { #define PLUGIN_TARGET(Name) \ do { \ if (StringRef(#Name) != "host") \ - Context.Platforms.emplace_back(ol_platform_impl_t{ \ + Context.Platforms.emplace_back(std::make_unique<ol_platform_impl_t>( \ std::unique_ptr<GenericPluginTy>(createPlugin_##Name()), \ - pluginNameToBackend(#Name)}); \ + pluginNameToBackend(#Name))); \ } while (false); #include "Shared/Targets.def" // Preemptively initialize all devices in the plugin for (auto &Platform : Context.Platforms) { - auto Err = Platform.Plugin->init(); + auto Err = Platform->Plugin->init(); [[maybe_unused]] std::string InfoMsg = toString(std::move(Err)); - for (auto DevNum = 0; DevNum < Platform.Plugin->number_of_devices(); + for (auto DevNum = 0; DevNum < Platform->Plugin->number_of_devices(); DevNum++) { - if (Platform.Plugin->init_device(DevNum) == OFFLOAD_SUCCESS) { - auto Device = &Platform.Plugin->getDevice(DevNum); + if (Platform->Plugin->init_device(DevNum) == OFFLOAD_SUCCESS) { + auto Device = &Platform->Plugin->getDevice(DevNum); auto Info = Device->obtainInfoImpl(); if (auto Err = Info.takeError()) return Err; - Platform.Devices.emplace_back(std::make_unique<ol_device_impl_t>( - DevNum, Device, &Platform, std::move(*Info))); + Platform->Devices.emplace_back(std::make_unique<ol_device_impl_t>( + DevNum, Device, *Platform, std::move(*Info))); } } } // Add the special host device auto &HostPlatform = Context.Platforms.emplace_back( - ol_platform_impl_t{nullptr, OL_PLATFORM_BACKEND_HOST}); - HostPlatform.Devices.emplace_back( - std::make_unique<ol_device_impl_t>(-1, nullptr, nullptr, InfoTreeNode{})); - Context.HostDevice()->Platform = &HostPlatform; + std::make_unique<ol_platform_impl_t>(nullptr, OL_PLATFORM_BACKEND_HOST)); + HostPlatform->Devices.emplace_back(std::make_unique<ol_device_impl_t>( + -1, nullptr, *HostPlatform, InfoTreeNode{})); Context.TracingEnabled = std::getenv("OFFLOAD_TRACE"); Context.ValidationEnabled = !std::getenv("OFFLOAD_DISABLE_VALIDATION"); @@ -316,10 +317,10 @@ Error olShutDown_impl() { for (auto &P : OldContext->Platforms) { // Host plugin is nullptr and has no deinit - if (!P.Plugin || !P.Plugin->is_initialized()) + if (!P->Plugin || !P->Plugin->is_initialized()) continue; - if (auto Res = P.destroy()) + if (auto Res = P->destroy()) Result = llvm::joinErrors(std::move(Result), std::move(Res)); } @@ -384,7 +385,7 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device, // These are not implemented by the plugin interface switch (PropName) { case OL_DEVICE_INFO_PLATFORM: - return Info.write<void *>(Device->Platform); + return Info.write<void *>(&Device->Platform); case OL_DEVICE_INFO_TYPE: return Info.write<ol_device_type_t>(OL_DEVICE_TYPE_GPU); @@ -517,7 +518,7 @@ Error olGetDeviceInfoImplDetailHost(ol_device_handle_t Device, switch (PropName) { case OL_DEVICE_INFO_PLATFORM: - return Info.write<void *>(Device->Platform); + return Info.write<void *>(&Device->Platform); case OL_DEVICE_INFO_TYPE: return Info.write<ol_device_type_t>(OL_DEVICE_TYPE_HOST); case OL_DEVICE_INFO_NAME: @@ -595,7 +596,7 @@ Error olGetDeviceInfoSize_impl(ol_device_handle_t Device, Error olIterateDevices_impl(ol_device_iterate_cb_t Callback, void *UserData) { for (auto &Platform : OffloadContext::get().Platforms) { - for (auto &Device : Platform.Devices) { + for (auto &Device : Platform->Devices) { if (!Callback(Device.get(), UserData)) { break; } |