aboutsummaryrefslogtreecommitdiff
path: root/offload/liboffload/src/OffloadImpl.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'offload/liboffload/src/OffloadImpl.cpp')
-rw-r--r--offload/liboffload/src/OffloadImpl.cpp97
1 files changed, 58 insertions, 39 deletions
diff --git a/offload/liboffload/src/OffloadImpl.cpp b/offload/liboffload/src/OffloadImpl.cpp
index c549ae0..6d22fae 100644
--- a/offload/liboffload/src/OffloadImpl.cpp
+++ b/offload/liboffload/src/OffloadImpl.cpp
@@ -42,9 +42,7 @@ using namespace error;
struct ol_platform_impl_t {
ol_platform_impl_t(std::unique_ptr<GenericPluginTy> Plugin,
ol_platform_backend_t BackendType)
- : Plugin(std::move(Plugin)), BackendType(BackendType) {}
- std::unique_ptr<GenericPluginTy> Plugin;
- llvm::SmallVector<std::unique_ptr<ol_device_impl_t>> Devices;
+ : BackendType(BackendType), Plugin(std::move(Plugin)) {}
ol_platform_backend_t BackendType;
/// Complete all pending work for this platform and perform any needed
@@ -53,6 +51,14 @@ struct ol_platform_impl_t {
/// After calling this function, no liboffload functions should be called with
/// this platform handle.
llvm::Error destroy();
+
+ /// Initialize the associated plugin and devices.
+ llvm::Error init();
+
+ /// Direct access to the plugin, may be uninitialized if accessed here.
+ std::unique_ptr<GenericPluginTy> Plugin;
+
+ llvm::SmallVector<std::unique_ptr<ol_device_impl_t>> Devices;
};
// Handle type definitions. Ideally these would be 1:1 with the plugins, but
@@ -130,6 +136,28 @@ llvm::Error ol_platform_impl_t::destroy() {
return Result;
}
+llvm::Error ol_platform_impl_t::init() {
+ if (!Plugin)
+ return llvm::Error::success();
+
+ if (llvm::Error Err = Plugin->init())
+ return Err;
+
+ for (auto Id = 0, End = Plugin->getNumDevices(); Id != End; Id++) {
+ if (llvm::Error Err = Plugin->initDevice(Id))
+ return Err;
+
+ auto Device = &Plugin->getDevice(Id);
+ auto Info = Device->obtainInfoImpl();
+ if (llvm::Error Err = Info.takeError())
+ return Err;
+ Devices.emplace_back(std::make_unique<ol_device_impl_t>(Id, Device, *this,
+ std::move(*Info)));
+ }
+
+ return llvm::Error::success();
+}
+
struct ol_queue_impl_t {
ol_queue_impl_t(__tgt_async_info *AsyncInfo, ol_device_handle_t Device)
: AsyncInfo(AsyncInfo), Device(Device), Id(IdCounter++) {}
@@ -207,15 +235,11 @@ struct OffloadContext {
std::mutex AllocInfoMapMutex{};
// Partitioned list of memory base addresses. Each element in this list is a
// key in AllocInfoMap
- llvm::SmallVector<void *> AllocBases{};
+ SmallVector<void *> AllocBases{};
SmallVector<std::unique_ptr<ol_platform_impl_t>, 4> Platforms{};
+ ol_device_handle_t HostDevice;
size_t RefCount;
- ol_device_handle_t HostDevice() {
- // The host platform is always inserted last
- return Platforms.back()->Devices[0].get();
- }
-
static OffloadContext &get() {
assert(OffloadContextVal);
return *OffloadContextVal;
@@ -259,28 +283,21 @@ Error initPlugins(OffloadContext &Context) {
} while (false);
#include "Shared/Targets.def"
- // Preemptively initialize all devices in the plugin
+ // Eagerly initialize all of the plugins and devices. We need to make sure
+ // that the platform is initialized at a consistent point to maintain the
+ // expected teardown order in the vendor libraries.
for (auto &Platform : Context.Platforms) {
- auto Err = Platform->Plugin->init();
- [[maybe_unused]] std::string InfoMsg = toString(std::move(Err));
- for (auto DevNum = 0; DevNum < Platform->Plugin->number_of_devices();
- DevNum++) {
- if (Platform->Plugin->init_device(DevNum) == OFFLOAD_SUCCESS) {
- auto Device = &Platform->Plugin->getDevice(DevNum);
- auto Info = Device->obtainInfoImpl();
- if (auto Err = Info.takeError())
- return Err;
- Platform->Devices.emplace_back(std::make_unique<ol_device_impl_t>(
- DevNum, Device, *Platform, std::move(*Info)));
- }
- }
+ if (Error Err = Platform->init())
+ return Err;
}
- // Add the special host device
+ // Add the special host device.
auto &HostPlatform = Context.Platforms.emplace_back(
std::make_unique<ol_platform_impl_t>(nullptr, OL_PLATFORM_BACKEND_HOST));
- HostPlatform->Devices.emplace_back(std::make_unique<ol_device_impl_t>(
- -1, nullptr, *HostPlatform, InfoTreeNode{}));
+ Context.HostDevice = HostPlatform->Devices
+ .emplace_back(std::make_unique<ol_device_impl_t>(
+ -1, nullptr, *HostPlatform, InfoTreeNode{}))
+ .get();
Context.TracingEnabled = std::getenv("OFFLOAD_TRACE");
Context.ValidationEnabled = !std::getenv("OFFLOAD_DISABLE_VALIDATION");
@@ -312,16 +329,16 @@ Error olShutDown_impl() {
if (--OffloadContext::get().RefCount != 0)
return Error::success();
- llvm::Error Result = Error::success();
+ Error Result = Error::success();
auto *OldContext = OffloadContextVal.exchange(nullptr);
- for (auto &P : OldContext->Platforms) {
+ for (auto &Platform : OldContext->Platforms) {
// Host plugin is nullptr and has no deinit
- if (!P->Plugin || !P->Plugin->is_initialized())
+ if (!Platform->Plugin || !Platform->Plugin->is_initialized())
continue;
- if (auto Res = P->destroy())
- Result = llvm::joinErrors(std::move(Result), std::move(Res));
+ if (auto Res = Platform->destroy())
+ Result = joinErrors(std::move(Result), std::move(Res));
}
delete OldContext;
@@ -334,6 +351,8 @@ Error olGetPlatformInfoImplDetail(ol_platform_handle_t Platform,
InfoWriter Info(PropSize, PropValue, PropSizeRet);
bool IsHost = Platform->BackendType == OL_PLATFORM_BACKEND_HOST;
+ // Note that the plugin is potentially uninitialized here. It will need to be
+ // initialized once info is added that requires it to be initialized.
switch (PropName) {
case OL_PLATFORM_INFO_NAME:
return Info.writeString(IsHost ? "Host" : Platform->Plugin->getName());
@@ -373,12 +392,12 @@ Error olGetPlatformInfoSize_impl(ol_platform_handle_t Platform,
Error olGetDeviceInfoImplDetail(ol_device_handle_t Device,
ol_device_info_t PropName, size_t PropSize,
void *PropValue, size_t *PropSizeRet) {
- assert(Device != OffloadContext::get().HostDevice());
+ assert(Device != OffloadContext::get().HostDevice);
InfoWriter Info(PropSize, PropValue, PropSizeRet);
auto makeError = [&](ErrorCode Code, StringRef Err) {
std::string ErrBuffer;
- llvm::raw_string_ostream(ErrBuffer) << PropName << ": " << Err;
+ raw_string_ostream(ErrBuffer) << PropName << ": " << Err;
return Plugin::error(ErrorCode::UNIMPLEMENTED, ErrBuffer.c_str());
};
@@ -511,7 +530,7 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device,
Error olGetDeviceInfoImplDetailHost(ol_device_handle_t Device,
ol_device_info_t PropName, size_t PropSize,
void *PropValue, size_t *PropSizeRet) {
- assert(Device == OffloadContext::get().HostDevice());
+ assert(Device == OffloadContext::get().HostDevice);
InfoWriter Info(PropSize, PropValue, PropSizeRet);
constexpr auto uint32_max = std::numeric_limits<uint32_t>::max();
@@ -579,7 +598,7 @@ Error olGetDeviceInfoImplDetailHost(ol_device_handle_t Device,
Error olGetDeviceInfo_impl(ol_device_handle_t Device, ol_device_info_t PropName,
size_t PropSize, void *PropValue) {
- if (Device == OffloadContext::get().HostDevice())
+ if (Device == OffloadContext::get().HostDevice)
return olGetDeviceInfoImplDetailHost(Device, PropName, PropSize, PropValue,
nullptr);
return olGetDeviceInfoImplDetail(Device, PropName, PropSize, PropValue,
@@ -588,7 +607,7 @@ Error olGetDeviceInfo_impl(ol_device_handle_t Device, ol_device_info_t PropName,
Error olGetDeviceInfoSize_impl(ol_device_handle_t Device,
ol_device_info_t PropName, size_t *PropSizeRet) {
- if (Device == OffloadContext::get().HostDevice())
+ if (Device == OffloadContext::get().HostDevice)
return olGetDeviceInfoImplDetailHost(Device, PropName, 0, nullptr,
PropSizeRet);
return olGetDeviceInfoImplDetail(Device, PropName, 0, nullptr, PropSizeRet);
@@ -598,7 +617,7 @@ Error olIterateDevices_impl(ol_device_iterate_cb_t Callback, void *UserData) {
for (auto &Platform : OffloadContext::get().Platforms) {
for (auto &Device : Platform->Devices) {
if (!Callback(Device.get(), UserData)) {
- break;
+ return Error::success();
}
}
}
@@ -949,7 +968,7 @@ Error olCreateEvent_impl(ol_queue_handle_t Queue, ol_event_handle_t *EventOut) {
Error olMemcpy_impl(ol_queue_handle_t Queue, void *DstPtr,
ol_device_handle_t DstDevice, const void *SrcPtr,
ol_device_handle_t SrcDevice, size_t Size) {
- auto Host = OffloadContext::get().HostDevice();
+ auto Host = OffloadContext::get().HostDevice;
if (DstDevice == Host && SrcDevice == Host) {
if (!Queue) {
std::memcpy(DstPtr, SrcPtr, Size);
@@ -1138,7 +1157,7 @@ Error olGetSymbolInfoImplDetail(ol_symbol_handle_t Symbol,
auto CheckKind = [&](ol_symbol_kind_t Required) {
if (Symbol->Kind != Required) {
std::string ErrBuffer;
- llvm::raw_string_ostream(ErrBuffer)
+ raw_string_ostream(ErrBuffer)
<< PropName << ": Expected a symbol of Kind " << Required
<< " but given a symbol of Kind " << Symbol->Kind;
return Plugin::error(ErrorCode::SYMBOL_KIND, ErrBuffer.c_str());