#include <linux/tcp.h>
 #include <linux/bpf_trace.h>
 #include <net/busy_poll.h>
+#include <net/ip6_checksum.h>
 #include "en.h"
 #include "en_tc.h"
 #include "eswitch.h"
        return true;
 }
 
+static void mlx5e_lro_update_tcp_hdr(struct mlx5_cqe64 *cqe, struct tcphdr *tcp)
+{
+       u8 l4_hdr_type = get_cqe_l4_hdr_type(cqe);
+       u8 tcp_ack     = (l4_hdr_type == CQE_L4_HDR_TYPE_TCP_ACK_NO_DATA) ||
+                        (l4_hdr_type == CQE_L4_HDR_TYPE_TCP_ACK_AND_DATA);
+
+       tcp->check                      = 0;
+       tcp->psh                        = get_cqe_lro_tcppsh(cqe);
+
+       if (tcp_ack) {
+               tcp->ack                = 1;
+               tcp->ack_seq            = cqe->lro_ack_seq_num;
+               tcp->window             = cqe->lro_tcp_win;
+       }
+}
+
 static void mlx5e_lro_update_hdr(struct sk_buff *skb, struct mlx5_cqe64 *cqe,
                                 u32 cqe_bcnt)
 {
        struct ethhdr   *eth = (struct ethhdr *)(skb->data);
        struct tcphdr   *tcp;
        int network_depth = 0;
+       __wsum check;
        __be16 proto;
        u16 tot_len;
        void *ip_p;
 
-       u8 l4_hdr_type = get_cqe_l4_hdr_type(cqe);
-       u8 tcp_ack = (l4_hdr_type == CQE_L4_HDR_TYPE_TCP_ACK_NO_DATA) ||
-               (l4_hdr_type == CQE_L4_HDR_TYPE_TCP_ACK_AND_DATA);
-
        proto = __vlan_get_protocol(skb, eth->h_proto, &network_depth);
 
        tot_len = cqe_bcnt - network_depth;
                ipv4->check             = 0;
                ipv4->check             = ip_fast_csum((unsigned char *)ipv4,
                                                       ipv4->ihl);
+
+               mlx5e_lro_update_tcp_hdr(cqe, tcp);
+               check = csum_partial(tcp, tcp->doff * 4,
+                                    csum_unfold((__force __sum16)cqe->check_sum));
+               /* Almost done, don't forget the pseudo header */
+               tcp->check = csum_tcpudp_magic(ipv4->saddr, ipv4->daddr,
+                                              tot_len - sizeof(struct iphdr),
+                                              IPPROTO_TCP, check);
        } else {
+               u16 payload_len = tot_len - sizeof(struct ipv6hdr);
                struct ipv6hdr *ipv6 = ip_p;
 
                tcp = ip_p + sizeof(struct ipv6hdr);
                skb_shinfo(skb)->gso_type = SKB_GSO_TCPV6;
 
                ipv6->hop_limit         = cqe->lro_min_ttl;
-               ipv6->payload_len       = cpu_to_be16(tot_len -
-                                                     sizeof(struct ipv6hdr));
-       }
-
-       tcp->psh = get_cqe_lro_tcppsh(cqe);
-
-       if (tcp_ack) {
-               tcp->ack                = 1;
-               tcp->ack_seq            = cqe->lro_ack_seq_num;
-               tcp->window             = cqe->lro_tcp_win;
+               ipv6->payload_len       = cpu_to_be16(payload_len);
+
+               mlx5e_lro_update_tcp_hdr(cqe, tcp);
+               check = csum_partial(tcp, tcp->doff * 4,
+                                    csum_unfold((__force __sum16)cqe->check_sum));
+               /* Almost done, don't forget the pseudo header */
+               tcp->check = csum_ipv6_magic(&ipv6->saddr, &ipv6->daddr, payload_len,
+                                            IPPROTO_TCP, check);
        }
 }