struct net_device *dev, struct in6_addr *saddr,
                           struct in6_addr *daddr, __u8 prio, __u8 ttl,
                           __be16 src_port, __be16 dst_port, __be32 vni,
-                          struct vxlan_metadata *md, bool xnet, u32 vxflags)
+                          struct vxlan_metadata *md, bool xnet, u32 vxflags,
+                          bool udp_sum)
 {
        struct vxlanhdr *vxh;
        int min_headroom;
        int err;
-       bool udp_sum = !(vxflags & VXLAN_F_UDP_ZERO_CSUM6_TX);
+       bool nocheck = !udp_sum;
        int type = udp_sum ? SKB_GSO_UDP_TUNNEL_CSUM : SKB_GSO_UDP_TUNNEL;
        u16 hdrlen = sizeof(struct vxlanhdr);
 
        skb_set_inner_protocol(skb, htons(ETH_P_TEB));
 
        udp_tunnel6_xmit_skb(dst, sk, skb, dev, saddr, daddr, prio,
-                            ttl, src_port, dst_port,
-                            !!(vxflags & VXLAN_F_UDP_ZERO_CSUM6_TX));
+                            ttl, src_port, dst_port, nocheck);
        return 0;
 err:
        dst_release(dst);
 static int vxlan_xmit_skb(struct rtable *rt, struct sock *sk, struct sk_buff *skb,
                          __be32 src, __be32 dst, __u8 tos, __u8 ttl, __be16 df,
                          __be16 src_port, __be16 dst_port, __be32 vni,
-                         struct vxlan_metadata *md, bool xnet, u32 vxflags)
+                         struct vxlan_metadata *md, bool xnet, u32 vxflags,
+                         bool udp_sum)
 {
        struct vxlanhdr *vxh;
        int min_headroom;
        int err;
-       bool udp_sum = !!(vxflags & VXLAN_F_UDP_CSUM);
+       bool nocheck = !udp_sum;
        int type = udp_sum ? SKB_GSO_UDP_TUNNEL_CSUM : SKB_GSO_UDP_TUNNEL;
        u16 hdrlen = sizeof(struct vxlanhdr);
 
        skb_set_inner_protocol(skb, htons(ETH_P_TEB));
 
        udp_tunnel_xmit_skb(rt, sk, skb, src, dst, tos, ttl, df,
-                           src_port, dst_port, xnet,
-                           !(vxflags & VXLAN_F_UDP_CSUM));
+                           src_port, dst_port, xnet, nocheck);
        return 0;
 }
 
        __u8 tos, ttl;
        int err;
        u32 flags = vxlan->flags;
+       bool udp_sum = false;
 
        info = skb_tunnel_info(skb);
 
        if (info) {
                ttl = info->key.ttl;
                tos = info->key.tos;
+               udp_sum = !!(info->key.tun_flags & TUNNEL_CSUM);
 
                if (info->options_len)
                        md = ip_tunnel_info_opts(info);
                if (info) {
                        if (info->key.tun_flags & TUNNEL_DONT_FRAGMENT)
                                df = htons(IP_DF);
-
-                       if (info->key.tun_flags & TUNNEL_CSUM)
-                               flags |= VXLAN_F_UDP_CSUM;
-                       else
-                               flags &= ~VXLAN_F_UDP_CSUM;
+               } else {
+                       udp_sum = !!(flags & VXLAN_F_UDP_CSUM);
                }
 
                rt = vxlan_get_route(vxlan, skb,
                                     dst->sin.sin_addr.s_addr, tos, ttl, df,
                                     src_port, dst_port, htonl(vni << 8), md,
                                     !net_eq(vxlan->net, dev_net(vxlan->dev)),
-                                    flags);
+                                    flags, udp_sum);
                if (err < 0) {
                        /* skb is already freed. */
                        skb = NULL;
                        return;
                }
 
-               if (info) {
-                       if (info->key.tun_flags & TUNNEL_CSUM)
-                               flags &= ~VXLAN_F_UDP_ZERO_CSUM6_TX;
-                       else
-                               flags |= VXLAN_F_UDP_ZERO_CSUM6_TX;
-               }
+               if (!info)
+                       udp_sum = !(flags & VXLAN_F_UDP_ZERO_CSUM6_TX);
 
                ttl = ttl ? : ip6_dst_hoplimit(ndst);
                err = vxlan6_xmit_skb(ndst, sk, skb, dev, &saddr, &dst->sin6.sin6_addr,
                                      0, ttl, src_port, dst_port, htonl(vni << 8), md,
                                      !net_eq(vxlan->net, dev_net(vxlan->dev)),
-                                     flags);
+                                     flags, udp_sum);
 #endif
        }