diff options
Diffstat (limited to 'offload/plugins-nextgen/cuda/src/rtl.cpp')
-rw-r--r-- | offload/plugins-nextgen/cuda/src/rtl.cpp | 32 |
1 files changed, 15 insertions, 17 deletions
diff --git a/offload/plugins-nextgen/cuda/src/rtl.cpp b/offload/plugins-nextgen/cuda/src/rtl.cpp index 7649fd9..f3f3783 100644 --- a/offload/plugins-nextgen/cuda/src/rtl.cpp +++ b/offload/plugins-nextgen/cuda/src/rtl.cpp @@ -522,16 +522,11 @@ struct CUDADeviceTy : public GenericDeviceTy { /// Get the stream of the asynchronous info structure or get a new one. Error getStream(AsyncInfoWrapperTy &AsyncInfoWrapper, CUstream &Stream) { - // Get the stream (if any) from the async info. - Stream = AsyncInfoWrapper.getQueueAs<CUstream>(); - if (!Stream) { - // There was no stream; get an idle one. - if (auto Err = CUDAStreamManager.getResource(Stream)) - return Err; - - // Modify the async info's stream. - AsyncInfoWrapper.setQueueAs<CUstream>(Stream); - } + auto WrapperStream = + AsyncInfoWrapper.getOrInitQueue<CUstream>(CUDAStreamManager); + if (!WrapperStream) + return WrapperStream.takeError(); + Stream = *WrapperStream; return Plugin::success(); } @@ -642,17 +637,20 @@ struct CUDADeviceTy : public GenericDeviceTy { } /// Synchronize current thread with the pending operations on the async info. - Error synchronizeImpl(__tgt_async_info &AsyncInfo) override { + Error synchronizeImpl(__tgt_async_info &AsyncInfo, + bool ReleaseQueue) override { CUstream Stream = reinterpret_cast<CUstream>(AsyncInfo.Queue); CUresult Res; Res = cuStreamSynchronize(Stream); - // Once the stream is synchronized, return it to stream pool and reset - // AsyncInfo. This is to make sure the synchronization only works for its - // own tasks. - AsyncInfo.Queue = nullptr; - if (auto Err = CUDAStreamManager.returnResource(Stream)) - return Err; + // Once the stream is synchronized and we want to release the queue, return + // it to stream pool and reset AsyncInfo. This is to make sure the + // synchronization only works for its own tasks. + if (ReleaseQueue) { + AsyncInfo.Queue = nullptr; + if (auto Err = CUDAStreamManager.returnResource(Stream)) + return Err; + } return Plugin::check(Res, "error in cuStreamSynchronize: %s"); } |