struct sock *vsock_find_bound_socket(struct sockaddr_vm *addr);
 struct sock *vsock_find_connected_socket(struct sockaddr_vm *src,
                                         struct sockaddr_vm *dst);
+void vsock_remove_sock(struct vsock_sock *vsk);
 void vsock_for_each_connected_socket(void (*fn)(struct sock *sk));
 
 #endif /* __AF_VSOCK_H__ */
 
        return ret;
 }
 
+void vsock_remove_sock(struct vsock_sock *vsk)
+{
+       if (vsock_in_bound_table(vsk))
+               vsock_remove_bound(vsk);
+
+       if (vsock_in_connected_table(vsk))
+               vsock_remove_connected(vsk);
+}
+EXPORT_SYMBOL_GPL(vsock_remove_sock);
+
 void vsock_for_each_connected_socket(void (*fn)(struct sock *sk))
 {
        int i;
                vsk = vsock_sk(sk);
                pending = NULL; /* Compiler warning. */
 
-               if (vsock_in_bound_table(vsk))
-                       vsock_remove_bound(vsk);
-
-               if (vsock_in_connected_table(vsk))
-                       vsock_remove_connected(vsk);
-
                transport->release(vsk);
 
                lock_sock(sk);
 
 
 static void vmci_transport_release(struct vsock_sock *vsk)
 {
+       vsock_remove_sock(vsk);
+
        if (!vmci_handle_is_invalid(vmci_trans(vsk)->dg_handle)) {
                vmci_datagram_destroy_handle(vmci_trans(vsk)->dg_handle);
                vmci_trans(vsk)->dg_handle = VMCI_INVALID_HANDLE;