bool tcp_ao_ignore_icmp(const struct sock *sk, int family, int type, int code);
 int tcp_ao_get_mkts(struct sock *sk, sockptr_t optval, sockptr_t optlen);
 int tcp_ao_get_sock_info(struct sock *sk, sockptr_t optval, sockptr_t optlen);
+int tcp_ao_get_repair(struct sock *sk, sockptr_t optval, sockptr_t optlen);
+int tcp_ao_set_repair(struct sock *sk, sockptr_t optval, unsigned int optlen);
 enum skb_drop_reason tcp_inbound_ao_hash(struct sock *sk,
                        const struct sk_buff *skb, unsigned short int family,
                        const struct request_sock *req, int l3index,
 {
        return -ENOPROTOOPT;
 }
+
+static inline int tcp_ao_get_repair(struct sock *sk,
+                                   sockptr_t optval, sockptr_t optlen)
+{
+       return -ENOPROTOOPT;
+}
+
+static inline int tcp_ao_set_repair(struct sock *sk,
+                                   sockptr_t optval, unsigned int optlen)
+{
+       return -ENOPROTOOPT;
+}
 #endif
 
 #if defined(CONFIG_TCP_MD5SIG) || defined(CONFIG_TCP_AO)
 
 #define TCP_AO_DEL_KEY         39      /* Delete MKT */
 #define TCP_AO_INFO            40      /* Set/list TCP-AO per-socket options */
 #define TCP_AO_GET_KEYS                41      /* List MKT(s) */
+#define TCP_AO_REPAIR          42      /* Get/Set SNEs and ISNs */
 
 #define TCP_REPAIR_ON          1
 #define TCP_REPAIR_OFF         0
        __u64   pkt_bad;                /* out: segments that failed verification */
 } __attribute__((aligned(8)));
 
+struct tcp_ao_repair { /* {s,g}etsockopt(TCP_AO_REPAIR) */
+       __be32                  snt_isn;
+       __be32                  rcv_isn;
+       __u32                   snd_sne;
+       __u32                   rcv_sne;
+} __attribute__((aligned(8)));
+
 /* setsockopt(fd, IPPROTO_TCP, TCP_ZEROCOPY_RECEIVE, ...) */
 
 #define TCP_RECEIVE_ZEROCOPY_FLAG_TLB_CLEAN_HINT 0x1
 
                __tcp_sock_set_quickack(sk, val);
                break;
 
+       case TCP_AO_REPAIR:
+               err = tcp_ao_set_repair(sk, optval, optlen);
+               break;
 #ifdef CONFIG_TCP_AO
        case TCP_AO_ADD_KEY:
        case TCP_AO_DEL_KEY:
        case TCP_AO_INFO: {
                /* If this is the first TCP-AO setsockopt() on the socket,
-                * sk_state has to be LISTEN or CLOSE
+                * sk_state has to be LISTEN or CLOSE. Allow TCP_REPAIR
+                * in any state.
                 */
-               if (((1 << sk->sk_state) & (TCPF_LISTEN | TCPF_CLOSE)) ||
-                   rcu_dereference_protected(tcp_sk(sk)->ao_info,
+               if ((1 << sk->sk_state) & (TCPF_LISTEN | TCPF_CLOSE))
+                       goto ao_parse;
+               if (rcu_dereference_protected(tcp_sk(sk)->ao_info,
                                              lockdep_sock_is_held(sk)))
-                       err = tp->af_specific->ao_parse(sk, optname, optval,
-                                                       optlen);
-               else
-                       err = -EISCONN;
+                       goto ao_parse;
+               if (tp->repair)
+                       goto ao_parse;
+               err = -EISCONN;
+               break;
+ao_parse:
+               err = tp->af_specific->ao_parse(sk, optname, optval, optlen);
                break;
        }
 #endif
                return err;
        }
 #endif
+       case TCP_AO_REPAIR:
+               return tcp_ao_get_repair(sk, optval, optlen);
        case TCP_AO_GET_KEYS:
        case TCP_AO_INFO: {
                int err;
 
        return ERR_PTR(-ESOCKTNOSUPPORT);
 }
 
+static struct tcp_ao_info *getsockopt_ao_info(struct sock *sk)
+{
+       if (sk_fullsock(sk))
+               return rcu_dereference(tcp_sk(sk)->ao_info);
+       else if (sk->sk_state == TCP_TIME_WAIT)
+               return rcu_dereference(tcp_twsk(sk)->ao_info);
+
+       return ERR_PTR(-ESOCKTNOSUPPORT);
+}
+
 #define TCP_AO_KEYF_ALL (TCP_AO_KEYF_IFINDEX | TCP_AO_KEYF_EXCLUDE_OPT)
 #define TCP_AO_GET_KEYF_VALID  (TCP_AO_KEYF_IFINDEX)
 
        if (ret < 0)
                goto err_free_sock;
 
-       /* Change this condition if we allow adding keys in states
-        * like close_wait, syn_sent or fin_wait...
-        */
-       if (sk->sk_state == TCP_ESTABLISHED)
+       if (!((1 << sk->sk_state) & (TCPF_LISTEN | TCPF_CLOSE))) {
                tcp_ao_cache_traffic_keys(sk, ao_info, key);
+               if (first) {
+                       ao_info->current_key = key;
+                       ao_info->rnext_key = key;
+               }
+       }
 
        tcp_ao_link_mkt(ao_info, key);
        if (first) {
        if (IS_ERR(ao_info))
                return PTR_ERR(ao_info);
        if (!ao_info) {
+               if (!((1 << sk->sk_state) & (TCPF_LISTEN | TCPF_CLOSE)))
+                       return -EINVAL;
                ao_info = tcp_ao_alloc_info(GFP_KERNEL);
                if (!ao_info)
                        return -ENOMEM;
        return 0;
 }
 
+int tcp_ao_set_repair(struct sock *sk, sockptr_t optval, unsigned int optlen)
+{
+       struct tcp_sock *tp = tcp_sk(sk);
+       struct tcp_ao_repair cmd;
+       struct tcp_ao_key *key;
+       struct tcp_ao_info *ao;
+       int err;
+
+       if (optlen < sizeof(cmd))
+               return -EINVAL;
+
+       err = copy_struct_from_sockptr(&cmd, sizeof(cmd), optval, optlen);
+       if (err)
+               return err;
+
+       if (!tp->repair)
+               return -EPERM;
+
+       ao = setsockopt_ao_info(sk);
+       if (IS_ERR(ao))
+               return PTR_ERR(ao);
+       if (!ao)
+               return -ENOENT;
+
+       WRITE_ONCE(ao->lisn, cmd.snt_isn);
+       WRITE_ONCE(ao->risn, cmd.rcv_isn);
+       WRITE_ONCE(ao->snd_sne, cmd.snd_sne);
+       WRITE_ONCE(ao->rcv_sne, cmd.rcv_sne);
+
+       hlist_for_each_entry_rcu(key, &ao->head, node)
+               tcp_ao_cache_traffic_keys(sk, ao, key);
+
+       return 0;
+}
+
+int tcp_ao_get_repair(struct sock *sk, sockptr_t optval, sockptr_t optlen)
+{
+       struct tcp_sock *tp = tcp_sk(sk);
+       struct tcp_ao_repair opt;
+       struct tcp_ao_info *ao;
+       int len;
+
+       if (copy_from_sockptr(&len, optlen, sizeof(int)))
+               return -EFAULT;
+
+       if (len <= 0)
+               return -EINVAL;
+
+       if (!tp->repair)
+               return -EPERM;
+
+       rcu_read_lock();
+       ao = getsockopt_ao_info(sk);
+       if (IS_ERR_OR_NULL(ao)) {
+               rcu_read_unlock();
+               return ao ? PTR_ERR(ao) : -ENOENT;
+       }
+
+       opt.snt_isn     = ao->lisn;
+       opt.rcv_isn     = ao->risn;
+       opt.snd_sne     = READ_ONCE(ao->snd_sne);
+       opt.rcv_sne     = READ_ONCE(ao->rcv_sne);
+       rcu_read_unlock();
+
+       if (copy_to_sockptr(optval, &opt, min_t(int, len, sizeof(opt))))
+               return -EFAULT;
+       return 0;
+}