diff options
Diffstat (limited to 'offload/liboffload/src/OffloadImpl.cpp')
-rw-r--r-- | offload/liboffload/src/OffloadImpl.cpp | 168 |
1 files changed, 81 insertions, 87 deletions
diff --git a/offload/liboffload/src/OffloadImpl.cpp b/offload/liboffload/src/OffloadImpl.cpp index ffc9016b..2444ccd 100644 --- a/offload/liboffload/src/OffloadImpl.cpp +++ b/offload/liboffload/src/OffloadImpl.cpp @@ -299,78 +299,62 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device, return Plugin::error(ErrorCode::UNIMPLEMENTED, ErrBuffer.c_str()); }; - // Find the info if it exists under any of the given names - auto getInfoString = - [&](std::vector<std::string> Names) -> llvm::Expected<const char *> { - for (auto &Name : Names) { - if (auto Entry = Device->Info.get(Name)) { - if (!std::holds_alternative<std::string>((*Entry)->Value)) - return makeError(ErrorCode::BACKEND_FAILURE, - "plugin returned incorrect type"); - return std::get<std::string>((*Entry)->Value).c_str(); - } - } + // These are not implemented by the plugin interface + if (PropName == OL_DEVICE_INFO_PLATFORM) + return Info.write<void *>(Device->Platform); + if (PropName == OL_DEVICE_INFO_TYPE) + return Info.write<ol_device_type_t>(OL_DEVICE_TYPE_GPU); + if (PropName >= OL_DEVICE_INFO_LAST) + return createOffloadError(ErrorCode::INVALID_ENUMERATION, + "getDeviceInfo enum '%i' is invalid", PropName); + auto EntryOpt = Device->Info.get(static_cast<DeviceInfo>(PropName)); + if (!EntryOpt) return makeError(ErrorCode::UNIMPLEMENTED, "plugin did not provide a response for this information"); - }; - - auto getInfoXyz = - [&](std::vector<std::string> Names) -> llvm::Expected<ol_dimensions_t> { - for (auto &Name : Names) { - if (auto Entry = Device->Info.get(Name)) { - auto Node = *Entry; - ol_dimensions_t Out{0, 0, 0}; - - auto getField = [&](StringRef Name, uint32_t &Dest) { - if (auto F = Node->get(Name)) { - if (!std::holds_alternative<size_t>((*F)->Value)) - return makeError( - ErrorCode::BACKEND_FAILURE, - "plugin returned incorrect type for dimensions element"); - Dest = std::get<size_t>((*F)->Value); - } else - return makeError(ErrorCode::BACKEND_FAILURE, - "plugin didn't provide all values for dimensions"); - return Plugin::success(); - }; - - if (auto Res = getField("x", Out.x)) - return Res; - if (auto Res = getField("y", Out.y)) - return Res; - if (auto Res = getField("z", Out.z)) - return Res; - - return Out; - } - } - - return makeError(ErrorCode::UNIMPLEMENTED, - "plugin did not provide a response for this information"); - }; + auto Entry = *EntryOpt; switch (PropName) { - case OL_DEVICE_INFO_PLATFORM: - return Info.write<void *>(Device->Platform); - case OL_DEVICE_INFO_TYPE: - return Info.write<ol_device_type_t>(OL_DEVICE_TYPE_GPU); case OL_DEVICE_INFO_NAME: - return Info.writeString(getInfoString({"Device Name"})); case OL_DEVICE_INFO_VENDOR: - return Info.writeString(getInfoString({"Vendor Name"})); - case OL_DEVICE_INFO_DRIVER_VERSION: - return Info.writeString( - getInfoString({"CUDA Driver Version", "HSA Runtime Version"})); - case OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE: - return Info.write(getInfoXyz({"Workgroup Max Size per Dimension" /*AMD*/, - "Maximum Block Dimensions" /*CUDA*/})); - default: - return createOffloadError(ErrorCode::INVALID_ENUMERATION, - "getDeviceInfo enum '%i' is invalid", PropName); + case OL_DEVICE_INFO_DRIVER_VERSION: { + // String values + if (!std::holds_alternative<std::string>(Entry->Value)) + return makeError(ErrorCode::BACKEND_FAILURE, + "plugin returned incorrect type"); + return Info.writeString(std::get<std::string>(Entry->Value).c_str()); } - return Error::success(); + case OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE: { + // {x, y, z} triples + ol_dimensions_t Out{0, 0, 0}; + + auto getField = [&](StringRef Name, uint32_t &Dest) { + if (auto F = Entry->get(Name)) { + if (!std::holds_alternative<size_t>((*F)->Value)) + return makeError( + ErrorCode::BACKEND_FAILURE, + "plugin returned incorrect type for dimensions element"); + Dest = std::get<size_t>((*F)->Value); + } else + return makeError(ErrorCode::BACKEND_FAILURE, + "plugin didn't provide all values for dimensions"); + return Plugin::success(); + }; + + if (auto Res = getField("x", Out.x)) + return Res; + if (auto Res = getField("y", Out.y)) + return Res; + if (auto Res = getField("z", Out.z)) + return Res; + + return Info.write(Out); + } + + default: + llvm_unreachable("Unimplemented device info"); + } } Error olGetDeviceInfoImplDetailHost(ol_device_handle_t Device, @@ -483,7 +467,7 @@ Error olCreateQueue_impl(ol_device_handle_t Device, ol_queue_handle_t *Queue) { Error olDestroyQueue_impl(ol_queue_handle_t Queue) { return olDestroy(Queue); } -Error olWaitQueue_impl(ol_queue_handle_t Queue) { +Error olSyncQueue_impl(ol_queue_handle_t Queue) { // Host plugin doesn't have a queue set so it's not safe to call synchronize // on it, but we have nothing to synchronize in that situation anyway. if (Queue->AsyncInfo->Queue) { @@ -500,6 +484,28 @@ Error olWaitQueue_impl(ol_queue_handle_t Queue) { return Error::success(); } +Error olWaitEvents_impl(ol_queue_handle_t Queue, ol_event_handle_t *Events, + size_t NumEvents) { + auto *Device = Queue->Device->Device; + + for (size_t I = 0; I < NumEvents; I++) { + auto *Event = Events[I]; + + if (!Event) + return Plugin::error(ErrorCode::INVALID_NULL_HANDLE, + "olWaitEvents asked to wait on a NULL event"); + + // Do nothing if the event is for this queue + if (Event->Queue == Queue) + continue; + + if (auto Err = Device->waitEvent(Event->EventInfo, Queue->AsyncInfo)) + return Err; + } + + return Error::success(); +} + Error olGetQueueInfoImplDetail(ol_queue_handle_t Queue, ol_queue_info_t PropName, size_t PropSize, void *PropValue, size_t *PropSizeRet) { @@ -527,7 +533,7 @@ Error olGetQueueInfoSize_impl(ol_queue_handle_t Queue, ol_queue_info_t PropName, return olGetQueueInfoImplDetail(Queue, PropName, 0, nullptr, PropSizeRet); } -Error olWaitEvent_impl(ol_event_handle_t Event) { +Error olSyncEvent_impl(ol_event_handle_t Event) { if (auto Res = Event->Queue->Device->Device->syncEvent(Event->EventInfo)) return Res; @@ -569,26 +575,21 @@ Error olGetEventInfoSize_impl(ol_event_handle_t Event, ol_event_info_t PropName, return olGetEventInfoImplDetail(Event, PropName, 0, nullptr, PropSizeRet); } -ol_event_handle_t makeEvent(ol_queue_handle_t Queue) { - auto EventImpl = std::make_unique<ol_event_impl_t>(nullptr, Queue); - if (auto Res = Queue->Device->Device->createEvent(&EventImpl->EventInfo)) { - llvm::consumeError(std::move(Res)); - return nullptr; - } +Error olCreateEvent_impl(ol_queue_handle_t Queue, ol_event_handle_t *EventOut) { + *EventOut = new ol_event_impl_t(nullptr, Queue); + if (auto Res = Queue->Device->Device->createEvent(&(*EventOut)->EventInfo)) + return Res; - if (auto Res = Queue->Device->Device->recordEvent(EventImpl->EventInfo, - Queue->AsyncInfo)) { - llvm::consumeError(std::move(Res)); - return nullptr; - } + if (auto Res = Queue->Device->Device->recordEvent((*EventOut)->EventInfo, + Queue->AsyncInfo)) + return Res; - return EventImpl.release(); + return Plugin::success(); } Error olMemcpy_impl(ol_queue_handle_t Queue, void *DstPtr, ol_device_handle_t DstDevice, const void *SrcPtr, - ol_device_handle_t SrcDevice, size_t Size, - ol_event_handle_t *EventOut) { + ol_device_handle_t SrcDevice, size_t Size) { auto Host = OffloadContext::get().HostDevice(); if (DstDevice == Host && SrcDevice == Host) { if (!Queue) { @@ -619,9 +620,6 @@ Error olMemcpy_impl(ol_queue_handle_t Queue, void *DstPtr, return Res; } - if (EventOut) - *EventOut = makeEvent(Queue); - return Error::success(); } @@ -668,8 +666,7 @@ Error olDestroyProgram_impl(ol_program_handle_t Program) { Error olLaunchKernel_impl(ol_queue_handle_t Queue, ol_device_handle_t Device, ol_symbol_handle_t Kernel, const void *ArgumentsData, size_t ArgumentsSize, - const ol_kernel_launch_size_args_t *LaunchSizeArgs, - ol_event_handle_t *EventOut) { + const ol_kernel_launch_size_args_t *LaunchSizeArgs) { auto *DeviceImpl = Device->Device; if (Queue && Device != Queue->Device) { return createOffloadError( @@ -707,9 +704,6 @@ Error olLaunchKernel_impl(ol_queue_handle_t Queue, ol_device_handle_t Device, if (Err) return Err; - if (EventOut) - *EventOut = makeEvent(Queue); - return Error::success(); } |