diff options
Diffstat (limited to 'offload/liboffload/src/OffloadImpl.cpp')
-rw-r--r-- | offload/liboffload/src/OffloadImpl.cpp | 97 |
1 files changed, 58 insertions, 39 deletions
diff --git a/offload/liboffload/src/OffloadImpl.cpp b/offload/liboffload/src/OffloadImpl.cpp index c549ae0..6d22fae 100644 --- a/offload/liboffload/src/OffloadImpl.cpp +++ b/offload/liboffload/src/OffloadImpl.cpp @@ -42,9 +42,7 @@ using namespace error; struct ol_platform_impl_t { ol_platform_impl_t(std::unique_ptr<GenericPluginTy> Plugin, ol_platform_backend_t BackendType) - : Plugin(std::move(Plugin)), BackendType(BackendType) {} - std::unique_ptr<GenericPluginTy> Plugin; - llvm::SmallVector<std::unique_ptr<ol_device_impl_t>> Devices; + : BackendType(BackendType), Plugin(std::move(Plugin)) {} ol_platform_backend_t BackendType; /// Complete all pending work for this platform and perform any needed @@ -53,6 +51,14 @@ struct ol_platform_impl_t { /// After calling this function, no liboffload functions should be called with /// this platform handle. llvm::Error destroy(); + + /// Initialize the associated plugin and devices. + llvm::Error init(); + + /// Direct access to the plugin, may be uninitialized if accessed here. + std::unique_ptr<GenericPluginTy> Plugin; + + llvm::SmallVector<std::unique_ptr<ol_device_impl_t>> Devices; }; // Handle type definitions. Ideally these would be 1:1 with the plugins, but @@ -130,6 +136,28 @@ llvm::Error ol_platform_impl_t::destroy() { return Result; } +llvm::Error ol_platform_impl_t::init() { + if (!Plugin) + return llvm::Error::success(); + + if (llvm::Error Err = Plugin->init()) + return Err; + + for (auto Id = 0, End = Plugin->getNumDevices(); Id != End; Id++) { + if (llvm::Error Err = Plugin->initDevice(Id)) + return Err; + + auto Device = &Plugin->getDevice(Id); + auto Info = Device->obtainInfoImpl(); + if (llvm::Error Err = Info.takeError()) + return Err; + Devices.emplace_back(std::make_unique<ol_device_impl_t>(Id, Device, *this, + std::move(*Info))); + } + + return llvm::Error::success(); +} + struct ol_queue_impl_t { ol_queue_impl_t(__tgt_async_info *AsyncInfo, ol_device_handle_t Device) : AsyncInfo(AsyncInfo), Device(Device), Id(IdCounter++) {} @@ -207,15 +235,11 @@ struct OffloadContext { std::mutex AllocInfoMapMutex{}; // Partitioned list of memory base addresses. Each element in this list is a // key in AllocInfoMap - llvm::SmallVector<void *> AllocBases{}; + SmallVector<void *> AllocBases{}; SmallVector<std::unique_ptr<ol_platform_impl_t>, 4> Platforms{}; + ol_device_handle_t HostDevice; size_t RefCount; - ol_device_handle_t HostDevice() { - // The host platform is always inserted last - return Platforms.back()->Devices[0].get(); - } - static OffloadContext &get() { assert(OffloadContextVal); return *OffloadContextVal; @@ -259,28 +283,21 @@ Error initPlugins(OffloadContext &Context) { } while (false); #include "Shared/Targets.def" - // Preemptively initialize all devices in the plugin + // Eagerly initialize all of the plugins and devices. We need to make sure + // that the platform is initialized at a consistent point to maintain the + // expected teardown order in the vendor libraries. for (auto &Platform : Context.Platforms) { - auto Err = Platform->Plugin->init(); - [[maybe_unused]] std::string InfoMsg = toString(std::move(Err)); - for (auto DevNum = 0; DevNum < Platform->Plugin->number_of_devices(); - DevNum++) { - if (Platform->Plugin->init_device(DevNum) == OFFLOAD_SUCCESS) { - auto Device = &Platform->Plugin->getDevice(DevNum); - auto Info = Device->obtainInfoImpl(); - if (auto Err = Info.takeError()) - return Err; - Platform->Devices.emplace_back(std::make_unique<ol_device_impl_t>( - DevNum, Device, *Platform, std::move(*Info))); - } - } + if (Error Err = Platform->init()) + return Err; } - // Add the special host device + // Add the special host device. auto &HostPlatform = Context.Platforms.emplace_back( std::make_unique<ol_platform_impl_t>(nullptr, OL_PLATFORM_BACKEND_HOST)); - HostPlatform->Devices.emplace_back(std::make_unique<ol_device_impl_t>( - -1, nullptr, *HostPlatform, InfoTreeNode{})); + Context.HostDevice = HostPlatform->Devices + .emplace_back(std::make_unique<ol_device_impl_t>( + -1, nullptr, *HostPlatform, InfoTreeNode{})) + .get(); Context.TracingEnabled = std::getenv("OFFLOAD_TRACE"); Context.ValidationEnabled = !std::getenv("OFFLOAD_DISABLE_VALIDATION"); @@ -312,16 +329,16 @@ Error olShutDown_impl() { if (--OffloadContext::get().RefCount != 0) return Error::success(); - llvm::Error Result = Error::success(); + Error Result = Error::success(); auto *OldContext = OffloadContextVal.exchange(nullptr); - for (auto &P : OldContext->Platforms) { + for (auto &Platform : OldContext->Platforms) { // Host plugin is nullptr and has no deinit - if (!P->Plugin || !P->Plugin->is_initialized()) + if (!Platform->Plugin || !Platform->Plugin->is_initialized()) continue; - if (auto Res = P->destroy()) - Result = llvm::joinErrors(std::move(Result), std::move(Res)); + if (auto Res = Platform->destroy()) + Result = joinErrors(std::move(Result), std::move(Res)); } delete OldContext; @@ -334,6 +351,8 @@ Error olGetPlatformInfoImplDetail(ol_platform_handle_t Platform, InfoWriter Info(PropSize, PropValue, PropSizeRet); bool IsHost = Platform->BackendType == OL_PLATFORM_BACKEND_HOST; + // Note that the plugin is potentially uninitialized here. It will need to be + // initialized once info is added that requires it to be initialized. switch (PropName) { case OL_PLATFORM_INFO_NAME: return Info.writeString(IsHost ? "Host" : Platform->Plugin->getName()); @@ -373,12 +392,12 @@ Error olGetPlatformInfoSize_impl(ol_platform_handle_t Platform, Error olGetDeviceInfoImplDetail(ol_device_handle_t Device, ol_device_info_t PropName, size_t PropSize, void *PropValue, size_t *PropSizeRet) { - assert(Device != OffloadContext::get().HostDevice()); + assert(Device != OffloadContext::get().HostDevice); InfoWriter Info(PropSize, PropValue, PropSizeRet); auto makeError = [&](ErrorCode Code, StringRef Err) { std::string ErrBuffer; - llvm::raw_string_ostream(ErrBuffer) << PropName << ": " << Err; + raw_string_ostream(ErrBuffer) << PropName << ": " << Err; return Plugin::error(ErrorCode::UNIMPLEMENTED, ErrBuffer.c_str()); }; @@ -511,7 +530,7 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device, Error olGetDeviceInfoImplDetailHost(ol_device_handle_t Device, ol_device_info_t PropName, size_t PropSize, void *PropValue, size_t *PropSizeRet) { - assert(Device == OffloadContext::get().HostDevice()); + assert(Device == OffloadContext::get().HostDevice); InfoWriter Info(PropSize, PropValue, PropSizeRet); constexpr auto uint32_max = std::numeric_limits<uint32_t>::max(); @@ -579,7 +598,7 @@ Error olGetDeviceInfoImplDetailHost(ol_device_handle_t Device, Error olGetDeviceInfo_impl(ol_device_handle_t Device, ol_device_info_t PropName, size_t PropSize, void *PropValue) { - if (Device == OffloadContext::get().HostDevice()) + if (Device == OffloadContext::get().HostDevice) return olGetDeviceInfoImplDetailHost(Device, PropName, PropSize, PropValue, nullptr); return olGetDeviceInfoImplDetail(Device, PropName, PropSize, PropValue, @@ -588,7 +607,7 @@ Error olGetDeviceInfo_impl(ol_device_handle_t Device, ol_device_info_t PropName, Error olGetDeviceInfoSize_impl(ol_device_handle_t Device, ol_device_info_t PropName, size_t *PropSizeRet) { - if (Device == OffloadContext::get().HostDevice()) + if (Device == OffloadContext::get().HostDevice) return olGetDeviceInfoImplDetailHost(Device, PropName, 0, nullptr, PropSizeRet); return olGetDeviceInfoImplDetail(Device, PropName, 0, nullptr, PropSizeRet); @@ -598,7 +617,7 @@ Error olIterateDevices_impl(ol_device_iterate_cb_t Callback, void *UserData) { for (auto &Platform : OffloadContext::get().Platforms) { for (auto &Device : Platform->Devices) { if (!Callback(Device.get(), UserData)) { - break; + return Error::success(); } } } @@ -949,7 +968,7 @@ Error olCreateEvent_impl(ol_queue_handle_t Queue, ol_event_handle_t *EventOut) { Error olMemcpy_impl(ol_queue_handle_t Queue, void *DstPtr, ol_device_handle_t DstDevice, const void *SrcPtr, ol_device_handle_t SrcDevice, size_t Size) { - auto Host = OffloadContext::get().HostDevice(); + auto Host = OffloadContext::get().HostDevice; if (DstDevice == Host && SrcDevice == Host) { if (!Queue) { std::memcpy(DstPtr, SrcPtr, Size); @@ -1138,7 +1157,7 @@ Error olGetSymbolInfoImplDetail(ol_symbol_handle_t Symbol, auto CheckKind = [&](ol_symbol_kind_t Required) { if (Symbol->Kind != Required) { std::string ErrBuffer; - llvm::raw_string_ostream(ErrBuffer) + raw_string_ostream(ErrBuffer) << PropName << ": Expected a symbol of Kind " << Required << " but given a symbol of Kind " << Symbol->Kind; return Plugin::error(ErrorCode::SYMBOL_KIND, ErrBuffer.c_str()); |