aboutsummaryrefslogtreecommitdiff
path: root/offload/plugins-nextgen/cuda/src/rtl.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'offload/plugins-nextgen/cuda/src/rtl.cpp')
-rw-r--r--offload/plugins-nextgen/cuda/src/rtl.cpp32
1 files changed, 15 insertions, 17 deletions
diff --git a/offload/plugins-nextgen/cuda/src/rtl.cpp b/offload/plugins-nextgen/cuda/src/rtl.cpp
index 7649fd9..f3f3783 100644
--- a/offload/plugins-nextgen/cuda/src/rtl.cpp
+++ b/offload/plugins-nextgen/cuda/src/rtl.cpp
@@ -522,16 +522,11 @@ struct CUDADeviceTy : public GenericDeviceTy {
/// Get the stream of the asynchronous info structure or get a new one.
Error getStream(AsyncInfoWrapperTy &AsyncInfoWrapper, CUstream &Stream) {
- // Get the stream (if any) from the async info.
- Stream = AsyncInfoWrapper.getQueueAs<CUstream>();
- if (!Stream) {
- // There was no stream; get an idle one.
- if (auto Err = CUDAStreamManager.getResource(Stream))
- return Err;
-
- // Modify the async info's stream.
- AsyncInfoWrapper.setQueueAs<CUstream>(Stream);
- }
+ auto WrapperStream =
+ AsyncInfoWrapper.getOrInitQueue<CUstream>(CUDAStreamManager);
+ if (!WrapperStream)
+ return WrapperStream.takeError();
+ Stream = *WrapperStream;
return Plugin::success();
}
@@ -642,17 +637,20 @@ struct CUDADeviceTy : public GenericDeviceTy {
}
/// Synchronize current thread with the pending operations on the async info.
- Error synchronizeImpl(__tgt_async_info &AsyncInfo) override {
+ Error synchronizeImpl(__tgt_async_info &AsyncInfo,
+ bool ReleaseQueue) override {
CUstream Stream = reinterpret_cast<CUstream>(AsyncInfo.Queue);
CUresult Res;
Res = cuStreamSynchronize(Stream);
- // Once the stream is synchronized, return it to stream pool and reset
- // AsyncInfo. This is to make sure the synchronization only works for its
- // own tasks.
- AsyncInfo.Queue = nullptr;
- if (auto Err = CUDAStreamManager.returnResource(Stream))
- return Err;
+ // Once the stream is synchronized and we want to release the queue, return
+ // it to stream pool and reset AsyncInfo. This is to make sure the
+ // synchronization only works for its own tasks.
+ if (ReleaseQueue) {
+ AsyncInfo.Queue = nullptr;
+ if (auto Err = CUDAStreamManager.returnResource(Stream))
+ return Err;
+ }
return Plugin::check(Res, "error in cuStreamSynchronize: %s");
}