aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMattias Nissler <122288598+mnissler-rivos@users.noreply.github.com>2023-09-15 12:33:37 +0200
committerGitHub <noreply@github.com>2023-09-15 11:33:37 +0100
commit1569a37a54ecb63bd4008708c76339ccf7d06115 (patch)
tree249f6e0bd2dd5b8620bc55637e60c8834d5cf630
parent478ddb5f87ea257c8682c5288761606d5fa216ad (diff)
downloadlibvfio-user-1569a37a54ecb63bd4008708c76339ccf7d06115.zip
libvfio-user-1569a37a54ecb63bd4008708c76339ccf7d06115.tar.gz
libvfio-user-1569a37a54ecb63bd4008708c76339ccf7d06115.tar.bz2
Pass server->client command over a separate socket pair (#762)
Use separate socket for server->client commands This change adds support for a separate socket to carry commands in the server-to-client direction. It has proven problematic to send commands in both directions over a single socket, since matching replies to commands can become non-trivial when both sides send commands at the same time and adds significant complexity. See issue #279 for details. To set up the reverse communication channel, the client indicates support for it via a new capability flag in the version message. The server will then create a fresh pair of sockets and pass one end to the client in its version reply. When the server wishes to send commands to the client at a later point, it now uses its end of the new socket pair rather than the main socket. Corresponding replies are also passed back over the new socket pair. Signed-off-by: Mattias Nissler <mnissler@rivosinc.com>
-rw-r--r--include/libvfio-user.h2
-rw-r--r--lib/tran.c95
-rw-r--r--lib/tran.h5
-rw-r--r--lib/tran_pipe.c2
-rw-r--r--lib/tran_sock.c31
-rw-r--r--samples/client.c2
-rw-r--r--test/py/libvfio_user.py66
-rw-r--r--test/py/meson.build1
-rw-r--r--test/py/test_sgl_read_write.py192
9 files changed, 373 insertions, 23 deletions
diff --git a/include/libvfio-user.h b/include/libvfio-user.h
index 72369b6..21cb99a 100644
--- a/include/libvfio-user.h
+++ b/include/libvfio-user.h
@@ -61,7 +61,7 @@ extern "C" {
#endif
#define LIB_VFIO_USER_MAJOR 0
-#define LIB_VFIO_USER_MINOR 1
+#define LIB_VFIO_USER_MINOR 2
/* DMA addresses cannot be directly de-referenced. */
typedef void *vfu_dma_addr_t;
diff --git a/lib/tran.c b/lib/tran.c
index 3c8b25a..46f5874 100644
--- a/lib/tran.c
+++ b/lib/tran.c
@@ -36,6 +36,7 @@
#include <stdlib.h>
#include <stdio.h>
#include <string.h>
+#include <sys/socket.h>
#include <json.h>
@@ -52,9 +53,13 @@
* {
* "capabilities": {
* "max_msg_fds": 32,
- * "max_data_xfer_size": 1048576
+ * "max_data_xfer_size": 1048576,
* "migration": {
* "pgsize": 4096
+ * },
+ * "twin_socket": {
+ * "supported": true,
+ * "fd_index": 0
* }
* }
* }
@@ -64,7 +69,8 @@
*/
int
tran_parse_version_json(const char *json_str, int *client_max_fdsp,
- size_t *client_max_data_xfer_sizep, size_t *pgsizep)
+ size_t *client_max_data_xfer_sizep, size_t *pgsizep,
+ bool *twin_socket_supportedp)
{
struct json_object *jo_caps = NULL;
struct json_object *jo_top = NULL;
@@ -130,6 +136,27 @@ tran_parse_version_json(const char *json_str, int *client_max_fdsp,
}
}
+ if (json_object_object_get_ex(jo_caps, "twin_socket", &jo)) {
+ struct json_object *jo2 = NULL;
+
+ if (json_object_get_type(jo) != json_type_object) {
+ goto out;
+ }
+
+ if (json_object_object_get_ex(jo, "supported", &jo2)) {
+ if (json_object_get_type(jo2) != json_type_boolean) {
+ goto out;
+ }
+
+ errno = 0;
+ *twin_socket_supportedp = json_object_get_boolean(jo2);
+
+ if (errno != 0) {
+ goto out;
+ }
+ }
+ }
+
ret = 0;
out:
@@ -143,7 +170,7 @@ out:
static int
recv_version(vfu_ctx_t *vfu_ctx, uint16_t *msg_idp,
- struct vfio_user_version **versionp)
+ struct vfio_user_version **versionp, bool *twin_socket_supportedp)
{
struct vfio_user_version *cversion = NULL;
vfu_msg_t msg = { { 0 } };
@@ -208,7 +235,7 @@ recv_version(vfu_ctx_t *vfu_ctx, uint16_t *msg_idp,
ret = tran_parse_version_json(json_str, &vfu_ctx->client_max_fds,
&vfu_ctx->client_max_data_xfer_size,
- &pgsize);
+ &pgsize, twin_socket_supportedp);
if (ret < 0) {
/* No client-supplied strings in the log for release build. */
@@ -312,8 +339,9 @@ json_add_uint64(struct json_object *jso, const char *key, uint64_t value)
* be freed by the caller.
*/
static char *
-format_server_capabilities(vfu_ctx_t *vfu_ctx)
+format_server_capabilities(vfu_ctx_t *vfu_ctx, int twin_socket_fd_index)
{
+ struct json_object *jo_twin_socket = NULL;
struct json_object *jo_migration = NULL;
struct json_object *jo_caps = NULL;
struct json_object *jo_top = NULL;
@@ -347,6 +375,25 @@ format_server_capabilities(vfu_ctx_t *vfu_ctx)
}
}
+ if (twin_socket_fd_index >= 0) {
+ struct json_object *jo_supported = NULL;
+
+ if ((jo_twin_socket = json_object_new_object()) == NULL) {
+ goto out;
+ }
+
+ if ((jo_supported = json_object_new_boolean(true)) == NULL ||
+ json_add(jo_twin_socket, "supported", &jo_supported) < 0 ||
+ json_add_uint64(jo_twin_socket, "fd_index",
+ twin_socket_fd_index) < 0) {
+ goto out;
+ }
+
+ if (json_add(jo_caps, "twin_socket", &jo_twin_socket) < 0) {
+ goto out;
+ }
+ }
+
if ((jo_top = json_object_new_object()) == NULL ||
json_add(jo_top, "capabilities", &jo_caps) < 0) {
goto out;
@@ -355,6 +402,7 @@ format_server_capabilities(vfu_ctx_t *vfu_ctx)
caps_str = strdup(json_object_to_json_string(jo_top));
out:
+ json_object_put(jo_twin_socket);
json_object_put(jo_migration);
json_object_put(jo_caps);
json_object_put(jo_top);
@@ -363,15 +411,17 @@ out:
static int
send_version(vfu_ctx_t *vfu_ctx, uint16_t msg_id,
- struct vfio_user_version *cversion)
+ struct vfio_user_version *cversion, int client_cmd_socket_fd)
{
+ int twin_socket_fd_index = client_cmd_socket_fd >= 0 ? 0 : -1;
struct vfio_user_version sversion = { 0 };
struct iovec iovecs[2] = { { 0 } };
vfu_msg_t msg = { { 0 } };
char *server_caps = NULL;
int ret;
- if ((server_caps = format_server_capabilities(vfu_ctx)) == NULL) {
+ server_caps = format_server_capabilities(vfu_ctx, twin_socket_fd_index);
+ if (server_caps == NULL) {
errno = ENOMEM;
return -1;
}
@@ -391,6 +441,11 @@ send_version(vfu_ctx_t *vfu_ctx, uint16_t msg_id,
msg.hdr.msg_id = msg_id;
msg.out_iovecs = iovecs;
msg.nr_out_iovecs = 2;
+ if (client_cmd_socket_fd >= 0) {
+ msg.out.fds = &client_cmd_socket_fd;
+ msg.out.nr_fds = 1;
+ assert(msg.out.fds[twin_socket_fd_index] == client_cmd_socket_fd);
+ }
ret = vfu_ctx->tran->reply(vfu_ctx, &msg, 0);
free(server_caps);
@@ -398,25 +453,45 @@ send_version(vfu_ctx_t *vfu_ctx, uint16_t msg_id,
}
int
-tran_negotiate(vfu_ctx_t *vfu_ctx)
+tran_negotiate(vfu_ctx_t *vfu_ctx, int *client_cmd_socket_fdp)
{
struct vfio_user_version *client_version = NULL;
+ int client_cmd_socket_fds[2] = { -1, -1 };
+ bool twin_socket_supported = false;
uint16_t msg_id = 0x0bad;
int ret;
- ret = recv_version(vfu_ctx, &msg_id, &client_version);
+ ret = recv_version(vfu_ctx, &msg_id, &client_version,
+ &twin_socket_supported);
if (ret < 0) {
vfu_log(vfu_ctx, LOG_ERR, "failed to recv version: %m");
return ret;
}
- ret = send_version(vfu_ctx, msg_id, client_version);
+ if (twin_socket_supported && client_cmd_socket_fdp != NULL &&
+ vfu_ctx->client_max_fds > 0) {
+ if (socketpair(AF_UNIX, SOCK_STREAM, 0, client_cmd_socket_fds) == -1) {
+ vfu_log(vfu_ctx, LOG_ERR, "failed to create cmd socket: %m");
+ return -1;
+ }
+ }
+
+ ret = send_version(vfu_ctx, msg_id, client_version,
+ client_cmd_socket_fds[0]);
free(client_version);
+ /*
+ * The remote end of the client command socket pair is no longer needed.
+ * The local end is kept only if passed to the caller on successful return.
+ */
+ close_safely(&client_cmd_socket_fds[0]);
if (ret < 0) {
vfu_log(vfu_ctx, LOG_ERR, "failed to send version: %m");
+ close_safely(&client_cmd_socket_fds[1]);
+ } else if (client_cmd_socket_fdp != NULL) {
+ *client_cmd_socket_fdp = client_cmd_socket_fds[1];
}
return ret;
diff --git a/lib/tran.h b/lib/tran.h
index fee96e8..9ad1203 100644
--- a/lib/tran.h
+++ b/lib/tran.h
@@ -72,10 +72,11 @@ struct transport_ops {
*/
int
tran_parse_version_json(const char *json_str, int *client_max_fdsp,
- size_t *client_max_data_xfer_sizep, size_t *pgsizep);
+ size_t *client_max_data_xfer_sizep, size_t *pgsizep,
+ bool *twin_socket_supportedp);
int
-tran_negotiate(vfu_ctx_t *vfu_ctx);
+tran_negotiate(vfu_ctx_t *vfu_ctx, int *client_cmd_socket_fdp);
#endif /* LIB_VFIO_USER_TRAN_H */
diff --git a/lib/tran_pipe.c b/lib/tran_pipe.c
index e7aa84d..8fb605c 100644
--- a/lib/tran_pipe.c
+++ b/lib/tran_pipe.c
@@ -285,7 +285,7 @@ tran_pipe_attach(vfu_ctx_t *vfu_ctx)
tp->in_fd = STDIN_FILENO;
tp->out_fd = STDOUT_FILENO;
- ret = tran_negotiate(vfu_ctx);
+ ret = tran_negotiate(vfu_ctx, NULL);
if (ret < 0) {
ret = errno;
tp->in_fd = -1;
diff --git a/lib/tran_sock.c b/lib/tran_sock.c
index 3f4c8c3..8a652c7 100644
--- a/lib/tran_sock.c
+++ b/lib/tran_sock.c
@@ -46,6 +46,7 @@
typedef struct {
int listen_fd;
int conn_fd;
+ int client_cmd_socket_fd;
} tran_sock_t;
int
@@ -380,6 +381,7 @@ tran_sock_init(vfu_ctx_t *vfu_ctx)
ts->listen_fd = -1;
ts->conn_fd = -1;
+ ts->client_cmd_socket_fd = -1;
if ((ts->listen_fd = socket(AF_UNIX, SOCK_STREAM, 0)) == -1) {
ret = errno;
@@ -464,7 +466,7 @@ tran_sock_attach(vfu_ctx_t *vfu_ctx)
return -1;
}
- ret = tran_negotiate(vfu_ctx);
+ ret = tran_negotiate(vfu_ctx, &ts->client_cmd_socket_fd);
if (ret < 0) {
close_safely(&ts->conn_fd);
return -1;
@@ -607,6 +609,21 @@ tran_sock_reply(vfu_ctx_t *vfu_ctx, vfu_msg_t *msg, int err)
return ret;
}
+static void maybe_print_cmd_collision_warning(vfu_ctx_t *vfu_ctx) {
+ static bool warning_printed = false;
+ static const char *warning_msg =
+ "You are using libvfio-user in a configuration that issues "
+ "client-to-server commands, but without the twin_socket feature "
+ "enabled. This is known to break when client and server send a command "
+ "at the same time. See "
+ "https://github.com/nutanix/libvfio-user/issues/279 for details.";
+
+ if (!warning_printed) {
+ vfu_log(vfu_ctx, LOG_WARNING, "%s", warning_msg);
+ warning_printed = true;
+ }
+}
+
static int
tran_sock_send_msg(vfu_ctx_t *vfu_ctx, uint16_t msg_id,
enum vfio_user_command cmd,
@@ -615,14 +632,21 @@ tran_sock_send_msg(vfu_ctx_t *vfu_ctx, uint16_t msg_id,
void *recv_data, size_t recv_len)
{
tran_sock_t *ts;
+ int fd;
assert(vfu_ctx != NULL);
assert(vfu_ctx->tran_data != NULL);
ts = vfu_ctx->tran_data;
- return tran_sock_msg(ts->conn_fd, msg_id, cmd, send_data, send_len,
- hdr, recv_data, recv_len);
+ fd = ts->client_cmd_socket_fd;
+ if (fd == -1) {
+ maybe_print_cmd_collision_warning(vfu_ctx);
+ fd = ts->conn_fd;
+ }
+
+ return tran_sock_msg(fd, msg_id, cmd, send_data, send_len, hdr, recv_data,
+ recv_len);
}
static void
@@ -636,6 +660,7 @@ tran_sock_detach(vfu_ctx_t *vfu_ctx)
if (ts != NULL) {
close_safely(&ts->conn_fd);
+ close_safely(&ts->client_cmd_socket_fd);
}
}
diff --git a/samples/client.c b/samples/client.c
index 0086fd6..ed66a30 100644
--- a/samples/client.c
+++ b/samples/client.c
@@ -197,7 +197,7 @@ recv_version(int sock, int *server_max_fds, size_t *server_max_data_xfer_size,
}
ret = tran_parse_version_json(json_str, server_max_fds,
- server_max_data_xfer_size, pgsize);
+ server_max_data_xfer_size, pgsize, NULL);
if (ret < 0) {
err(EXIT_FAILURE, "failed to parse server JSON \"%s\"", json_str);
diff --git a/test/py/libvfio_user.py b/test/py/libvfio_user.py
index 86c8cbd..a701d1b 100644
--- a/test/py/libvfio_user.py
+++ b/test/py/libvfio_user.py
@@ -32,6 +32,7 @@
#
from types import SimpleNamespace
+import collections.abc
import ctypes as c
import array
import errno
@@ -484,6 +485,15 @@ class vfio_user_dma_unmap(Structure):
]
+class vfio_user_dma_region_access(Structure):
+ """Payload for VFIO_USER_DMA_READ and VFIO_USER_DMA_WRITE."""
+ _pack_ = 1
+ _fields_ = [
+ ("addr", c.c_uint64),
+ ("count", c.c_uint64),
+ ]
+
+
class vfu_dma_info_t(Structure):
_fields_ = [
("iova", iovec_t),
@@ -642,6 +652,10 @@ lib.vfu_sgl_get.argtypes = (c.c_void_p, c.POINTER(dma_sg_t),
c.POINTER(iovec_t), c.c_size_t, c.c_int)
lib.vfu_sgl_put.argtypes = (c.c_void_p, c.POINTER(dma_sg_t),
c.POINTER(iovec_t), c.c_size_t)
+lib.vfu_sgl_read.argtypes = (c.c_void_p, c.POINTER(dma_sg_t), c.c_size_t,
+ c.c_void_p)
+lib.vfu_sgl_write.argtypes = (c.c_void_p, c.POINTER(dma_sg_t), c.c_size_t,
+ c.c_void_p)
lib.vfu_create_ioeventfd.argtypes = (c.c_void_p, c.c_uint32, c.c_int,
c.c_size_t, c.c_uint32, c.c_uint32,
@@ -695,22 +709,52 @@ class Client:
self.sock = sock
self.client_cmd_socket = None
- def connect(self, ctx):
+ def connect(self, ctx, capabilities={}):
self.sock = connect_sock()
- json = b'{ "capabilities": { "max_msg_fds": 8 } }'
+ client_caps = {
+ "capabilities": {
+ "max_data_xfer_size": VFIO_USER_DEFAULT_MAX_DATA_XFER_SIZE,
+ "max_msg_fds": 8,
+ },
+ }
+
+ def update(target, overrides):
+ for k, v in overrides.items():
+ if isinstance(v, collections.abc.Mapping):
+ target[k] = target.get(k, {})
+ update(target[k], v)
+ else:
+ target[k] = v
+
+ update(client_caps, capabilities)
+ caps_json = json.dumps(client_caps)
+
# struct vfio_user_version
- payload = struct.pack("HH%dsc" % len(json), LIBVFIO_USER_MAJOR,
- LIBVFIO_USER_MINOR, json, b'\0')
+ payload = struct.pack("HH%dsc" % len(caps_json), LIBVFIO_USER_MAJOR,
+ LIBVFIO_USER_MINOR, caps_json.encode(), b'\0')
hdr = vfio_user_header(VFIO_USER_VERSION, size=len(payload))
self.sock.send(hdr + payload)
vfu_attach_ctx(ctx, expect=0)
- payload = get_reply(self.sock, expect=0)
+ fds, payload = get_reply_fds(self.sock, expect=0)
+
+ server_caps = json.loads(payload[struct.calcsize("HH"):-1].decode())
+ try:
+ if (client_caps["capabilities"]["twin_socket"]["supported"] and
+ server_caps["capabilities"]["twin_socket"]["supported"]):
+ index = server_caps["capabilities"]["twin_socket"]["fd_index"]
+ self.client_cmd_socket = socket.socket(fileno=fds[index])
+ except KeyError:
+ pass
+
return self.sock
def disconnect(self, ctx):
self.sock.close()
self.sock = None
+ if self.client_cmd_socket is not None:
+ self.client_cmd_socket.close()
+ self.client_cmd_socket = None
# notice client closed connection
vfu_run_ctx(ctx, errno.ENOTCONN)
@@ -1274,6 +1318,18 @@ def vfu_sgl_put(ctx, sg, iovec, cnt=1):
return lib.vfu_sgl_put(ctx, sg, iovec, cnt)
+def vfu_sgl_read(ctx, sg, cnt=1):
+ data = bytearray(sum([sge.length for sge in sg]))
+ buf = (c.c_byte * len(data)).from_buffer(data)
+ return lib.vfu_sgl_read(ctx, sg, cnt, buf), data
+
+
+def vfu_sgl_write(ctx, sg, cnt=1, data=bytearray()):
+ assert len(data) == sum([sge.length for sge in sg])
+ buf = (c.c_byte * len(data)).from_buffer(data)
+ return lib.vfu_sgl_write(ctx, sg, cnt, buf)
+
+
def vfu_create_ioeventfd(ctx, region_idx, fd, gpa_offset, size, flags,
datamatch, shadow_fd=-1, shadow_offset=0):
assert ctx is not None
diff --git a/test/py/meson.build b/test/py/meson.build
index 0ea9f08..ecd2fe2 100644
--- a/test/py/meson.build
+++ b/test/py/meson.build
@@ -45,6 +45,7 @@ python_tests = [
'test_request_errors.py',
'test_setup_region.py',
'test_sgl_get_put.py',
+ 'test_sgl_read_write.py',
'test_vfu_create_ctx.py',
'test_vfu_realize_ctx.py',
]
diff --git a/test/py/test_sgl_read_write.py b/test/py/test_sgl_read_write.py
new file mode 100644
index 0000000..2f4e992
--- /dev/null
+++ b/test/py/test_sgl_read_write.py
@@ -0,0 +1,192 @@
+#
+# Copyright (c) 2023 Nutanix Inc. All rights reserved.
+# Copyright (c) 2023 Rivos Inc. All rights reserved.
+#
+# Authors: Mattias Nissler <mnissler@rivosinc.com>
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in the
+# documentation and/or other materials provided with the distribution.
+# * Neither the name of Nutanix nor the names of its contributors may be
+# used to endorse or promote products derived from this software without
+# specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+# ARE DISCLAIMED. IN NO EVENT SHALL <COPYRIGHT HOLDER> BE LIABLE FOR ANY
+# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
+# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
+# DAMAGE.
+#
+
+from libvfio_user import *
+import select
+import threading
+
+MAP_ADDR = 0x10000000
+MAP_SIZE = 16 << PAGE_SHIFT
+
+ctx = None
+client = None
+
+
+class DMARegionHandler:
+ """
+ A helper to service DMA region accesses arriving over a socket. Accesses
+ are performed against an internal bytearray buffer. DMA request processing
+ takes place on a separate thread so as to not block the test code.
+ """
+
+ def __handle_requests(sock, pipe, buf, lock, addr, error_no):
+ while True:
+ (ready, _, _) = select.select([sock, pipe], [], [])
+ if pipe in ready:
+ break
+
+ # Read a command from the socket and service it.
+ _, msg_id, cmd, payload = get_msg_fds(sock,
+ VFIO_USER_F_TYPE_COMMAND)
+ assert cmd in [VFIO_USER_DMA_READ, VFIO_USER_DMA_WRITE]
+ access, data = vfio_user_dma_region_access.pop_from_buffer(payload)
+
+ assert access.addr >= addr
+ assert access.addr + access.count <= addr + len(buf)
+
+ offset = access.addr - addr
+ with lock:
+ if cmd == VFIO_USER_DMA_READ:
+ data = buf[offset:offset + access.count]
+ else:
+ buf[offset:offset + access.count] = data
+ data = bytearray()
+
+ send_msg(sock,
+ cmd,
+ VFIO_USER_F_TYPE_REPLY,
+ payload=payload[:c.sizeof(access)] + data,
+ msg_id=msg_id,
+ error_no=error_no)
+
+ os.close(pipe)
+ sock.close()
+
+ def __init__(self, sock, addr, size, error_no=0):
+ self.data = bytearray(size)
+ self.data_lock = threading.Lock()
+ self.addr = addr
+ (pipe_r, self.pipe_w) = os.pipe()
+ # Duplicate the socket file descriptor so the thread can own it and
+ # make sure it gets closed only when terminating the thread.
+ sock = socket.socket(fileno=os.dup(sock.fileno()))
+ thread = threading.Thread(
+ target=DMARegionHandler.__handle_requests,
+ args=[sock, pipe_r, self.data, self.data_lock, addr, error_no])
+ thread.start()
+
+ def shutdown(self):
+ # Closing the pipe's write end will signal the thread to terminate.
+ os.close(self.pipe_w)
+
+ def read(self, addr, size):
+ offset = addr - self.addr
+ with self.data_lock:
+ return self.data[offset:offset + size]
+
+
+def setup_function(function):
+ global ctx, client, dma_handler
+ ctx = prepare_ctx_for_dma()
+ assert ctx is not None
+ caps = {
+ "capabilities": {
+ "max_data_xfer_size": PAGE_SIZE,
+ "twin_socket": {
+ "supported": True,
+ },
+ }
+ }
+ client = connect_client(ctx, caps)
+ assert client.client_cmd_socket is not None
+
+ payload = vfio_user_dma_map(argsz=len(vfio_user_dma_map()),
+ flags=(VFIO_USER_F_DMA_REGION_READ
+ | VFIO_USER_F_DMA_REGION_WRITE),
+ offset=0,
+ addr=MAP_ADDR,
+ size=MAP_SIZE)
+
+ msg(ctx, client.sock, VFIO_USER_DMA_MAP, payload)
+
+ dma_handler = DMARegionHandler(client.client_cmd_socket, payload.addr,
+ payload.size)
+
+
+def teardown_function(function):
+ dma_handler.shutdown()
+ client.disconnect(ctx)
+ vfu_destroy_ctx(ctx)
+
+
+def test_dma_read_write():
+ ret, sg = vfu_addr_to_sgl(ctx,
+ dma_addr=MAP_ADDR + 0x1000,
+ length=64,
+ max_nr_sgs=1,
+ prot=mmap.PROT_READ | mmap.PROT_WRITE)
+ assert ret == 1
+
+ data = bytearray([x & 0xff for x in range(0, sg[0].length)])
+ assert vfu_sgl_write(ctx, sg, 1, data) == 0
+
+ assert vfu_sgl_read(ctx, sg, 1) == (0, data)
+
+ assert dma_handler.read(sg[0].dma_addr + sg[0].offset,
+ sg[0].length) == data
+
+
+def test_dma_read_write_large():
+ ret, sg = vfu_addr_to_sgl(ctx,
+ dma_addr=MAP_ADDR + 0x1000,
+ length=2 * PAGE_SIZE,
+ max_nr_sgs=1,
+ prot=mmap.PROT_READ | mmap.PROT_WRITE)
+ assert ret == 1
+
+ data = bytearray([x & 0xff for x in range(0, sg[0].length)])
+ assert vfu_sgl_write(ctx, sg, 1, data) == 0
+
+ assert vfu_sgl_read(ctx, sg, 1) == (0, data)
+
+ assert dma_handler.read(sg[0].dma_addr + sg[0].offset,
+ sg[0].length) == data
+
+
+def test_dma_read_write_error():
+ # Reinitialize the handler to return EIO.
+ global dma_handler
+ dma_handler.shutdown()
+ dma_handler = DMARegionHandler(client.client_cmd_socket, MAP_ADDR,
+ MAP_SIZE, error_no=errno.EIO)
+
+ ret, sg = vfu_addr_to_sgl(ctx,
+ dma_addr=MAP_ADDR + 0x1000,
+ length=64,
+ max_nr_sgs=1,
+ prot=mmap.PROT_READ | mmap.PROT_WRITE)
+ assert ret == 1
+
+ ret, _ = vfu_sgl_read(ctx, sg, 1)
+ assert ret == -1
+ assert c.get_errno() == errno.EIO
+
+
+# ex: set tabstop=4 shiftwidth=4 softtabstop=4 expandtab: #