#include <linux/skbuff.h>
 #include <net/ip.h>
 #include <net/icmp.h>
+#include <net/netfilter/nf_reject.h>
 
 void nf_send_unreach(struct sk_buff *skb_in, int code, int hook);
 void nf_send_reset(struct net *net, struct sk_buff *oldskb, int hook);
 
 #define _IPV6_NF_REJECT_H
 
 #include <linux/icmpv6.h>
+#include <net/netfilter/nf_reject.h>
 
 void nf_send_unreach6(struct net *net, struct sk_buff *skb_in, unsigned char code,
                      unsigned int hooknum);
 
--- /dev/null
+/* SPDX-License-Identifier: GPL-2.0 */
+#ifndef _NF_REJECT_H
+#define _NF_REJECT_H
+
+static inline bool nf_reject_verify_csum(__u8 proto)
+{
+       /* Skip protocols that don't use 16-bit one's complement checksum
+        * of the entire payload.
+        */
+       switch (proto) {
+               /* Protocols with other integrity checks. */
+               case IPPROTO_AH:
+               case IPPROTO_ESP:
+               case IPPROTO_SCTP:
+
+               /* Protocols with partial checksums. */
+               case IPPROTO_UDPLITE:
+               case IPPROTO_DCCP:
+
+               /* Protocols with optional checksums. */
+               case IPPROTO_GRE:
+                       return false;
+       }
+       return true;
+}
+
+#endif /* _NF_REJECT_H */
 
        if (pskb_trim_rcsum(oldskb, ntohs(ip_hdr(oldskb)->tot_len)))
                return;
 
-       if (ip_hdr(oldskb)->protocol == IPPROTO_TCP ||
-           ip_hdr(oldskb)->protocol == IPPROTO_UDP)
-               proto = ip_hdr(oldskb)->protocol;
-       else
-               proto = 0;
+       proto = ip_hdr(oldskb)->protocol;
 
        if (!skb_csum_unnecessary(oldskb) &&
+           nf_reject_verify_csum(proto) &&
            nf_ip_checksum(oldskb, hook, ip_hdrlen(oldskb), proto))
                return;
 
        if (thoff < 0 || thoff >= skb->len || (fo & htons(~0x7)) != 0)
                return false;
 
+       if (!nf_reject_verify_csum(proto))
+               return true;
+
        return nf_ip6_checksum(skb, hook, thoff, proto) == 0;
 }
 
 
 void nf_send_unreach(struct sk_buff *skb_in, int code, int hook)
 {
        struct iphdr *iph = ip_hdr(skb_in);
-       u8 proto;
+       u8 proto = iph->protocol;
 
        if (iph->frag_off & htons(IP_OFFSET))
                return;
 
-       if (skb_csum_unnecessary(skb_in)) {
+       if (skb_csum_unnecessary(skb_in) || !nf_reject_verify_csum(proto)) {
                icmp_send(skb_in, ICMP_DEST_UNREACH, code, 0);
                return;
        }
 
-       if (iph->protocol == IPPROTO_TCP || iph->protocol == IPPROTO_UDP)
-               proto = iph->protocol;
-       else
-               proto = 0;
-
        if (nf_ip_checksum(skb_in, hook, ip_hdrlen(skb_in), proto) == 0)
                icmp_send(skb_in, ICMP_DEST_UNREACH, code, 0);
 }
 
        if (thoff < 0 || thoff >= skb->len || (fo & htons(~0x7)) != 0)
                return false;
 
+       if (!nf_reject_verify_csum(proto))
+               return true;
+
        return nf_ip6_checksum(skb, hook, thoff, proto) == 0;
 }