l4_offset = l4.hdr - skb->data;
 
                        /* remove payload length from outer checksum */
-                       paylen = (__force u16)l4.udp->check;
-                       paylen += ntohs((__force __be16)1) *
-                                       (u16)~(skb->len - l4_offset);
-                       l4.udp->check = ~csum_fold((__force __wsum)paylen);
+                       paylen = skb->len - l4_offset;
+                       csum_replace_by_diff(&l4.udp->check, htonl(paylen));
                }
 
                /* reset pointers to inner headers */
        l4_offset = l4.hdr - skb->data;
 
        /* remove payload length from inner checksum */
-       paylen = (__force u16)l4.tcp->check;
-       paylen += ntohs((__force __be16)1) * (u16)~(skb->len - l4_offset);
-       l4.tcp->check = ~csum_fold((__force __wsum)paylen);
+       paylen = skb->len - l4_offset;
+       csum_replace_by_diff(&l4.tcp->check, htonl(paylen));
 
        /* compute length of segmentation header */
        *hdr_len = (l4.tcp->doff * 4) + l4_offset;
 
                        l4_offset = l4.hdr - skb->data;
 
                        /* remove payload length from outer checksum */
-                       paylen = (__force u16)l4.udp->check;
-                       paylen += ntohs((__force __be16)1) *
-                                       (u16)~(skb->len - l4_offset);
-                       l4.udp->check = ~csum_fold((__force __wsum)paylen);
+                       paylen = skb->len - l4_offset;
+                       csum_replace_by_diff(&l4.udp->check, htonl(paylen));
                }
 
                /* reset pointers to inner headers */
        l4_offset = l4.hdr - skb->data;
 
        /* remove payload length from inner checksum */
-       paylen = (__force u16)l4.tcp->check;
-       paylen += ntohs((__force __be16)1) * (u16)~(skb->len - l4_offset);
-       l4.tcp->check = ~csum_fold((__force __wsum)paylen);
+       paylen = skb->len - l4_offset;
+       csum_replace_by_diff(&l4.tcp->check, htonl(paylen));
 
        /* compute length of segmentation header */
        *hdr_len = (l4.tcp->doff * 4) + l4_offset;