int tnl_hlen = skb_inner_mac_header(skb) - skb_transport_header(skb);
        __be16 protocol = skb->protocol;
        netdev_features_t enc_features;
-       int outer_hlen;
+       int udp_offset, outer_hlen;
+       unsigned int oldlen;
+       bool need_csum;
+
+       oldlen = (u16)~skb->len;
 
        if (unlikely(!pskb_may_pull(skb, tnl_hlen)))
                goto out;
        skb->mac_len = skb_inner_network_offset(skb);
        skb->protocol = htons(ETH_P_TEB);
 
+       need_csum = !!(skb_shinfo(skb)->gso_type & SKB_GSO_UDP_TUNNEL_CSUM);
+       if (need_csum)
+               skb->encap_hdr_csum = 1;
+
        /* segment inner packet. */
        enc_features = skb->dev->hw_enc_features & netif_skb_features(skb);
        segs = skb_mac_gso_segment(skb, enc_features);
        }
 
        outer_hlen = skb_tnl_header_len(skb);
+       udp_offset = outer_hlen - tnl_hlen;
        skb = segs;
        do {
                struct udphdr *uh;
-               int udp_offset = outer_hlen - tnl_hlen;
+               int len;
 
                skb_reset_inner_headers(skb);
                skb->encapsulation = 1;
                skb_reset_mac_header(skb);
                skb_set_network_header(skb, mac_len);
                skb_set_transport_header(skb, udp_offset);
+               len = skb->len - udp_offset;
                uh = udp_hdr(skb);
-               uh->len = htons(skb->len - udp_offset);
-
-               /* csum segment if tunnel sets skb with csum. */
-               if (protocol == htons(ETH_P_IP) && unlikely(uh->check)) {
-                       struct iphdr *iph = ip_hdr(skb);
+               uh->len = htons(len);
 
-                       uh->check = ~csum_tcpudp_magic(iph->saddr, iph->daddr,
-                                                      skb->len - udp_offset,
-                                                      IPPROTO_UDP, 0);
-                       uh->check = csum_fold(skb_checksum(skb, udp_offset,
-                                                          skb->len - udp_offset, 0));
-                       if (uh->check == 0)
-                               uh->check = CSUM_MANGLED_0;
+               if (need_csum) {
+                       __be32 delta = htonl(oldlen + len);
 
-               } else if (protocol == htons(ETH_P_IPV6)) {
-                       struct ipv6hdr *ipv6h = ipv6_hdr(skb);
-                       u32 len = skb->len - udp_offset;
+                       uh->check = ~csum_fold((__force __wsum)
+                                              ((__force u32)uh->check +
+                                               (__force u32)delta));
+                       uh->check = gso_make_checksum(skb, ~uh->check);
 
-                       uh->check = ~csum_ipv6_magic(&ipv6h->saddr, &ipv6h->daddr,
-                                                    len, IPPROTO_UDP, 0);
-                       uh->check = csum_fold(skb_checksum(skb, udp_offset, len, 0));
                        if (uh->check == 0)
                                uh->check = CSUM_MANGLED_0;
-                       skb->ip_summed = CHECKSUM_NONE;
                }
 
                skb->protocol = protocol;
 
        __wsum csum;
 
        if (skb->encapsulation &&
-           skb_shinfo(skb)->gso_type & SKB_GSO_UDP_TUNNEL) {
+           (skb_shinfo(skb)->gso_type &
+            (SKB_GSO_UDP_TUNNEL|SKB_GSO_UDP_TUNNEL_CSUM))) {
                segs = skb_udp_tunnel_segment(skb, features);
                goto out;
        }
 
                if (unlikely(type & ~(SKB_GSO_UDP | SKB_GSO_DODGY |
                                      SKB_GSO_UDP_TUNNEL |
+                                     SKB_GSO_UDP_TUNNEL_CSUM |
                                      SKB_GSO_IPIP |
                                      SKB_GSO_GRE | SKB_GSO_MPLS) ||
                             !(type & (SKB_GSO_UDP))))