]> www.infradead.org Git - users/dwmw2/linux.git/commitdiff
net: add header len parameter to tun_get_socket(), tap_get_socket()
authorDavid Woodhouse <dwmw@amazon.co.uk>
Wed, 23 Jun 2021 22:25:51 +0000 (23:25 +0100)
committerDavid Woodhouse <dwmw@amazon.co.uk>
Thu, 24 Jun 2021 12:14:04 +0000 (13:14 +0100)
The vhost-net driver was making wild assumptions about the header length
of the underlying tun/tap socket. Then it was discarding packets if
the number of bytes it got from sock_recvmsg() didn't precisely match
its guess.

Fix it to get the correct information along with the socket itself.
As a side-effect, this means that tun_get_socket() won't work if the
tun file isn't actually connected to a device, since there's no 'tun'
yet in that case to get the information from.

On the receive side, where the tun device generates the virtio_net_hdr
but VIRITO_NET_F_MSG_RXBUF was negotiated and vhost-net needs to fill
in the 'num_buffers' field on top of the existing virtio_net_hdr, fix
that to use 'sock_hlen - 2' as the location, which means that it goes
in the right place regardless of whether the tun device is using an
additional tun_pi header or not. In this case, the user should have
configured the tun device with a vnet hdr size of 12, to make room.

Fixes: 8dd014adfea6f ("vhost-net: mergeable buffers support")
Signed-off-by: David Woodhouse <dwmw@amazon.co.uk>
drivers/net/tap.c
drivers/net/tun.c
drivers/vhost/net.c
include/linux/if_tap.h
include/linux/if_tun.h

index 8e3a28ba6b28238d88dddd9104921c86e49fff87..2170a0d3d34ce5271e70b531bbbb1473d1e3a5e8 100644 (file)
@@ -1246,7 +1246,7 @@ static const struct proto_ops tap_socket_ops = {
  * attached to a device.  The returned object works like a packet socket, it
  * can be used for sock_sendmsg/sock_recvmsg.  The caller is responsible for
  * holding a reference to the file for as long as the socket is in use. */
-struct socket *tap_get_socket(struct file *file)
+struct socket *tap_get_socket(struct file *file, size_t *hlen)
 {
        struct tap_queue *q;
        if (file->f_op != &tap_fops)
@@ -1254,6 +1254,9 @@ struct socket *tap_get_socket(struct file *file)
        q = file->private_data;
        if (!q)
                return ERR_PTR(-EBADFD);
+       if (hlen)
+               *hlen = (q->flags & IFF_VNET_HDR) ? q->vnet_hdr_sz : 0;
+
        return &q->sock;
 }
 EXPORT_SYMBOL_GPL(tap_get_socket);
index 4cf38be26dc99e455349f22a59eee58002e54424..67b406fa088194203b696a6bd2ca449bac4e2e6e 100644 (file)
@@ -3649,7 +3649,7 @@ static void tun_cleanup(void)
  * attached to a device.  The returned object works like a packet socket, it
  * can be used for sock_sendmsg/sock_recvmsg.  The caller is responsible for
  * holding a reference to the file for as long as the socket is in use. */
-struct socket *tun_get_socket(struct file *file)
+struct socket *tun_get_socket(struct file *file, size_t *hlen)
 {
        struct tun_file *tfile;
        if (file->f_op != &tun_fops)
@@ -3657,6 +3657,20 @@ struct socket *tun_get_socket(struct file *file)
        tfile = file->private_data;
        if (!tfile)
                return ERR_PTR(-EBADFD);
+
+       if (hlen) {
+               struct tun_struct *tun = tun_get(tfile);
+               size_t len = 0;
+
+               if (!tun)
+                       return ERR_PTR(-ENOTCONN);
+               if (tun->flags & IFF_VNET_HDR)
+                       len += READ_ONCE(tun->vnet_hdr_sz);
+               if (!(tun->flags & IFF_NO_PI))
+                       len += sizeof(struct tun_pi);
+               tun_put(tun);
+               *hlen = len;
+       }
        return &tfile->socket;
 }
 EXPORT_SYMBOL_GPL(tun_get_socket);
index df82b124170ec9b3dc8b6e574dc27011b868e4e2..b92a7144ed90de65768bbc140408b2ac14e74f68 100644 (file)
@@ -1143,7 +1143,8 @@ static void handle_rx(struct vhost_net *net)
 
        vq_log = unlikely(vhost_has_feature(vq, VHOST_F_LOG_ALL)) ?
                vq->log : NULL;
-       mergeable = vhost_has_feature(vq, VIRTIO_NET_F_MRG_RXBUF);
+       mergeable = vhost_has_feature(vq, VIRTIO_NET_F_MRG_RXBUF) &&
+               (vhost_hlen || sock_hlen >= sizeof(num_buffers));
 
        do {
                sock_len = vhost_net_rx_peek_head_len(net, sock->sk,
@@ -1213,9 +1214,10 @@ static void handle_rx(struct vhost_net *net)
                        }
                } else {
                        /* Header came from socket; we'll need to patch
-                        * ->num_buffers over if VIRTIO_NET_F_MRG_RXBUF
+                        * ->num_buffers over the last two bytes if
+                        * VIRTIO_NET_F_MRG_RXBUF is enabled.
                         */
-                       iov_iter_advance(&fixup, sizeof(hdr));
+                       iov_iter_advance(&fixup, sock_hlen - 2);
                }
                /* TODO: Should check and handle checksum. */
 
@@ -1420,7 +1422,7 @@ static int vhost_net_release(struct inode *inode, struct file *f)
        return 0;
 }
 
-static struct socket *get_raw_socket(int fd)
+static struct socket *get_raw_socket(int fd, size_t *hlen)
 {
        int r;
        struct socket *sock = sockfd_lookup(fd, &r);
@@ -1438,6 +1440,7 @@ static struct socket *get_raw_socket(int fd)
                r = -EPFNOSUPPORT;
                goto err;
        }
+       *hlen = 0;
        return sock;
 err:
        sockfd_put(sock);
@@ -1463,33 +1466,33 @@ out:
        return ring;
 }
 
-static struct socket *get_tap_socket(int fd)
+static struct socket *get_tap_socket(int fd, size_t *hlen)
 {
        struct file *file = fget(fd);
        struct socket *sock;
 
        if (!file)
                return ERR_PTR(-EBADF);
-       sock = tun_get_socket(file);
+       sock = tun_get_socket(file, hlen);
        if (!IS_ERR(sock))
                return sock;
-       sock = tap_get_socket(file);
+       sock = tap_get_socket(file, hlen);
        if (IS_ERR(sock))
                fput(file);
        return sock;
 }
 
-static struct socket *get_socket(int fd)
+static struct socket *get_socket(int fd, size_t *hlen)
 {
        struct socket *sock;
 
        /* special case to disable backend */
        if (fd == -1)
                return NULL;
-       sock = get_raw_socket(fd);
+       sock = get_raw_socket(fd, hlen);
        if (!IS_ERR(sock))
                return sock;
-       sock = get_tap_socket(fd);
+       sock = get_tap_socket(fd, hlen);
        if (!IS_ERR(sock))
                return sock;
        return ERR_PTR(-ENOTSOCK);
@@ -1521,7 +1524,7 @@ static long vhost_net_set_backend(struct vhost_net *n, unsigned index, int fd)
                r = -EFAULT;
                goto err_vq;
        }
-       sock = get_socket(fd);
+       sock = get_socket(fd, &nvq->sock_hlen);
        if (IS_ERR(sock)) {
                r = PTR_ERR(sock);
                goto err_vq;
@@ -1621,7 +1624,7 @@ done:
 
 static int vhost_net_set_features(struct vhost_net *n, u64 features)
 {
-       size_t vhost_hlen, sock_hlen, hdr_len;
+       size_t vhost_hlen, hdr_len;
        int i;
 
        hdr_len = (features & ((1ULL << VIRTIO_NET_F_MRG_RXBUF) |
@@ -1631,11 +1634,8 @@ static int vhost_net_set_features(struct vhost_net *n, u64 features)
        if (features & (1 << VHOST_NET_F_VIRTIO_NET_HDR)) {
                /* vhost provides vnet_hdr */
                vhost_hlen = hdr_len;
-               sock_hlen = 0;
        } else {
-               /* socket provides vnet_hdr */
                vhost_hlen = 0;
-               sock_hlen = hdr_len;
        }
        mutex_lock(&n->dev.mutex);
        if ((features & (1 << VHOST_F_LOG_ALL)) &&
@@ -1651,7 +1651,6 @@ static int vhost_net_set_features(struct vhost_net *n, u64 features)
                mutex_lock(&n->vqs[i].vq.mutex);
                n->vqs[i].vq.acked_features = features;
                n->vqs[i].vhost_hlen = vhost_hlen;
-               n->vqs[i].sock_hlen = sock_hlen;
                mutex_unlock(&n->vqs[i].vq.mutex);
        }
        mutex_unlock(&n->dev.mutex);
index 915a187cfabdac61da21a58c9ab00862c9deaae8..b460ba98f34e55dfc090b60de567ed18ff1f64a2 100644 (file)
@@ -3,14 +3,14 @@
 #define _LINUX_IF_TAP_H_
 
 #if IS_ENABLED(CONFIG_TAP)
-struct socket *tap_get_socket(struct file *);
+struct socket *tap_get_socket(struct file *, size_t *);
 struct ptr_ring *tap_get_ptr_ring(struct file *file);
 #else
 #include <linux/err.h>
 #include <linux/errno.h>
 struct file;
 struct socket;
-static inline struct socket *tap_get_socket(struct file *f)
+static inline struct socket *tap_get_socket(struct file *f, size_t *)
 {
        return ERR_PTR(-EINVAL);
 }
index 2a7660843444de5059551632b817239630b401b4..8a7debd3f663936877cf01abb716e24cee1b6137 100644 (file)
@@ -25,7 +25,7 @@ struct tun_xdp_hdr {
 };
 
 #if defined(CONFIG_TUN) || defined(CONFIG_TUN_MODULE)
-struct socket *tun_get_socket(struct file *);
+struct socket *tun_get_socket(struct file *, size_t *);
 struct ptr_ring *tun_get_tx_ring(struct file *file);
 static inline bool tun_is_xdp_frame(void *ptr)
 {
@@ -45,7 +45,7 @@ void tun_ptr_free(void *ptr);
 #include <linux/errno.h>
 struct file;
 struct socket;
-static inline struct socket *tun_get_socket(struct file *f)
+static inline struct socket *tun_get_socket(struct file *f, size_t *)
 {
        return ERR_PTR(-EINVAL);
 }