const struct sock *sk, const struct sk_buff *skb);
 int tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr,
                   int family, u8 prefixlen, int l3index, u8 flags,
-                  const u8 *newkey, u8 newkeylen, gfp_t gfp);
+                  const u8 *newkey, u8 newkeylen);
+int tcp_md5_key_copy(struct sock *sk, const union tcp_md5_addr *addr,
+                    int family, u8 prefixlen, int l3index,
+                    struct tcp_md5sig_key *key);
+
 int tcp_md5_do_del(struct sock *sk, const union tcp_md5_addr *addr,
                   int family, u8 prefixlen, int l3index, u8 flags);
 struct tcp_md5sig_key *tcp_v4_md5_lookup(const struct sock *sk,
 
 #ifdef CONFIG_TCP_MD5SIG
 #include <linux/jump_label.h>
-extern struct static_key_false tcp_md5_needed;
+extern struct static_key_false_deferred tcp_md5_needed;
 struct tcp_md5sig_key *__tcp_md5_do_lookup(const struct sock *sk, int l3index,
                                           const union tcp_md5_addr *addr,
                                           int family);
 tcp_md5_do_lookup(const struct sock *sk, int l3index,
                  const union tcp_md5_addr *addr, int family)
 {
-       if (!static_branch_unlikely(&tcp_md5_needed))
+       if (!static_branch_unlikely(&tcp_md5_needed.key))
                return NULL;
        return __tcp_md5_do_lookup(sk, l3index, addr, family);
 }
 
  * We need to maintain these in the sk structure.
  */
 
-DEFINE_STATIC_KEY_FALSE(tcp_md5_needed);
+DEFINE_STATIC_KEY_DEFERRED_FALSE(tcp_md5_needed, HZ);
 EXPORT_SYMBOL(tcp_md5_needed);
 
 static bool better_md5_match(struct tcp_md5sig_key *old, struct tcp_md5sig_key *new)
        struct tcp_sock *tp = tcp_sk(sk);
        struct tcp_md5sig_info *md5sig;
 
-       if (rcu_dereference_protected(tp->md5sig_info, lockdep_sock_is_held(sk)))
-               return 0;
-
        md5sig = kmalloc(sizeof(*md5sig), gfp);
        if (!md5sig)
                return -ENOMEM;
 }
 
 /* This can be called on a newly created socket, from other files */
-int tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr,
-                  int family, u8 prefixlen, int l3index, u8 flags,
-                  const u8 *newkey, u8 newkeylen, gfp_t gfp)
+static int __tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr,
+                           int family, u8 prefixlen, int l3index, u8 flags,
+                           const u8 *newkey, u8 newkeylen, gfp_t gfp)
 {
        /* Add Key to the list */
        struct tcp_md5sig_key *key;
                return 0;
        }
 
-       if (tcp_md5sig_info_add(sk, gfp))
-               return -ENOMEM;
-
        md5sig = rcu_dereference_protected(tp->md5sig_info,
                                           lockdep_sock_is_held(sk));
 
        hlist_add_head_rcu(&key->node, &md5sig->head);
        return 0;
 }
+
+int tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr,
+                  int family, u8 prefixlen, int l3index, u8 flags,
+                  const u8 *newkey, u8 newkeylen)
+{
+       struct tcp_sock *tp = tcp_sk(sk);
+
+       if (!rcu_dereference_protected(tp->md5sig_info, lockdep_sock_is_held(sk))) {
+               if (tcp_md5sig_info_add(sk, GFP_KERNEL))
+                       return -ENOMEM;
+
+               if (!static_branch_inc(&tcp_md5_needed.key)) {
+                       struct tcp_md5sig_info *md5sig;
+
+                       md5sig = rcu_dereference_protected(tp->md5sig_info, lockdep_sock_is_held(sk));
+                       rcu_assign_pointer(tp->md5sig_info, NULL);
+                       kfree_rcu(md5sig);
+                       return -EUSERS;
+               }
+       }
+
+       return __tcp_md5_do_add(sk, addr, family, prefixlen, l3index, flags,
+                               newkey, newkeylen, GFP_KERNEL);
+}
 EXPORT_SYMBOL(tcp_md5_do_add);
 
+int tcp_md5_key_copy(struct sock *sk, const union tcp_md5_addr *addr,
+                    int family, u8 prefixlen, int l3index,
+                    struct tcp_md5sig_key *key)
+{
+       struct tcp_sock *tp = tcp_sk(sk);
+
+       if (!rcu_dereference_protected(tp->md5sig_info, lockdep_sock_is_held(sk))) {
+               if (tcp_md5sig_info_add(sk, sk_gfp_mask(sk, GFP_ATOMIC)))
+                       return -ENOMEM;
+
+               if (!static_key_fast_inc_not_disabled(&tcp_md5_needed.key.key)) {
+                       struct tcp_md5sig_info *md5sig;
+
+                       md5sig = rcu_dereference_protected(tp->md5sig_info, lockdep_sock_is_held(sk));
+                       net_warn_ratelimited("Too many TCP-MD5 keys in the system\n");
+                       rcu_assign_pointer(tp->md5sig_info, NULL);
+                       kfree_rcu(md5sig);
+                       return -EUSERS;
+               }
+       }
+
+       return __tcp_md5_do_add(sk, addr, family, prefixlen, l3index,
+                               key->flags, key->key, key->keylen,
+                               sk_gfp_mask(sk, GFP_ATOMIC));
+}
+EXPORT_SYMBOL(tcp_md5_key_copy);
+
 int tcp_md5_do_del(struct sock *sk, const union tcp_md5_addr *addr, int family,
                   u8 prefixlen, int l3index, u8 flags)
 {
                return -EINVAL;
 
        return tcp_md5_do_add(sk, addr, AF_INET, prefixlen, l3index, flags,
-                             cmd.tcpm_key, cmd.tcpm_keylen, GFP_KERNEL);
+                             cmd.tcpm_key, cmd.tcpm_keylen);
 }
 
 static int tcp_v4_md5_hash_headers(struct tcp_md5sig_pool *hp,
                 * memory, then we end up not copying the key
                 * across. Shucks.
                 */
-               tcp_md5_do_add(newsk, addr, AF_INET, 32, l3index, key->flags,
-                              key->key, key->keylen, GFP_ATOMIC);
+               tcp_md5_key_copy(newsk, addr, AF_INET, 32, l3index, key);
                sk_gso_disable(newsk);
        }
 #endif
                tcp_clear_md5_list(sk);
                kfree_rcu(rcu_dereference_protected(tp->md5sig_info, 1), rcu);
                tp->md5sig_info = NULL;
+               static_branch_slow_dec_deferred(&tcp_md5_needed);
        }
 #endif
 
 
                 */
                do {
                        tcptw->tw_md5_key = NULL;
-                       if (static_branch_unlikely(&tcp_md5_needed)) {
+                       if (static_branch_unlikely(&tcp_md5_needed.key)) {
                                struct tcp_md5sig_key *key;
 
                                key = tp->af_specific->md5_lookup(sk, sk);
                                if (key) {
                                        tcptw->tw_md5_key = kmemdup(key, sizeof(*key), GFP_ATOMIC);
-                                       BUG_ON(tcptw->tw_md5_key && !tcp_alloc_md5sig_pool());
+                                       if (!tcptw->tw_md5_key)
+                                               break;
+                                       BUG_ON(!tcp_alloc_md5sig_pool());
+                                       if (!static_key_fast_inc_not_disabled(&tcp_md5_needed.key.key)) {
+                                               kfree(tcptw->tw_md5_key);
+                                               tcptw->tw_md5_key = NULL;
+                                       }
                                }
                        }
                } while (0);
 void tcp_twsk_destructor(struct sock *sk)
 {
 #ifdef CONFIG_TCP_MD5SIG
-       if (static_branch_unlikely(&tcp_md5_needed)) {
+       if (static_branch_unlikely(&tcp_md5_needed.key)) {
                struct tcp_timewait_sock *twsk = tcp_twsk(sk);
 
-               if (twsk->tw_md5_key)
+               if (twsk->tw_md5_key) {
                        kfree_rcu(twsk->tw_md5_key, rcu);
+                       static_branch_slow_dec_deferred(&tcp_md5_needed);
+               }
        }
 #endif
 }
 
        if (ipv6_addr_v4mapped(&sin6->sin6_addr))
                return tcp_md5_do_add(sk, (union tcp_md5_addr *)&sin6->sin6_addr.s6_addr32[3],
                                      AF_INET, prefixlen, l3index, flags,
-                                     cmd.tcpm_key, cmd.tcpm_keylen,
-                                     GFP_KERNEL);
+                                     cmd.tcpm_key, cmd.tcpm_keylen);
 
        return tcp_md5_do_add(sk, (union tcp_md5_addr *)&sin6->sin6_addr,
                              AF_INET6, prefixlen, l3index, flags,
-                             cmd.tcpm_key, cmd.tcpm_keylen, GFP_KERNEL);
+                             cmd.tcpm_key, cmd.tcpm_keylen);
 }
 
 static int tcp_v6_md5_hash_headers(struct tcp_md5sig_pool *hp,
                 * memory, then we end up not copying the key
                 * across. Shucks.
                 */
-               tcp_md5_do_add(newsk, (union tcp_md5_addr *)&newsk->sk_v6_daddr,
-                              AF_INET6, 128, l3index, key->flags, key->key, key->keylen,
-                              sk_gfp_mask(sk, GFP_ATOMIC));
+               tcp_md5_key_copy(newsk, (union tcp_md5_addr *)&newsk->sk_v6_daddr,
+                                AF_INET6, 128, l3index, key);
        }
 #endif