}
 EXPORT_SYMBOL(reuseport_has_conns_set);
 
+static void __reuseport_get_incoming_cpu(struct sock_reuseport *reuse)
+{
+       /* Paired with READ_ONCE() in reuseport_select_sock_by_hash(). */
+       WRITE_ONCE(reuse->incoming_cpu, reuse->incoming_cpu + 1);
+}
+
+static void __reuseport_put_incoming_cpu(struct sock_reuseport *reuse)
+{
+       /* Paired with READ_ONCE() in reuseport_select_sock_by_hash(). */
+       WRITE_ONCE(reuse->incoming_cpu, reuse->incoming_cpu - 1);
+}
+
+static void reuseport_get_incoming_cpu(struct sock *sk, struct sock_reuseport *reuse)
+{
+       if (sk->sk_incoming_cpu >= 0)
+               __reuseport_get_incoming_cpu(reuse);
+}
+
+static void reuseport_put_incoming_cpu(struct sock *sk, struct sock_reuseport *reuse)
+{
+       if (sk->sk_incoming_cpu >= 0)
+               __reuseport_put_incoming_cpu(reuse);
+}
+
+void reuseport_update_incoming_cpu(struct sock *sk, int val)
+{
+       struct sock_reuseport *reuse;
+       int old_sk_incoming_cpu;
+
+       if (unlikely(!rcu_access_pointer(sk->sk_reuseport_cb))) {
+               /* Paired with REAE_ONCE() in sk_incoming_cpu_update()
+                * and compute_score().
+                */
+               WRITE_ONCE(sk->sk_incoming_cpu, val);
+               return;
+       }
+
+       spin_lock_bh(&reuseport_lock);
+
+       /* This must be done under reuseport_lock to avoid a race with
+        * reuseport_grow(), which accesses sk->sk_incoming_cpu without
+        * lock_sock() when detaching a shutdown()ed sk.
+        *
+        * Paired with READ_ONCE() in reuseport_select_sock_by_hash().
+        */
+       old_sk_incoming_cpu = sk->sk_incoming_cpu;
+       WRITE_ONCE(sk->sk_incoming_cpu, val);
+
+       reuse = rcu_dereference_protected(sk->sk_reuseport_cb,
+                                         lockdep_is_held(&reuseport_lock));
+
+       /* reuseport_grow() has detached a closed sk. */
+       if (!reuse)
+               goto out;
+
+       if (old_sk_incoming_cpu < 0 && val >= 0)
+               __reuseport_get_incoming_cpu(reuse);
+       else if (old_sk_incoming_cpu >= 0 && val < 0)
+               __reuseport_put_incoming_cpu(reuse);
+
+out:
+       spin_unlock_bh(&reuseport_lock);
+}
+
 static int reuseport_sock_index(struct sock *sk,
                                const struct sock_reuseport *reuse,
                                bool closed)
        /* paired with smp_rmb() in reuseport_(select|migrate)_sock() */
        smp_wmb();
        reuse->num_socks++;
+       reuseport_get_incoming_cpu(sk, reuse);
 }
 
 static bool __reuseport_detach_sock(struct sock *sk,
 
        reuse->socks[i] = reuse->socks[reuse->num_socks - 1];
        reuse->num_socks--;
+       reuseport_put_incoming_cpu(sk, reuse);
 
        return true;
 }
        reuse->socks[reuse->max_socks - reuse->num_closed_socks - 1] = sk;
        /* paired with READ_ONCE() in inet_csk_bind_conflict() */
        WRITE_ONCE(reuse->num_closed_socks, reuse->num_closed_socks + 1);
+       reuseport_get_incoming_cpu(sk, reuse);
 }
 
 static bool __reuseport_detach_closed_sock(struct sock *sk,
        reuse->socks[i] = reuse->socks[reuse->max_socks - reuse->num_closed_socks];
        /* paired with READ_ONCE() in inet_csk_bind_conflict() */
        WRITE_ONCE(reuse->num_closed_socks, reuse->num_closed_socks - 1);
+       reuseport_put_incoming_cpu(sk, reuse);
 
        return true;
 }
        reuse->bind_inany = bind_inany;
        reuse->socks[0] = sk;
        reuse->num_socks = 1;
+       reuseport_get_incoming_cpu(sk, reuse);
        rcu_assign_pointer(sk->sk_reuseport_cb, reuse);
 
 out:
        more_reuse->reuseport_id = reuse->reuseport_id;
        more_reuse->bind_inany = reuse->bind_inany;
        more_reuse->has_conns = reuse->has_conns;
+       more_reuse->incoming_cpu = reuse->incoming_cpu;
 
        memcpy(more_reuse->socks, reuse->socks,
               reuse->num_socks * sizeof(struct sock *));
 static struct sock *reuseport_select_sock_by_hash(struct sock_reuseport *reuse,
                                                  u32 hash, u16 num_socks)
 {
+       struct sock *first_valid_sk = NULL;
        int i, j;
 
        i = j = reciprocal_scale(hash, num_socks);
-       while (reuse->socks[i]->sk_state == TCP_ESTABLISHED) {
+       do {
+               struct sock *sk = reuse->socks[i];
+
+               if (sk->sk_state != TCP_ESTABLISHED) {
+                       /* Paired with WRITE_ONCE() in __reuseport_(get|put)_incoming_cpu(). */
+                       if (!READ_ONCE(reuse->incoming_cpu))
+                               return sk;
+
+                       /* Paired with WRITE_ONCE() in reuseport_update_incoming_cpu(). */
+                       if (READ_ONCE(sk->sk_incoming_cpu) == raw_smp_processor_id())
+                               return sk;
+
+                       if (!first_valid_sk)
+                               first_valid_sk = sk;
+               }
+
                i++;
                if (i >= num_socks)
                        i = 0;
-               if (i == j)
-                       return NULL;
-       }
+       } while (i != j);
 
-       return reuse->socks[i];
+       return first_valid_sk;
 }
 
 /**