From: David Woodhouse <dwmw@amazon.co.uk>
Date: Wed, 23 Jun 2021 22:25:51 +0000 (+0100)
Subject: net: add header len parameter to tun_get_socket(), tap_get_socket()
X-Git-Url: https://www.infradead.org/git/?a=commitdiff_plain;h=3ab3af720af8cfab34248424a44a7e4786a954d2;p=users%2Fdwmw2%2Flinux.git

net: add header len parameter to tun_get_socket(), tap_get_socket()

The vhost-net driver was making wild assumptions about the header length
of the underlying tun/tap socket. Then it was discarding packets if
the number of bytes it got from sock_recvmsg() didn't precisely match
its guess.

Fix it to get the correct information along with the socket itself.
As a side-effect, this means that tun_get_socket() won't work if the
tun file isn't actually connected to a device, since there's no 'tun'
yet in that case to get the information from.

On the receive side, where the tun device generates the virtio_net_hdr
but VIRITO_NET_F_MSG_RXBUF was negotiated and vhost-net needs to fill
in the 'num_buffers' field on top of the existing virtio_net_hdr, fix
that to use 'sock_hlen - 2' as the location, which means that it goes
in the right place regardless of whether the tun device is using an
additional tun_pi header or not. In this case, the user should have
configured the tun device with a vnet hdr size of 12, to make room.

Fixes: 8dd014adfea6f ("vhost-net: mergeable buffers support")
Signed-off-by: David Woodhouse <dwmw@amazon.co.uk>
---

diff --git a/drivers/net/tap.c b/drivers/net/tap.c
index 8e3a28ba6b282..2170a0d3d34ce 100644
--- a/drivers/net/tap.c
+++ b/drivers/net/tap.c
@@ -1246,7 +1246,7 @@ static const struct proto_ops tap_socket_ops = {
  * attached to a device.  The returned object works like a packet socket, it
  * can be used for sock_sendmsg/sock_recvmsg.  The caller is responsible for
  * holding a reference to the file for as long as the socket is in use. */
-struct socket *tap_get_socket(struct file *file)
+struct socket *tap_get_socket(struct file *file, size_t *hlen)
 {
 	struct tap_queue *q;
 	if (file->f_op != &tap_fops)
@@ -1254,6 +1254,9 @@ struct socket *tap_get_socket(struct file *file)
 	q = file->private_data;
 	if (!q)
 		return ERR_PTR(-EBADFD);
+	if (hlen)
+		*hlen = (q->flags & IFF_VNET_HDR) ? q->vnet_hdr_sz : 0;
+
 	return &q->sock;
 }
 EXPORT_SYMBOL_GPL(tap_get_socket);
diff --git a/drivers/net/tun.c b/drivers/net/tun.c
index 4cf38be26dc99..67b406fa08819 100644
--- a/drivers/net/tun.c
+++ b/drivers/net/tun.c
@@ -3649,7 +3649,7 @@ static void tun_cleanup(void)
  * attached to a device.  The returned object works like a packet socket, it
  * can be used for sock_sendmsg/sock_recvmsg.  The caller is responsible for
  * holding a reference to the file for as long as the socket is in use. */
-struct socket *tun_get_socket(struct file *file)
+struct socket *tun_get_socket(struct file *file, size_t *hlen)
 {
 	struct tun_file *tfile;
 	if (file->f_op != &tun_fops)
@@ -3657,6 +3657,20 @@ struct socket *tun_get_socket(struct file *file)
 	tfile = file->private_data;
 	if (!tfile)
 		return ERR_PTR(-EBADFD);
+
+	if (hlen) {
+		struct tun_struct *tun = tun_get(tfile);
+		size_t len = 0;
+
+		if (!tun)
+			return ERR_PTR(-ENOTCONN);
+		if (tun->flags & IFF_VNET_HDR)
+			len += READ_ONCE(tun->vnet_hdr_sz);
+		if (!(tun->flags & IFF_NO_PI))
+			len += sizeof(struct tun_pi);
+		tun_put(tun);
+		*hlen = len;
+	}
 	return &tfile->socket;
 }
 EXPORT_SYMBOL_GPL(tun_get_socket);
diff --git a/drivers/vhost/net.c b/drivers/vhost/net.c
index df82b124170ec..b92a7144ed90d 100644
--- a/drivers/vhost/net.c
+++ b/drivers/vhost/net.c
@@ -1143,7 +1143,8 @@ static void handle_rx(struct vhost_net *net)
 
 	vq_log = unlikely(vhost_has_feature(vq, VHOST_F_LOG_ALL)) ?
 		vq->log : NULL;
-	mergeable = vhost_has_feature(vq, VIRTIO_NET_F_MRG_RXBUF);
+	mergeable = vhost_has_feature(vq, VIRTIO_NET_F_MRG_RXBUF) &&
+		(vhost_hlen || sock_hlen >= sizeof(num_buffers));
 
 	do {
 		sock_len = vhost_net_rx_peek_head_len(net, sock->sk,
@@ -1213,9 +1214,10 @@ static void handle_rx(struct vhost_net *net)
 			}
 		} else {
 			/* Header came from socket; we'll need to patch
-			 * ->num_buffers over if VIRTIO_NET_F_MRG_RXBUF
+			 * ->num_buffers over the last two bytes if
+			 * VIRTIO_NET_F_MRG_RXBUF is enabled.
 			 */
-			iov_iter_advance(&fixup, sizeof(hdr));
+			iov_iter_advance(&fixup, sock_hlen - 2);
 		}
 		/* TODO: Should check and handle checksum. */
 
@@ -1420,7 +1422,7 @@ static int vhost_net_release(struct inode *inode, struct file *f)
 	return 0;
 }
 
-static struct socket *get_raw_socket(int fd)
+static struct socket *get_raw_socket(int fd, size_t *hlen)
 {
 	int r;
 	struct socket *sock = sockfd_lookup(fd, &r);
@@ -1438,6 +1440,7 @@ static struct socket *get_raw_socket(int fd)
 		r = -EPFNOSUPPORT;
 		goto err;
 	}
+	*hlen = 0;
 	return sock;
 err:
 	sockfd_put(sock);
@@ -1463,33 +1466,33 @@ out:
 	return ring;
 }
 
-static struct socket *get_tap_socket(int fd)
+static struct socket *get_tap_socket(int fd, size_t *hlen)
 {
 	struct file *file = fget(fd);
 	struct socket *sock;
 
 	if (!file)
 		return ERR_PTR(-EBADF);
-	sock = tun_get_socket(file);
+	sock = tun_get_socket(file, hlen);
 	if (!IS_ERR(sock))
 		return sock;
-	sock = tap_get_socket(file);
+	sock = tap_get_socket(file, hlen);
 	if (IS_ERR(sock))
 		fput(file);
 	return sock;
 }
 
-static struct socket *get_socket(int fd)
+static struct socket *get_socket(int fd, size_t *hlen)
 {
 	struct socket *sock;
 
 	/* special case to disable backend */
 	if (fd == -1)
 		return NULL;
-	sock = get_raw_socket(fd);
+	sock = get_raw_socket(fd, hlen);
 	if (!IS_ERR(sock))
 		return sock;
-	sock = get_tap_socket(fd);
+	sock = get_tap_socket(fd, hlen);
 	if (!IS_ERR(sock))
 		return sock;
 	return ERR_PTR(-ENOTSOCK);
@@ -1521,7 +1524,7 @@ static long vhost_net_set_backend(struct vhost_net *n, unsigned index, int fd)
 		r = -EFAULT;
 		goto err_vq;
 	}
-	sock = get_socket(fd);
+	sock = get_socket(fd, &nvq->sock_hlen);
 	if (IS_ERR(sock)) {
 		r = PTR_ERR(sock);
 		goto err_vq;
@@ -1621,7 +1624,7 @@ done:
 
 static int vhost_net_set_features(struct vhost_net *n, u64 features)
 {
-	size_t vhost_hlen, sock_hlen, hdr_len;
+	size_t vhost_hlen, hdr_len;
 	int i;
 
 	hdr_len = (features & ((1ULL << VIRTIO_NET_F_MRG_RXBUF) |
@@ -1631,11 +1634,8 @@ static int vhost_net_set_features(struct vhost_net *n, u64 features)
 	if (features & (1 << VHOST_NET_F_VIRTIO_NET_HDR)) {
 		/* vhost provides vnet_hdr */
 		vhost_hlen = hdr_len;
-		sock_hlen = 0;
 	} else {
-		/* socket provides vnet_hdr */
 		vhost_hlen = 0;
-		sock_hlen = hdr_len;
 	}
 	mutex_lock(&n->dev.mutex);
 	if ((features & (1 << VHOST_F_LOG_ALL)) &&
@@ -1651,7 +1651,6 @@ static int vhost_net_set_features(struct vhost_net *n, u64 features)
 		mutex_lock(&n->vqs[i].vq.mutex);
 		n->vqs[i].vq.acked_features = features;
 		n->vqs[i].vhost_hlen = vhost_hlen;
-		n->vqs[i].sock_hlen = sock_hlen;
 		mutex_unlock(&n->vqs[i].vq.mutex);
 	}
 	mutex_unlock(&n->dev.mutex);
diff --git a/include/linux/if_tap.h b/include/linux/if_tap.h
index 915a187cfabda..b460ba98f34e5 100644
--- a/include/linux/if_tap.h
+++ b/include/linux/if_tap.h
@@ -3,14 +3,14 @@
 #define _LINUX_IF_TAP_H_
 
 #if IS_ENABLED(CONFIG_TAP)
-struct socket *tap_get_socket(struct file *);
+struct socket *tap_get_socket(struct file *, size_t *);
 struct ptr_ring *tap_get_ptr_ring(struct file *file);
 #else
 #include <linux/err.h>
 #include <linux/errno.h>
 struct file;
 struct socket;
-static inline struct socket *tap_get_socket(struct file *f)
+static inline struct socket *tap_get_socket(struct file *f, size_t *)
 {
 	return ERR_PTR(-EINVAL);
 }
diff --git a/include/linux/if_tun.h b/include/linux/if_tun.h
index 2a7660843444d..8a7debd3f6639 100644
--- a/include/linux/if_tun.h
+++ b/include/linux/if_tun.h
@@ -25,7 +25,7 @@ struct tun_xdp_hdr {
 };
 
 #if defined(CONFIG_TUN) || defined(CONFIG_TUN_MODULE)
-struct socket *tun_get_socket(struct file *);
+struct socket *tun_get_socket(struct file *, size_t *);
 struct ptr_ring *tun_get_tx_ring(struct file *file);
 static inline bool tun_is_xdp_frame(void *ptr)
 {
@@ -45,7 +45,7 @@ void tun_ptr_free(void *ptr);
 #include <linux/errno.h>
 struct file;
 struct socket;
-static inline struct socket *tun_get_socket(struct file *f)
+static inline struct socket *tun_get_socket(struct file *f, size_t *)
 {
 	return ERR_PTR(-EINVAL);
 }