struct vsock_sock *vsk;
 
        list_for_each_entry(vsk, vsock_bound_sockets(addr), bound_table)
-               if (vsock_addr_equals_addr_any(addr, &vsk->local_addr))
+               if (addr->svm_port == vsk->local_addr.svm_port)
                        return sk_vsock(vsk);
 
        return NULL;
 
        list_for_each_entry(vsk, vsock_connected_sockets(src, dst),
                            connected_table) {
-               if (vsock_addr_equals_addr(src, &vsk->remote_addr)
-                   && vsock_addr_equals_addr(dst, &vsk->local_addr)) {
+               if (vsock_addr_equals_addr(src, &vsk->remote_addr) &&
+                   dst->svm_port == vsk->local_addr.svm_port) {
                        return sk_vsock(vsk);
                }
        }
 
        struct vsock_sock *vlistener;
        struct vsock_sock *vpending;
        struct sock *pending;
+       struct sockaddr_vm src;
+
+       vsock_addr_init(&src, pkt->dg.src.context, pkt->src_port);
 
        vlistener = vsock_sk(listener);
 
        list_for_each_entry(vpending, &vlistener->pending_links,
                            pending_links) {
-               struct sockaddr_vm src;
-               struct sockaddr_vm dst;
-
-               vsock_addr_init(&src, pkt->dg.src.context, pkt->src_port);
-               vsock_addr_init(&dst, pkt->dg.dst.context, pkt->dst_port);
-
                if (vsock_addr_equals_addr(&src, &vpending->remote_addr) &&
-                   vsock_addr_equals_addr(&dst, &vpending->local_addr)) {
+                   pkt->dst_port == vpending->local_addr.svm_port) {
                        pending = sk_vsock(vpending);
                        sock_hold(pending);
                        goto found;
         */
        bh_lock_sock(sk);
 
-       if (!sock_owned_by_user(sk) && sk->sk_state == SS_CONNECTED)
-               vmci_trans(vsk)->notify_ops->handle_notify_pkt(
-                               sk, pkt, true, &dst, &src,
-                               &bh_process_pkt);
+       if (!sock_owned_by_user(sk)) {
+               /* The local context ID may be out of date, update it. */
+               vsk->local_addr.svm_cid = dst.svm_cid;
+
+               if (sk->sk_state == SS_CONNECTED)
+                       vmci_trans(vsk)->notify_ops->handle_notify_pkt(
+                                       sk, pkt, true, &dst, &src,
+                                       &bh_process_pkt);
+       }
 
        bh_unlock_sock(sk);
 
 
        lock_sock(sk);
 
+       /* The local context ID may be out of date. */
+       vsock_sk(sk)->local_addr.svm_cid = pkt->dg.dst.context;
+
        switch (sk->sk_state) {
        case SS_LISTEN:
                vmci_transport_recv_listen(sk, pkt);
        pending = vmci_transport_get_pending(sk, pkt);
        if (pending) {
                lock_sock(pending);
+
+               /* The local context ID may be out of date. */
+               vsock_sk(pending)->local_addr.svm_cid = pkt->dg.dst.context;
+
                switch (pending->sk_state) {
                case SS_CONNECTING:
                        err = vmci_transport_recv_connecting_server(sk,
 
 }
 EXPORT_SYMBOL_GPL(vsock_addr_equals_addr);
 
-bool vsock_addr_equals_addr_any(const struct sockaddr_vm *addr,
-                               const struct sockaddr_vm *other)
-{
-       return (addr->svm_cid == VMADDR_CID_ANY ||
-               other->svm_cid == VMADDR_CID_ANY ||
-               addr->svm_cid == other->svm_cid) &&
-              addr->svm_port == other->svm_port;
-}
-EXPORT_SYMBOL_GPL(vsock_addr_equals_addr_any);
-
 int vsock_addr_cast(const struct sockaddr *addr,
                    size_t len, struct sockaddr_vm **out_addr)
 {
 
 void vsock_addr_unbind(struct sockaddr_vm *addr);
 bool vsock_addr_equals_addr(const struct sockaddr_vm *addr,
                            const struct sockaddr_vm *other);
-bool vsock_addr_equals_addr_any(const struct sockaddr_vm *addr,
-                               const struct sockaddr_vm *other);
 int vsock_addr_cast(const struct sockaddr *addr, size_t len,
                    struct sockaddr_vm **out_addr);