aboutsummaryrefslogtreecommitdiff
path: root/offload/plugins-nextgen/cuda
diff options
context:
space:
mode:
Diffstat (limited to 'offload/plugins-nextgen/cuda')
-rw-r--r--offload/plugins-nextgen/cuda/src/rtl.cpp15
1 files changed, 11 insertions, 4 deletions
diff --git a/offload/plugins-nextgen/cuda/src/rtl.cpp b/offload/plugins-nextgen/cuda/src/rtl.cpp
index 621c90e..d5ab0b3 100644
--- a/offload/plugins-nextgen/cuda/src/rtl.cpp
+++ b/offload/plugins-nextgen/cuda/src/rtl.cpp
@@ -793,7 +793,10 @@ struct CUDADeviceTy : public GenericDeviceTy {
}
/// Query for the completion of the pending operations on the async info.
- Error queryAsyncImpl(__tgt_async_info &AsyncInfo) override {
+ Error queryAsyncImpl(__tgt_async_info &AsyncInfo, bool ReleaseQueue,
+ bool *IsQueueWorkCompleted) override {
+ if (IsQueueWorkCompleted)
+ *IsQueueWorkCompleted = false;
CUstream Stream = reinterpret_cast<CUstream>(AsyncInfo.Queue);
CUresult Res = cuStreamQuery(Stream);
@@ -801,12 +804,16 @@ struct CUDADeviceTy : public GenericDeviceTy {
if (Res == CUDA_ERROR_NOT_READY)
return Plugin::success();
+ if (IsQueueWorkCompleted)
+ *IsQueueWorkCompleted = true;
// Once the stream is synchronized and the operations completed (or an error
// occurs), return it to stream pool and reset AsyncInfo. This is to make
// sure the synchronization only works for its own tasks.
- AsyncInfo.Queue = nullptr;
- if (auto Err = CUDAStreamManager.returnResource(Stream))
- return Err;
+ if (ReleaseQueue) {
+ AsyncInfo.Queue = nullptr;
+ if (auto Err = CUDAStreamManager.returnResource(Stream))
+ return Err;
+ }
return Plugin::check(Res, "error in cuStreamQuery: %s");
}