}
 
                ndst = &rt->dst;
-               err = skb_tunnel_check_pmtu(skb, ndst, VXLAN_HEADROOM,
+               err = skb_tunnel_check_pmtu(skb, ndst, vxlan_headroom(flags & VXLAN_F_GPE),
                                            netif_is_any_bridge_port(dev));
                if (err < 0) {
                        goto tx_error;
                                goto out_unlock;
                }
 
-               err = skb_tunnel_check_pmtu(skb, ndst, VXLAN6_HEADROOM,
+               err = skb_tunnel_check_pmtu(skb, ndst,
+                                           vxlan_headroom((flags & VXLAN_F_GPE) | VXLAN_F_IPV6),
                                            netif_is_any_bridge_port(dev));
                if (err < 0) {
                        goto tx_error;
        struct vxlan_rdst *dst = &vxlan->default_dst;
        struct net_device *lowerdev = __dev_get_by_index(vxlan->net,
                                                         dst->remote_ifindex);
-       bool use_ipv6 = !!(vxlan->cfg.flags & VXLAN_F_IPV6);
 
        /* This check is different than dev->max_mtu, because it looks at
         * the lowerdev->mtu, rather than the static dev->max_mtu
         */
        if (lowerdev) {
-               int max_mtu = lowerdev->mtu -
-                             (use_ipv6 ? VXLAN6_HEADROOM : VXLAN_HEADROOM);
+               int max_mtu = lowerdev->mtu - vxlan_headroom(vxlan->cfg.flags);
                if (new_mtu > max_mtu)
                        return -EINVAL;
        }
        struct vxlan_dev *vxlan = netdev_priv(dev);
        struct vxlan_rdst *dst = &vxlan->default_dst;
        unsigned short needed_headroom = ETH_HLEN;
-       bool use_ipv6 = !!(conf->flags & VXLAN_F_IPV6);
        int max_mtu = ETH_MAX_MTU;
+       u32 flags = conf->flags;
 
        if (!changelink) {
-               if (conf->flags & VXLAN_F_GPE)
+               if (flags & VXLAN_F_GPE)
                        vxlan_raw_setup(dev);
                else
                        vxlan_ether_setup(dev);
 
                dev->needed_tailroom = lowerdev->needed_tailroom;
 
-               max_mtu = lowerdev->mtu - (use_ipv6 ? VXLAN6_HEADROOM :
-                                          VXLAN_HEADROOM);
+               max_mtu = lowerdev->mtu - vxlan_headroom(flags);
                if (max_mtu < ETH_MIN_MTU)
                        max_mtu = ETH_MIN_MTU;
 
        if (dev->mtu > max_mtu)
                dev->mtu = max_mtu;
 
-       if (use_ipv6 || conf->flags & VXLAN_F_COLLECT_METADATA)
-               needed_headroom += VXLAN6_HEADROOM;
-       else
-               needed_headroom += VXLAN_HEADROOM;
+       if (flags & VXLAN_F_COLLECT_METADATA)
+               flags |= VXLAN_F_IPV6;
+       needed_headroom += vxlan_headroom(flags);
        dev->needed_headroom = needed_headroom;
 
        memcpy(&vxlan->cfg, conf, sizeof(*conf));
 
        return features;
 }
 
-/* IP header + UDP + VXLAN + Ethernet header */
-#define VXLAN_HEADROOM (20 + 8 + 8 + 14)
-/* IPv6 header + UDP + VXLAN + Ethernet header */
-#define VXLAN6_HEADROOM (40 + 8 + 8 + 14)
+static inline int vxlan_headroom(u32 flags)
+{
+       /* VXLAN:     IP4/6 header + UDP + VXLAN + Ethernet header */
+       /* VXLAN-GPE: IP4/6 header + UDP + VXLAN */
+       return (flags & VXLAN_F_IPV6 ? sizeof(struct ipv6hdr) :
+                                      sizeof(struct iphdr)) +
+              sizeof(struct udphdr) + sizeof(struct vxlanhdr) +
+              (flags & VXLAN_F_GPE ? 0 : ETH_HLEN);
+}
 
 static inline struct vxlanhdr *vxlan_hdr(struct sk_buff *skb)
 {