diff options
Diffstat (limited to 'offload/liboffload/src')
-rw-r--r-- | offload/liboffload/src/OffloadImpl.cpp | 366 |
1 files changed, 321 insertions, 45 deletions
diff --git a/offload/liboffload/src/OffloadImpl.cpp b/offload/liboffload/src/OffloadImpl.cpp index 6486b2b..7e8e297 100644 --- a/offload/liboffload/src/OffloadImpl.cpp +++ b/offload/liboffload/src/OffloadImpl.cpp @@ -47,10 +47,59 @@ struct ol_device_impl_t { ol_platform_handle_t Platform, InfoTreeNode &&DevInfo) : DeviceNum(DeviceNum), Device(Device), Platform(Platform), Info(std::forward<InfoTreeNode>(DevInfo)) {} + + ~ol_device_impl_t() { + assert(!OutstandingQueues.size() && + "Device object dropped with outstanding queues"); + } + int DeviceNum; GenericDeviceTy *Device; ol_platform_handle_t Platform; InfoTreeNode Info; + + llvm::SmallVector<__tgt_async_info *> OutstandingQueues; + std::mutex OutstandingQueuesMutex; + + /// If the device has any outstanding queues that are now complete, remove it + /// from the list and return it. + /// + /// Queues may be added to the outstanding queue list by olDestroyQueue if + /// they are destroyed but not completed. + __tgt_async_info *getOutstandingQueue() { + // Not locking the `size()` access is fine here - In the worst case we + // either miss a queue that exists or loop through an empty array after + // taking the lock. Both are sub-optimal but not that bad. + if (OutstandingQueues.size()) { + std::lock_guard<std::mutex> Lock(OutstandingQueuesMutex); + + // As queues are pulled and popped from this list, longer running queues + // naturally bubble to the start of the array. Hence looping backwards. + for (auto Q = OutstandingQueues.rbegin(); Q != OutstandingQueues.rend(); + Q++) { + if (!Device->hasPendingWork(*Q)) { + auto OutstandingQueue = *Q; + *Q = OutstandingQueues.back(); + OutstandingQueues.pop_back(); + return OutstandingQueue; + } + } + } + return nullptr; + } + + /// Complete all pending work for this device and perform any needed cleanup. + /// + /// After calling this function, no liboffload functions should be called with + /// this device handle. + llvm::Error destroy() { + llvm::Error Result = Plugin::success(); + for (auto Q : OutstandingQueues) + if (auto Err = Device->synchronize(Q, /*Release=*/true)) + Result = llvm::joinErrors(std::move(Result), std::move(Err)); + OutstandingQueues.clear(); + return Result; + } }; struct ol_platform_impl_t { @@ -58,21 +107,51 @@ struct ol_platform_impl_t { ol_platform_backend_t BackendType) : Plugin(std::move(Plugin)), BackendType(BackendType) {} std::unique_ptr<GenericPluginTy> Plugin; - std::vector<ol_device_impl_t> Devices; + llvm::SmallVector<std::unique_ptr<ol_device_impl_t>> Devices; ol_platform_backend_t BackendType; + + /// Complete all pending work for this platform and perform any needed + /// cleanup. + /// + /// After calling this function, no liboffload functions should be called with + /// this platform handle. + llvm::Error destroy() { + llvm::Error Result = Plugin::success(); + for (auto &D : Devices) + if (auto Err = D->destroy()) + Result = llvm::joinErrors(std::move(Result), std::move(Err)); + + if (auto Res = Plugin->deinit()) + Result = llvm::joinErrors(std::move(Result), std::move(Res)); + + return Result; + } }; struct ol_queue_impl_t { ol_queue_impl_t(__tgt_async_info *AsyncInfo, ol_device_handle_t Device) - : AsyncInfo(AsyncInfo), Device(Device) {} + : AsyncInfo(AsyncInfo), Device(Device), Id(IdCounter++) {} __tgt_async_info *AsyncInfo; ol_device_handle_t Device; + // A unique identifier for the queue + size_t Id; + static std::atomic<size_t> IdCounter; }; +std::atomic<size_t> ol_queue_impl_t::IdCounter(0); struct ol_event_impl_t { - ol_event_impl_t(void *EventInfo, ol_queue_handle_t Queue) - : EventInfo(EventInfo), Queue(Queue) {} + ol_event_impl_t(void *EventInfo, ol_device_handle_t Device, + ol_queue_handle_t Queue) + : EventInfo(EventInfo), Device(Device), QueueId(Queue->Id), Queue(Queue) { + } + // EventInfo may be null, in which case the event should be considered always + // complete void *EventInfo; + ol_device_handle_t Device; + size_t QueueId; + // Events may outlive the queue - don't assume this is always valid. + // It is provided only to implement OL_EVENT_INFO_QUEUE. Use QueueId to check + // for queue equality instead. ol_queue_handle_t Queue; }; @@ -123,12 +202,13 @@ struct OffloadContext { bool TracingEnabled = false; bool ValidationEnabled = true; DenseMap<void *, AllocInfo> AllocInfoMap{}; + std::mutex AllocInfoMapMutex{}; SmallVector<ol_platform_impl_t, 4> Platforms{}; size_t RefCount; ol_device_handle_t HostDevice() { // The host platform is always inserted last - return &Platforms.back().Devices[0]; + return Platforms.back().Devices[0].get(); } static OffloadContext &get() { @@ -187,8 +267,8 @@ Error initPlugins(OffloadContext &Context) { auto Info = Device->obtainInfoImpl(); if (auto Err = Info.takeError()) return Err; - Platform.Devices.emplace_back(DevNum, Device, &Platform, - std::move(*Info)); + Platform.Devices.emplace_back(std::make_unique<ol_device_impl_t>( + DevNum, Device, &Platform, std::move(*Info))); } } } @@ -196,7 +276,8 @@ Error initPlugins(OffloadContext &Context) { // Add the special host device auto &HostPlatform = Context.Platforms.emplace_back( ol_platform_impl_t{nullptr, OL_PLATFORM_BACKEND_HOST}); - HostPlatform.Devices.emplace_back(-1, nullptr, nullptr, InfoTreeNode{}); + HostPlatform.Devices.emplace_back( + std::make_unique<ol_device_impl_t>(-1, nullptr, nullptr, InfoTreeNode{})); Context.HostDevice()->Platform = &HostPlatform; Context.TracingEnabled = std::getenv("OFFLOAD_TRACE"); @@ -206,7 +287,7 @@ Error initPlugins(OffloadContext &Context) { } Error olInit_impl() { - std::lock_guard<std::mutex> Lock{OffloadContextValMutex}; + std::lock_guard<std::mutex> Lock(OffloadContextValMutex); if (isOffloadInitialized()) { OffloadContext::get().RefCount++; @@ -224,7 +305,7 @@ Error olInit_impl() { } Error olShutDown_impl() { - std::lock_guard<std::mutex> Lock{OffloadContextValMutex}; + std::lock_guard<std::mutex> Lock(OffloadContextValMutex); if (--OffloadContext::get().RefCount != 0) return Error::success(); @@ -237,7 +318,7 @@ Error olShutDown_impl() { if (!P.Plugin || !P.Plugin->is_initialized()) continue; - if (auto Res = P.Plugin->deinit()) + if (auto Res = P.destroy()) Result = llvm::joinErrors(std::move(Result), std::move(Res)); } @@ -300,10 +381,57 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device, }; // These are not implemented by the plugin interface - if (PropName == OL_DEVICE_INFO_PLATFORM) + switch (PropName) { + case OL_DEVICE_INFO_PLATFORM: return Info.write<void *>(Device->Platform); - if (PropName == OL_DEVICE_INFO_TYPE) + + case OL_DEVICE_INFO_TYPE: return Info.write<ol_device_type_t>(OL_DEVICE_TYPE_GPU); + + case OL_DEVICE_INFO_SINGLE_FP_CONFIG: + case OL_DEVICE_INFO_DOUBLE_FP_CONFIG: { + ol_device_fp_capability_flags_t flags{0}; + flags |= OL_DEVICE_FP_CAPABILITY_FLAG_CORRECTLY_ROUNDED_DIVIDE_SQRT | + OL_DEVICE_FP_CAPABILITY_FLAG_ROUND_TO_NEAREST | + OL_DEVICE_FP_CAPABILITY_FLAG_ROUND_TO_ZERO | + OL_DEVICE_FP_CAPABILITY_FLAG_ROUND_TO_INF | + OL_DEVICE_FP_CAPABILITY_FLAG_INF_NAN | + OL_DEVICE_FP_CAPABILITY_FLAG_DENORM | + OL_DEVICE_FP_CAPABILITY_FLAG_FMA; + return Info.write(flags); + } + + case OL_DEVICE_INFO_HALF_FP_CONFIG: + return Info.write<ol_device_fp_capability_flags_t>(0); + + case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_CHAR: + case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_SHORT: + case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_INT: + case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_LONG: + case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_FLOAT: + case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_DOUBLE: + return Info.write<uint32_t>(1); + + case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_HALF: + return Info.write<uint32_t>(0); + + // None of the existing plugins specify a limit on a single allocation, + // so return the global memory size instead + case OL_DEVICE_INFO_MAX_MEM_ALLOC_SIZE: + [[fallthrough]]; + // AMD doesn't provide the global memory size (trivially) with the device info + // struct, so use the plugin interface + case OL_DEVICE_INFO_GLOBAL_MEM_SIZE: { + uint64_t Mem; + if (auto Err = Device->Device->getDeviceMemorySize(Mem)) + return Err; + return Info.write<uint64_t>(Mem); + } break; + + default: + break; + } + if (PropName >= OL_DEVICE_INFO_LAST) return createOffloadError(ErrorCode::INVALID_ENUMERATION, "getDeviceInfo enum '%i' is invalid", PropName); @@ -314,8 +442,10 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device, "plugin did not provide a response for this information"); auto Entry = *EntryOpt; + // Retrieve properties from the plugin interface switch (PropName) { case OL_DEVICE_INFO_NAME: + case OL_DEVICE_INFO_PRODUCT_NAME: case OL_DEVICE_INFO_VENDOR: case OL_DEVICE_INFO_DRIVER_VERSION: { // String values @@ -325,7 +455,13 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device, return Info.writeString(std::get<std::string>(Entry->Value).c_str()); } - case OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE: { + case OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE: + case OL_DEVICE_INFO_MAX_WORK_SIZE: + case OL_DEVICE_INFO_VENDOR_ID: + case OL_DEVICE_INFO_NUM_COMPUTE_UNITS: + case OL_DEVICE_INFO_ADDRESS_BITS: + case OL_DEVICE_INFO_MAX_CLOCK_FREQUENCY: + case OL_DEVICE_INFO_MEMORY_CLOCK_RATE: { // Uint32 values if (!std::holds_alternative<uint64_t>(Entry->Value)) return makeError(ErrorCode::BACKEND_FAILURE, @@ -337,6 +473,7 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device, return Info.write(static_cast<uint32_t>(Value)); } + case OL_DEVICE_INFO_MAX_WORK_SIZE_PER_DIMENSION: case OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE_PER_DIMENSION: { // {x, y, z} triples ol_dimensions_t Out{0, 0, 0}; @@ -375,6 +512,8 @@ Error olGetDeviceInfoImplDetailHost(ol_device_handle_t Device, assert(Device == OffloadContext::get().HostDevice()); InfoWriter Info(PropSize, PropValue, PropSizeRet); + constexpr auto uint32_max = std::numeric_limits<uint32_t>::max(); + switch (PropName) { case OL_DEVICE_INFO_PLATFORM: return Info.write<void *>(Device->Platform); @@ -382,14 +521,52 @@ Error olGetDeviceInfoImplDetailHost(ol_device_handle_t Device, return Info.write<ol_device_type_t>(OL_DEVICE_TYPE_HOST); case OL_DEVICE_INFO_NAME: return Info.writeString("Virtual Host Device"); + case OL_DEVICE_INFO_PRODUCT_NAME: + return Info.writeString("Virtual Host Device"); case OL_DEVICE_INFO_VENDOR: return Info.writeString("Liboffload"); case OL_DEVICE_INFO_DRIVER_VERSION: return Info.writeString(LLVM_VERSION_STRING); case OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE: - return Info.write<uint64_t>(1); + return Info.write<uint32_t>(1); case OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE_PER_DIMENSION: return Info.write<ol_dimensions_t>(ol_dimensions_t{1, 1, 1}); + case OL_DEVICE_INFO_MAX_WORK_SIZE: + return Info.write<uint32_t>(uint32_max); + case OL_DEVICE_INFO_MAX_WORK_SIZE_PER_DIMENSION: + return Info.write<ol_dimensions_t>( + ol_dimensions_t{uint32_max, uint32_max, uint32_max}); + case OL_DEVICE_INFO_VENDOR_ID: + return Info.write<uint32_t>(0); + case OL_DEVICE_INFO_NUM_COMPUTE_UNITS: + return Info.write<uint32_t>(1); + case OL_DEVICE_INFO_SINGLE_FP_CONFIG: + case OL_DEVICE_INFO_DOUBLE_FP_CONFIG: + return Info.write<ol_device_fp_capability_flags_t>( + OL_DEVICE_FP_CAPABILITY_FLAG_CORRECTLY_ROUNDED_DIVIDE_SQRT | + OL_DEVICE_FP_CAPABILITY_FLAG_ROUND_TO_NEAREST | + OL_DEVICE_FP_CAPABILITY_FLAG_ROUND_TO_ZERO | + OL_DEVICE_FP_CAPABILITY_FLAG_ROUND_TO_INF | + OL_DEVICE_FP_CAPABILITY_FLAG_INF_NAN | + OL_DEVICE_FP_CAPABILITY_FLAG_DENORM | OL_DEVICE_FP_CAPABILITY_FLAG_FMA); + case OL_DEVICE_INFO_HALF_FP_CONFIG: + return Info.write<ol_device_fp_capability_flags_t>(0); + case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_CHAR: + case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_SHORT: + case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_INT: + case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_LONG: + case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_FLOAT: + case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_DOUBLE: + return Info.write<uint32_t>(1); + case OL_DEVICE_INFO_NATIVE_VECTOR_WIDTH_HALF: + return Info.write<uint32_t>(0); + case OL_DEVICE_INFO_MAX_CLOCK_FREQUENCY: + case OL_DEVICE_INFO_MEMORY_CLOCK_RATE: + case OL_DEVICE_INFO_ADDRESS_BITS: + return Info.write<uint32_t>(std::numeric_limits<uintptr_t>::digits); + case OL_DEVICE_INFO_MAX_MEM_ALLOC_SIZE: + case OL_DEVICE_INFO_GLOBAL_MEM_SIZE: + return Info.write<uint64_t>(0); default: return createOffloadError(ErrorCode::INVALID_ENUMERATION, "getDeviceInfo enum '%i' is invalid", PropName); @@ -418,7 +595,7 @@ Error olGetDeviceInfoSize_impl(ol_device_handle_t Device, Error olIterateDevices_impl(ol_device_iterate_cb_t Callback, void *UserData) { for (auto &Platform : OffloadContext::get().Platforms) { for (auto &Device : Platform.Devices) { - if (!Callback(&Device, UserData)) { + if (!Callback(Device.get(), UserData)) { break; } } @@ -447,54 +624,90 @@ Error olMemAlloc_impl(ol_device_handle_t Device, ol_alloc_type_t Type, return Alloc.takeError(); *AllocationOut = *Alloc; - OffloadContext::get().AllocInfoMap.insert_or_assign(*Alloc, - AllocInfo{Device, Type}); + { + std::lock_guard<std::mutex> Lock(OffloadContext::get().AllocInfoMapMutex); + OffloadContext::get().AllocInfoMap.insert_or_assign( + *Alloc, AllocInfo{Device, Type}); + } return Error::success(); } Error olMemFree_impl(void *Address) { - if (!OffloadContext::get().AllocInfoMap.contains(Address)) - return createOffloadError(ErrorCode::INVALID_ARGUMENT, - "address is not a known allocation"); - - auto AllocInfo = OffloadContext::get().AllocInfoMap.at(Address); - auto Device = AllocInfo.Device; - auto Type = AllocInfo.Type; + ol_device_handle_t Device; + ol_alloc_type_t Type; + { + std::lock_guard<std::mutex> Lock(OffloadContext::get().AllocInfoMapMutex); + if (!OffloadContext::get().AllocInfoMap.contains(Address)) + return createOffloadError(ErrorCode::INVALID_ARGUMENT, + "address is not a known allocation"); + + auto AllocInfo = OffloadContext::get().AllocInfoMap.at(Address); + Device = AllocInfo.Device; + Type = AllocInfo.Type; + OffloadContext::get().AllocInfoMap.erase(Address); + } if (auto Res = Device->Device->dataDelete(Address, convertOlToPluginAllocTy(Type))) return Res; - OffloadContext::get().AllocInfoMap.erase(Address); - return Error::success(); } Error olCreateQueue_impl(ol_device_handle_t Device, ol_queue_handle_t *Queue) { auto CreatedQueue = std::make_unique<ol_queue_impl_t>(nullptr, Device); - if (auto Err = Device->Device->initAsyncInfo(&(CreatedQueue->AsyncInfo))) + + auto OutstandingQueue = Device->getOutstandingQueue(); + if (OutstandingQueue) { + // The queue is empty, but we still need to sync it to release any temporary + // memory allocations or do other cleanup. + if (auto Err = + Device->Device->synchronize(OutstandingQueue, /*Release=*/false)) + return Err; + CreatedQueue->AsyncInfo = OutstandingQueue; + } else if (auto Err = + Device->Device->initAsyncInfo(&(CreatedQueue->AsyncInfo))) { return Err; + } *Queue = CreatedQueue.release(); return Error::success(); } -Error olDestroyQueue_impl(ol_queue_handle_t Queue) { return olDestroy(Queue); } +Error olDestroyQueue_impl(ol_queue_handle_t Queue) { + auto *Device = Queue->Device; + // This is safe; as soon as olDestroyQueue is called it is not possible to add + // any more work to the queue, so if it's finished now it will remain finished + // forever. + auto Res = Device->Device->hasPendingWork(Queue->AsyncInfo); + if (!Res) + return Res.takeError(); + + if (!*Res) { + // The queue is complete, so sync it and throw it back into the pool. + if (auto Err = Device->Device->synchronize(Queue->AsyncInfo, + /*Release=*/true)) + return Err; + } else { + // The queue still has outstanding work. Store it so we can check it later. + std::lock_guard<std::mutex> Lock(Device->OutstandingQueuesMutex); + Device->OutstandingQueues.push_back(Queue->AsyncInfo); + } + + return olDestroy(Queue); +} Error olSyncQueue_impl(ol_queue_handle_t Queue) { // Host plugin doesn't have a queue set so it's not safe to call synchronize // on it, but we have nothing to synchronize in that situation anyway. if (Queue->AsyncInfo->Queue) { - if (auto Err = Queue->Device->Device->synchronize(Queue->AsyncInfo)) + // We don't need to release the queue and we would like the ability for + // other offload threads to submit work concurrently, so pass "false" here + // so we don't release the underlying queue object. + if (auto Err = Queue->Device->Device->synchronize(Queue->AsyncInfo, false)) return Err; } - // Recreate the stream resource so the queue can be reused - // TODO: Would be easier for the synchronization to (optionally) not release - // it to begin with. - if (auto Res = Queue->Device->Device->initAsyncInfo(&Queue->AsyncInfo)) - return Res; - return Error::success(); } @@ -509,8 +722,8 @@ Error olWaitEvents_impl(ol_queue_handle_t Queue, ol_event_handle_t *Events, return Plugin::error(ErrorCode::INVALID_NULL_HANDLE, "olWaitEvents asked to wait on a NULL event"); - // Do nothing if the event is for this queue - if (Event->Queue == Queue) + // Do nothing if the event is for this queue or the event is always complete + if (Event->QueueId == Queue->Id || !Event->EventInfo) continue; if (auto Err = Device->waitEvent(Event->EventInfo, Queue->AsyncInfo)) @@ -528,6 +741,12 @@ Error olGetQueueInfoImplDetail(ol_queue_handle_t Queue, switch (PropName) { case OL_QUEUE_INFO_DEVICE: return Info.write<ol_device_handle_t>(Queue->Device); + case OL_QUEUE_INFO_EMPTY: { + auto Pending = Queue->Device->Device->hasPendingWork(Queue->AsyncInfo); + if (auto Err = Pending.takeError()) + return Err; + return Info.write<bool>(!*Pending); + } default: return createOffloadError(ErrorCode::INVALID_ENUMERATION, "olGetQueueInfo enum '%i' is invalid", PropName); @@ -548,15 +767,20 @@ Error olGetQueueInfoSize_impl(ol_queue_handle_t Queue, ol_queue_info_t PropName, } Error olSyncEvent_impl(ol_event_handle_t Event) { - if (auto Res = Event->Queue->Device->Device->syncEvent(Event->EventInfo)) + // No event info means that this event was complete on creation + if (!Event->EventInfo) + return Plugin::success(); + + if (auto Res = Event->Device->Device->syncEvent(Event->EventInfo)) return Res; return Error::success(); } Error olDestroyEvent_impl(ol_event_handle_t Event) { - if (auto Res = Event->Queue->Device->Device->destroyEvent(Event->EventInfo)) - return Res; + if (Event->EventInfo) + if (auto Res = Event->Device->Device->destroyEvent(Event->EventInfo)) + return Res; return olDestroy(Event); } @@ -565,10 +789,22 @@ Error olGetEventInfoImplDetail(ol_event_handle_t Event, ol_event_info_t PropName, size_t PropSize, void *PropValue, size_t *PropSizeRet) { InfoWriter Info(PropSize, PropValue, PropSizeRet); + auto Queue = Event->Queue; switch (PropName) { case OL_EVENT_INFO_QUEUE: - return Info.write<ol_queue_handle_t>(Event->Queue); + return Info.write<ol_queue_handle_t>(Queue); + case OL_EVENT_INFO_IS_COMPLETE: { + // No event info means that this event was complete on creation + if (!Event->EventInfo) + return Info.write<bool>(true); + + auto Res = Queue->Device->Device->isEventComplete(Event->EventInfo, + Queue->AsyncInfo); + if (auto Err = Res.takeError()) + return Err; + return Info.write<bool>(*Res); + } default: return createOffloadError(ErrorCode::INVALID_ENUMERATION, "olGetEventInfo enum '%i' is invalid", PropName); @@ -590,7 +826,16 @@ Error olGetEventInfoSize_impl(ol_event_handle_t Event, ol_event_info_t PropName, } Error olCreateEvent_impl(ol_queue_handle_t Queue, ol_event_handle_t *EventOut) { - *EventOut = new ol_event_impl_t(nullptr, Queue); + auto Pending = Queue->Device->Device->hasPendingWork(Queue->AsyncInfo); + if (auto Err = Pending.takeError()) + return Err; + + *EventOut = new ol_event_impl_t(nullptr, Queue->Device, Queue); + if (!*Pending) + // Queue is empty, don't record an event and consider the event always + // complete + return Plugin::success(); + if (auto Res = Queue->Device->Device->createEvent(&(*EventOut)->EventInfo)) return Res; @@ -637,6 +882,12 @@ Error olMemcpy_impl(ol_queue_handle_t Queue, void *DstPtr, return Error::success(); } +Error olMemFill_impl(ol_queue_handle_t Queue, void *Ptr, size_t PatternSize, + const void *PatternPtr, size_t FillSize) { + return Queue->Device->Device->dataFill(Ptr, PatternPtr, PatternSize, FillSize, + Queue->AsyncInfo); +} + Error olCreateProgram_impl(ol_device_handle_t Device, const void *ProgData, size_t ProgDataSize, ol_program_handle_t *Program) { // Make a copy of the program binary in case it is released by the caller. @@ -677,6 +928,24 @@ Error olDestroyProgram_impl(ol_program_handle_t Program) { return olDestroy(Program); } +Error olCalculateOptimalOccupancy_impl(ol_device_handle_t Device, + ol_symbol_handle_t Kernel, + size_t DynamicMemSize, + size_t *GroupSize) { + if (Kernel->Kind != OL_SYMBOL_KIND_KERNEL) + return createOffloadError(ErrorCode::SYMBOL_KIND, + "provided symbol is not a kernel"); + auto *KernelImpl = std::get<GenericKernelTy *>(Kernel->PluginImpl); + + auto Res = KernelImpl->maxGroupSize(*Device->Device, DynamicMemSize); + if (auto Err = Res.takeError()) + return Err; + + *GroupSize = *Res; + + return Error::success(); +} + Error olLaunchKernel_impl(ol_queue_handle_t Queue, ol_device_handle_t Device, ol_symbol_handle_t Kernel, const void *ArgumentsData, size_t ArgumentsSize, @@ -725,7 +994,7 @@ Error olGetSymbol_impl(ol_program_handle_t Program, const char *Name, ol_symbol_kind_t Kind, ol_symbol_handle_t *Symbol) { auto &Device = Program->Image->getDevice(); - std::lock_guard<std::mutex> Lock{Program->SymbolListMutex}; + std::lock_guard<std::mutex> Lock(Program->SymbolListMutex); switch (Kind) { case OL_SYMBOL_KIND_KERNEL: { @@ -746,7 +1015,7 @@ Error olGetSymbol_impl(ol_program_handle_t Program, const char *Name, return Error::success(); } case OL_SYMBOL_KIND_GLOBAL_VARIABLE: { - auto &Global = Program->KernelSymbols[Name]; + auto &Global = Program->GlobalSymbols[Name]; if (!Global) { GlobalTy GlobalObj{Name}; if (auto Res = @@ -814,5 +1083,12 @@ Error olGetSymbolInfoSize_impl(ol_symbol_handle_t Symbol, return olGetSymbolInfoImplDetail(Symbol, PropName, 0, nullptr, PropSizeRet); } +Error olLaunchHostFunction_impl(ol_queue_handle_t Queue, + ol_host_function_cb_t Callback, + void *UserData) { + return Queue->Device->Device->enqueueHostCall(Callback, UserData, + Queue->AsyncInfo); +} + } // namespace offload } // namespace llvm |