diff options
-rw-r--r-- | include/libvfio-user.h | 2 | ||||
-rw-r--r-- | lib/tran.c | 95 | ||||
-rw-r--r-- | lib/tran.h | 5 | ||||
-rw-r--r-- | lib/tran_pipe.c | 2 | ||||
-rw-r--r-- | lib/tran_sock.c | 31 | ||||
-rw-r--r-- | samples/client.c | 2 | ||||
-rw-r--r-- | test/py/libvfio_user.py | 66 | ||||
-rw-r--r-- | test/py/meson.build | 1 | ||||
-rw-r--r-- | test/py/test_sgl_read_write.py | 192 |
9 files changed, 373 insertions, 23 deletions
diff --git a/include/libvfio-user.h b/include/libvfio-user.h index 72369b6..21cb99a 100644 --- a/include/libvfio-user.h +++ b/include/libvfio-user.h @@ -61,7 +61,7 @@ extern "C" { #endif #define LIB_VFIO_USER_MAJOR 0 -#define LIB_VFIO_USER_MINOR 1 +#define LIB_VFIO_USER_MINOR 2 /* DMA addresses cannot be directly de-referenced. */ typedef void *vfu_dma_addr_t; @@ -36,6 +36,7 @@ #include <stdlib.h> #include <stdio.h> #include <string.h> +#include <sys/socket.h> #include <json.h> @@ -52,9 +53,13 @@ * { * "capabilities": { * "max_msg_fds": 32, - * "max_data_xfer_size": 1048576 + * "max_data_xfer_size": 1048576, * "migration": { * "pgsize": 4096 + * }, + * "twin_socket": { + * "supported": true, + * "fd_index": 0 * } * } * } @@ -64,7 +69,8 @@ */ int tran_parse_version_json(const char *json_str, int *client_max_fdsp, - size_t *client_max_data_xfer_sizep, size_t *pgsizep) + size_t *client_max_data_xfer_sizep, size_t *pgsizep, + bool *twin_socket_supportedp) { struct json_object *jo_caps = NULL; struct json_object *jo_top = NULL; @@ -130,6 +136,27 @@ tran_parse_version_json(const char *json_str, int *client_max_fdsp, } } + if (json_object_object_get_ex(jo_caps, "twin_socket", &jo)) { + struct json_object *jo2 = NULL; + + if (json_object_get_type(jo) != json_type_object) { + goto out; + } + + if (json_object_object_get_ex(jo, "supported", &jo2)) { + if (json_object_get_type(jo2) != json_type_boolean) { + goto out; + } + + errno = 0; + *twin_socket_supportedp = json_object_get_boolean(jo2); + + if (errno != 0) { + goto out; + } + } + } + ret = 0; out: @@ -143,7 +170,7 @@ out: static int recv_version(vfu_ctx_t *vfu_ctx, uint16_t *msg_idp, - struct vfio_user_version **versionp) + struct vfio_user_version **versionp, bool *twin_socket_supportedp) { struct vfio_user_version *cversion = NULL; vfu_msg_t msg = { { 0 } }; @@ -208,7 +235,7 @@ recv_version(vfu_ctx_t *vfu_ctx, uint16_t *msg_idp, ret = tran_parse_version_json(json_str, &vfu_ctx->client_max_fds, &vfu_ctx->client_max_data_xfer_size, - &pgsize); + &pgsize, twin_socket_supportedp); if (ret < 0) { /* No client-supplied strings in the log for release build. */ @@ -312,8 +339,9 @@ json_add_uint64(struct json_object *jso, const char *key, uint64_t value) * be freed by the caller. */ static char * -format_server_capabilities(vfu_ctx_t *vfu_ctx) +format_server_capabilities(vfu_ctx_t *vfu_ctx, int twin_socket_fd_index) { + struct json_object *jo_twin_socket = NULL; struct json_object *jo_migration = NULL; struct json_object *jo_caps = NULL; struct json_object *jo_top = NULL; @@ -347,6 +375,25 @@ format_server_capabilities(vfu_ctx_t *vfu_ctx) } } + if (twin_socket_fd_index >= 0) { + struct json_object *jo_supported = NULL; + + if ((jo_twin_socket = json_object_new_object()) == NULL) { + goto out; + } + + if ((jo_supported = json_object_new_boolean(true)) == NULL || + json_add(jo_twin_socket, "supported", &jo_supported) < 0 || + json_add_uint64(jo_twin_socket, "fd_index", + twin_socket_fd_index) < 0) { + goto out; + } + + if (json_add(jo_caps, "twin_socket", &jo_twin_socket) < 0) { + goto out; + } + } + if ((jo_top = json_object_new_object()) == NULL || json_add(jo_top, "capabilities", &jo_caps) < 0) { goto out; @@ -355,6 +402,7 @@ format_server_capabilities(vfu_ctx_t *vfu_ctx) caps_str = strdup(json_object_to_json_string(jo_top)); out: + json_object_put(jo_twin_socket); json_object_put(jo_migration); json_object_put(jo_caps); json_object_put(jo_top); @@ -363,15 +411,17 @@ out: static int send_version(vfu_ctx_t *vfu_ctx, uint16_t msg_id, - struct vfio_user_version *cversion) + struct vfio_user_version *cversion, int client_cmd_socket_fd) { + int twin_socket_fd_index = client_cmd_socket_fd >= 0 ? 0 : -1; struct vfio_user_version sversion = { 0 }; struct iovec iovecs[2] = { { 0 } }; vfu_msg_t msg = { { 0 } }; char *server_caps = NULL; int ret; - if ((server_caps = format_server_capabilities(vfu_ctx)) == NULL) { + server_caps = format_server_capabilities(vfu_ctx, twin_socket_fd_index); + if (server_caps == NULL) { errno = ENOMEM; return -1; } @@ -391,6 +441,11 @@ send_version(vfu_ctx_t *vfu_ctx, uint16_t msg_id, msg.hdr.msg_id = msg_id; msg.out_iovecs = iovecs; msg.nr_out_iovecs = 2; + if (client_cmd_socket_fd >= 0) { + msg.out.fds = &client_cmd_socket_fd; + msg.out.nr_fds = 1; + assert(msg.out.fds[twin_socket_fd_index] == client_cmd_socket_fd); + } ret = vfu_ctx->tran->reply(vfu_ctx, &msg, 0); free(server_caps); @@ -398,25 +453,45 @@ send_version(vfu_ctx_t *vfu_ctx, uint16_t msg_id, } int -tran_negotiate(vfu_ctx_t *vfu_ctx) +tran_negotiate(vfu_ctx_t *vfu_ctx, int *client_cmd_socket_fdp) { struct vfio_user_version *client_version = NULL; + int client_cmd_socket_fds[2] = { -1, -1 }; + bool twin_socket_supported = false; uint16_t msg_id = 0x0bad; int ret; - ret = recv_version(vfu_ctx, &msg_id, &client_version); + ret = recv_version(vfu_ctx, &msg_id, &client_version, + &twin_socket_supported); if (ret < 0) { vfu_log(vfu_ctx, LOG_ERR, "failed to recv version: %m"); return ret; } - ret = send_version(vfu_ctx, msg_id, client_version); + if (twin_socket_supported && client_cmd_socket_fdp != NULL && + vfu_ctx->client_max_fds > 0) { + if (socketpair(AF_UNIX, SOCK_STREAM, 0, client_cmd_socket_fds) == -1) { + vfu_log(vfu_ctx, LOG_ERR, "failed to create cmd socket: %m"); + return -1; + } + } + + ret = send_version(vfu_ctx, msg_id, client_version, + client_cmd_socket_fds[0]); free(client_version); + /* + * The remote end of the client command socket pair is no longer needed. + * The local end is kept only if passed to the caller on successful return. + */ + close_safely(&client_cmd_socket_fds[0]); if (ret < 0) { vfu_log(vfu_ctx, LOG_ERR, "failed to send version: %m"); + close_safely(&client_cmd_socket_fds[1]); + } else if (client_cmd_socket_fdp != NULL) { + *client_cmd_socket_fdp = client_cmd_socket_fds[1]; } return ret; @@ -72,10 +72,11 @@ struct transport_ops { */ int tran_parse_version_json(const char *json_str, int *client_max_fdsp, - size_t *client_max_data_xfer_sizep, size_t *pgsizep); + size_t *client_max_data_xfer_sizep, size_t *pgsizep, + bool *twin_socket_supportedp); int -tran_negotiate(vfu_ctx_t *vfu_ctx); +tran_negotiate(vfu_ctx_t *vfu_ctx, int *client_cmd_socket_fdp); #endif /* LIB_VFIO_USER_TRAN_H */ diff --git a/lib/tran_pipe.c b/lib/tran_pipe.c index e7aa84d..8fb605c 100644 --- a/lib/tran_pipe.c +++ b/lib/tran_pipe.c @@ -285,7 +285,7 @@ tran_pipe_attach(vfu_ctx_t *vfu_ctx) tp->in_fd = STDIN_FILENO; tp->out_fd = STDOUT_FILENO; - ret = tran_negotiate(vfu_ctx); + ret = tran_negotiate(vfu_ctx, NULL); if (ret < 0) { ret = errno; tp->in_fd = -1; diff --git a/lib/tran_sock.c b/lib/tran_sock.c index 3f4c8c3..8a652c7 100644 --- a/lib/tran_sock.c +++ b/lib/tran_sock.c @@ -46,6 +46,7 @@ typedef struct { int listen_fd; int conn_fd; + int client_cmd_socket_fd; } tran_sock_t; int @@ -380,6 +381,7 @@ tran_sock_init(vfu_ctx_t *vfu_ctx) ts->listen_fd = -1; ts->conn_fd = -1; + ts->client_cmd_socket_fd = -1; if ((ts->listen_fd = socket(AF_UNIX, SOCK_STREAM, 0)) == -1) { ret = errno; @@ -464,7 +466,7 @@ tran_sock_attach(vfu_ctx_t *vfu_ctx) return -1; } - ret = tran_negotiate(vfu_ctx); + ret = tran_negotiate(vfu_ctx, &ts->client_cmd_socket_fd); if (ret < 0) { close_safely(&ts->conn_fd); return -1; @@ -607,6 +609,21 @@ tran_sock_reply(vfu_ctx_t *vfu_ctx, vfu_msg_t *msg, int err) return ret; } +static void maybe_print_cmd_collision_warning(vfu_ctx_t *vfu_ctx) { + static bool warning_printed = false; + static const char *warning_msg = + "You are using libvfio-user in a configuration that issues " + "client-to-server commands, but without the twin_socket feature " + "enabled. This is known to break when client and server send a command " + "at the same time. See " + "https://github.com/nutanix/libvfio-user/issues/279 for details."; + + if (!warning_printed) { + vfu_log(vfu_ctx, LOG_WARNING, "%s", warning_msg); + warning_printed = true; + } +} + static int tran_sock_send_msg(vfu_ctx_t *vfu_ctx, uint16_t msg_id, enum vfio_user_command cmd, @@ -615,14 +632,21 @@ tran_sock_send_msg(vfu_ctx_t *vfu_ctx, uint16_t msg_id, void *recv_data, size_t recv_len) { tran_sock_t *ts; + int fd; assert(vfu_ctx != NULL); assert(vfu_ctx->tran_data != NULL); ts = vfu_ctx->tran_data; - return tran_sock_msg(ts->conn_fd, msg_id, cmd, send_data, send_len, - hdr, recv_data, recv_len); + fd = ts->client_cmd_socket_fd; + if (fd == -1) { + maybe_print_cmd_collision_warning(vfu_ctx); + fd = ts->conn_fd; + } + + return tran_sock_msg(fd, msg_id, cmd, send_data, send_len, hdr, recv_data, + recv_len); } static void @@ -636,6 +660,7 @@ tran_sock_detach(vfu_ctx_t *vfu_ctx) if (ts != NULL) { close_safely(&ts->conn_fd); + close_safely(&ts->client_cmd_socket_fd); } } diff --git a/samples/client.c b/samples/client.c index 0086fd6..ed66a30 100644 --- a/samples/client.c +++ b/samples/client.c @@ -197,7 +197,7 @@ recv_version(int sock, int *server_max_fds, size_t *server_max_data_xfer_size, } ret = tran_parse_version_json(json_str, server_max_fds, - server_max_data_xfer_size, pgsize); + server_max_data_xfer_size, pgsize, NULL); if (ret < 0) { err(EXIT_FAILURE, "failed to parse server JSON \"%s\"", json_str); diff --git a/test/py/libvfio_user.py b/test/py/libvfio_user.py index 86c8cbd..a701d1b 100644 --- a/test/py/libvfio_user.py +++ b/test/py/libvfio_user.py @@ -32,6 +32,7 @@ # from types import SimpleNamespace +import collections.abc import ctypes as c import array import errno @@ -484,6 +485,15 @@ class vfio_user_dma_unmap(Structure): ] +class vfio_user_dma_region_access(Structure): + """Payload for VFIO_USER_DMA_READ and VFIO_USER_DMA_WRITE.""" + _pack_ = 1 + _fields_ = [ + ("addr", c.c_uint64), + ("count", c.c_uint64), + ] + + class vfu_dma_info_t(Structure): _fields_ = [ ("iova", iovec_t), @@ -642,6 +652,10 @@ lib.vfu_sgl_get.argtypes = (c.c_void_p, c.POINTER(dma_sg_t), c.POINTER(iovec_t), c.c_size_t, c.c_int) lib.vfu_sgl_put.argtypes = (c.c_void_p, c.POINTER(dma_sg_t), c.POINTER(iovec_t), c.c_size_t) +lib.vfu_sgl_read.argtypes = (c.c_void_p, c.POINTER(dma_sg_t), c.c_size_t, + c.c_void_p) +lib.vfu_sgl_write.argtypes = (c.c_void_p, c.POINTER(dma_sg_t), c.c_size_t, + c.c_void_p) lib.vfu_create_ioeventfd.argtypes = (c.c_void_p, c.c_uint32, c.c_int, c.c_size_t, c.c_uint32, c.c_uint32, @@ -695,22 +709,52 @@ class Client: self.sock = sock self.client_cmd_socket = None - def connect(self, ctx): + def connect(self, ctx, capabilities={}): self.sock = connect_sock() - json = b'{ "capabilities": { "max_msg_fds": 8 } }' + client_caps = { + "capabilities": { + "max_data_xfer_size": VFIO_USER_DEFAULT_MAX_DATA_XFER_SIZE, + "max_msg_fds": 8, + }, + } + + def update(target, overrides): + for k, v in overrides.items(): + if isinstance(v, collections.abc.Mapping): + target[k] = target.get(k, {}) + update(target[k], v) + else: + target[k] = v + + update(client_caps, capabilities) + caps_json = json.dumps(client_caps) + # struct vfio_user_version - payload = struct.pack("HH%dsc" % len(json), LIBVFIO_USER_MAJOR, - LIBVFIO_USER_MINOR, json, b'\0') + payload = struct.pack("HH%dsc" % len(caps_json), LIBVFIO_USER_MAJOR, + LIBVFIO_USER_MINOR, caps_json.encode(), b'\0') hdr = vfio_user_header(VFIO_USER_VERSION, size=len(payload)) self.sock.send(hdr + payload) vfu_attach_ctx(ctx, expect=0) - payload = get_reply(self.sock, expect=0) + fds, payload = get_reply_fds(self.sock, expect=0) + + server_caps = json.loads(payload[struct.calcsize("HH"):-1].decode()) + try: + if (client_caps["capabilities"]["twin_socket"]["supported"] and + server_caps["capabilities"]["twin_socket"]["supported"]): + index = server_caps["capabilities"]["twin_socket"]["fd_index"] + self.client_cmd_socket = socket.socket(fileno=fds[index]) + except KeyError: + pass + return self.sock def disconnect(self, ctx): self.sock.close() self.sock = None + if self.client_cmd_socket is not None: + self.client_cmd_socket.close() + self.client_cmd_socket = None # notice client closed connection vfu_run_ctx(ctx, errno.ENOTCONN) @@ -1274,6 +1318,18 @@ def vfu_sgl_put(ctx, sg, iovec, cnt=1): return lib.vfu_sgl_put(ctx, sg, iovec, cnt) +def vfu_sgl_read(ctx, sg, cnt=1): + data = bytearray(sum([sge.length for sge in sg])) + buf = (c.c_byte * len(data)).from_buffer(data) + return lib.vfu_sgl_read(ctx, sg, cnt, buf), data + + +def vfu_sgl_write(ctx, sg, cnt=1, data=bytearray()): + assert len(data) == sum([sge.length for sge in sg]) + buf = (c.c_byte * len(data)).from_buffer(data) + return lib.vfu_sgl_write(ctx, sg, cnt, buf) + + def vfu_create_ioeventfd(ctx, region_idx, fd, gpa_offset, size, flags, datamatch, shadow_fd=-1, shadow_offset=0): assert ctx is not None diff --git a/test/py/meson.build b/test/py/meson.build index 0ea9f08..ecd2fe2 100644 --- a/test/py/meson.build +++ b/test/py/meson.build @@ -45,6 +45,7 @@ python_tests = [ 'test_request_errors.py', 'test_setup_region.py', 'test_sgl_get_put.py', + 'test_sgl_read_write.py', 'test_vfu_create_ctx.py', 'test_vfu_realize_ctx.py', ] diff --git a/test/py/test_sgl_read_write.py b/test/py/test_sgl_read_write.py new file mode 100644 index 0000000..2f4e992 --- /dev/null +++ b/test/py/test_sgl_read_write.py @@ -0,0 +1,192 @@ +# +# Copyright (c) 2023 Nutanix Inc. All rights reserved. +# Copyright (c) 2023 Rivos Inc. All rights reserved. +# +# Authors: Mattias Nissler <mnissler@rivosinc.com> +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of Nutanix nor the names of its contributors may be +# used to endorse or promote products derived from this software without +# specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL <COPYRIGHT HOLDER> BE LIABLE FOR ANY +# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY +# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH +# DAMAGE. +# + +from libvfio_user import * +import select +import threading + +MAP_ADDR = 0x10000000 +MAP_SIZE = 16 << PAGE_SHIFT + +ctx = None +client = None + + +class DMARegionHandler: + """ + A helper to service DMA region accesses arriving over a socket. Accesses + are performed against an internal bytearray buffer. DMA request processing + takes place on a separate thread so as to not block the test code. + """ + + def __handle_requests(sock, pipe, buf, lock, addr, error_no): + while True: + (ready, _, _) = select.select([sock, pipe], [], []) + if pipe in ready: + break + + # Read a command from the socket and service it. + _, msg_id, cmd, payload = get_msg_fds(sock, + VFIO_USER_F_TYPE_COMMAND) + assert cmd in [VFIO_USER_DMA_READ, VFIO_USER_DMA_WRITE] + access, data = vfio_user_dma_region_access.pop_from_buffer(payload) + + assert access.addr >= addr + assert access.addr + access.count <= addr + len(buf) + + offset = access.addr - addr + with lock: + if cmd == VFIO_USER_DMA_READ: + data = buf[offset:offset + access.count] + else: + buf[offset:offset + access.count] = data + data = bytearray() + + send_msg(sock, + cmd, + VFIO_USER_F_TYPE_REPLY, + payload=payload[:c.sizeof(access)] + data, + msg_id=msg_id, + error_no=error_no) + + os.close(pipe) + sock.close() + + def __init__(self, sock, addr, size, error_no=0): + self.data = bytearray(size) + self.data_lock = threading.Lock() + self.addr = addr + (pipe_r, self.pipe_w) = os.pipe() + # Duplicate the socket file descriptor so the thread can own it and + # make sure it gets closed only when terminating the thread. + sock = socket.socket(fileno=os.dup(sock.fileno())) + thread = threading.Thread( + target=DMARegionHandler.__handle_requests, + args=[sock, pipe_r, self.data, self.data_lock, addr, error_no]) + thread.start() + + def shutdown(self): + # Closing the pipe's write end will signal the thread to terminate. + os.close(self.pipe_w) + + def read(self, addr, size): + offset = addr - self.addr + with self.data_lock: + return self.data[offset:offset + size] + + +def setup_function(function): + global ctx, client, dma_handler + ctx = prepare_ctx_for_dma() + assert ctx is not None + caps = { + "capabilities": { + "max_data_xfer_size": PAGE_SIZE, + "twin_socket": { + "supported": True, + }, + } + } + client = connect_client(ctx, caps) + assert client.client_cmd_socket is not None + + payload = vfio_user_dma_map(argsz=len(vfio_user_dma_map()), + flags=(VFIO_USER_F_DMA_REGION_READ + | VFIO_USER_F_DMA_REGION_WRITE), + offset=0, + addr=MAP_ADDR, + size=MAP_SIZE) + + msg(ctx, client.sock, VFIO_USER_DMA_MAP, payload) + + dma_handler = DMARegionHandler(client.client_cmd_socket, payload.addr, + payload.size) + + +def teardown_function(function): + dma_handler.shutdown() + client.disconnect(ctx) + vfu_destroy_ctx(ctx) + + +def test_dma_read_write(): + ret, sg = vfu_addr_to_sgl(ctx, + dma_addr=MAP_ADDR + 0x1000, + length=64, + max_nr_sgs=1, + prot=mmap.PROT_READ | mmap.PROT_WRITE) + assert ret == 1 + + data = bytearray([x & 0xff for x in range(0, sg[0].length)]) + assert vfu_sgl_write(ctx, sg, 1, data) == 0 + + assert vfu_sgl_read(ctx, sg, 1) == (0, data) + + assert dma_handler.read(sg[0].dma_addr + sg[0].offset, + sg[0].length) == data + + +def test_dma_read_write_large(): + ret, sg = vfu_addr_to_sgl(ctx, + dma_addr=MAP_ADDR + 0x1000, + length=2 * PAGE_SIZE, + max_nr_sgs=1, + prot=mmap.PROT_READ | mmap.PROT_WRITE) + assert ret == 1 + + data = bytearray([x & 0xff for x in range(0, sg[0].length)]) + assert vfu_sgl_write(ctx, sg, 1, data) == 0 + + assert vfu_sgl_read(ctx, sg, 1) == (0, data) + + assert dma_handler.read(sg[0].dma_addr + sg[0].offset, + sg[0].length) == data + + +def test_dma_read_write_error(): + # Reinitialize the handler to return EIO. + global dma_handler + dma_handler.shutdown() + dma_handler = DMARegionHandler(client.client_cmd_socket, MAP_ADDR, + MAP_SIZE, error_no=errno.EIO) + + ret, sg = vfu_addr_to_sgl(ctx, + dma_addr=MAP_ADDR + 0x1000, + length=64, + max_nr_sgs=1, + prot=mmap.PROT_READ | mmap.PROT_WRITE) + assert ret == 1 + + ret, _ = vfu_sgl_read(ctx, sg, 1) + assert ret == -1 + assert c.get_errno() == errno.EIO + + +# ex: set tabstop=4 shiftwidth=4 softtabstop=4 expandtab: # |