return 0;
 }
 
+#define MAX_TCP_HDR_LEN (15 * 4)
+
+static __sum16 *skb_checksum_setup_ip(struct sk_buff *skb,
+                                     typeof(IPPROTO_IP) proto,
+                                     unsigned int off)
+{
+       switch (proto) {
+               int err;
+
+       case IPPROTO_TCP:
+               err = skb_maybe_pull_tail(skb, off + sizeof(struct tcphdr),
+                                         off + MAX_TCP_HDR_LEN);
+               if (!err && !skb_partial_csum_set(skb, off,
+                                                 offsetof(struct tcphdr,
+                                                          check)))
+                       err = -EPROTO;
+               return err ? ERR_PTR(err) : &tcp_hdr(skb)->check;
+
+       case IPPROTO_UDP:
+               err = skb_maybe_pull_tail(skb, off + sizeof(struct udphdr),
+                                         off + sizeof(struct udphdr));
+               if (!err && !skb_partial_csum_set(skb, off,
+                                                 offsetof(struct udphdr,
+                                                          check)))
+                       err = -EPROTO;
+               return err ? ERR_PTR(err) : &udp_hdr(skb)->check;
+       }
+
+       return ERR_PTR(-EPROTO);
+}
+
 /* This value should be large enough to cover a tagged ethernet header plus
  * maximally sized IP and TCP or UDP headers.
  */
 #define MAX_IP_HDR_LEN 128
 
-static int skb_checksum_setup_ip(struct sk_buff *skb, bool recalculate)
+static int skb_checksum_setup_ipv4(struct sk_buff *skb, bool recalculate)
 {
        unsigned int off;
        bool fragment;
+       __sum16 *csum;
        int err;
 
        fragment = false;
        if (fragment)
                goto out;
 
-       switch (ip_hdr(skb)->protocol) {
-       case IPPROTO_TCP:
-               err = skb_maybe_pull_tail(skb,
-                                         off + sizeof(struct tcphdr),
-                                         MAX_IP_HDR_LEN);
-               if (err < 0)
-                       goto out;
-
-               if (!skb_partial_csum_set(skb, off,
-                                         offsetof(struct tcphdr, check))) {
-                       err = -EPROTO;
-                       goto out;
-               }
-
-               if (recalculate)
-                       tcp_hdr(skb)->check =
-                               ~csum_tcpudp_magic(ip_hdr(skb)->saddr,
-                                                  ip_hdr(skb)->daddr,
-                                                  skb->len - off,
-                                                  IPPROTO_TCP, 0);
-               break;
-       case IPPROTO_UDP:
-               err = skb_maybe_pull_tail(skb,
-                                         off + sizeof(struct udphdr),
-                                         MAX_IP_HDR_LEN);
-               if (err < 0)
-                       goto out;
-
-               if (!skb_partial_csum_set(skb, off,
-                                         offsetof(struct udphdr, check))) {
-                       err = -EPROTO;
-                       goto out;
-               }
-
-               if (recalculate)
-                       udp_hdr(skb)->check =
-                               ~csum_tcpudp_magic(ip_hdr(skb)->saddr,
-                                                  ip_hdr(skb)->daddr,
-                                                  skb->len - off,
-                                                  IPPROTO_UDP, 0);
-               break;
-       default:
-               goto out;
-       }
+       csum = skb_checksum_setup_ip(skb, ip_hdr(skb)->protocol, off);
+       if (IS_ERR(csum))
+               return PTR_ERR(csum);
 
+       if (recalculate)
+               *csum = ~csum_tcpudp_magic(ip_hdr(skb)->saddr,
+                                          ip_hdr(skb)->daddr,
+                                          skb->len - off,
+                                          ip_hdr(skb)->protocol, 0);
        err = 0;
 
 out:
        unsigned int len;
        bool fragment;
        bool done;
+       __sum16 *csum;
 
        fragment = false;
        done = false;
        if (!done || fragment)
                goto out;
 
-       switch (nexthdr) {
-       case IPPROTO_TCP:
-               err = skb_maybe_pull_tail(skb,
-                                         off + sizeof(struct tcphdr),
-                                         MAX_IPV6_HDR_LEN);
-               if (err < 0)
-                       goto out;
-
-               if (!skb_partial_csum_set(skb, off,
-                                         offsetof(struct tcphdr, check))) {
-                       err = -EPROTO;
-                       goto out;
-               }
-
-               if (recalculate)
-                       tcp_hdr(skb)->check =
-                               ~csum_ipv6_magic(&ipv6_hdr(skb)->saddr,
-                                                &ipv6_hdr(skb)->daddr,
-                                                skb->len - off,
-                                                IPPROTO_TCP, 0);
-               break;
-       case IPPROTO_UDP:
-               err = skb_maybe_pull_tail(skb,
-                                         off + sizeof(struct udphdr),
-                                         MAX_IPV6_HDR_LEN);
-               if (err < 0)
-                       goto out;
-
-               if (!skb_partial_csum_set(skb, off,
-                                         offsetof(struct udphdr, check))) {
-                       err = -EPROTO;
-                       goto out;
-               }
-
-               if (recalculate)
-                       udp_hdr(skb)->check =
-                               ~csum_ipv6_magic(&ipv6_hdr(skb)->saddr,
-                                                &ipv6_hdr(skb)->daddr,
-                                                skb->len - off,
-                                                IPPROTO_UDP, 0);
-               break;
-       default:
-               goto out;
-       }
+       csum = skb_checksum_setup_ip(skb, nexthdr, off);
+       if (IS_ERR(csum))
+               return PTR_ERR(csum);
 
+       if (recalculate)
+               *csum = ~csum_ipv6_magic(&ipv6_hdr(skb)->saddr,
+                                        &ipv6_hdr(skb)->daddr,
+                                        skb->len - off, nexthdr, 0);
        err = 0;
 
 out:
 
        switch (skb->protocol) {
        case htons(ETH_P_IP):
-               err = skb_checksum_setup_ip(skb, recalculate);
+               err = skb_checksum_setup_ipv4(skb, recalculate);
                break;
 
        case htons(ETH_P_IPV6):