aboutsummaryrefslogtreecommitdiff
path: root/offload/liboffload/src/OffloadImpl.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'offload/liboffload/src/OffloadImpl.cpp')
-rw-r--r--offload/liboffload/src/OffloadImpl.cpp168
1 files changed, 81 insertions, 87 deletions
diff --git a/offload/liboffload/src/OffloadImpl.cpp b/offload/liboffload/src/OffloadImpl.cpp
index ffc9016b..2444ccd 100644
--- a/offload/liboffload/src/OffloadImpl.cpp
+++ b/offload/liboffload/src/OffloadImpl.cpp
@@ -299,78 +299,62 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device,
return Plugin::error(ErrorCode::UNIMPLEMENTED, ErrBuffer.c_str());
};
- // Find the info if it exists under any of the given names
- auto getInfoString =
- [&](std::vector<std::string> Names) -> llvm::Expected<const char *> {
- for (auto &Name : Names) {
- if (auto Entry = Device->Info.get(Name)) {
- if (!std::holds_alternative<std::string>((*Entry)->Value))
- return makeError(ErrorCode::BACKEND_FAILURE,
- "plugin returned incorrect type");
- return std::get<std::string>((*Entry)->Value).c_str();
- }
- }
+ // These are not implemented by the plugin interface
+ if (PropName == OL_DEVICE_INFO_PLATFORM)
+ return Info.write<void *>(Device->Platform);
+ if (PropName == OL_DEVICE_INFO_TYPE)
+ return Info.write<ol_device_type_t>(OL_DEVICE_TYPE_GPU);
+ if (PropName >= OL_DEVICE_INFO_LAST)
+ return createOffloadError(ErrorCode::INVALID_ENUMERATION,
+ "getDeviceInfo enum '%i' is invalid", PropName);
+ auto EntryOpt = Device->Info.get(static_cast<DeviceInfo>(PropName));
+ if (!EntryOpt)
return makeError(ErrorCode::UNIMPLEMENTED,
"plugin did not provide a response for this information");
- };
-
- auto getInfoXyz =
- [&](std::vector<std::string> Names) -> llvm::Expected<ol_dimensions_t> {
- for (auto &Name : Names) {
- if (auto Entry = Device->Info.get(Name)) {
- auto Node = *Entry;
- ol_dimensions_t Out{0, 0, 0};
-
- auto getField = [&](StringRef Name, uint32_t &Dest) {
- if (auto F = Node->get(Name)) {
- if (!std::holds_alternative<size_t>((*F)->Value))
- return makeError(
- ErrorCode::BACKEND_FAILURE,
- "plugin returned incorrect type for dimensions element");
- Dest = std::get<size_t>((*F)->Value);
- } else
- return makeError(ErrorCode::BACKEND_FAILURE,
- "plugin didn't provide all values for dimensions");
- return Plugin::success();
- };
-
- if (auto Res = getField("x", Out.x))
- return Res;
- if (auto Res = getField("y", Out.y))
- return Res;
- if (auto Res = getField("z", Out.z))
- return Res;
-
- return Out;
- }
- }
-
- return makeError(ErrorCode::UNIMPLEMENTED,
- "plugin did not provide a response for this information");
- };
+ auto Entry = *EntryOpt;
switch (PropName) {
- case OL_DEVICE_INFO_PLATFORM:
- return Info.write<void *>(Device->Platform);
- case OL_DEVICE_INFO_TYPE:
- return Info.write<ol_device_type_t>(OL_DEVICE_TYPE_GPU);
case OL_DEVICE_INFO_NAME:
- return Info.writeString(getInfoString({"Device Name"}));
case OL_DEVICE_INFO_VENDOR:
- return Info.writeString(getInfoString({"Vendor Name"}));
- case OL_DEVICE_INFO_DRIVER_VERSION:
- return Info.writeString(
- getInfoString({"CUDA Driver Version", "HSA Runtime Version"}));
- case OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE:
- return Info.write(getInfoXyz({"Workgroup Max Size per Dimension" /*AMD*/,
- "Maximum Block Dimensions" /*CUDA*/}));
- default:
- return createOffloadError(ErrorCode::INVALID_ENUMERATION,
- "getDeviceInfo enum '%i' is invalid", PropName);
+ case OL_DEVICE_INFO_DRIVER_VERSION: {
+ // String values
+ if (!std::holds_alternative<std::string>(Entry->Value))
+ return makeError(ErrorCode::BACKEND_FAILURE,
+ "plugin returned incorrect type");
+ return Info.writeString(std::get<std::string>(Entry->Value).c_str());
}
- return Error::success();
+ case OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE: {
+ // {x, y, z} triples
+ ol_dimensions_t Out{0, 0, 0};
+
+ auto getField = [&](StringRef Name, uint32_t &Dest) {
+ if (auto F = Entry->get(Name)) {
+ if (!std::holds_alternative<size_t>((*F)->Value))
+ return makeError(
+ ErrorCode::BACKEND_FAILURE,
+ "plugin returned incorrect type for dimensions element");
+ Dest = std::get<size_t>((*F)->Value);
+ } else
+ return makeError(ErrorCode::BACKEND_FAILURE,
+ "plugin didn't provide all values for dimensions");
+ return Plugin::success();
+ };
+
+ if (auto Res = getField("x", Out.x))
+ return Res;
+ if (auto Res = getField("y", Out.y))
+ return Res;
+ if (auto Res = getField("z", Out.z))
+ return Res;
+
+ return Info.write(Out);
+ }
+
+ default:
+ llvm_unreachable("Unimplemented device info");
+ }
}
Error olGetDeviceInfoImplDetailHost(ol_device_handle_t Device,
@@ -483,7 +467,7 @@ Error olCreateQueue_impl(ol_device_handle_t Device, ol_queue_handle_t *Queue) {
Error olDestroyQueue_impl(ol_queue_handle_t Queue) { return olDestroy(Queue); }
-Error olWaitQueue_impl(ol_queue_handle_t Queue) {
+Error olSyncQueue_impl(ol_queue_handle_t Queue) {
// Host plugin doesn't have a queue set so it's not safe to call synchronize
// on it, but we have nothing to synchronize in that situation anyway.
if (Queue->AsyncInfo->Queue) {
@@ -500,6 +484,28 @@ Error olWaitQueue_impl(ol_queue_handle_t Queue) {
return Error::success();
}
+Error olWaitEvents_impl(ol_queue_handle_t Queue, ol_event_handle_t *Events,
+ size_t NumEvents) {
+ auto *Device = Queue->Device->Device;
+
+ for (size_t I = 0; I < NumEvents; I++) {
+ auto *Event = Events[I];
+
+ if (!Event)
+ return Plugin::error(ErrorCode::INVALID_NULL_HANDLE,
+ "olWaitEvents asked to wait on a NULL event");
+
+ // Do nothing if the event is for this queue
+ if (Event->Queue == Queue)
+ continue;
+
+ if (auto Err = Device->waitEvent(Event->EventInfo, Queue->AsyncInfo))
+ return Err;
+ }
+
+ return Error::success();
+}
+
Error olGetQueueInfoImplDetail(ol_queue_handle_t Queue,
ol_queue_info_t PropName, size_t PropSize,
void *PropValue, size_t *PropSizeRet) {
@@ -527,7 +533,7 @@ Error olGetQueueInfoSize_impl(ol_queue_handle_t Queue, ol_queue_info_t PropName,
return olGetQueueInfoImplDetail(Queue, PropName, 0, nullptr, PropSizeRet);
}
-Error olWaitEvent_impl(ol_event_handle_t Event) {
+Error olSyncEvent_impl(ol_event_handle_t Event) {
if (auto Res = Event->Queue->Device->Device->syncEvent(Event->EventInfo))
return Res;
@@ -569,26 +575,21 @@ Error olGetEventInfoSize_impl(ol_event_handle_t Event, ol_event_info_t PropName,
return olGetEventInfoImplDetail(Event, PropName, 0, nullptr, PropSizeRet);
}
-ol_event_handle_t makeEvent(ol_queue_handle_t Queue) {
- auto EventImpl = std::make_unique<ol_event_impl_t>(nullptr, Queue);
- if (auto Res = Queue->Device->Device->createEvent(&EventImpl->EventInfo)) {
- llvm::consumeError(std::move(Res));
- return nullptr;
- }
+Error olCreateEvent_impl(ol_queue_handle_t Queue, ol_event_handle_t *EventOut) {
+ *EventOut = new ol_event_impl_t(nullptr, Queue);
+ if (auto Res = Queue->Device->Device->createEvent(&(*EventOut)->EventInfo))
+ return Res;
- if (auto Res = Queue->Device->Device->recordEvent(EventImpl->EventInfo,
- Queue->AsyncInfo)) {
- llvm::consumeError(std::move(Res));
- return nullptr;
- }
+ if (auto Res = Queue->Device->Device->recordEvent((*EventOut)->EventInfo,
+ Queue->AsyncInfo))
+ return Res;
- return EventImpl.release();
+ return Plugin::success();
}
Error olMemcpy_impl(ol_queue_handle_t Queue, void *DstPtr,
ol_device_handle_t DstDevice, const void *SrcPtr,
- ol_device_handle_t SrcDevice, size_t Size,
- ol_event_handle_t *EventOut) {
+ ol_device_handle_t SrcDevice, size_t Size) {
auto Host = OffloadContext::get().HostDevice();
if (DstDevice == Host && SrcDevice == Host) {
if (!Queue) {
@@ -619,9 +620,6 @@ Error olMemcpy_impl(ol_queue_handle_t Queue, void *DstPtr,
return Res;
}
- if (EventOut)
- *EventOut = makeEvent(Queue);
-
return Error::success();
}
@@ -668,8 +666,7 @@ Error olDestroyProgram_impl(ol_program_handle_t Program) {
Error olLaunchKernel_impl(ol_queue_handle_t Queue, ol_device_handle_t Device,
ol_symbol_handle_t Kernel, const void *ArgumentsData,
size_t ArgumentsSize,
- const ol_kernel_launch_size_args_t *LaunchSizeArgs,
- ol_event_handle_t *EventOut) {
+ const ol_kernel_launch_size_args_t *LaunchSizeArgs) {
auto *DeviceImpl = Device->Device;
if (Queue && Device != Queue->Device) {
return createOffloadError(
@@ -707,9 +704,6 @@ Error olLaunchKernel_impl(ol_queue_handle_t Queue, ol_device_handle_t Device,
if (Err)
return Err;
- if (EventOut)
- *EventOut = makeEvent(Queue);
-
return Error::success();
}