#include <uapi/linux/udp.h>
 #include <uapi/linux/virtio_net.h>
 
+static inline bool virtio_net_hdr_match_proto(__be16 protocol, __u8 gso_type)
+{
+       switch (gso_type & ~VIRTIO_NET_HDR_GSO_ECN) {
+       case VIRTIO_NET_HDR_GSO_TCPV4:
+               return protocol == cpu_to_be16(ETH_P_IP);
+       case VIRTIO_NET_HDR_GSO_TCPV6:
+               return protocol == cpu_to_be16(ETH_P_IPV6);
+       case VIRTIO_NET_HDR_GSO_UDP:
+               return protocol == cpu_to_be16(ETH_P_IP) ||
+                      protocol == cpu_to_be16(ETH_P_IPV6);
+       default:
+               return false;
+       }
+}
+
 static inline int virtio_net_hdr_set_proto(struct sk_buff *skb,
                                           const struct virtio_net_hdr *hdr)
 {
                        if (!skb->protocol) {
                                __be16 protocol = dev_parse_header_protocol(skb);
 
-                               virtio_net_hdr_set_proto(skb, hdr);
-                               if (protocol && protocol != skb->protocol)
+                               if (!protocol)
+                                       virtio_net_hdr_set_proto(skb, hdr);
+                               else if (!virtio_net_hdr_match_proto(protocol, hdr->gso_type))
                                        return -EINVAL;
+                               else
+                                       skb->protocol = protocol;
                        }
 retry:
                        if (!skb_flow_dissect_flow_keys_basic(NULL, skb, &keys,