--- /dev/null
+// SPDX-License-Identifier: GPL-2.0
+
+#define _GNU_SOURCE
+
+#include <arpa/inet.h>
+#include <errno.h>
+#include <error.h>
+#include <fcntl.h>
+#include <limits.h>
+#include <linux/filter.h>
+#include <linux/bpf.h>
+#include <linux/if_packet.h>
+#include <linux/if_vlan.h>
+#include <linux/virtio_net.h>
+#include <net/if.h>
+#include <net/ethernet.h>
+#include <netinet/ip.h>
+#include <netinet/udp.h>
+#include <poll.h>
+#include <sched.h>
+#include <stdbool.h>
+#include <stdint.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include <sys/mman.h>
+#include <sys/socket.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include "psock_lib.h"
+
+static bool    cfg_use_bind;
+static bool    cfg_use_csum_off;
+static bool    cfg_use_csum_off_bad;
+static bool    cfg_use_dgram;
+static bool    cfg_use_gso;
+static bool    cfg_use_qdisc_bypass;
+static bool    cfg_use_vlan;
+static bool    cfg_use_vnet;
+
+static char    *cfg_ifname = "lo";
+static int     cfg_mtu = 1500;
+static int     cfg_payload_len = DATA_LEN;
+static int     cfg_truncate_len = INT_MAX;
+static uint16_t        cfg_port = 8000;
+
+/* test sending up to max mtu + 1 */
+#define TEST_SZ        (sizeof(struct virtio_net_hdr) + ETH_HLEN + ETH_MAX_MTU + 1)
+
+static char tbuf[TEST_SZ], rbuf[TEST_SZ];
+
+static unsigned long add_csum_hword(const uint16_t *start, int num_u16)
+{
+       unsigned long sum = 0;
+       int i;
+
+       for (i = 0; i < num_u16; i++)
+               sum += start[i];
+
+       return sum;
+}
+
+static uint16_t build_ip_csum(const uint16_t *start, int num_u16,
+                             unsigned long sum)
+{
+       sum += add_csum_hword(start, num_u16);
+
+       while (sum >> 16)
+               sum = (sum & 0xffff) + (sum >> 16);
+
+       return ~sum;
+}
+
+static int build_vnet_header(void *header)
+{
+       struct virtio_net_hdr *vh = header;
+
+       vh->hdr_len = ETH_HLEN + sizeof(struct iphdr) + sizeof(struct udphdr);
+
+       if (cfg_use_csum_off) {
+               vh->flags |= VIRTIO_NET_HDR_F_NEEDS_CSUM;
+               vh->csum_start = ETH_HLEN + sizeof(struct iphdr);
+               vh->csum_offset = __builtin_offsetof(struct udphdr, check);
+
+               /* position check field exactly one byte beyond end of packet */
+               if (cfg_use_csum_off_bad)
+                       vh->csum_start += sizeof(struct udphdr) + cfg_payload_len -
+                                         vh->csum_offset - 1;
+       }
+
+       if (cfg_use_gso) {
+               vh->gso_type = VIRTIO_NET_HDR_GSO_UDP;
+               vh->gso_size = cfg_mtu - sizeof(struct iphdr);
+       }
+
+       return sizeof(*vh);
+}
+
+static int build_eth_header(void *header)
+{
+       struct ethhdr *eth = header;
+
+       if (cfg_use_vlan) {
+               uint16_t *tag = header + ETH_HLEN;
+
+               eth->h_proto = htons(ETH_P_8021Q);
+               tag[1] = htons(ETH_P_IP);
+               return ETH_HLEN + 4;
+       }
+
+       eth->h_proto = htons(ETH_P_IP);
+       return ETH_HLEN;
+}
+
+static int build_ipv4_header(void *header, int payload_len)
+{
+       struct iphdr *iph = header;
+
+       iph->ihl = 5;
+       iph->version = 4;
+       iph->ttl = 8;
+       iph->tot_len = htons(sizeof(*iph) + sizeof(struct udphdr) + payload_len);
+       iph->id = htons(1337);
+       iph->protocol = IPPROTO_UDP;
+       iph->saddr = htonl((172 << 24) | (17 << 16) | 2);
+       iph->daddr = htonl((172 << 24) | (17 << 16) | 1);
+       iph->check = build_ip_csum((void *) iph, iph->ihl << 1, 0);
+
+       return iph->ihl << 2;
+}
+
+static int build_udp_header(void *header, int payload_len)
+{
+       const int alen = sizeof(uint32_t);
+       struct udphdr *udph = header;
+       int len = sizeof(*udph) + payload_len;
+
+       udph->source = htons(9);
+       udph->dest = htons(cfg_port);
+       udph->len = htons(len);
+
+       if (cfg_use_csum_off)
+               udph->check = build_ip_csum(header - (2 * alen), alen,
+                                           htons(IPPROTO_UDP) + udph->len);
+       else
+               udph->check = 0;
+
+       return sizeof(*udph);
+}
+
+static int build_packet(int payload_len)
+{
+       int off = 0;
+
+       off += build_vnet_header(tbuf);
+       off += build_eth_header(tbuf + off);
+       off += build_ipv4_header(tbuf + off, payload_len);
+       off += build_udp_header(tbuf + off, payload_len);
+
+       if (off + payload_len > sizeof(tbuf))
+               error(1, 0, "payload length exceeds max");
+
+       memset(tbuf + off, DATA_CHAR, payload_len);
+
+       return off + payload_len;
+}
+
+static void do_bind(int fd)
+{
+       struct sockaddr_ll laddr = {0};
+
+       laddr.sll_family = AF_PACKET;
+       laddr.sll_protocol = htons(ETH_P_IP);
+       laddr.sll_ifindex = if_nametoindex(cfg_ifname);
+       if (!laddr.sll_ifindex)
+               error(1, errno, "if_nametoindex");
+
+       if (bind(fd, (void *)&laddr, sizeof(laddr)))
+               error(1, errno, "bind");
+}
+
+static void do_send(int fd, char *buf, int len)
+{
+       int ret;
+
+       if (!cfg_use_vnet) {
+               buf += sizeof(struct virtio_net_hdr);
+               len -= sizeof(struct virtio_net_hdr);
+       }
+       if (cfg_use_dgram) {
+               buf += ETH_HLEN;
+               len -= ETH_HLEN;
+       }
+
+       if (cfg_use_bind) {
+               ret = write(fd, buf, len);
+       } else {
+               struct sockaddr_ll laddr = {0};
+
+               laddr.sll_protocol = htons(ETH_P_IP);
+               laddr.sll_ifindex = if_nametoindex(cfg_ifname);
+               if (!laddr.sll_ifindex)
+                       error(1, errno, "if_nametoindex");
+
+               ret = sendto(fd, buf, len, 0, (void *)&laddr, sizeof(laddr));
+       }
+
+       if (ret == -1)
+               error(1, errno, "write");
+       if (ret != len)
+               error(1, 0, "write: %u %u", ret, len);
+
+       fprintf(stderr, "tx: %u\n", ret);
+}
+
+static int do_tx(void)
+{
+       const int one = 1;
+       int fd, len;
+
+       fd = socket(PF_PACKET, cfg_use_dgram ? SOCK_DGRAM : SOCK_RAW, 0);
+       if (fd == -1)
+               error(1, errno, "socket t");
+
+       if (cfg_use_bind)
+               do_bind(fd);
+
+       if (cfg_use_qdisc_bypass &&
+           setsockopt(fd, SOL_PACKET, PACKET_QDISC_BYPASS, &one, sizeof(one)))
+               error(1, errno, "setsockopt qdisc bypass");
+
+       if (cfg_use_vnet &&
+           setsockopt(fd, SOL_PACKET, PACKET_VNET_HDR, &one, sizeof(one)))
+               error(1, errno, "setsockopt vnet");
+
+       len = build_packet(cfg_payload_len);
+
+       if (cfg_truncate_len < len)
+               len = cfg_truncate_len;
+
+       do_send(fd, tbuf, len);
+
+       if (close(fd))
+               error(1, errno, "close t");
+
+       return len;
+}
+
+static int setup_rx(void)
+{
+       struct timeval tv = { .tv_usec = 100 * 1000 };
+       struct sockaddr_in raddr = {0};
+       int fd;
+
+       fd = socket(PF_INET, SOCK_DGRAM, 0);
+       if (fd == -1)
+               error(1, errno, "socket r");
+
+       if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)))
+               error(1, errno, "setsockopt rcv timeout");
+
+       raddr.sin_family = AF_INET;
+       raddr.sin_port = htons(cfg_port);
+       raddr.sin_addr.s_addr = htonl(INADDR_ANY);
+
+       if (bind(fd, (void *)&raddr, sizeof(raddr)))
+               error(1, errno, "bind r");
+
+       return fd;
+}
+
+static void do_rx(int fd, int expected_len, char *expected)
+{
+       int ret;
+
+       ret = recv(fd, rbuf, sizeof(rbuf), 0);
+       if (ret == -1)
+               error(1, errno, "recv");
+       if (ret != expected_len)
+               error(1, 0, "recv: %u != %u", ret, expected_len);
+
+       if (memcmp(rbuf, expected, ret))
+               error(1, 0, "recv: data mismatch");
+
+       fprintf(stderr, "rx: %u\n", ret);
+}
+
+static int setup_sniffer(void)
+{
+       struct timeval tv = { .tv_usec = 100 * 1000 };
+       int fd;
+
+       fd = socket(PF_PACKET, SOCK_RAW, 0);
+       if (fd == -1)
+               error(1, errno, "socket p");
+
+       if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)))
+               error(1, errno, "setsockopt rcv timeout");
+
+       pair_udp_setfilter(fd);
+       do_bind(fd);
+
+       return fd;
+}
+
+static void parse_opts(int argc, char **argv)
+{
+       int c;
+
+       while ((c = getopt(argc, argv, "bcCdgl:qt:vV")) != -1) {
+               switch (c) {
+               case 'b':
+                       cfg_use_bind = true;
+                       break;
+               case 'c':
+                       cfg_use_csum_off = true;
+                       break;
+               case 'C':
+                       cfg_use_csum_off_bad = true;
+                       break;
+               case 'd':
+                       cfg_use_dgram = true;
+                       break;
+               case 'g':
+                       cfg_use_gso = true;
+                       break;
+               case 'l':
+                       cfg_payload_len = strtoul(optarg, NULL, 0);
+                       break;
+               case 'q':
+                       cfg_use_qdisc_bypass = true;
+                       break;
+               case 't':
+                       cfg_truncate_len = strtoul(optarg, NULL, 0);
+                       break;
+               case 'v':
+                       cfg_use_vnet = true;
+                       break;
+               case 'V':
+                       cfg_use_vlan = true;
+                       break;
+               default:
+                       error(1, 0, "%s: parse error", argv[0]);
+               }
+       }
+
+       if (cfg_use_vlan && cfg_use_dgram)
+               error(1, 0, "option vlan (-V) conflicts with dgram (-d)");
+
+       if (cfg_use_csum_off && !cfg_use_vnet)
+               error(1, 0, "option csum offload (-c) requires vnet (-v)");
+
+       if (cfg_use_csum_off_bad && !cfg_use_csum_off)
+               error(1, 0, "option csum bad (-C) requires csum offload (-c)");
+
+       if (cfg_use_gso && !cfg_use_csum_off)
+               error(1, 0, "option gso (-g) requires csum offload (-c)");
+}
+
+static void run_test(void)
+{
+       int fdr, fds, total_len;
+
+       fdr = setup_rx();
+       fds = setup_sniffer();
+
+       total_len = do_tx();
+
+       /* BPF filter accepts only this length, vlan changes MAC */
+       if (cfg_payload_len == DATA_LEN && !cfg_use_vlan)
+               do_rx(fds, total_len - sizeof(struct virtio_net_hdr),
+                     tbuf + sizeof(struct virtio_net_hdr));
+
+       do_rx(fdr, cfg_payload_len, tbuf + total_len - cfg_payload_len);
+
+       if (close(fds))
+               error(1, errno, "close s");
+       if (close(fdr))
+               error(1, errno, "close r");
+}
+
+int main(int argc, char **argv)
+{
+       parse_opts(argc, argv);
+
+       if (system("ip link set dev lo mtu 1500"))
+               error(1, errno, "ip link set mtu");
+       if (system("ip addr add dev lo 172.17.0.1/24"))
+               error(1, errno, "ip addr add");
+
+       run_test();
+
+       fprintf(stderr, "OK\n\n");
+       return 0;
+}
 
--- /dev/null
+#!/bin/bash
+# SPDX-License-Identifier: GPL-2.0
+#
+# Run a series of packet socket send regression tests
+
+set -e
+
+readonly mtu=1500
+readonly iphlen=20
+readonly udphlen=8
+
+readonly vnet_hlen=10
+readonly eth_hlen=14
+
+readonly mss="$((${mtu} - ${iphlen} - ${udphlen}))"
+readonly mss_exceeds="$((${mss} + 1))"
+
+readonly max_mtu=65535
+readonly max_mss="$((${max_mtu} - ${iphlen} - ${udphlen}))"
+readonly max_mss_exceeds="$((${max_mss} + 1))"
+
+# functional checks (not a full cross-product)
+
+echo "dgram"
+./in_netns.sh ./psock_snd -d
+
+echo "dgram bind"
+./in_netns.sh ./psock_snd -d -b
+
+echo "raw"
+./in_netns.sh ./psock_snd
+
+echo "raw bind"
+./in_netns.sh ./psock_snd -b
+
+echo "raw qdisc bypass"
+./in_netns.sh ./psock_snd -q
+
+echo "raw vlan"
+./in_netns.sh ./psock_snd -V
+
+echo "raw vnet hdr"
+./in_netns.sh ./psock_snd -v
+
+echo "raw csum_off"
+./in_netns.sh ./psock_snd -v -c
+
+echo "raw csum_off with bad offset (fails)"
+(! ./in_netns.sh ./psock_snd -v -c -C)
+
+
+# bounds check: send {max, max + 1, min, min - 1} lengths
+
+echo "raw min size"
+./in_netns.sh ./psock_snd -l 0
+
+echo "raw mtu size"
+./in_netns.sh ./psock_snd -l "${mss}"
+
+echo "raw mtu size + 1 (fails)"
+(! ./in_netns.sh ./psock_snd -l "${mss_exceeds}")
+
+# fails due to ARPHRD_ETHER check in packet_extra_vlan_len_allowed
+#
+# echo "raw vlan mtu size"
+# ./in_netns.sh ./psock_snd -V -l "${mss}"
+
+echo "raw vlan mtu size + 1 (fails)"
+(! ./in_netns.sh ./psock_snd -V -l "${mss_exceeds}")
+
+echo "dgram mtu size"
+./in_netns.sh ./psock_snd -d -l "${mss}"
+
+echo "dgram mtu size + 1 (fails)"
+(! ./in_netns.sh ./psock_snd -d -l "${mss_exceeds}")
+
+echo "raw truncate hlen (fails: does not arrive)"
+(! ./in_netns.sh ./psock_snd -t "$((${vnet_hlen} + ${eth_hlen}))")
+
+echo "raw truncate hlen - 1 (fails: EINVAL)"
+(! ./in_netns.sh ./psock_snd -t "$((${vnet_hlen} + ${eth_hlen} - 1))")
+
+
+# gso checks: implies -l, because with gso len must exceed gso_size
+
+echo "raw gso min size"
+./in_netns.sh ./psock_snd -v -c -g -l "${mss_exceeds}"
+
+echo "raw gso min size - 1 (fails)"
+(! ./in_netns.sh ./psock_snd -v -c -g -l "${mss}")
+
+echo "raw gso max size"
+./in_netns.sh ./psock_snd -v -c -g -l "${max_mss}"
+
+echo "raw gso max size + 1 (fails)"
+(! ./in_netns.sh ./psock_snd -v -c -g -l "${max_mss_exceeds}")
+
+echo "OK. All tests passed"