aboutsummaryrefslogtreecommitdiff
path: root/offload/plugins-nextgen
diff options
context:
space:
mode:
Diffstat (limited to 'offload/plugins-nextgen')
-rw-r--r--offload/plugins-nextgen/amdgpu/src/rtl.cpp14
-rw-r--r--offload/plugins-nextgen/common/include/PluginInterface.h6
-rw-r--r--offload/plugins-nextgen/common/src/PluginInterface.cpp23
-rw-r--r--offload/plugins-nextgen/cuda/src/rtl.cpp15
-rw-r--r--offload/plugins-nextgen/host/src/rtl.cpp5
-rw-r--r--offload/plugins-nextgen/level_zero/include/L0Device.h3
-rw-r--r--offload/plugins-nextgen/level_zero/src/L0Device.cpp21
7 files changed, 61 insertions, 26 deletions
diff --git a/offload/plugins-nextgen/amdgpu/src/rtl.cpp b/offload/plugins-nextgen/amdgpu/src/rtl.cpp
index 287bb14..379c8ec 100644
--- a/offload/plugins-nextgen/amdgpu/src/rtl.cpp
+++ b/offload/plugins-nextgen/amdgpu/src/rtl.cpp
@@ -2431,7 +2431,10 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
}
/// Query for the completion of the pending operations on the async info.
- Error queryAsyncImpl(__tgt_async_info &AsyncInfo) override {
+ Error queryAsyncImpl(__tgt_async_info &AsyncInfo, bool ReleaseQueue,
+ bool *IsQueueWorkCompleted) override {
+ if (IsQueueWorkCompleted)
+ *IsQueueWorkCompleted = false;
AMDGPUStreamTy *Stream =
reinterpret_cast<AMDGPUStreamTy *>(AsyncInfo.Queue);
assert(Stream && "Invalid stream");
@@ -2444,11 +2447,16 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
if (!(*CompletedOrErr))
return Plugin::success();
+ if (IsQueueWorkCompleted)
+ *IsQueueWorkCompleted = true;
// Once the stream is completed, return it to stream pool and reset
// AsyncInfo. This is to make sure the synchronization only works for its
// own tasks.
- AsyncInfo.Queue = nullptr;
- return AMDGPUStreamManager.returnResource(Stream);
+ if (ReleaseQueue) {
+ AsyncInfo.Queue = nullptr;
+ return AMDGPUStreamManager.returnResource(Stream);
+ }
+ return Plugin::success();
}
/// Pin the host buffer and return the device pointer that should be used for
diff --git a/offload/plugins-nextgen/common/include/PluginInterface.h b/offload/plugins-nextgen/common/include/PluginInterface.h
index caf86a9..19db44c 100644
--- a/offload/plugins-nextgen/common/include/PluginInterface.h
+++ b/offload/plugins-nextgen/common/include/PluginInterface.h
@@ -854,8 +854,10 @@ struct GenericDeviceTy : public DeviceAllocatorTy {
/// Query for the completion of the pending operations on the __tgt_async_info
/// structure in a non-blocking manner.
- Error queryAsync(__tgt_async_info *AsyncInfo);
- virtual Error queryAsyncImpl(__tgt_async_info &AsyncInfo) = 0;
+ Error queryAsync(__tgt_async_info *AsyncInfo, bool ReleaseQueue = true,
+ bool *IsQueueWorkCompleted = nullptr);
+ virtual Error queryAsyncImpl(__tgt_async_info &AsyncInfo, bool ReleaseQueue,
+ bool *IsQueueWorkCompleted) = 0;
/// Check whether the architecture supports VA management
virtual bool supportVAManagement() const { return false; }
diff --git a/offload/plugins-nextgen/common/src/PluginInterface.cpp b/offload/plugins-nextgen/common/src/PluginInterface.cpp
index 4ec8366..807df0f 100644
--- a/offload/plugins-nextgen/common/src/PluginInterface.cpp
+++ b/offload/plugins-nextgen/common/src/PluginInterface.cpp
@@ -849,7 +849,8 @@ Error GenericDeviceTy::deinit(GenericPluginTy &Plugin) {
}
Expected<DeviceImageTy *> GenericDeviceTy::loadBinary(GenericPluginTy &Plugin,
StringRef InputTgtImage) {
- ODBG(OLDT_Init) << "Load data from image " << InputTgtImage.bytes_begin();
+ ODBG(OLDT_Init) << "Load data from image "
+ << static_cast<const void *>(InputTgtImage.bytes_begin());
std::unique_ptr<MemoryBuffer> Buffer;
if (identify_magic(InputTgtImage) == file_magic::bitcode) {
@@ -1198,12 +1199,14 @@ Error GenericDeviceTy::synchronize(__tgt_async_info *AsyncInfo,
return Plugin::success();
}
-Error GenericDeviceTy::queryAsync(__tgt_async_info *AsyncInfo) {
+Error GenericDeviceTy::queryAsync(__tgt_async_info *AsyncInfo,
+ bool ReleaseQueue,
+ bool *IsQueueWorkCompleted) {
if (!AsyncInfo || !AsyncInfo->Queue)
return Plugin::error(ErrorCode::INVALID_ARGUMENT,
"invalid async info queue");
- return queryAsyncImpl(*AsyncInfo);
+ return queryAsyncImpl(*AsyncInfo, ReleaseQueue, IsQueueWorkCompleted);
}
Error GenericDeviceTy::memoryVAMap(void **Addr, void *VAddr, size_t *RSize) {
@@ -1656,9 +1659,10 @@ int32_t GenericPluginTy::is_initialized() const { return Initialized; }
int32_t GenericPluginTy::isPluginCompatible(StringRef Image) {
auto HandleError = [&](Error Err) -> bool {
- [[maybe_unused]] std::string ErrStr = toString(std::move(Err));
- ODBG(OLDT_Init) << "Failure to check validity of image " << Image.data()
- << ": " << ErrStr;
+ std::string ErrStr = toString(std::move(Err));
+ ODBG(OLDT_Init) << "Failure to check validity of image "
+ << static_cast<const void *>(Image.data()) << ": "
+ << ErrStr;
return false;
};
switch (identify_magic(Image)) {
@@ -1685,8 +1689,9 @@ int32_t GenericPluginTy::isPluginCompatible(StringRef Image) {
int32_t GenericPluginTy::isDeviceCompatible(int32_t DeviceId, StringRef Image) {
auto HandleError = [&](Error Err) -> bool {
- [[maybe_unused]] std::string ErrStr = toString(std::move(Err));
- ODBG(OLDT_Init) << "Failure to check validity of image " << Image << ": "
+ std::string ErrStr = toString(std::move(Err));
+ ODBG(OLDT_Init) << "Failure to check validity of image "
+ << static_cast<const void *>(Image.data()) << ": "
<< ErrStr;
return false;
};
@@ -2069,7 +2074,7 @@ int32_t GenericPluginTy::use_auto_zero_copy(int32_t DeviceId) {
int32_t GenericPluginTy::is_accessible_ptr(int32_t DeviceId, const void *Ptr,
size_t Size) {
auto HandleError = [&](Error Err) -> bool {
- [[maybe_unused]] std::string ErrStr = toString(std::move(Err));
+ std::string ErrStr = toString(std::move(Err));
ODBG(OLDT_Device) << "Failure while checking accessibility of pointer "
<< Ptr << " for device " << DeviceId << ": " << ErrStr;
return false;
diff --git a/offload/plugins-nextgen/cuda/src/rtl.cpp b/offload/plugins-nextgen/cuda/src/rtl.cpp
index 621c90e..d5ab0b3 100644
--- a/offload/plugins-nextgen/cuda/src/rtl.cpp
+++ b/offload/plugins-nextgen/cuda/src/rtl.cpp
@@ -793,7 +793,10 @@ struct CUDADeviceTy : public GenericDeviceTy {
}
/// Query for the completion of the pending operations on the async info.
- Error queryAsyncImpl(__tgt_async_info &AsyncInfo) override {
+ Error queryAsyncImpl(__tgt_async_info &AsyncInfo, bool ReleaseQueue,
+ bool *IsQueueWorkCompleted) override {
+ if (IsQueueWorkCompleted)
+ *IsQueueWorkCompleted = false;
CUstream Stream = reinterpret_cast<CUstream>(AsyncInfo.Queue);
CUresult Res = cuStreamQuery(Stream);
@@ -801,12 +804,16 @@ struct CUDADeviceTy : public GenericDeviceTy {
if (Res == CUDA_ERROR_NOT_READY)
return Plugin::success();
+ if (IsQueueWorkCompleted)
+ *IsQueueWorkCompleted = true;
// Once the stream is synchronized and the operations completed (or an error
// occurs), return it to stream pool and reset AsyncInfo. This is to make
// sure the synchronization only works for its own tasks.
- AsyncInfo.Queue = nullptr;
- if (auto Err = CUDAStreamManager.returnResource(Stream))
- return Err;
+ if (ReleaseQueue) {
+ AsyncInfo.Queue = nullptr;
+ if (auto Err = CUDAStreamManager.returnResource(Stream))
+ return Err;
+ }
return Plugin::check(Res, "error in cuStreamQuery: %s");
}
diff --git a/offload/plugins-nextgen/host/src/rtl.cpp b/offload/plugins-nextgen/host/src/rtl.cpp
index 81fbb67..6033796 100644
--- a/offload/plugins-nextgen/host/src/rtl.cpp
+++ b/offload/plugins-nextgen/host/src/rtl.cpp
@@ -336,7 +336,10 @@ struct GenELF64DeviceTy : public GenericDeviceTy {
/// All functions are already synchronous. No need to do anything on this
/// query function.
- Error queryAsyncImpl(__tgt_async_info &AsyncInfo) override {
+ Error queryAsyncImpl(__tgt_async_info &AsyncInfo, bool ReleaseQueue,
+ bool *IsQueueWorkCompleted) override {
+ if (IsQueueWorkCompleted)
+ *IsQueueWorkCompleted = true;
return Plugin::success();
}
diff --git a/offload/plugins-nextgen/level_zero/include/L0Device.h b/offload/plugins-nextgen/level_zero/include/L0Device.h
index d14e710..001a41b 100644
--- a/offload/plugins-nextgen/level_zero/include/L0Device.h
+++ b/offload/plugins-nextgen/level_zero/include/L0Device.h
@@ -576,7 +576,8 @@ public:
AsyncInfoWrapperTy &AsyncInfoWrapper) override;
Error synchronizeImpl(__tgt_async_info &AsyncInfo,
bool ReleaseQueue) override;
- Error queryAsyncImpl(__tgt_async_info &AsyncInfo) override;
+ Error queryAsyncImpl(__tgt_async_info &AsyncInfo, bool ReleaseQueue,
+ bool *IsQueueWorkCompleted) override;
Error dataSubmitImpl(void *TgtPtr, const void *HstPtr, int64_t Size,
AsyncInfoWrapperTy &AsyncInfoWrapper) override;
Error dataRetrieveImpl(void *HstPtr, const void *TgtPtr, int64_t Size,
diff --git a/offload/plugins-nextgen/level_zero/src/L0Device.cpp b/offload/plugins-nextgen/level_zero/src/L0Device.cpp
index 2cae1e4..4db3c4e 100644
--- a/offload/plugins-nextgen/level_zero/src/L0Device.cpp
+++ b/offload/plugins-nextgen/level_zero/src/L0Device.cpp
@@ -192,8 +192,7 @@ Error L0DeviceTy::initImpl(GenericPluginTy &Plugin) {
CALL_ZE_RET_ERROR(zeDeviceGetCacheProperties, zeDevice, &Count,
&CacheProperties);
- DeviceName =
- std::string(DeviceProperties.name, sizeof(DeviceProperties.name));
+ DeviceName = std::string(DeviceProperties.name);
ODBG(OLDT_Device) << "Found a GPU device, Name = " << DeviceProperties.name;
@@ -356,10 +355,15 @@ L0DeviceTy::hasPendingWorkImpl(AsyncInfoWrapperTy &AsyncInfoWrapper) {
return true;
}
-Error L0DeviceTy::queryAsyncImpl(__tgt_async_info &AsyncInfo) {
+Error L0DeviceTy::queryAsyncImpl(__tgt_async_info &AsyncInfo, bool ReleaseQueue,
+ bool *IsQueueWorkCompleted) {
+ if (IsQueueWorkCompleted)
+ *IsQueueWorkCompleted = true;
const bool IsAsync = AsyncInfo.Queue && asyncEnabled();
if (!IsAsync)
return Plugin::success();
+ if (IsQueueWorkCompleted)
+ *IsQueueWorkCompleted = false;
auto &Plugin = getPlugin();
auto *AsyncQueue = static_cast<AsyncQueueTy *>(AsyncInfo.Queue);
@@ -367,6 +371,9 @@ Error L0DeviceTy::queryAsyncImpl(__tgt_async_info &AsyncInfo) {
if (!AsyncQueue->WaitEvents.empty())
return Plugin::success();
+ if (IsQueueWorkCompleted)
+ *IsQueueWorkCompleted = true;
+
// Commit delayed USM2M copies.
for (auto &USM2M : AsyncQueue->USM2MList) {
std::copy_n(static_cast<const char *>(std::get<0>(USM2M)),
@@ -377,9 +384,11 @@ Error L0DeviceTy::queryAsyncImpl(__tgt_async_info &AsyncInfo) {
std::copy_n(static_cast<char *>(std::get<0>(H2M)), std::get<2>(H2M),
static_cast<char *>(std::get<1>(H2M)));
}
- Plugin.releaseAsyncQueue(AsyncQueue);
- getStagingBuffer().reset();
- AsyncInfo.Queue = nullptr;
+ if (ReleaseQueue) {
+ Plugin.releaseAsyncQueue(AsyncQueue);
+ getStagingBuffer().reset();
+ AsyncInfo.Queue = nullptr;
+ }
return Plugin::success();
}