void ip_tunnel_xmit(struct sk_buff *skb, struct net_device *dev,
                    const struct iphdr *tnl_params, const u8 protocol);
 void ip_md_tunnel_xmit(struct sk_buff *skb, struct net_device *dev,
-                      const u8 proto);
+                      const u8 proto, int tunnel_hlen);
 int ip_tunnel_ioctl(struct net_device *dev, struct ip_tunnel_parm *p, int cmd);
 int __ip_tunnel_change_mtu(struct net_device *dev, int new_mtu, bool strict);
 int ip_tunnel_change_mtu(struct net_device *dev, int new_mtu);
 
 
 static int tnl_update_pmtu(struct net_device *dev, struct sk_buff *skb,
                            struct rtable *rt, __be16 df,
-                           const struct iphdr *inner_iph)
+                           const struct iphdr *inner_iph,
+                           int tunnel_hlen, __be32 dst, bool md)
 {
        struct ip_tunnel *tunnel = netdev_priv(dev);
-       int pkt_size = skb->len - tunnel->hlen - dev->hard_header_len;
+       int pkt_size;
        int mtu;
 
+       tunnel_hlen = md ? tunnel_hlen : tunnel->hlen;
+       pkt_size = skb->len - tunnel_hlen - dev->hard_header_len;
+
        if (df)
                mtu = dst_mtu(&rt->dst) - dev->hard_header_len
-                                       - sizeof(struct iphdr) - tunnel->hlen;
+                                       - sizeof(struct iphdr) - tunnel_hlen;
        else
                mtu = skb_dst(skb) ? dst_mtu(skb_dst(skb)) : dev->mtu;
 
 #if IS_ENABLED(CONFIG_IPV6)
        else if (skb->protocol == htons(ETH_P_IPV6)) {
                struct rt6_info *rt6 = (struct rt6_info *)skb_dst(skb);
+               __be32 daddr;
+
+               daddr = md ? dst : tunnel->parms.iph.daddr;
 
                if (rt6 && mtu < dst_mtu(skb_dst(skb)) &&
                           mtu >= IPV6_MIN_MTU) {
-                       if ((tunnel->parms.iph.daddr &&
-                           !ipv4_is_multicast(tunnel->parms.iph.daddr)) ||
+                       if ((daddr && !ipv4_is_multicast(daddr)) ||
                            rt6->rt6i_dst.plen == 128) {
                                rt6->rt6i_flags |= RTF_MODIFIED;
                                dst_metric_set(skb_dst(skb), RTAX_MTU, mtu);
        return 0;
 }
 
-void ip_md_tunnel_xmit(struct sk_buff *skb, struct net_device *dev, u8 proto)
+void ip_md_tunnel_xmit(struct sk_buff *skb, struct net_device *dev,
+                      u8 proto, int tunnel_hlen)
 {
        struct ip_tunnel *tunnel = netdev_priv(dev);
        u32 headroom = sizeof(struct iphdr);
                dev->stats.collisions++;
                goto tx_error;
        }
+
+       if (key->tun_flags & TUNNEL_DONT_FRAGMENT)
+               df = htons(IP_DF);
+       if (tnl_update_pmtu(dev, skb, rt, df, inner_iph, tunnel_hlen,
+                           key->u.ipv4.dst, true)) {
+               ip_rt_put(rt);
+               goto tx_error;
+       }
+
        tos = ip_tunnel_ecn_encap(tos, inner_iph, skb);
        ttl = key->ttl;
        if (ttl == 0) {
                else
                        ttl = ip4_dst_hoplimit(&rt->dst);
        }
-       if (key->tun_flags & TUNNEL_DONT_FRAGMENT)
-               df = htons(IP_DF);
-       else if (skb->protocol == htons(ETH_P_IP))
+
+       if (!df && skb->protocol == htons(ETH_P_IP))
                df = inner_iph->frag_off & htons(IP_DF);
+
        headroom += LL_RESERVED_SPACE(rt->dst.dev) + rt->dst.header_len;
        if (headroom > dev->needed_headroom)
                dev->needed_headroom = headroom;
                goto tx_error;
        }
 
-       if (tnl_update_pmtu(dev, skb, rt, tnl_params->frag_off, inner_iph)) {
+       if (tnl_update_pmtu(dev, skb, rt, tnl_params->frag_off, inner_iph,
+                           0, 0, false)) {
                ip_rt_put(rt);
                goto tx_error;
        }