From 1569a37a54ecb63bd4008708c76339ccf7d06115 Mon Sep 17 00:00:00 2001 From: Mattias Nissler <122288598+mnissler-rivos@users.noreply.github.com> Date: Fri, 15 Sep 2023 12:33:37 +0200 Subject: Pass server->client command over a separate socket pair (#762) Use separate socket for server->client commands This change adds support for a separate socket to carry commands in the server-to-client direction. It has proven problematic to send commands in both directions over a single socket, since matching replies to commands can become non-trivial when both sides send commands at the same time and adds significant complexity. See issue #279 for details. To set up the reverse communication channel, the client indicates support for it via a new capability flag in the version message. The server will then create a fresh pair of sockets and pass one end to the client in its version reply. When the server wishes to send commands to the client at a later point, it now uses its end of the new socket pair rather than the main socket. Corresponding replies are also passed back over the new socket pair. Signed-off-by: Mattias Nissler --- test/py/libvfio_user.py | 66 ++++++++++++-- test/py/meson.build | 1 + test/py/test_sgl_read_write.py | 192 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 254 insertions(+), 5 deletions(-) create mode 100644 test/py/test_sgl_read_write.py (limited to 'test/py') diff --git a/test/py/libvfio_user.py b/test/py/libvfio_user.py index 86c8cbd..a701d1b 100644 --- a/test/py/libvfio_user.py +++ b/test/py/libvfio_user.py @@ -32,6 +32,7 @@ # from types import SimpleNamespace +import collections.abc import ctypes as c import array import errno @@ -484,6 +485,15 @@ class vfio_user_dma_unmap(Structure): ] +class vfio_user_dma_region_access(Structure): + """Payload for VFIO_USER_DMA_READ and VFIO_USER_DMA_WRITE.""" + _pack_ = 1 + _fields_ = [ + ("addr", c.c_uint64), + ("count", c.c_uint64), + ] + + class vfu_dma_info_t(Structure): _fields_ = [ ("iova", iovec_t), @@ -642,6 +652,10 @@ lib.vfu_sgl_get.argtypes = (c.c_void_p, c.POINTER(dma_sg_t), c.POINTER(iovec_t), c.c_size_t, c.c_int) lib.vfu_sgl_put.argtypes = (c.c_void_p, c.POINTER(dma_sg_t), c.POINTER(iovec_t), c.c_size_t) +lib.vfu_sgl_read.argtypes = (c.c_void_p, c.POINTER(dma_sg_t), c.c_size_t, + c.c_void_p) +lib.vfu_sgl_write.argtypes = (c.c_void_p, c.POINTER(dma_sg_t), c.c_size_t, + c.c_void_p) lib.vfu_create_ioeventfd.argtypes = (c.c_void_p, c.c_uint32, c.c_int, c.c_size_t, c.c_uint32, c.c_uint32, @@ -695,22 +709,52 @@ class Client: self.sock = sock self.client_cmd_socket = None - def connect(self, ctx): + def connect(self, ctx, capabilities={}): self.sock = connect_sock() - json = b'{ "capabilities": { "max_msg_fds": 8 } }' + client_caps = { + "capabilities": { + "max_data_xfer_size": VFIO_USER_DEFAULT_MAX_DATA_XFER_SIZE, + "max_msg_fds": 8, + }, + } + + def update(target, overrides): + for k, v in overrides.items(): + if isinstance(v, collections.abc.Mapping): + target[k] = target.get(k, {}) + update(target[k], v) + else: + target[k] = v + + update(client_caps, capabilities) + caps_json = json.dumps(client_caps) + # struct vfio_user_version - payload = struct.pack("HH%dsc" % len(json), LIBVFIO_USER_MAJOR, - LIBVFIO_USER_MINOR, json, b'\0') + payload = struct.pack("HH%dsc" % len(caps_json), LIBVFIO_USER_MAJOR, + LIBVFIO_USER_MINOR, caps_json.encode(), b'\0') hdr = vfio_user_header(VFIO_USER_VERSION, size=len(payload)) self.sock.send(hdr + payload) vfu_attach_ctx(ctx, expect=0) - payload = get_reply(self.sock, expect=0) + fds, payload = get_reply_fds(self.sock, expect=0) + + server_caps = json.loads(payload[struct.calcsize("HH"):-1].decode()) + try: + if (client_caps["capabilities"]["twin_socket"]["supported"] and + server_caps["capabilities"]["twin_socket"]["supported"]): + index = server_caps["capabilities"]["twin_socket"]["fd_index"] + self.client_cmd_socket = socket.socket(fileno=fds[index]) + except KeyError: + pass + return self.sock def disconnect(self, ctx): self.sock.close() self.sock = None + if self.client_cmd_socket is not None: + self.client_cmd_socket.close() + self.client_cmd_socket = None # notice client closed connection vfu_run_ctx(ctx, errno.ENOTCONN) @@ -1274,6 +1318,18 @@ def vfu_sgl_put(ctx, sg, iovec, cnt=1): return lib.vfu_sgl_put(ctx, sg, iovec, cnt) +def vfu_sgl_read(ctx, sg, cnt=1): + data = bytearray(sum([sge.length for sge in sg])) + buf = (c.c_byte * len(data)).from_buffer(data) + return lib.vfu_sgl_read(ctx, sg, cnt, buf), data + + +def vfu_sgl_write(ctx, sg, cnt=1, data=bytearray()): + assert len(data) == sum([sge.length for sge in sg]) + buf = (c.c_byte * len(data)).from_buffer(data) + return lib.vfu_sgl_write(ctx, sg, cnt, buf) + + def vfu_create_ioeventfd(ctx, region_idx, fd, gpa_offset, size, flags, datamatch, shadow_fd=-1, shadow_offset=0): assert ctx is not None diff --git a/test/py/meson.build b/test/py/meson.build index 0ea9f08..ecd2fe2 100644 --- a/test/py/meson.build +++ b/test/py/meson.build @@ -45,6 +45,7 @@ python_tests = [ 'test_request_errors.py', 'test_setup_region.py', 'test_sgl_get_put.py', + 'test_sgl_read_write.py', 'test_vfu_create_ctx.py', 'test_vfu_realize_ctx.py', ] diff --git a/test/py/test_sgl_read_write.py b/test/py/test_sgl_read_write.py new file mode 100644 index 0000000..2f4e992 --- /dev/null +++ b/test/py/test_sgl_read_write.py @@ -0,0 +1,192 @@ +# +# Copyright (c) 2023 Nutanix Inc. All rights reserved. +# Copyright (c) 2023 Rivos Inc. All rights reserved. +# +# Authors: Mattias Nissler +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of Nutanix nor the names of its contributors may be +# used to endorse or promote products derived from this software without +# specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL BE LIABLE FOR ANY +# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY +# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH +# DAMAGE. +# + +from libvfio_user import * +import select +import threading + +MAP_ADDR = 0x10000000 +MAP_SIZE = 16 << PAGE_SHIFT + +ctx = None +client = None + + +class DMARegionHandler: + """ + A helper to service DMA region accesses arriving over a socket. Accesses + are performed against an internal bytearray buffer. DMA request processing + takes place on a separate thread so as to not block the test code. + """ + + def __handle_requests(sock, pipe, buf, lock, addr, error_no): + while True: + (ready, _, _) = select.select([sock, pipe], [], []) + if pipe in ready: + break + + # Read a command from the socket and service it. + _, msg_id, cmd, payload = get_msg_fds(sock, + VFIO_USER_F_TYPE_COMMAND) + assert cmd in [VFIO_USER_DMA_READ, VFIO_USER_DMA_WRITE] + access, data = vfio_user_dma_region_access.pop_from_buffer(payload) + + assert access.addr >= addr + assert access.addr + access.count <= addr + len(buf) + + offset = access.addr - addr + with lock: + if cmd == VFIO_USER_DMA_READ: + data = buf[offset:offset + access.count] + else: + buf[offset:offset + access.count] = data + data = bytearray() + + send_msg(sock, + cmd, + VFIO_USER_F_TYPE_REPLY, + payload=payload[:c.sizeof(access)] + data, + msg_id=msg_id, + error_no=error_no) + + os.close(pipe) + sock.close() + + def __init__(self, sock, addr, size, error_no=0): + self.data = bytearray(size) + self.data_lock = threading.Lock() + self.addr = addr + (pipe_r, self.pipe_w) = os.pipe() + # Duplicate the socket file descriptor so the thread can own it and + # make sure it gets closed only when terminating the thread. + sock = socket.socket(fileno=os.dup(sock.fileno())) + thread = threading.Thread( + target=DMARegionHandler.__handle_requests, + args=[sock, pipe_r, self.data, self.data_lock, addr, error_no]) + thread.start() + + def shutdown(self): + # Closing the pipe's write end will signal the thread to terminate. + os.close(self.pipe_w) + + def read(self, addr, size): + offset = addr - self.addr + with self.data_lock: + return self.data[offset:offset + size] + + +def setup_function(function): + global ctx, client, dma_handler + ctx = prepare_ctx_for_dma() + assert ctx is not None + caps = { + "capabilities": { + "max_data_xfer_size": PAGE_SIZE, + "twin_socket": { + "supported": True, + }, + } + } + client = connect_client(ctx, caps) + assert client.client_cmd_socket is not None + + payload = vfio_user_dma_map(argsz=len(vfio_user_dma_map()), + flags=(VFIO_USER_F_DMA_REGION_READ + | VFIO_USER_F_DMA_REGION_WRITE), + offset=0, + addr=MAP_ADDR, + size=MAP_SIZE) + + msg(ctx, client.sock, VFIO_USER_DMA_MAP, payload) + + dma_handler = DMARegionHandler(client.client_cmd_socket, payload.addr, + payload.size) + + +def teardown_function(function): + dma_handler.shutdown() + client.disconnect(ctx) + vfu_destroy_ctx(ctx) + + +def test_dma_read_write(): + ret, sg = vfu_addr_to_sgl(ctx, + dma_addr=MAP_ADDR + 0x1000, + length=64, + max_nr_sgs=1, + prot=mmap.PROT_READ | mmap.PROT_WRITE) + assert ret == 1 + + data = bytearray([x & 0xff for x in range(0, sg[0].length)]) + assert vfu_sgl_write(ctx, sg, 1, data) == 0 + + assert vfu_sgl_read(ctx, sg, 1) == (0, data) + + assert dma_handler.read(sg[0].dma_addr + sg[0].offset, + sg[0].length) == data + + +def test_dma_read_write_large(): + ret, sg = vfu_addr_to_sgl(ctx, + dma_addr=MAP_ADDR + 0x1000, + length=2 * PAGE_SIZE, + max_nr_sgs=1, + prot=mmap.PROT_READ | mmap.PROT_WRITE) + assert ret == 1 + + data = bytearray([x & 0xff for x in range(0, sg[0].length)]) + assert vfu_sgl_write(ctx, sg, 1, data) == 0 + + assert vfu_sgl_read(ctx, sg, 1) == (0, data) + + assert dma_handler.read(sg[0].dma_addr + sg[0].offset, + sg[0].length) == data + + +def test_dma_read_write_error(): + # Reinitialize the handler to return EIO. + global dma_handler + dma_handler.shutdown() + dma_handler = DMARegionHandler(client.client_cmd_socket, MAP_ADDR, + MAP_SIZE, error_no=errno.EIO) + + ret, sg = vfu_addr_to_sgl(ctx, + dma_addr=MAP_ADDR + 0x1000, + length=64, + max_nr_sgs=1, + prot=mmap.PROT_READ | mmap.PROT_WRITE) + assert ret == 1 + + ret, _ = vfu_sgl_read(ctx, sg, 1) + assert ret == -1 + assert c.get_errno() == errno.EIO + + +# ex: set tabstop=4 shiftwidth=4 softtabstop=4 expandtab: # -- cgit v1.1