#include "en.h"
 #include <linux/indirect_call_wrapper.h>
+#include <net/ip6_checksum.h>
+#include <net/tcp.h>
 
 #define MLX5E_TX_WQE_EMPTY_DS_COUNT (sizeof(struct mlx5e_tx_wqe) / MLX5_SEND_WQE_DS)
 
        }
 }
 
+static inline void
+mlx5e_swp_encap_csum_partial(struct mlx5_core_dev *mdev, struct sk_buff *skb, bool tunnel)
+{
+       const struct iphdr *ip = tunnel ? inner_ip_hdr(skb) : ip_hdr(skb);
+       const struct ipv6hdr *ip6;
+       struct tcphdr *th;
+       struct udphdr *uh;
+       int len;
+
+       if (!MLX5_CAP_ETH(mdev, swp_csum_l4_partial) || !skb_is_gso(skb))
+               return;
+
+       if (skb_is_gso_tcp(skb)) {
+               th = inner_tcp_hdr(skb);
+               len = skb_shinfo(skb)->gso_size + inner_tcp_hdrlen(skb);
+
+               if (ip->version == 4) {
+                       th->check = ~tcp_v4_check(len, ip->saddr, ip->daddr, 0);
+               } else {
+                       ip6 = tunnel ? inner_ipv6_hdr(skb) : ipv6_hdr(skb);
+                       th->check = ~tcp_v6_check(len, &ip6->saddr, &ip6->daddr, 0);
+               }
+       } else if (skb_shinfo(skb)->gso_type & SKB_GSO_UDP_L4) {
+               uh = (struct udphdr *)skb_inner_transport_header(skb);
+               len = skb_shinfo(skb)->gso_size + sizeof(struct udphdr);
+
+               if (ip->version == 4) {
+                       uh->check = ~udp_v4_check(len, ip->saddr, ip->daddr, 0);
+               } else {
+                       ip6 = tunnel ? inner_ipv6_hdr(skb) : ipv6_hdr(skb);
+                       uh->check = ~udp_v6_check(len, &ip6->saddr, &ip6->daddr, 0);
+               }
+       }
+}
+
 #define MLX5E_STOP_ROOM(wqebbs) ((wqebbs) * 2 - 1)
 
 static inline u16 mlx5e_stop_room_for_wqe(struct mlx5_core_dev *mdev, u16 wqe_size)
 
 mlx5e_ipsec_txwqe_build_eseg_csum(struct mlx5e_txqsq *sq, struct sk_buff *skb,
                                  struct mlx5_wqe_eth_seg *eseg)
 {
+       struct mlx5_core_dev *mdev = sq->mdev;
        u8 inner_ipproto;
 
        if (!mlx5e_ipsec_eseg_meta(eseg))
        inner_ipproto = xfrm_offload(skb)->inner_ipproto;
        if (inner_ipproto) {
                eseg->cs_flags |= MLX5_ETH_WQE_L3_INNER_CSUM;
-               if (inner_ipproto == IPPROTO_TCP || inner_ipproto == IPPROTO_UDP)
+               if (inner_ipproto == IPPROTO_TCP || inner_ipproto == IPPROTO_UDP) {
+                       mlx5e_swp_encap_csum_partial(mdev, skb, true);
                        eseg->cs_flags |= MLX5_ETH_WQE_L4_INNER_CSUM;
+               }
        } else if (likely(skb->ip_summed == CHECKSUM_PARTIAL)) {
+               mlx5e_swp_encap_csum_partial(mdev, skb, false);
                eseg->cs_flags |= MLX5_ETH_WQE_L4_CSUM;
                sq->stats->csum_partial_inner++;
        }
 
        u8         tunnel_stateless_ip_over_ip_tx[0x1];
        u8         reserved_at_2e[0x2];
        u8         max_vxlan_udp_ports[0x8];
-       u8         reserved_at_38[0x6];
+       u8         swp_csum_l4_partial[0x1];
+       u8         reserved_at_39[0x5];
        u8         max_geneve_opt_len[0x1];
        u8         tunnel_stateless_geneve_rx[0x1];