]> www.infradead.org Git - users/dwmw2/openconnect.git/commitdiff
Support vhost on more than just x86_64
authorDavid Woodhouse <dwmw2@infradead.org>
Mon, 21 Feb 2022 12:02:26 +0000 (12:02 +0000)
committerDavid Woodhouse <dwmw2@infradead.org>
Mon, 21 Feb 2022 15:01:18 +0000 (15:01 +0000)
The reason we only supported x86_64 is because the vhost interface is
kind of awful and requires that we *know* what TASK_SIZE is.

If we specify a range in with the VHOST_SET_MEM_TABLE ioctl that goes
higher than TASK_SIZE, it doesn't fail immediately, which would be
helpful. Instead it succeeds, but then *all* buffer accesses fail,
even the ones at valid addresses, because our mapping *could* have
pointed to an invalid address for which !access_ok().

So... here's a really awful hack to attempt to determine TASK_SIZE
by attempting to mmap() at different addresses. For addresses which
are invalid, it'll return -ENOMEM, but for addresses which are valid
it returns -EINVAL because we deliberately don't specify either
MAP_PRIVATE or MAP_SHARED as we should. This is better than the criu
version which tries to *unmap*, but in doing so much actually unmap
something that was supposed to be there!

Fixes: #383
Signed-off-by: David Woodhouse <dwmw2@infradead.org>
configure.ac
vhost.c

index 5b233f3c45c2f20c86e74ba617f0dad4cff5c318..5ac966f51f829c4924e05a6fa7fdd210c326dab7 100644 (file)
@@ -38,11 +38,7 @@ AC_PROG_CC_C99
 have_vhost=no
 case $host_os in
  *linux* | *gnu* | *nacl*)
-    case $host_cpu in
-       x86_64|amd64)
-           have_vhost=yes
-           ;;
-    esac
+    have_vhost=yes
     AC_MSG_NOTICE([Applying feature macros for GNU build])
     AC_DEFINE(_GNU_SOURCE, 1, [_GNU_SOURCE])
     ;;
diff --git a/vhost.c b/vhost.c
index e95ea6cbcb8fca46e03b58dd2309e47489548508..c4f0c6ddcc5bec444f1d4bf2b5fe35855bd4820b 100644 (file)
--- a/vhost.c
+++ b/vhost.c
@@ -27,6 +27,7 @@
 #include <sys/stat.h>
 #include <sys/types.h>
 #include <sys/wait.h>
+#include <sys/mman.h>
 
 #include <ctype.h>
 #include <errno.h>
@@ -89,9 +90,9 @@ static int setup_vring(struct openconnect_info *vpninfo, int idx)
 
        struct vhost_vring_addr va = { };
        va.index = idx;
-       va.desc_user_addr = (uint64_t)vring->desc;
-       va.avail_user_addr = (uint64_t)vring->avail;
-       va.used_user_addr  = (uint64_t)vring->used;
+       va.desc_user_addr = (unsigned long)vring->desc;
+       va.avail_user_addr = (unsigned long)vring->avail;
+       va.used_user_addr  = (unsigned long)vring->used;
        if (ioctl(vpninfo->vhost_fd, VHOST_SET_VRING_ADDR, &va) < 0) {
                ret = -errno;
                vpn_progress(vpninfo, PRG_ERR, _("Failed to set vring #%d base: %s\n"),
@@ -128,6 +129,113 @@ static int setup_vring(struct openconnect_info *vpninfo, int idx)
        return 0;
 }
 
+/*
+ * This is awful. The kernel doesn't let us just ask for a 1:1 mapping of
+ * our virtual address space; we have to *know* the minimum and maximum
+ * addresses. We can't test it directly with VHOST_SET_MEM_TABLE because
+ * that actually succeeds, and the failure only occurs later when we try
+ * to use a buffer at an address that *is* valid, but our memory table
+ * *could* point to addresses that aren't. Ewww.
+ *
+ * So... attempt to work out what TASK_SIZE is for the kernel we happen
+ * to be running on right now...
+ */
+
+static int testaddr(unsigned long addr)
+{
+       void *res = mmap((void *)addr, getpagesize(), PROT_NONE,
+                        MAP_FIXED|MAP_ANONYMOUS, -1, 0);
+       if (res == MAP_FAILED) {
+               if (errno == EEXIST || errno == EINVAL)
+                       return 1;
+
+               /* We get ENOMEM for a bad virtual address */
+               return 0;
+       }
+       /* It shouldn't actually succeed without either MAP_SHARED or
+        * MAP_PRIVATE in the flags, but just in case... */
+       munmap((void *)addr, getpagesize());
+       return 1;
+}
+
+static int find_vmem_range(struct openconnect_info *vpninfo,
+                          struct vhost_memory *vmem)
+{
+       const unsigned long page_size = getpagesize();
+       unsigned long top;
+       unsigned long bottom;
+
+
+       top = -page_size;
+
+       if (testaddr(top)) {
+               vmem->regions[0].memory_size = top;
+               goto out;
+       }
+
+       /* 'top' is the lowest address known *not* to work */
+       bottom = top;
+       while (1) {
+               bottom >>= 1;
+               bottom &= ~(page_size - 1);
+               if (!bottom) {
+                       vpn_progress(vpninfo, PRG_ERR,
+                                    _("Failed to find virtual task size; search reached zero"));
+                       return -EINVAL;
+               }
+
+               if (testaddr(bottom))
+                       break;
+               top = bottom;
+       }
+
+       /* It's often a page or two below the boundary */
+       top -= page_size;
+       if (testaddr(top)) {
+               vmem->regions[0].memory_size = top;
+               goto out;
+       }
+       top -= page_size;
+       if (testaddr(top)) {
+               vmem->regions[0].memory_size = top;
+               goto out;
+       }
+
+       /* Now, bottom is the highest address known to work,
+          and we must search between it and 'top' which is
+          the lowest address known not to. */
+       while (bottom + page_size != top) {
+               unsigned long test = bottom + (top - bottom) / 2;
+               test &= ~(page_size - 1);
+
+               if (testaddr(test)) {
+                       bottom = test;
+                       continue;
+               }
+               test -= page_size;
+               if (testaddr(test)) {
+                       vmem->regions[0].memory_size = test;
+                       goto out;
+               }
+
+               test -= page_size;
+               if (testaddr(test)) {
+                       vmem->regions[0].memory_size = test;
+                       goto out;
+               }
+               top = test;
+       }
+       vmem->regions[0].memory_size = bottom;
+
+ out:
+       vmem->regions[0].guest_phys_addr = page_size;
+       vmem->regions[0].userspace_addr = page_size;
+       vpn_progress(vpninfo, PRG_DEBUG, _("Detected virtual address range 0x%lx-0x%lx\n"),
+                    page_size,
+                    (unsigned long)(page_size + vmem->regions[0].memory_size));
+       return 0;
+}
+
 #define OC_VHOST_NET_FEATURES ((1ULL << VHOST_NET_F_VIRTIO_NET_HDR) |  \
                               (1ULL << VIRTIO_F_VERSION_1) |           \
                               (1ULL << VIRTIO_RING_F_EVENT_IDX))
@@ -206,13 +314,11 @@ int setup_vhost(struct openconnect_info *vpninfo, int tun_fd)
 
        memset(vmem, 0, sizeof(*vmem) + sizeof(vmem->regions[0]));
        vmem->nregions = 1;
-#ifdef __x86_64__
-       vmem->regions[0].guest_phys_addr = 4096;
-       vmem->regions[0].memory_size = 0x7fffffffe000; /* Why doesn't it allow 0x7fffffff000? */
-       vmem->regions[0].userspace_addr = 4096;
-#else
-#error Need magic vhost numbers for this platform
-#endif
+
+       ret = find_vmem_range(vpninfo, vmem);
+       if (ret)
+               goto err;
+
        if (ioctl(vpninfo->vhost_fd, VHOST_SET_MEM_TABLE, vmem) < 0) {
                ret = -errno;
                vpn_progress(vpninfo, PRG_DEBUG, _("Failed to set vhost memory map: %s\n"),
@@ -340,7 +446,7 @@ static void dump_vring(struct openconnect_info *vpninfo, struct oc_vring *ring)
        for (int i = 0; i < vpninfo->vhost_ring_size + 1; i++)
                vpn_progress(vpninfo, PRG_ERR,
                             "%d %p %x %x\n", i,
-                            (void *)vio64(ring->desc[i].addr),
+                            (void *)(unsigned long)vio64(ring->desc[i].addr),
                             vio16(ring->avail->ring[i]),
                             vio32(ring->used->ring[i].id));
 }
@@ -457,7 +563,7 @@ static inline int process_ring(struct openconnect_info *vpninfo, int tx, uint64_
 
                if (!tx)
                        ring->desc[desc].flags = vio16(VRING_DESC_F_WRITE);
-               ring->desc[desc].addr = vio64((uint64_t)this + pkt_offset(virtio.h));
+               ring->desc[desc].addr = vio64((unsigned long)this + pkt_offset(virtio.h));
                ring->desc[desc].len = vio32(this->len + sizeof(this->virtio.h));
                barrier();