diff options
Diffstat (limited to 'python/qemu/qmp/protocol.py')
-rw-r--r-- | python/qemu/qmp/protocol.py | 1048 |
1 files changed, 1048 insertions, 0 deletions
diff --git a/python/qemu/qmp/protocol.py b/python/qemu/qmp/protocol.py new file mode 100644 index 0000000..6ea8665 --- /dev/null +++ b/python/qemu/qmp/protocol.py @@ -0,0 +1,1048 @@ +""" +Generic Asynchronous Message-based Protocol Support + +This module provides a generic framework for sending and receiving +messages over an asyncio stream. `AsyncProtocol` is an abstract class +that implements the core mechanisms of a simple send/receive protocol, +and is designed to be extended. + +In this package, it is used as the implementation for the `QMPClient` +class. +""" + +# It's all the docstrings ... ! It's long for a good reason ^_^; +# pylint: disable=too-many-lines + +import asyncio +from asyncio import StreamReader, StreamWriter +from enum import Enum +from functools import wraps +import logging +from ssl import SSLContext +from typing import ( + Any, + Awaitable, + Callable, + Generic, + List, + Optional, + Tuple, + TypeVar, + Union, + cast, +) + +from .error import QMPError +from .util import ( + bottom_half, + create_task, + exception_summary, + flush, + is_closing, + pretty_traceback, + upper_half, + wait_closed, +) + + +T = TypeVar('T') +_U = TypeVar('_U') +_TaskFN = Callable[[], Awaitable[None]] # aka ``async def func() -> None`` + +InternetAddrT = Tuple[str, int] +UnixAddrT = str +SocketAddrT = Union[UnixAddrT, InternetAddrT] + + +class Runstate(Enum): + """Protocol session runstate.""" + + #: Fully quiesced and disconnected. + IDLE = 0 + #: In the process of connecting or establishing a session. + CONNECTING = 1 + #: Fully connected and active session. + RUNNING = 2 + #: In the process of disconnecting. + #: Runstate may be returned to `IDLE` by calling `disconnect()`. + DISCONNECTING = 3 + + +class ConnectError(QMPError): + """ + Raised when the initial connection process has failed. + + This Exception always wraps a "root cause" exception that can be + interrogated for additional information. + + :param error_message: Human-readable string describing the error. + :param exc: The root-cause exception. + """ + def __init__(self, error_message: str, exc: Exception): + super().__init__(error_message) + #: Human-readable error string + self.error_message: str = error_message + #: Wrapped root cause exception + self.exc: Exception = exc + + def __str__(self) -> str: + cause = str(self.exc) + if not cause: + # If there's no error string, use the exception name. + cause = exception_summary(self.exc) + return f"{self.error_message}: {cause}" + + +class StateError(QMPError): + """ + An API command (connect, execute, etc) was issued at an inappropriate time. + + This error is raised when a command like + :py:meth:`~AsyncProtocol.connect()` is issued at an inappropriate + time. + + :param error_message: Human-readable string describing the state violation. + :param state: The actual `Runstate` seen at the time of the violation. + :param required: The `Runstate` required to process this command. + """ + def __init__(self, error_message: str, + state: Runstate, required: Runstate): + super().__init__(error_message) + self.error_message = error_message + self.state = state + self.required = required + + +F = TypeVar('F', bound=Callable[..., Any]) # pylint: disable=invalid-name + + +# Don't Panic. +def require(required_state: Runstate) -> Callable[[F], F]: + """ + Decorator: protect a method so it can only be run in a certain `Runstate`. + + :param required_state: The `Runstate` required to invoke this method. + :raise StateError: When the required `Runstate` is not met. + """ + def _decorator(func: F) -> F: + # _decorator is the decorator that is built by calling the + # require() decorator factory; e.g.: + # + # @require(Runstate.IDLE) def foo(): ... + # will replace 'foo' with the result of '_decorator(foo)'. + + @wraps(func) + def _wrapper(proto: 'AsyncProtocol[Any]', + *args: Any, **kwargs: Any) -> Any: + # _wrapper is the function that gets executed prior to the + # decorated method. + + name = type(proto).__name__ + + if proto.runstate != required_state: + if proto.runstate == Runstate.CONNECTING: + emsg = f"{name} is currently connecting." + elif proto.runstate == Runstate.DISCONNECTING: + emsg = (f"{name} is disconnecting." + " Call disconnect() to return to IDLE state.") + elif proto.runstate == Runstate.RUNNING: + emsg = f"{name} is already connected and running." + elif proto.runstate == Runstate.IDLE: + emsg = f"{name} is disconnected and idle." + else: + assert False + raise StateError(emsg, proto.runstate, required_state) + # No StateError, so call the wrapped method. + return func(proto, *args, **kwargs) + + # Return the decorated method; + # Transforming Func to Decorated[Func]. + return cast(F, _wrapper) + + # Return the decorator instance from the decorator factory. Phew! + return _decorator + + +class AsyncProtocol(Generic[T]): + """ + AsyncProtocol implements a generic async message-based protocol. + + This protocol assumes the basic unit of information transfer between + client and server is a "message", the details of which are left up + to the implementation. It assumes the sending and receiving of these + messages is full-duplex and not necessarily correlated; i.e. it + supports asynchronous inbound messages. + + It is designed to be extended by a specific protocol which provides + the implementations for how to read and send messages. These must be + defined in `_do_recv()` and `_do_send()`, respectively. + + Other callbacks have a default implementation, but are intended to be + either extended or overridden: + + - `_establish_session`: + The base implementation starts the reader/writer tasks. + A protocol implementation can override this call, inserting + actions to be taken prior to starting the reader/writer tasks + before the super() call; actions needing to occur afterwards + can be written after the super() call. + - `_on_message`: + Actions to be performed when a message is received. + - `_cb_outbound`: + Logging/Filtering hook for all outbound messages. + - `_cb_inbound`: + Logging/Filtering hook for all inbound messages. + This hook runs *before* `_on_message()`. + + :param name: + Name used for logging messages, if any. By default, messages + will log to 'qemu.qmp.protocol', but each individual connection + can be given its own logger by giving it a name; messages will + then log to 'qemu.qmp.protocol.${name}'. + """ + # pylint: disable=too-many-instance-attributes + + #: Logger object for debugging messages from this connection. + logger = logging.getLogger(__name__) + + # Maximum allowable size of read buffer + _limit = (64 * 1024) + + # ------------------------- + # Section: Public interface + # ------------------------- + + def __init__(self, name: Optional[str] = None) -> None: + #: The nickname for this connection, if any. + self.name: Optional[str] = name + if self.name is not None: + self.logger = self.logger.getChild(self.name) + + # stream I/O + self._reader: Optional[StreamReader] = None + self._writer: Optional[StreamWriter] = None + + # Outbound Message queue + self._outgoing: asyncio.Queue[T] + + # Special, long-running tasks: + self._reader_task: Optional[asyncio.Future[None]] = None + self._writer_task: Optional[asyncio.Future[None]] = None + + # Aggregate of the above two tasks, used for Exception management. + self._bh_tasks: Optional[asyncio.Future[Tuple[None, None]]] = None + + #: Disconnect task. The disconnect implementation runs in a task + #: so that asynchronous disconnects (initiated by the + #: reader/writer) are allowed to wait for the reader/writers to + #: exit. + self._dc_task: Optional[asyncio.Future[None]] = None + + self._runstate = Runstate.IDLE + self._runstate_changed: Optional[asyncio.Event] = None + + # Server state for start_server() and _incoming() + self._server: Optional[asyncio.AbstractServer] = None + self._accepted: Optional[asyncio.Event] = None + + def __repr__(self) -> str: + cls_name = type(self).__name__ + tokens = [] + if self.name is not None: + tokens.append(f"name={self.name!r}") + tokens.append(f"runstate={self.runstate.name}") + return f"<{cls_name} {' '.join(tokens)}>" + + @property # @upper_half + def runstate(self) -> Runstate: + """The current `Runstate` of the connection.""" + return self._runstate + + @upper_half + async def runstate_changed(self) -> Runstate: + """ + Wait for the `runstate` to change, then return that runstate. + """ + await self._runstate_event.wait() + return self.runstate + + @upper_half + @require(Runstate.IDLE) + async def start_server_and_accept( + self, address: SocketAddrT, + ssl: Optional[SSLContext] = None + ) -> None: + """ + Accept a connection and begin processing message queues. + + If this call fails, `runstate` is guaranteed to be set back to `IDLE`. + This method is precisely equivalent to calling `start_server()` + followed by `accept()`. + + :param address: + Address to listen on; UNIX socket path or TCP address/port. + :param ssl: SSL context to use, if any. + + :raise StateError: When the `Runstate` is not `IDLE`. + :raise ConnectError: + When a connection or session cannot be established. + + This exception will wrap a more concrete one. In most cases, + the wrapped exception will be `OSError` or `EOFError`. If a + protocol-level failure occurs while establishing a new + session, the wrapped error may also be an `QMPError`. + """ + await self.start_server(address, ssl) + await self.accept() + assert self.runstate == Runstate.RUNNING + + @upper_half + @require(Runstate.IDLE) + async def start_server(self, address: SocketAddrT, + ssl: Optional[SSLContext] = None) -> None: + """ + Start listening for an incoming connection, but do not wait for a peer. + + This method starts listening for an incoming connection, but + does not block waiting for a peer. This call will return + immediately after binding and listening on a socket. A later + call to `accept()` must be made in order to finalize the + incoming connection. + + :param address: + Address to listen on; UNIX socket path or TCP address/port. + :param ssl: SSL context to use, if any. + + :raise StateError: When the `Runstate` is not `IDLE`. + :raise ConnectError: + When the server could not start listening on this address. + + This exception will wrap a more concrete one. In most cases, + the wrapped exception will be `OSError`. + """ + await self._session_guard( + self._do_start_server(address, ssl), + 'Failed to establish connection') + assert self.runstate == Runstate.CONNECTING + + @upper_half + @require(Runstate.CONNECTING) + async def accept(self) -> None: + """ + Accept an incoming connection and begin processing message queues. + + If this call fails, `runstate` is guaranteed to be set back to `IDLE`. + + :raise StateError: When the `Runstate` is not `CONNECTING`. + :raise QMPError: When `start_server()` was not called yet. + :raise ConnectError: + When a connection or session cannot be established. + + This exception will wrap a more concrete one. In most cases, + the wrapped exception will be `OSError` or `EOFError`. If a + protocol-level failure occurs while establishing a new + session, the wrapped error may also be an `QMPError`. + """ + if self._accepted is None: + raise QMPError("Cannot call accept() before start_server().") + await self._session_guard( + self._do_accept(), + 'Failed to establish connection') + await self._session_guard( + self._establish_session(), + 'Failed to establish session') + assert self.runstate == Runstate.RUNNING + + @upper_half + @require(Runstate.IDLE) + async def connect(self, address: SocketAddrT, + ssl: Optional[SSLContext] = None) -> None: + """ + Connect to the server and begin processing message queues. + + If this call fails, `runstate` is guaranteed to be set back to `IDLE`. + + :param address: + Address to connect to; UNIX socket path or TCP address/port. + :param ssl: SSL context to use, if any. + + :raise StateError: When the `Runstate` is not `IDLE`. + :raise ConnectError: + When a connection or session cannot be established. + + This exception will wrap a more concrete one. In most cases, + the wrapped exception will be `OSError` or `EOFError`. If a + protocol-level failure occurs while establishing a new + session, the wrapped error may also be an `QMPError`. + """ + await self._session_guard( + self._do_connect(address, ssl), + 'Failed to establish connection') + await self._session_guard( + self._establish_session(), + 'Failed to establish session') + assert self.runstate == Runstate.RUNNING + + @upper_half + async def disconnect(self) -> None: + """ + Disconnect and wait for all tasks to fully stop. + + If there was an exception that caused the reader/writers to + terminate prematurely, it will be raised here. + + :raise Exception: When the reader or writer terminate unexpectedly. + """ + self.logger.debug("disconnect() called.") + self._schedule_disconnect() + await self._wait_disconnect() + + # -------------------------- + # Section: Session machinery + # -------------------------- + + async def _session_guard(self, coro: Awaitable[None], emsg: str) -> None: + """ + Async guard function used to roll back to `IDLE` on any error. + + On any Exception, the state machine will be reset back to + `IDLE`. Most Exceptions will be wrapped with `ConnectError`, but + `BaseException` events will be left alone (This includes + asyncio.CancelledError, even prior to Python 3.8). + + :param error_message: + Human-readable string describing what connection phase failed. + + :raise BaseException: + When `BaseException` occurs in the guarded block. + :raise ConnectError: + When any other error is encountered in the guarded block. + """ + # Note: After Python 3.6 support is removed, this should be an + # @asynccontextmanager instead of accepting a callback. + try: + await coro + except BaseException as err: + self.logger.error("%s: %s", emsg, exception_summary(err)) + self.logger.debug("%s:\n%s\n", emsg, pretty_traceback()) + try: + # Reset the runstate back to IDLE. + await self.disconnect() + except: + # We don't expect any Exceptions from the disconnect function + # here, because we failed to connect in the first place. + # The disconnect() function is intended to perform + # only cannot-fail cleanup here, but you never know. + emsg = ( + "Unexpected bottom half exception. " + "This is a bug in the QMP library. " + "Please report it to <qemu-devel@nongnu.org> and " + "CC: John Snow <jsnow@redhat.com>." + ) + self.logger.critical("%s:\n%s\n", emsg, pretty_traceback()) + raise + + # CancelledError is an Exception with special semantic meaning; + # We do NOT want to wrap it up under ConnectError. + # NB: CancelledError is not a BaseException before Python 3.8 + if isinstance(err, asyncio.CancelledError): + raise + + # Any other kind of error can be treated as some kind of connection + # failure broadly. Inspect the 'exc' field to explore the root + # cause in greater detail. + if isinstance(err, Exception): + raise ConnectError(emsg, err) from err + + # Raise BaseExceptions un-wrapped, they're more important. + raise + + @property + def _runstate_event(self) -> asyncio.Event: + # asyncio.Event() objects should not be created prior to entrance into + # an event loop, so we can ensure we create it in the correct context. + # Create it on-demand *only* at the behest of an 'async def' method. + if not self._runstate_changed: + self._runstate_changed = asyncio.Event() + return self._runstate_changed + + @upper_half + @bottom_half + def _set_state(self, state: Runstate) -> None: + """ + Change the `Runstate` of the protocol connection. + + Signals the `runstate_changed` event. + """ + if state == self._runstate: + return + + self.logger.debug("Transitioning from '%s' to '%s'.", + str(self._runstate), str(state)) + self._runstate = state + self._runstate_event.set() + self._runstate_event.clear() + + @bottom_half + async def _stop_server(self) -> None: + """ + Stop listening for / accepting new incoming connections. + """ + if self._server is None: + return + + try: + self.logger.debug("Stopping server.") + self._server.close() + await self._server.wait_closed() + self.logger.debug("Server stopped.") + finally: + self._server = None + + @bottom_half # However, it does not run from the R/W tasks. + async def _incoming(self, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter) -> None: + """ + Accept an incoming connection and signal the upper_half. + + This method does the minimum necessary to accept a single + incoming connection. It signals back to the upper_half ASAP so + that any errors during session initialization can occur + naturally in the caller's stack. + + :param reader: Incoming `asyncio.StreamReader` + :param writer: Incoming `asyncio.StreamWriter` + """ + peer = writer.get_extra_info('peername', 'Unknown peer') + self.logger.debug("Incoming connection from %s", peer) + + if self._reader or self._writer: + # Sadly, we can have more than one pending connection + # because of https://bugs.python.org/issue46715 + # Close any extra connections we don't actually want. + self.logger.warning("Extraneous connection inadvertently accepted") + writer.close() + return + + # A connection has been accepted; stop listening for new ones. + assert self._accepted is not None + await self._stop_server() + self._reader, self._writer = (reader, writer) + self._accepted.set() + + @upper_half + async def _do_start_server(self, address: SocketAddrT, + ssl: Optional[SSLContext] = None) -> None: + """ + Start listening for an incoming connection, but do not wait for a peer. + + This method starts listening for an incoming connection, but does not + block waiting for a peer. This call will return immediately after + binding and listening to a socket. A later call to accept() must be + made in order to finalize the incoming connection. + + :param address: + Address to listen on; UNIX socket path or TCP address/port. + :param ssl: SSL context to use, if any. + + :raise OSError: For stream-related errors. + """ + assert self.runstate == Runstate.IDLE + self._set_state(Runstate.CONNECTING) + + self.logger.debug("Awaiting connection on %s ...", address) + self._accepted = asyncio.Event() + + if isinstance(address, tuple): + coro = asyncio.start_server( + self._incoming, + host=address[0], + port=address[1], + ssl=ssl, + backlog=1, + limit=self._limit, + ) + else: + coro = asyncio.start_unix_server( + self._incoming, + path=address, + ssl=ssl, + backlog=1, + limit=self._limit, + ) + + # Allow runstate watchers to witness 'CONNECTING' state; some + # failures in the streaming layer are synchronous and will not + # otherwise yield. + await asyncio.sleep(0) + + # This will start the server (bind(2), listen(2)). It will also + # call accept(2) if we yield, but we don't block on that here. + self._server = await coro + self.logger.debug("Server listening on %s", address) + + @upper_half + async def _do_accept(self) -> None: + """ + Wait for and accept an incoming connection. + + Requires that we have not yet accepted an incoming connection + from the upper_half, but it's OK if the server is no longer + running because the bottom_half has already accepted the + connection. + """ + assert self._accepted is not None + await self._accepted.wait() + assert self._server is None + self._accepted = None + + self.logger.debug("Connection accepted.") + + @upper_half + async def _do_connect(self, address: SocketAddrT, + ssl: Optional[SSLContext] = None) -> None: + """ + Acting as the transport client, initiate a connection to a server. + + :param address: + Address to connect to; UNIX socket path or TCP address/port. + :param ssl: SSL context to use, if any. + + :raise OSError: For stream-related errors. + """ + assert self.runstate == Runstate.IDLE + self._set_state(Runstate.CONNECTING) + + # Allow runstate watchers to witness 'CONNECTING' state; some + # failures in the streaming layer are synchronous and will not + # otherwise yield. + await asyncio.sleep(0) + + self.logger.debug("Connecting to %s ...", address) + + if isinstance(address, tuple): + connect = asyncio.open_connection( + address[0], + address[1], + ssl=ssl, + limit=self._limit, + ) + else: + connect = asyncio.open_unix_connection( + path=address, + ssl=ssl, + limit=self._limit, + ) + self._reader, self._writer = await connect + + self.logger.debug("Connected.") + + @upper_half + async def _establish_session(self) -> None: + """ + Establish a new session. + + Starts the readers/writer tasks; subclasses may perform their + own negotiations here. The Runstate will be RUNNING upon + successful conclusion. + """ + assert self.runstate == Runstate.CONNECTING + + self._outgoing = asyncio.Queue() + + reader_coro = self._bh_loop_forever(self._bh_recv_message, 'Reader') + writer_coro = self._bh_loop_forever(self._bh_send_message, 'Writer') + + self._reader_task = create_task(reader_coro) + self._writer_task = create_task(writer_coro) + + self._bh_tasks = asyncio.gather( + self._reader_task, + self._writer_task, + ) + + self._set_state(Runstate.RUNNING) + await asyncio.sleep(0) # Allow runstate_event to process + + @upper_half + @bottom_half + def _schedule_disconnect(self) -> None: + """ + Initiate a disconnect; idempotent. + + This method is used both in the upper-half as a direct + consequence of `disconnect()`, and in the bottom-half in the + case of unhandled exceptions in the reader/writer tasks. + + It can be invoked no matter what the `runstate` is. + """ + if not self._dc_task: + self._set_state(Runstate.DISCONNECTING) + self.logger.debug("Scheduling disconnect.") + self._dc_task = create_task(self._bh_disconnect()) + + @upper_half + async def _wait_disconnect(self) -> None: + """ + Waits for a previously scheduled disconnect to finish. + + This method will gather any bottom half exceptions and re-raise + the one that occurred first; presuming it to be the root cause + of any subsequent Exceptions. It is intended to be used in the + upper half of the call chain. + + :raise Exception: + Arbitrary exception re-raised on behalf of the reader/writer. + """ + assert self.runstate == Runstate.DISCONNECTING + assert self._dc_task + + aws: List[Awaitable[object]] = [self._dc_task] + if self._bh_tasks: + aws.insert(0, self._bh_tasks) + all_defined_tasks = asyncio.gather(*aws) + + # Ensure disconnect is done; Exception (if any) is not raised here: + await asyncio.wait((self._dc_task,)) + + try: + await all_defined_tasks # Raise Exceptions from the bottom half. + finally: + self._cleanup() + self._set_state(Runstate.IDLE) + + @upper_half + def _cleanup(self) -> None: + """ + Fully reset this object to a clean state and return to `IDLE`. + """ + def _paranoid_task_erase(task: Optional['asyncio.Future[_U]'] + ) -> Optional['asyncio.Future[_U]']: + # Help to erase a task, ENSURING it is fully quiesced first. + assert (task is None) or task.done() + return None if (task and task.done()) else task + + assert self.runstate == Runstate.DISCONNECTING + self._dc_task = _paranoid_task_erase(self._dc_task) + self._reader_task = _paranoid_task_erase(self._reader_task) + self._writer_task = _paranoid_task_erase(self._writer_task) + self._bh_tasks = _paranoid_task_erase(self._bh_tasks) + + self._reader = None + self._writer = None + self._accepted = None + + # NB: _runstate_changed cannot be cleared because we still need it to + # send the final runstate changed event ...! + + # ---------------------------- + # Section: Bottom Half methods + # ---------------------------- + + @bottom_half + async def _bh_disconnect(self) -> None: + """ + Disconnect and cancel all outstanding tasks. + + It is designed to be called from its task context, + :py:obj:`~AsyncProtocol._dc_task`. By running in its own task, + it is free to wait on any pending actions that may still need to + occur in either the reader or writer tasks. + """ + assert self.runstate == Runstate.DISCONNECTING + + def _done(task: Optional['asyncio.Future[Any]']) -> bool: + return task is not None and task.done() + + # If the server is running, stop it. + await self._stop_server() + + # Are we already in an error pathway? If either of the tasks are + # already done, or if we have no tasks but a reader/writer; we + # must be. + # + # NB: We can't use _bh_tasks to check for premature task + # completion, because it may not yet have had a chance to run + # and gather itself. + tasks = tuple(filter(None, (self._writer_task, self._reader_task))) + error_pathway = _done(self._reader_task) or _done(self._writer_task) + if not tasks: + error_pathway |= bool(self._reader) or bool(self._writer) + + try: + # Try to flush the writer, if possible. + # This *may* cause an error and force us over into the error path. + if not error_pathway: + await self._bh_flush_writer() + except BaseException as err: + error_pathway = True + emsg = "Failed to flush the writer" + self.logger.error("%s: %s", emsg, exception_summary(err)) + self.logger.debug("%s:\n%s\n", emsg, pretty_traceback()) + raise + finally: + # Cancel any still-running tasks (Won't raise): + if self._writer_task is not None and not self._writer_task.done(): + self.logger.debug("Cancelling writer task.") + self._writer_task.cancel() + if self._reader_task is not None and not self._reader_task.done(): + self.logger.debug("Cancelling reader task.") + self._reader_task.cancel() + + # Close out the tasks entirely (Won't raise): + if tasks: + self.logger.debug("Waiting for tasks to complete ...") + await asyncio.wait(tasks) + + # Lastly, close the stream itself. (*May raise*!): + await self._bh_close_stream(error_pathway) + self.logger.debug("Disconnected.") + + @bottom_half + async def _bh_flush_writer(self) -> None: + if not self._writer_task: + return + + self.logger.debug("Draining the outbound queue ...") + await self._outgoing.join() + if self._writer is not None: + self.logger.debug("Flushing the StreamWriter ...") + await flush(self._writer) + + @bottom_half + async def _bh_close_stream(self, error_pathway: bool = False) -> None: + # NB: Closing the writer also implcitly closes the reader. + if not self._writer: + return + + if not is_closing(self._writer): + self.logger.debug("Closing StreamWriter.") + self._writer.close() + + self.logger.debug("Waiting for StreamWriter to close ...") + try: + await wait_closed(self._writer) + except Exception: # pylint: disable=broad-except + # It's hard to tell if the Stream is already closed or + # not. Even if one of the tasks has failed, it may have + # failed for a higher-layered protocol reason. The + # stream could still be open and perfectly fine. + # I don't know how to discern its health here. + + if error_pathway: + # We already know that *something* went wrong. Let's + # just trust that the Exception we already have is the + # better one to present to the user, even if we don't + # genuinely *know* the relationship between the two. + self.logger.debug( + "Discarding Exception from wait_closed:\n%s\n", + pretty_traceback(), + ) + else: + # Oops, this is a brand-new error! + raise + finally: + self.logger.debug("StreamWriter closed.") + + @bottom_half + async def _bh_loop_forever(self, async_fn: _TaskFN, name: str) -> None: + """ + Run one of the bottom-half methods in a loop forever. + + If the bottom half ever raises any exception, schedule a + disconnect that will terminate the entire loop. + + :param async_fn: The bottom-half method to run in a loop. + :param name: The name of this task, used for logging. + """ + try: + while True: + await async_fn() + except asyncio.CancelledError: + # We have been cancelled by _bh_disconnect, exit gracefully. + self.logger.debug("Task.%s: cancelled.", name) + return + except BaseException as err: + self.logger.log( + logging.INFO if isinstance(err, EOFError) else logging.ERROR, + "Task.%s: %s", + name, exception_summary(err) + ) + self.logger.debug("Task.%s: failure:\n%s\n", + name, pretty_traceback()) + self._schedule_disconnect() + raise + finally: + self.logger.debug("Task.%s: exiting.", name) + + @bottom_half + async def _bh_send_message(self) -> None: + """ + Wait for an outgoing message, then send it. + + Designed to be run in `_bh_loop_forever()`. + """ + msg = await self._outgoing.get() + try: + await self._send(msg) + finally: + self._outgoing.task_done() + + @bottom_half + async def _bh_recv_message(self) -> None: + """ + Wait for an incoming message and call `_on_message` to route it. + + Designed to be run in `_bh_loop_forever()`. + """ + msg = await self._recv() + await self._on_message(msg) + + # -------------------- + # Section: Message I/O + # -------------------- + + @upper_half + @bottom_half + def _cb_outbound(self, msg: T) -> T: + """ + Callback: outbound message hook. + + This is intended for subclasses to be able to add arbitrary + hooks to filter or manipulate outgoing messages. The base + implementation does nothing but log the message without any + manipulation of the message. + + :param msg: raw outbound message + :return: final outbound message + """ + self.logger.debug("--> %s", str(msg)) + return msg + + @upper_half + @bottom_half + def _cb_inbound(self, msg: T) -> T: + """ + Callback: inbound message hook. + + This is intended for subclasses to be able to add arbitrary + hooks to filter or manipulate incoming messages. The base + implementation does nothing but log the message without any + manipulation of the message. + + This method does not "handle" incoming messages; it is a filter. + The actual "endpoint" for incoming messages is `_on_message()`. + + :param msg: raw inbound message + :return: processed inbound message + """ + self.logger.debug("<-- %s", str(msg)) + return msg + + @upper_half + @bottom_half + async def _readline(self) -> bytes: + """ + Wait for a newline from the incoming reader. + + This method is provided as a convenience for upper-layer + protocols, as many are line-based. + + This method *may* return a sequence of bytes without a trailing + newline if EOF occurs, but *some* bytes were received. In this + case, the next call will raise `EOFError`. It is assumed that + the layer 5 protocol will decide if there is anything meaningful + to be done with a partial message. + + :raise OSError: For stream-related errors. + :raise EOFError: + If the reader stream is at EOF and there are no bytes to return. + :return: bytes, including the newline. + """ + assert self._reader is not None + msg_bytes = await self._reader.readline() + + if not msg_bytes: + if self._reader.at_eof(): + raise EOFError + + return msg_bytes + + @upper_half + @bottom_half + async def _do_recv(self) -> T: + """ + Abstract: Read from the stream and return a message. + + Very low-level; intended to only be called by `_recv()`. + """ + raise NotImplementedError + + @upper_half + @bottom_half + async def _recv(self) -> T: + """ + Read an arbitrary protocol message. + + .. warning:: + This method is intended primarily for `_bh_recv_message()` + to use in an asynchronous task loop. Using it outside of + this loop will "steal" messages from the normal routing + mechanism. It is safe to use prior to `_establish_session()`, + but should not be used otherwise. + + This method uses `_do_recv()` to retrieve the raw message, and + then transforms it using `_cb_inbound()`. + + :return: A single (filtered, processed) protocol message. + """ + message = await self._do_recv() + return self._cb_inbound(message) + + @upper_half + @bottom_half + def _do_send(self, msg: T) -> None: + """ + Abstract: Write a message to the stream. + + Very low-level; intended to only be called by `_send()`. + """ + raise NotImplementedError + + @upper_half + @bottom_half + async def _send(self, msg: T) -> None: + """ + Send an arbitrary protocol message. + + This method will transform any outgoing messages according to + `_cb_outbound()`. + + .. warning:: + Like `_recv()`, this method is intended to be called by + the writer task loop that processes outgoing + messages. Calling it directly may circumvent logic + implemented by the caller meant to correlate outgoing and + incoming messages. + + :raise OSError: For problems with the underlying stream. + """ + msg = self._cb_outbound(msg) + self._do_send(msg) + + @bottom_half + async def _on_message(self, msg: T) -> None: + """ + Called to handle the receipt of a new message. + + .. caution:: + This is executed from within the reader loop, so be advised + that waiting on either the reader or writer task will lead + to deadlock. Additionally, any unhandled exceptions will + directly cause the loop to halt, so logic may be best-kept + to a minimum if at all possible. + + :param msg: The incoming message, already logged/filtered. + """ + # Nothing to do in the abstract case. |