aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/libvfio-user.c17
-rw-r--r--test/py/test_dirty_pages.py12
2 files changed, 24 insertions, 5 deletions
diff --git a/lib/libvfio-user.c b/lib/libvfio-user.c
index b922c2e..b2ffac6 100644
--- a/lib/libvfio-user.c
+++ b/lib/libvfio-user.c
@@ -893,8 +893,9 @@ handle_dirty_pages_get(vfu_ctx_t *vfu_ctx, vfu_msg_t *msg)
dirty_pages_in = msg->in_data;
- if (msg->in_size < sizeof(*dirty_pages_in) + sizeof(*range_in)
- || dirty_pages_in->argsz < sizeof(*dirty_pages_out)) {
+ if (msg->in_size < sizeof(*dirty_pages_in) + sizeof(*range_in) ||
+ dirty_pages_in->argsz > SERVER_MAX_DATA_XFER_SIZE ||
+ dirty_pages_in->argsz < sizeof(*dirty_pages_out)) {
vfu_log(vfu_ctx, LOG_ERR, "invalid message size=%zu argsz=%u",
msg->in_size, dirty_pages_in->argsz);
return ERROR_INT(EINVAL);
@@ -902,9 +903,15 @@ handle_dirty_pages_get(vfu_ctx_t *vfu_ctx, vfu_msg_t *msg)
range_in = msg->in_data + sizeof(*dirty_pages_in);
- /* NB: this is bound by MAX_DMA_SIZE. */
- argsz = sizeof(*dirty_pages_out) + sizeof(*range_out) +
- range_in->bitmap.size;
+ /*
+ * range_in is client-controlled, but we only need to protect against
+ * overflow here: we'll take MIN() against a validated value next, and
+ * dma_controller_dirty_page_get() will validate the actual ->bitmap.size
+ * value later, anyway.
+ */
+ argsz = satadd_u64(sizeof(*dirty_pages_out) + sizeof(*range_out),
+ range_in->bitmap.size);
+
msg->out_size = MIN(dirty_pages_in->argsz, argsz);
msg->out_data = malloc(msg->out_size);
if (msg->out_data == NULL) {
diff --git a/test/py/test_dirty_pages.py b/test/py/test_dirty_pages.py
index a5b85dc..9baf6cd 100644
--- a/test/py/test_dirty_pages.py
+++ b/test/py/test_dirty_pages.py
@@ -182,6 +182,18 @@ def test_dirty_pages_get_bad_bitmap_size():
msg(ctx, sock, VFIO_USER_DIRTY_PAGES, payload, expect=errno.EINVAL)
+def test_dirty_pages_get_bad_argsz():
+ dirty_pages = vfio_user_dirty_pages(argsz=SERVER_MAX_DATA_XFER_SIZE + 8,
+ flags=VFIO_IOMMU_DIRTY_PAGES_FLAG_GET_BITMAP)
+ bitmap = vfio_user_bitmap(pgsize=0x1000,
+ size=SERVER_MAX_DATA_XFER_SIZE + 8)
+ br = vfio_user_bitmap_range(iova=0x10000, size=0x10000, bitmap=bitmap)
+
+ payload = bytes(dirty_pages) + bytes(br)
+
+ msg(ctx, sock, VFIO_USER_DIRTY_PAGES, payload, expect=errno.EINVAL)
+
+
def test_dirty_pages_get_short_reply():
dirty_pages = vfio_user_dirty_pages(argsz=len(vfio_user_dirty_pages()),
flags=VFIO_IOMMU_DIRTY_PAGES_FLAG_GET_BITMAP)