diff options
author | Joseph Huber <jhuber6@vols.utk.edu> | 2023-07-11 09:27:22 -0500 |
---|---|---|
committer | Joseph Huber <jhuber6@vols.utk.edu> | 2023-07-11 10:54:40 -0500 |
commit | 8a0763f19ca543a00bb9e2c3b35279d3eb6339ff (patch) | |
tree | 79d092e522226d803fc43f46369ac357305f3901 /openmp | |
parent | 14742f2a689c825adebc54cbade9c89fbe426da8 (diff) | |
download | llvm-8a0763f19ca543a00bb9e2c3b35279d3eb6339ff.zip llvm-8a0763f19ca543a00bb9e2c3b35279d3eb6339ff.tar.gz llvm-8a0763f19ca543a00bb9e2c3b35279d3eb6339ff.tar.bz2 |
[Libomptarget] Remove RPCHandleTy indirection
The 'RPCHandleTy' was intended to capture the intention that a specific
device owns its slot in the RPC server. However, this required creating
a temporary store to hold these pointers. This was causing really weird
spurious failure due to undefined behaviour in the order of library
teardown. For example, the x64 plugin would be torn down, set this to
some invalid memory, and then the CUDA plugin would crash. Rather than
spend the time to fully diagnose this problem I found it pertinent to
simply remove the failure mode.
This patch removes this indirection so now the usage of the RPC server
must always be done with the intended device. This just requires some
extra handling for the AMDGPU indirection where we need to store a
reference to the device.
Reviewed By: JonChesterfield
Differential Revision: https://reviews.llvm.org/D154971
Diffstat (limited to 'openmp')
7 files changed, 27 insertions, 70 deletions
diff --git a/openmp/libomptarget/plugins-nextgen/amdgpu/src/rtl.cpp b/openmp/libomptarget/plugins-nextgen/amdgpu/src/rtl.cpp index d56d3c9..fbe5012 100644 --- a/openmp/libomptarget/plugins-nextgen/amdgpu/src/rtl.cpp +++ b/openmp/libomptarget/plugins-nextgen/amdgpu/src/rtl.cpp @@ -520,9 +520,9 @@ struct AMDGPUSignalTy { } /// Wait until the signal gets a zero value. - Error wait(const uint64_t ActiveTimeout = 0, - RPCHandleTy *RPCHandle = nullptr) const { - if (ActiveTimeout && !RPCHandle) { + Error wait(const uint64_t ActiveTimeout = 0, RPCServerTy *RPCServer = nullptr, + GenericDeviceTy *Device = nullptr) const { + if (ActiveTimeout && !RPCServer) { hsa_signal_value_t Got = 1; Got = hsa_signal_wait_scacquire(Signal, HSA_SIGNAL_CONDITION_EQ, 0, ActiveTimeout, HSA_WAIT_STATE_ACTIVE); @@ -531,12 +531,12 @@ struct AMDGPUSignalTy { } // If there is an RPC device attached to this stream we run it as a server. - uint64_t Timeout = RPCHandle ? 8192 : UINT64_MAX; - auto WaitState = RPCHandle ? HSA_WAIT_STATE_ACTIVE : HSA_WAIT_STATE_BLOCKED; + uint64_t Timeout = RPCServer ? 8192 : UINT64_MAX; + auto WaitState = RPCServer ? HSA_WAIT_STATE_ACTIVE : HSA_WAIT_STATE_BLOCKED; while (hsa_signal_wait_scacquire(Signal, HSA_SIGNAL_CONDITION_EQ, 0, Timeout, WaitState) != 0) { - if (RPCHandle) - if (auto Err = RPCHandle->runServer()) + if (RPCServer && Device) + if (auto Err = RPCServer->runServer(*Device)) return Err; } return Plugin::success(); @@ -888,6 +888,9 @@ private: /// The manager of signals to reuse signals. AMDGPUSignalManagerTy &SignalManager; + /// A reference to the associated device. + GenericDeviceTy &Device; + /// Array of stream slots. Use std::deque because it can dynamically grow /// without invalidating the already inserted elements. For instance, the /// std::vector may invalidate the elements by reallocating the internal @@ -907,7 +910,7 @@ private: /// A pointer associated with an RPC server running on the given device. If /// RPC is not being used this will be a null pointer. Otherwise, this /// indicates that an RPC server is expected to be run on this stream. - RPCHandleTy *RPCHandle; + RPCServerTy *RPCServer; /// Mutex to protect stream's management. mutable std::mutex Mutex; @@ -1064,8 +1067,8 @@ public: /// Deinitialize the stream's signals. Error deinit() { return Plugin::success(); } - /// Attach an RPC handle to this stream. - void setRPCHandle(RPCHandleTy *Handle) { RPCHandle = Handle; } + /// Attach an RPC server to this stream. + void setRPCServer(RPCServerTy *Server) { RPCServer = Server; } /// Push a asynchronous kernel to the stream. The kernel arguments must be /// placed in a special allocation for kernel args and must keep alive until @@ -1281,8 +1284,8 @@ public: return Plugin::success(); // Wait until all previous operations on the stream have completed. - if (auto Err = - Slots[last()].Signal->wait(StreamBusyWaitMicroseconds, RPCHandle)) + if (auto Err = Slots[last()].Signal->wait(StreamBusyWaitMicroseconds, + RPCServer, &Device)) return Err; // Reset the stream and perform all pending post actions. @@ -2529,9 +2532,9 @@ Error AMDGPUResourceRef<ResourceTy>::create(GenericDeviceTy &Device) { AMDGPUStreamTy::AMDGPUStreamTy(AMDGPUDeviceTy &Device) : Agent(Device.getAgent()), Queue(Device.getNextQueue()), - SignalManager(Device.getSignalManager()), + SignalManager(Device.getSignalManager()), Device(Device), // Initialize the std::deque with some empty positions. - Slots(32), NextSlot(0), SyncCycle(0), RPCHandle(nullptr), + Slots(32), NextSlot(0), SyncCycle(0), RPCServer(nullptr), StreamBusyWaitMicroseconds(Device.getStreamBusyWaitMicroseconds()) {} /// Class implementing the AMDGPU-specific functionalities of the global @@ -2866,8 +2869,8 @@ Error AMDGPUKernelTy::launchImpl(GenericDeviceTy &GenericDevice, AMDGPUStreamTy &Stream = AMDGPUDevice.getStream(AsyncInfoWrapper); // If this kernel requires an RPC server we attach its pointer to the stream. - if (GenericDevice.getRPCHandle()) - Stream.setRPCHandle(GenericDevice.getRPCHandle()); + if (GenericDevice.getRPCServer()) + Stream.setRPCServer(GenericDevice.getRPCServer()); // Push the kernel launch into the stream. return Stream.pushKernelLaunch(*this, AllArgs, NumThreads, NumBlocks, diff --git a/openmp/libomptarget/plugins-nextgen/common/PluginInterface/CMakeLists.txt b/openmp/libomptarget/plugins-nextgen/common/PluginInterface/CMakeLists.txt index 1801b0e..deecda2 100644 --- a/openmp/libomptarget/plugins-nextgen/common/PluginInterface/CMakeLists.txt +++ b/openmp/libomptarget/plugins-nextgen/common/PluginInterface/CMakeLists.txt @@ -70,7 +70,6 @@ elseif(${LIBOMPTARGET_GPU_LIBC_SUPPORT}) find_library(llvmlibc_rpc_server NAMES llvmlibc_rpc_server PATHS ${LIBOMPTARGET_LLVM_LIBRARY_DIR} NO_DEFAULT_PATH) if(llvmlibc_rpc_server) - message(WARNING ${llvmlibc_rpc_server}) target_link_libraries(PluginInterface PRIVATE llvmlibc_rpc_server) target_compile_definitions(PluginInterface PRIVATE LIBOMPTARGET_RPC_SUPPORT) endif() diff --git a/openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.cpp b/openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.cpp index 79a968c..4426032 100644 --- a/openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.cpp +++ b/openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.cpp @@ -401,7 +401,7 @@ GenericDeviceTy::GenericDeviceTy(int32_t DeviceId, int32_t NumDevices, OMPX_InitialNumEvents("LIBOMPTARGET_NUM_INITIAL_EVENTS", 32), DeviceId(DeviceId), GridValues(OMPGridValues), PeerAccesses(NumDevices, PeerAccessState::PENDING), PeerAccessesLock(), - PinnedAllocs(*this), RPCHandle(nullptr) { + PinnedAllocs(*this), RPCServer(nullptr) { #ifdef OMPT_SUPPORT OmptInitialized.store(false); // Bind the callbacks to this device's member functions @@ -483,8 +483,8 @@ Error GenericDeviceTy::deinit() { if (RecordReplay.isRecordingOrReplaying()) RecordReplay.deinit(); - if (RPCHandle) - if (auto Err = RPCHandle->deinitDevice()) + if (RPCServer) + if (auto Err = RPCServer->deinitDevice(*this)) return Err; #ifdef OMPT_SUPPORT @@ -599,10 +599,7 @@ Error GenericDeviceTy::setupRPCServer(GenericPluginTy &Plugin, if (auto Err = Server.initDevice(*this, Plugin.getGlobalHandler(), Image)) return Err; - auto DeviceOrErr = Server.getDevice(*this); - if (!DeviceOrErr) - return DeviceOrErr.takeError(); - RPCHandle = *DeviceOrErr; + RPCServer = &Server; DP("Running an RPC server on device %d\n", getDeviceId()); return Plugin::success(); } diff --git a/openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.h b/openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.h index ea49234..923bf96 100644 --- a/openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.h +++ b/openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.h @@ -762,7 +762,7 @@ struct GenericDeviceTy : public DeviceAllocatorTy { } /// Get the RPC server running on this device. - RPCHandleTy *getRPCHandle() const { return RPCHandle; } + RPCServerTy *getRPCServer() const { return RPCServer; } private: /// Register offload entry for global variable. @@ -857,7 +857,7 @@ protected: /// A pointer to an RPC server instance attached to this device if present. /// This is used to run the RPC server during task synchronization. - RPCHandleTy *RPCHandle; + RPCServerTy *RPCServer; #ifdef OMPT_SUPPORT /// OMPT callback functions diff --git a/openmp/libomptarget/plugins-nextgen/common/PluginInterface/RPC.cpp b/openmp/libomptarget/plugins-nextgen/common/PluginInterface/RPC.cpp index 41a3745..5254829 100644 --- a/openmp/libomptarget/plugins-nextgen/common/PluginInterface/RPC.cpp +++ b/openmp/libomptarget/plugins-nextgen/common/PluginInterface/RPC.cpp @@ -28,7 +28,6 @@ RPCServerTy::RPCServerTy(uint32_t NumDevices) { // If this fails then something is catastrophically wrong, just exit. if (rpc_status_t Err = rpc_init(NumDevices)) FATAL_MESSAGE(1, "Error initializing the RPC server: %d\n", Err); - Handles.resize(NumDevices); #endif } @@ -118,28 +117,10 @@ Error RPCServerTy::initDevice(plugin::GenericDeviceTy &Device, if (auto Err = Device.dataSubmit(ClientPtr, ClientBuffer, rpc_get_client_size(), nullptr)) return Err; - - Handles[DeviceId] = std::make_unique<RPCHandleTy>(*this, Device); #endif return Error::success(); } -llvm::Expected<RPCHandleTy *> -RPCServerTy::getDevice(plugin::GenericDeviceTy &Device) { -#ifdef LIBOMPTARGET_RPC_SUPPORT - uint32_t DeviceId = Device.getDeviceId(); - if (!Handles[DeviceId] || !rpc_get_buffer(DeviceId) || - !rpc_get_client_buffer(DeviceId)) - return plugin::Plugin::error( - "Attempt to get an RPC device while not initialized"); - - return Handles[DeviceId].get(); -#else - return plugin::Plugin::error( - "Attempt to get an RPC device while not available"); -#endif -} - Error RPCServerTy::runServer(plugin::GenericDeviceTy &Device) { #ifdef LIBOMPTARGET_RPC_SUPPORT if (rpc_status_t Err = rpc_handle_server(Device.getDeviceId())) diff --git a/openmp/libomptarget/plugins-nextgen/common/PluginInterface/RPC.h b/openmp/libomptarget/plugins-nextgen/common/PluginInterface/RPC.h index f8ffbc6..e1ebcb2 100644 --- a/openmp/libomptarget/plugins-nextgen/common/PluginInterface/RPC.h +++ b/openmp/libomptarget/plugins-nextgen/common/PluginInterface/RPC.h @@ -32,21 +32,6 @@ class DeviceImageTy; /// these routines will perform no action. struct RPCServerTy { public: - /// A wrapper around a single instance of the RPC server for a given device. - /// This is provided to simplify ownership of the underlying device. - struct RPCHandleTy { - RPCHandleTy(RPCServerTy &Server, plugin::GenericDeviceTy &Device) - : Server(Server), Device(Device) {} - - llvm::Error runServer() { return Server.runServer(Device); } - - llvm::Error deinitDevice() { return Server.deinitDevice(Device); } - - private: - RPCServerTy &Server; - plugin::GenericDeviceTy &Device; - }; - RPCServerTy(uint32_t NumDevices); /// Check if this device image is using an RPC server. This checks for the @@ -63,9 +48,6 @@ public: plugin::GenericGlobalHandlerTy &Handler, plugin::DeviceImageTy &Image); - /// Gets a reference to this server for a specific device. - llvm::Expected<RPCHandleTy *> getDevice(plugin::GenericDeviceTy &Device); - /// Runs the RPC server associated with the \p Device until the pending work /// is cleared. llvm::Error runServer(plugin::GenericDeviceTy &Device); @@ -75,13 +57,8 @@ public: llvm::Error deinitDevice(plugin::GenericDeviceTy &Device); ~RPCServerTy(); - -private: - llvm::SmallVector<std::unique_ptr<RPCHandleTy>> Handles; }; -using RPCHandleTy = RPCServerTy::RPCHandleTy; - } // namespace llvm::omp::target #endif diff --git a/openmp/libomptarget/plugins-nextgen/cuda/src/rtl.cpp b/openmp/libomptarget/plugins-nextgen/cuda/src/rtl.cpp index 4ae9fae..f05fbb5 100644 --- a/openmp/libomptarget/plugins-nextgen/cuda/src/rtl.cpp +++ b/openmp/libomptarget/plugins-nextgen/cuda/src/rtl.cpp @@ -474,12 +474,12 @@ struct CUDADeviceTy : public GenericDeviceTy { CUresult Res; // If we have an RPC server running on this device we will continuously // query it for work rather than blocking. - if (!getRPCHandle()) { + if (!getRPCServer()) { Res = cuStreamSynchronize(Stream); } else { do { Res = cuStreamQuery(Stream); - if (auto Err = getRPCHandle()->runServer()) + if (auto Err = getRPCServer()->runServer(*this)) return Err; } while (Res == CUDA_ERROR_NOT_READY); } |