aboutsummaryrefslogtreecommitdiff
path: root/offload/plugins-nextgen/amdgpu/src
diff options
context:
space:
mode:
Diffstat (limited to 'offload/plugins-nextgen/amdgpu/src')
-rw-r--r--offload/plugins-nextgen/amdgpu/src/rtl.cpp74
1 files changed, 56 insertions, 18 deletions
diff --git a/offload/plugins-nextgen/amdgpu/src/rtl.cpp b/offload/plugins-nextgen/amdgpu/src/rtl.cpp
index a7723b8..0b03ef5 100644
--- a/offload/plugins-nextgen/amdgpu/src/rtl.cpp
+++ b/offload/plugins-nextgen/amdgpu/src/rtl.cpp
@@ -923,6 +923,10 @@ private:
/// devices. This class relies on signals to implement streams and define the
/// dependencies between asynchronous operations.
struct AMDGPUStreamTy {
+public:
+ /// Function pointer type for `pushHostCallback`
+ using HostFnType = void (*)(void *);
+
private:
/// Utility struct holding arguments for async H2H memory copies.
struct MemcpyArgsTy {
@@ -1084,18 +1088,19 @@ private:
/// Indicate to spread data transfers across all available SDMAs
bool UseMultipleSdmaEngines;
+ struct CallbackDataType {
+ HostFnType UserFn;
+ void *UserData;
+ AMDGPUSignalTy *OutputSignal;
+ };
/// Wrapper function for implementing host callbacks
- static void CallbackWrapper(AMDGPUSignalTy *InputSignal,
- AMDGPUSignalTy *OutputSignal,
- void (*Callback)(void *), void *UserData) {
- // The wait call will not error in this context.
- if (InputSignal)
- if (auto Err = InputSignal->wait())
- reportFatalInternalError(std::move(Err));
-
- Callback(UserData);
-
- OutputSignal->signal();
+ static bool callbackWrapper([[maybe_unused]] hsa_signal_value_t Signal,
+ void *UserData) {
+ auto CallbackData = reinterpret_cast<CallbackDataType *>(UserData);
+ CallbackData->UserFn(CallbackData->UserData);
+ CallbackData->OutputSignal->signal();
+ delete CallbackData;
+ return false;
}
/// Return the current number of asynchronous operations on the stream.
@@ -1540,7 +1545,7 @@ public:
OutputSignal->get());
}
- Error pushHostCallback(void (*Callback)(void *), void *UserData) {
+ Error pushHostCallback(HostFnType Callback, void *UserData) {
// Retrieve an available signal for the operation's output.
AMDGPUSignalTy *OutputSignal = nullptr;
if (auto Err = SignalManager.getResource(OutputSignal))
@@ -1556,12 +1561,21 @@ public:
InputSignal = consume(OutputSignal).second;
}
- // "Leaking" the thread here is consistent with other work added to the
- // queue. The input and output signals will remain valid until the output is
- // signaled.
- std::thread(CallbackWrapper, InputSignal, OutputSignal, Callback, UserData)
- .detach();
+ auto *CallbackData = new CallbackDataType{Callback, UserData, OutputSignal};
+ if (InputSignal && InputSignal->load()) {
+ hsa_status_t Status = hsa_amd_signal_async_handler(
+ InputSignal->get(), HSA_SIGNAL_CONDITION_EQ, 0, callbackWrapper,
+ CallbackData);
+ return Plugin::check(Status, "error in hsa_amd_signal_async_handler: %s");
+ }
+
+ // No dependencies - schedule it now.
+ // Using a seperate thread because this function should run asynchronously
+ // and not block the main thread.
+ std::thread([](void *CallbackData) { callbackWrapper(0, CallbackData); },
+ CallbackData)
+ .detach();
return Plugin::success();
}
@@ -2733,7 +2747,7 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
return Plugin::success();
}
- Error enqueueHostCallImpl(void (*Callback)(void *), void *UserData,
+ Error enqueueHostCallImpl(AMDGPUStreamTy::HostFnType Callback, void *UserData,
AsyncInfoWrapperTy &AsyncInfo) override {
AMDGPUStreamTy *Stream = nullptr;
if (auto Err = getStream(AsyncInfo, Stream))
@@ -3048,6 +3062,30 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
return ((IsAPU || OMPX_ApuMaps) && IsXnackEnabled);
}
+ Expected<bool> isAccessiblePtrImpl(const void *Ptr, size_t Size) override {
+ hsa_amd_pointer_info_t Info;
+ Info.size = sizeof(hsa_amd_pointer_info_t);
+
+ hsa_agent_t *Agents = nullptr;
+ uint32_t Count = 0;
+ hsa_status_t Status =
+ hsa_amd_pointer_info(Ptr, &Info, malloc, &Count, &Agents);
+
+ if (auto Err = Plugin::check(Status, "error in hsa_amd_pointer_info: %s"))
+ return std::move(Err);
+
+ // Checks if the pointer is known by HSA and accessible by the device
+ for (uint32_t i = 0; i < Count; i++) {
+ if (Agents[i].handle == getAgent().handle)
+ return Info.sizeInBytes >= Size;
+ }
+
+ // If the pointer is unknown to HSA it's assumed a host pointer
+ // in that case the device can access it on unified memory support is
+ // enabled
+ return IsXnackEnabled;
+ }
+
/// Getters and setters for stack and heap sizes.
Error getDeviceStackSize(uint64_t &Value) override {
Value = StackSize;