diff options
-rw-r--r-- | python/qemu/aqmp/protocol.py | 117 | ||||
-rw-r--r-- | python/tests/protocol.py | 10 |
2 files changed, 53 insertions, 74 deletions
diff --git a/python/qemu/aqmp/protocol.py b/python/qemu/aqmp/protocol.py index 7371925..b7e5e63 100644 --- a/python/qemu/aqmp/protocol.py +++ b/python/qemu/aqmp/protocol.py @@ -275,13 +275,25 @@ class AsyncProtocol(Generic[T]): If this call fails, `runstate` is guaranteed to be set back to `IDLE`. :param address: - Address to listen to; UNIX socket path or TCP address/port. + Address to listen on; UNIX socket path or TCP address/port. :param ssl: SSL context to use, if any. :raise StateError: When the `Runstate` is not `IDLE`. - :raise ConnectError: If a connection could not be accepted. + :raise ConnectError: + When a connection or session cannot be established. + + This exception will wrap a more concrete one. In most cases, + the wrapped exception will be `OSError` or `EOFError`. If a + protocol-level failure occurs while establishing a new + session, the wrapped error may also be an `QMPError`. """ - await self._new_session(address, ssl, accept=True) + await self._session_guard( + self._do_accept(address, ssl), + 'Failed to establish connection') + await self._session_guard( + self._establish_session(), + 'Failed to establish session') + assert self.runstate == Runstate.RUNNING @upper_half @require(Runstate.IDLE) @@ -297,9 +309,21 @@ class AsyncProtocol(Generic[T]): :param ssl: SSL context to use, if any. :raise StateError: When the `Runstate` is not `IDLE`. - :raise ConnectError: If a connection cannot be made to the server. + :raise ConnectError: + When a connection or session cannot be established. + + This exception will wrap a more concrete one. In most cases, + the wrapped exception will be `OSError` or `EOFError`. If a + protocol-level failure occurs while establishing a new + session, the wrapped error may also be an `QMPError`. """ - await self._new_session(address, ssl) + await self._session_guard( + self._do_connect(address, ssl), + 'Failed to establish connection') + await self._session_guard( + self._establish_session(), + 'Failed to establish session') + assert self.runstate == Runstate.RUNNING @upper_half async def disconnect(self) -> None: @@ -401,73 +425,6 @@ class AsyncProtocol(Generic[T]): self._runstate_event.set() self._runstate_event.clear() - @upper_half - async def _new_session(self, - address: SocketAddrT, - ssl: Optional[SSLContext] = None, - accept: bool = False) -> None: - """ - Establish a new connection and initialize the session. - - Connect or accept a new connection, then begin the protocol - session machinery. If this call fails, `runstate` is guaranteed - to be set back to `IDLE`. - - :param address: - Address to connect to/listen on; - UNIX socket path or TCP address/port. - :param ssl: SSL context to use, if any. - :param accept: Accept a connection instead of connecting when `True`. - - :raise ConnectError: - When a connection or session cannot be established. - - This exception will wrap a more concrete one. In most cases, - the wrapped exception will be `OSError` or `EOFError`. If a - protocol-level failure occurs while establishing a new - session, the wrapped error may also be an `QMPError`. - """ - assert self.runstate == Runstate.IDLE - - await self._session_guard( - self._establish_connection(address, ssl, accept), - 'Failed to establish connection') - - await self._session_guard( - self._establish_session(), - 'Failed to establish session') - - assert self.runstate == Runstate.RUNNING - - @upper_half - async def _establish_connection( - self, - address: SocketAddrT, - ssl: Optional[SSLContext] = None, - accept: bool = False - ) -> None: - """ - Establish a new connection. - - :param address: - Address to connect to/listen on; - UNIX socket path or TCP address/port. - :param ssl: SSL context to use, if any. - :param accept: Accept a connection instead of connecting when `True`. - """ - assert self.runstate == Runstate.IDLE - self._set_state(Runstate.CONNECTING) - - # Allow runstate watchers to witness 'CONNECTING' state; some - # failures in the streaming layer are synchronous and will not - # otherwise yield. - await asyncio.sleep(0) - - if accept: - await self._do_accept(address, ssl) - else: - await self._do_connect(address, ssl) - def _bind_hack(self, address: Union[str, Tuple[str, int]]) -> None: """ Used to create a socket in advance of accept(). @@ -508,6 +465,9 @@ class AsyncProtocol(Generic[T]): :raise OSError: For stream-related errors. """ + assert self.runstate == Runstate.IDLE + self._set_state(Runstate.CONNECTING) + self.logger.debug("Awaiting connection on %s ...", address) connected = asyncio.Event() server: Optional[asyncio.AbstractServer] = None @@ -550,6 +510,11 @@ class AsyncProtocol(Generic[T]): sock=self._sock, ) + # Allow runstate watchers to witness 'CONNECTING' state; some + # failures in the streaming layer are synchronous and will not + # otherwise yield. + await asyncio.sleep(0) + server = await coro # Starts listening await connected.wait() # Waits for the callback to fire (and finish) assert server is None @@ -569,6 +534,14 @@ class AsyncProtocol(Generic[T]): :raise OSError: For stream-related errors. """ + assert self.runstate == Runstate.IDLE + self._set_state(Runstate.CONNECTING) + + # Allow runstate watchers to witness 'CONNECTING' state; some + # failures in the streaming layer are synchronous and will not + # otherwise yield. + await asyncio.sleep(0) + self.logger.debug("Connecting to %s ...", address) if isinstance(address, tuple): diff --git a/python/tests/protocol.py b/python/tests/protocol.py index 354d655..8dd26c4 100644 --- a/python/tests/protocol.py +++ b/python/tests/protocol.py @@ -42,11 +42,17 @@ class NullProtocol(AsyncProtocol[None]): await super()._establish_session() async def _do_accept(self, address, ssl=None): - if not self.fake_session: + if self.fake_session: + self._set_state(Runstate.CONNECTING) + await asyncio.sleep(0) + else: await super()._do_accept(address, ssl) async def _do_connect(self, address, ssl=None): - if not self.fake_session: + if self.fake_session: + self._set_state(Runstate.CONNECTING) + await asyncio.sleep(0) + else: await super()._do_connect(address, ssl) async def _do_recv(self) -> None: |