diff options
Diffstat (limited to 'offload/plugins-nextgen/amdgpu/src/rtl.cpp')
-rw-r--r-- | offload/plugins-nextgen/amdgpu/src/rtl.cpp | 50 |
1 files changed, 32 insertions, 18 deletions
diff --git a/offload/plugins-nextgen/amdgpu/src/rtl.cpp b/offload/plugins-nextgen/amdgpu/src/rtl.cpp index a7723b8..20d16fa 100644 --- a/offload/plugins-nextgen/amdgpu/src/rtl.cpp +++ b/offload/plugins-nextgen/amdgpu/src/rtl.cpp @@ -923,6 +923,10 @@ private: /// devices. This class relies on signals to implement streams and define the /// dependencies between asynchronous operations. struct AMDGPUStreamTy { +public: + /// Function pointer type for `pushHostCallback` + using HostFnType = void (*)(void *); + private: /// Utility struct holding arguments for async H2H memory copies. struct MemcpyArgsTy { @@ -1084,18 +1088,19 @@ private: /// Indicate to spread data transfers across all available SDMAs bool UseMultipleSdmaEngines; + struct CallbackDataType { + HostFnType UserFn; + void *UserData; + AMDGPUSignalTy *OutputSignal; + }; /// Wrapper function for implementing host callbacks - static void CallbackWrapper(AMDGPUSignalTy *InputSignal, - AMDGPUSignalTy *OutputSignal, - void (*Callback)(void *), void *UserData) { - // The wait call will not error in this context. - if (InputSignal) - if (auto Err = InputSignal->wait()) - reportFatalInternalError(std::move(Err)); - - Callback(UserData); - - OutputSignal->signal(); + static bool callbackWrapper([[maybe_unused]] hsa_signal_value_t Signal, + void *UserData) { + auto CallbackData = reinterpret_cast<CallbackDataType *>(UserData); + CallbackData->UserFn(CallbackData->UserData); + CallbackData->OutputSignal->signal(); + delete CallbackData; + return false; } /// Return the current number of asynchronous operations on the stream. @@ -1540,7 +1545,7 @@ public: OutputSignal->get()); } - Error pushHostCallback(void (*Callback)(void *), void *UserData) { + Error pushHostCallback(HostFnType Callback, void *UserData) { // Retrieve an available signal for the operation's output. AMDGPUSignalTy *OutputSignal = nullptr; if (auto Err = SignalManager.getResource(OutputSignal)) @@ -1556,12 +1561,21 @@ public: InputSignal = consume(OutputSignal).second; } - // "Leaking" the thread here is consistent with other work added to the - // queue. The input and output signals will remain valid until the output is - // signaled. - std::thread(CallbackWrapper, InputSignal, OutputSignal, Callback, UserData) - .detach(); + auto *CallbackData = new CallbackDataType{Callback, UserData, OutputSignal}; + if (InputSignal && InputSignal->load()) { + hsa_status_t Status = hsa_amd_signal_async_handler( + InputSignal->get(), HSA_SIGNAL_CONDITION_EQ, 0, callbackWrapper, + CallbackData); + return Plugin::check(Status, "error in hsa_amd_signal_async_handler: %s"); + } + + // No dependencies - schedule it now. + // Using a seperate thread because this function should run asynchronously + // and not block the main thread. + std::thread([](void *CallbackData) { callbackWrapper(0, CallbackData); }, + CallbackData) + .detach(); return Plugin::success(); } @@ -2733,7 +2747,7 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy { return Plugin::success(); } - Error enqueueHostCallImpl(void (*Callback)(void *), void *UserData, + Error enqueueHostCallImpl(AMDGPUStreamTy::HostFnType Callback, void *UserData, AsyncInfoWrapperTy &AsyncInfo) override { AMDGPUStreamTy *Stream = nullptr; if (auto Err = getStream(AsyncInfo, Stream)) |