diff options
Diffstat (limited to 'offload/plugins-nextgen')
| -rw-r--r-- | offload/plugins-nextgen/amdgpu/src/rtl.cpp | 14 | ||||
| -rw-r--r-- | offload/plugins-nextgen/common/include/PluginInterface.h | 6 | ||||
| -rw-r--r-- | offload/plugins-nextgen/common/src/PluginInterface.cpp | 23 | ||||
| -rw-r--r-- | offload/plugins-nextgen/cuda/src/rtl.cpp | 15 | ||||
| -rw-r--r-- | offload/plugins-nextgen/host/src/rtl.cpp | 5 | ||||
| -rw-r--r-- | offload/plugins-nextgen/level_zero/include/L0Device.h | 3 | ||||
| -rw-r--r-- | offload/plugins-nextgen/level_zero/src/L0Device.cpp | 21 |
7 files changed, 61 insertions, 26 deletions
diff --git a/offload/plugins-nextgen/amdgpu/src/rtl.cpp b/offload/plugins-nextgen/amdgpu/src/rtl.cpp index 287bb14..379c8ec 100644 --- a/offload/plugins-nextgen/amdgpu/src/rtl.cpp +++ b/offload/plugins-nextgen/amdgpu/src/rtl.cpp @@ -2431,7 +2431,10 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy { } /// Query for the completion of the pending operations on the async info. - Error queryAsyncImpl(__tgt_async_info &AsyncInfo) override { + Error queryAsyncImpl(__tgt_async_info &AsyncInfo, bool ReleaseQueue, + bool *IsQueueWorkCompleted) override { + if (IsQueueWorkCompleted) + *IsQueueWorkCompleted = false; AMDGPUStreamTy *Stream = reinterpret_cast<AMDGPUStreamTy *>(AsyncInfo.Queue); assert(Stream && "Invalid stream"); @@ -2444,11 +2447,16 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy { if (!(*CompletedOrErr)) return Plugin::success(); + if (IsQueueWorkCompleted) + *IsQueueWorkCompleted = true; // Once the stream is completed, return it to stream pool and reset // AsyncInfo. This is to make sure the synchronization only works for its // own tasks. - AsyncInfo.Queue = nullptr; - return AMDGPUStreamManager.returnResource(Stream); + if (ReleaseQueue) { + AsyncInfo.Queue = nullptr; + return AMDGPUStreamManager.returnResource(Stream); + } + return Plugin::success(); } /// Pin the host buffer and return the device pointer that should be used for diff --git a/offload/plugins-nextgen/common/include/PluginInterface.h b/offload/plugins-nextgen/common/include/PluginInterface.h index caf86a9..19db44c 100644 --- a/offload/plugins-nextgen/common/include/PluginInterface.h +++ b/offload/plugins-nextgen/common/include/PluginInterface.h @@ -854,8 +854,10 @@ struct GenericDeviceTy : public DeviceAllocatorTy { /// Query for the completion of the pending operations on the __tgt_async_info /// structure in a non-blocking manner. - Error queryAsync(__tgt_async_info *AsyncInfo); - virtual Error queryAsyncImpl(__tgt_async_info &AsyncInfo) = 0; + Error queryAsync(__tgt_async_info *AsyncInfo, bool ReleaseQueue = true, + bool *IsQueueWorkCompleted = nullptr); + virtual Error queryAsyncImpl(__tgt_async_info &AsyncInfo, bool ReleaseQueue, + bool *IsQueueWorkCompleted) = 0; /// Check whether the architecture supports VA management virtual bool supportVAManagement() const { return false; } diff --git a/offload/plugins-nextgen/common/src/PluginInterface.cpp b/offload/plugins-nextgen/common/src/PluginInterface.cpp index 4ec8366..807df0f 100644 --- a/offload/plugins-nextgen/common/src/PluginInterface.cpp +++ b/offload/plugins-nextgen/common/src/PluginInterface.cpp @@ -849,7 +849,8 @@ Error GenericDeviceTy::deinit(GenericPluginTy &Plugin) { } Expected<DeviceImageTy *> GenericDeviceTy::loadBinary(GenericPluginTy &Plugin, StringRef InputTgtImage) { - ODBG(OLDT_Init) << "Load data from image " << InputTgtImage.bytes_begin(); + ODBG(OLDT_Init) << "Load data from image " + << static_cast<const void *>(InputTgtImage.bytes_begin()); std::unique_ptr<MemoryBuffer> Buffer; if (identify_magic(InputTgtImage) == file_magic::bitcode) { @@ -1198,12 +1199,14 @@ Error GenericDeviceTy::synchronize(__tgt_async_info *AsyncInfo, return Plugin::success(); } -Error GenericDeviceTy::queryAsync(__tgt_async_info *AsyncInfo) { +Error GenericDeviceTy::queryAsync(__tgt_async_info *AsyncInfo, + bool ReleaseQueue, + bool *IsQueueWorkCompleted) { if (!AsyncInfo || !AsyncInfo->Queue) return Plugin::error(ErrorCode::INVALID_ARGUMENT, "invalid async info queue"); - return queryAsyncImpl(*AsyncInfo); + return queryAsyncImpl(*AsyncInfo, ReleaseQueue, IsQueueWorkCompleted); } Error GenericDeviceTy::memoryVAMap(void **Addr, void *VAddr, size_t *RSize) { @@ -1656,9 +1659,10 @@ int32_t GenericPluginTy::is_initialized() const { return Initialized; } int32_t GenericPluginTy::isPluginCompatible(StringRef Image) { auto HandleError = [&](Error Err) -> bool { - [[maybe_unused]] std::string ErrStr = toString(std::move(Err)); - ODBG(OLDT_Init) << "Failure to check validity of image " << Image.data() - << ": " << ErrStr; + std::string ErrStr = toString(std::move(Err)); + ODBG(OLDT_Init) << "Failure to check validity of image " + << static_cast<const void *>(Image.data()) << ": " + << ErrStr; return false; }; switch (identify_magic(Image)) { @@ -1685,8 +1689,9 @@ int32_t GenericPluginTy::isPluginCompatible(StringRef Image) { int32_t GenericPluginTy::isDeviceCompatible(int32_t DeviceId, StringRef Image) { auto HandleError = [&](Error Err) -> bool { - [[maybe_unused]] std::string ErrStr = toString(std::move(Err)); - ODBG(OLDT_Init) << "Failure to check validity of image " << Image << ": " + std::string ErrStr = toString(std::move(Err)); + ODBG(OLDT_Init) << "Failure to check validity of image " + << static_cast<const void *>(Image.data()) << ": " << ErrStr; return false; }; @@ -2069,7 +2074,7 @@ int32_t GenericPluginTy::use_auto_zero_copy(int32_t DeviceId) { int32_t GenericPluginTy::is_accessible_ptr(int32_t DeviceId, const void *Ptr, size_t Size) { auto HandleError = [&](Error Err) -> bool { - [[maybe_unused]] std::string ErrStr = toString(std::move(Err)); + std::string ErrStr = toString(std::move(Err)); ODBG(OLDT_Device) << "Failure while checking accessibility of pointer " << Ptr << " for device " << DeviceId << ": " << ErrStr; return false; diff --git a/offload/plugins-nextgen/cuda/src/rtl.cpp b/offload/plugins-nextgen/cuda/src/rtl.cpp index 621c90e..d5ab0b3 100644 --- a/offload/plugins-nextgen/cuda/src/rtl.cpp +++ b/offload/plugins-nextgen/cuda/src/rtl.cpp @@ -793,7 +793,10 @@ struct CUDADeviceTy : public GenericDeviceTy { } /// Query for the completion of the pending operations on the async info. - Error queryAsyncImpl(__tgt_async_info &AsyncInfo) override { + Error queryAsyncImpl(__tgt_async_info &AsyncInfo, bool ReleaseQueue, + bool *IsQueueWorkCompleted) override { + if (IsQueueWorkCompleted) + *IsQueueWorkCompleted = false; CUstream Stream = reinterpret_cast<CUstream>(AsyncInfo.Queue); CUresult Res = cuStreamQuery(Stream); @@ -801,12 +804,16 @@ struct CUDADeviceTy : public GenericDeviceTy { if (Res == CUDA_ERROR_NOT_READY) return Plugin::success(); + if (IsQueueWorkCompleted) + *IsQueueWorkCompleted = true; // Once the stream is synchronized and the operations completed (or an error // occurs), return it to stream pool and reset AsyncInfo. This is to make // sure the synchronization only works for its own tasks. - AsyncInfo.Queue = nullptr; - if (auto Err = CUDAStreamManager.returnResource(Stream)) - return Err; + if (ReleaseQueue) { + AsyncInfo.Queue = nullptr; + if (auto Err = CUDAStreamManager.returnResource(Stream)) + return Err; + } return Plugin::check(Res, "error in cuStreamQuery: %s"); } diff --git a/offload/plugins-nextgen/host/src/rtl.cpp b/offload/plugins-nextgen/host/src/rtl.cpp index 81fbb67..6033796 100644 --- a/offload/plugins-nextgen/host/src/rtl.cpp +++ b/offload/plugins-nextgen/host/src/rtl.cpp @@ -336,7 +336,10 @@ struct GenELF64DeviceTy : public GenericDeviceTy { /// All functions are already synchronous. No need to do anything on this /// query function. - Error queryAsyncImpl(__tgt_async_info &AsyncInfo) override { + Error queryAsyncImpl(__tgt_async_info &AsyncInfo, bool ReleaseQueue, + bool *IsQueueWorkCompleted) override { + if (IsQueueWorkCompleted) + *IsQueueWorkCompleted = true; return Plugin::success(); } diff --git a/offload/plugins-nextgen/level_zero/include/L0Device.h b/offload/plugins-nextgen/level_zero/include/L0Device.h index d14e710..001a41b 100644 --- a/offload/plugins-nextgen/level_zero/include/L0Device.h +++ b/offload/plugins-nextgen/level_zero/include/L0Device.h @@ -576,7 +576,8 @@ public: AsyncInfoWrapperTy &AsyncInfoWrapper) override; Error synchronizeImpl(__tgt_async_info &AsyncInfo, bool ReleaseQueue) override; - Error queryAsyncImpl(__tgt_async_info &AsyncInfo) override; + Error queryAsyncImpl(__tgt_async_info &AsyncInfo, bool ReleaseQueue, + bool *IsQueueWorkCompleted) override; Error dataSubmitImpl(void *TgtPtr, const void *HstPtr, int64_t Size, AsyncInfoWrapperTy &AsyncInfoWrapper) override; Error dataRetrieveImpl(void *HstPtr, const void *TgtPtr, int64_t Size, diff --git a/offload/plugins-nextgen/level_zero/src/L0Device.cpp b/offload/plugins-nextgen/level_zero/src/L0Device.cpp index 2cae1e4..4db3c4e 100644 --- a/offload/plugins-nextgen/level_zero/src/L0Device.cpp +++ b/offload/plugins-nextgen/level_zero/src/L0Device.cpp @@ -192,8 +192,7 @@ Error L0DeviceTy::initImpl(GenericPluginTy &Plugin) { CALL_ZE_RET_ERROR(zeDeviceGetCacheProperties, zeDevice, &Count, &CacheProperties); - DeviceName = - std::string(DeviceProperties.name, sizeof(DeviceProperties.name)); + DeviceName = std::string(DeviceProperties.name); ODBG(OLDT_Device) << "Found a GPU device, Name = " << DeviceProperties.name; @@ -356,10 +355,15 @@ L0DeviceTy::hasPendingWorkImpl(AsyncInfoWrapperTy &AsyncInfoWrapper) { return true; } -Error L0DeviceTy::queryAsyncImpl(__tgt_async_info &AsyncInfo) { +Error L0DeviceTy::queryAsyncImpl(__tgt_async_info &AsyncInfo, bool ReleaseQueue, + bool *IsQueueWorkCompleted) { + if (IsQueueWorkCompleted) + *IsQueueWorkCompleted = true; const bool IsAsync = AsyncInfo.Queue && asyncEnabled(); if (!IsAsync) return Plugin::success(); + if (IsQueueWorkCompleted) + *IsQueueWorkCompleted = false; auto &Plugin = getPlugin(); auto *AsyncQueue = static_cast<AsyncQueueTy *>(AsyncInfo.Queue); @@ -367,6 +371,9 @@ Error L0DeviceTy::queryAsyncImpl(__tgt_async_info &AsyncInfo) { if (!AsyncQueue->WaitEvents.empty()) return Plugin::success(); + if (IsQueueWorkCompleted) + *IsQueueWorkCompleted = true; + // Commit delayed USM2M copies. for (auto &USM2M : AsyncQueue->USM2MList) { std::copy_n(static_cast<const char *>(std::get<0>(USM2M)), @@ -377,9 +384,11 @@ Error L0DeviceTy::queryAsyncImpl(__tgt_async_info &AsyncInfo) { std::copy_n(static_cast<char *>(std::get<0>(H2M)), std::get<2>(H2M), static_cast<char *>(std::get<1>(H2M))); } - Plugin.releaseAsyncQueue(AsyncQueue); - getStagingBuffer().reset(); - AsyncInfo.Queue = nullptr; + if (ReleaseQueue) { + Plugin.releaseAsyncQueue(AsyncQueue); + getStagingBuffer().reset(); + AsyncInfo.Queue = nullptr; + } return Plugin::success(); } |
