#include <linux/vmalloc.h>
 #include <linux/irq.h>
 
+#if IS_ENABLED(CONFIG_IPV6)
+#include <net/ip6_checksum.h>
+#endif
+
 #include "mlx4_en.h"
 
 static int mlx4_alloc_pages(struct mlx4_en_priv *priv,
        }
 }
 
+/* When hardware doesn't strip the vlan, we need to calculate the checksum
+ * over it and add it to the hardware's checksum calculation
+ */
+static inline __wsum get_fixed_vlan_csum(__wsum hw_checksum,
+                                        struct vlan_hdr *vlanh)
+{
+       return csum_add(hw_checksum, *(__wsum *)vlanh);
+}
+
+/* Although the stack expects checksum which doesn't include the pseudo
+ * header, the HW adds it. To address that, we are subtracting the pseudo
+ * header checksum from the checksum value provided by the HW.
+ */
+static void get_fixed_ipv4_csum(__wsum hw_checksum, struct sk_buff *skb,
+                               struct iphdr *iph)
+{
+       __u16 length_for_csum = 0;
+       __wsum csum_pseudo_header = 0;
+
+       length_for_csum = (be16_to_cpu(iph->tot_len) - (iph->ihl << 2));
+       csum_pseudo_header = csum_tcpudp_nofold(iph->saddr, iph->daddr,
+                                               length_for_csum, iph->protocol, 0);
+       skb->csum = csum_sub(hw_checksum, csum_pseudo_header);
+}
+
+#if IS_ENABLED(CONFIG_IPV6)
+/* In IPv6 packets, besides subtracting the pseudo header checksum,
+ * we also compute/add the IP header checksum which
+ * is not added by the HW.
+ */
+static int get_fixed_ipv6_csum(__wsum hw_checksum, struct sk_buff *skb,
+                              struct ipv6hdr *ipv6h)
+{
+       __wsum csum_pseudo_hdr = 0;
+
+       if (ipv6h->nexthdr == IPPROTO_FRAGMENT || ipv6h->nexthdr == IPPROTO_HOPOPTS)
+               return -1;
+       hw_checksum = csum_add(hw_checksum, (__force __wsum)(ipv6h->nexthdr << 8));
+
+       csum_pseudo_hdr = csum_partial(&ipv6h->saddr,
+                                      sizeof(ipv6h->saddr) + sizeof(ipv6h->daddr), 0);
+       csum_pseudo_hdr = csum_add(csum_pseudo_hdr, (__force __wsum)ipv6h->payload_len);
+       csum_pseudo_hdr = csum_add(csum_pseudo_hdr, (__force __wsum)ntohs(ipv6h->nexthdr));
+
+       skb->csum = csum_sub(hw_checksum, csum_pseudo_hdr);
+       skb->csum = csum_add(skb->csum, csum_partial(ipv6h, sizeof(struct ipv6hdr), 0));
+       return 0;
+}
+#endif
+static int check_csum(struct mlx4_cqe *cqe, struct sk_buff *skb, void *va,
+                     int hwtstamp_rx_filter)
+{
+       __wsum hw_checksum = 0;
+
+       void *hdr = (u8 *)va + sizeof(struct ethhdr);
+
+       hw_checksum = csum_unfold((__force __sum16)cqe->checksum);
+
+       if (((struct ethhdr *)va)->h_proto == htons(ETH_P_8021Q) &&
+           hwtstamp_rx_filter != HWTSTAMP_FILTER_NONE) {
+               /* next protocol non IPv4 or IPv6 */
+               if (((struct vlan_hdr *)hdr)->h_vlan_encapsulated_proto
+                   != htons(ETH_P_IP) &&
+                   ((struct vlan_hdr *)hdr)->h_vlan_encapsulated_proto
+                   != htons(ETH_P_IPV6))
+                       return -1;
+               hw_checksum = get_fixed_vlan_csum(hw_checksum, hdr);
+               hdr += sizeof(struct vlan_hdr);
+       }
+
+       if (cqe->status & cpu_to_be16(MLX4_CQE_STATUS_IPV4))
+               get_fixed_ipv4_csum(hw_checksum, skb, hdr);
+#if IS_ENABLED(CONFIG_IPV6)
+       else if (cqe->status & cpu_to_be16(MLX4_CQE_STATUS_IPV6))
+               if (get_fixed_ipv6_csum(hw_checksum, skb, hdr))
+                       return -1;
+#endif
+       return 0;
+}
+
 int mlx4_en_process_rx_cq(struct net_device *dev, struct mlx4_en_cq *cq, int budget)
 {
        struct mlx4_en_priv *priv = netdev_priv(dev);
                        (cqe->vlan_my_qpn & cpu_to_be32(MLX4_CQE_L2_TUNNEL));
 
                if (likely(dev->features & NETIF_F_RXCSUM)) {
-                       if ((cqe->status & cpu_to_be16(MLX4_CQE_STATUS_IPOK)) &&
-                           (cqe->checksum == cpu_to_be16(0xffff))) {
-                               ring->csum_ok++;
-                               ip_summed = CHECKSUM_UNNECESSARY;
+                       if (cqe->status & cpu_to_be16(MLX4_CQE_STATUS_TCP |
+                                                     MLX4_CQE_STATUS_UDP)) {
+                               if ((cqe->status & cpu_to_be16(MLX4_CQE_STATUS_IPOK)) &&
+                                   cqe->checksum == cpu_to_be16(0xffff)) {
+                                       ip_summed = CHECKSUM_UNNECESSARY;
+                                       ring->csum_ok++;
+                               } else {
+                                       ip_summed = CHECKSUM_NONE;
+                                       ring->csum_none++;
+                               }
                        } else {
-                               ip_summed = CHECKSUM_NONE;
-                               ring->csum_none++;
+                               if (priv->flags & MLX4_EN_FLAG_RX_CSUM_NON_TCP_UDP &&
+                                   (cqe->status & cpu_to_be16(MLX4_CQE_STATUS_IPV4 |
+                                                              MLX4_CQE_STATUS_IPV6))) {
+                                       ip_summed = CHECKSUM_COMPLETE;
+                                       ring->csum_complete++;
+                               } else {
+                                       ip_summed = CHECKSUM_NONE;
+                                       ring->csum_none++;
+                               }
                        }
                } else {
                        ip_summed = CHECKSUM_NONE;
                        if (!nr)
                                goto next;
 
+                       if (ip_summed == CHECKSUM_COMPLETE) {
+                               void *va = skb_frag_address(skb_shinfo(gro_skb)->frags);
+                               if (check_csum(cqe, gro_skb, va, ring->hwtstamp_rx_filter)) {
+                                       ip_summed = CHECKSUM_NONE;
+                                       ring->csum_none++;
+                                       ring->csum_complete--;
+                               }
+                       }
+
                        skb_shinfo(gro_skb)->nr_frags = nr;
                        gro_skb->len = length;
                        gro_skb->data_len = length;
                        goto next;
                }
 
+               if (ip_summed == CHECKSUM_COMPLETE) {
+                       if (check_csum(cqe, skb, skb->data, ring->hwtstamp_rx_filter)) {
+                               ip_summed = CHECKSUM_NONE;
+                               ring->csum_complete--;
+                               ring->csum_none++;
+                       }
+               }
+
                skb->ip_summed = ip_summed;
                skb->protocol = eth_type_trans(skb, dev);
                skb_record_rx_queue(skb, cq->ring);