#define VSOCK_DEFAULT_BUFFER_MAX_SIZE (1024 * 256)
 #define VSOCK_DEFAULT_BUFFER_MIN_SIZE 128
 
-static const struct vsock_transport *transport_single;
+/* Transport used for host->guest communication */
+static const struct vsock_transport *transport_h2g;
+/* Transport used for guest->host communication */
+static const struct vsock_transport *transport_g2h;
+/* Transport used for DGRAM communication */
+static const struct vsock_transport *transport_dgram;
 static DEFINE_MUTEX(vsock_register_mutex);
 
 /**** UTILS ****/
        return __vsock_bind(sk, &local_addr);
 }
 
-static int __init vsock_init_tables(void)
+static void vsock_init_tables(void)
 {
        int i;
 
 
        for (i = 0; i < ARRAY_SIZE(vsock_connected_table); i++)
                INIT_LIST_HEAD(&vsock_connected_table[i]);
-       return 0;
 }
 
 static void __vsock_insert_bound(struct list_head *list,
 }
 EXPORT_SYMBOL_GPL(vsock_enqueue_accept);
 
+/* Assign a transport to a socket and call the .init transport callback.
+ *
+ * Note: for stream socket this must be called when vsk->remote_addr is set
+ * (e.g. during the connect() or when a connection request on a listener
+ * socket is received).
+ * The vsk->remote_addr is used to decide which transport to use:
+ *  - remote CID <= VMADDR_CID_HOST will use guest->host transport;
+ *  - remote CID == local_cid (guest->host transport) will use guest->host
+ *    transport for loopback (host->guest transports don't support loopback);
+ *  - remote CID > VMADDR_CID_HOST will use host->guest transport;
+ */
+int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk)
+{
+       const struct vsock_transport *new_transport;
+       struct sock *sk = sk_vsock(vsk);
+       unsigned int remote_cid = vsk->remote_addr.svm_cid;
+
+       switch (sk->sk_type) {
+       case SOCK_DGRAM:
+               new_transport = transport_dgram;
+               break;
+       case SOCK_STREAM:
+               if (remote_cid <= VMADDR_CID_HOST ||
+                   (transport_g2h &&
+                    remote_cid == transport_g2h->get_local_cid()))
+                       new_transport = transport_g2h;
+               else
+                       new_transport = transport_h2g;
+               break;
+       default:
+               return -ESOCKTNOSUPPORT;
+       }
+
+       if (vsk->transport) {
+               if (vsk->transport == new_transport)
+                       return 0;
+
+               vsk->transport->release(vsk);
+               vsk->transport->destruct(vsk);
+       }
+
+       if (!new_transport)
+               return -ENODEV;
+
+       vsk->transport = new_transport;
+
+       return vsk->transport->init(vsk, psk);
+}
+EXPORT_SYMBOL_GPL(vsock_assign_transport);
+
+bool vsock_find_cid(unsigned int cid)
+{
+       if (transport_g2h && cid == transport_g2h->get_local_cid())
+               return true;
+
+       if (transport_h2g && cid == VMADDR_CID_HOST)
+               return true;
+
+       return false;
+}
+EXPORT_SYMBOL_GPL(vsock_find_cid);
+
 static struct sock *vsock_dequeue_accept(struct sock *listener)
 {
        struct vsock_sock *vlistener;
 {
        struct vsock_sock *vsk = vsock_sk(sk);
 
+       if (!vsk->transport)
+               return -ENODEV;
+
        return vsk->transport->shutdown(vsk, mode);
 }
 
 static int __vsock_bind(struct sock *sk, struct sockaddr_vm *addr)
 {
        struct vsock_sock *vsk = vsock_sk(sk);
-       u32 cid;
        int retval;
 
        /* First ensure this socket isn't already bound. */
        /* Now bind to the provided address or select appropriate values if
         * none are provided (VMADDR_CID_ANY and VMADDR_PORT_ANY).  Note that
         * like AF_INET prevents binding to a non-local IP address (in most
-        * cases), we only allow binding to the local CID.
+        * cases), we only allow binding to a local CID.
         */
-       cid = vsk->transport->get_local_cid();
-       if (addr->svm_cid != cid && addr->svm_cid != VMADDR_CID_ANY)
+       if (addr->svm_cid != VMADDR_CID_ANY && !vsock_find_cid(addr->svm_cid))
                return -EADDRNOTAVAIL;
 
        switch (sk->sk_socket->type) {
                sk->sk_type = type;
 
        vsk = vsock_sk(sk);
-       vsk->transport = transport_single;
        vsock_addr_init(&vsk->local_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY);
        vsock_addr_init(&vsk->remote_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY);
 
                vsk->buffer_max_size = VSOCK_DEFAULT_BUFFER_MAX_SIZE;
        }
 
-       if (vsk->transport->init(vsk, psk) < 0) {
-               sk_free(sk);
-               return NULL;
-       }
-
        return sk;
 }
 
                /* The release call is supposed to use lock_sock_nested()
                 * rather than lock_sock(), if a sock lock should be acquired.
                 */
-               vsk->transport->release(vsk);
+               if (vsk->transport)
+                       vsk->transport->release(vsk);
+               else if (sk->sk_type == SOCK_STREAM)
+                       vsock_remove_sock(vsk);
 
                /* When "level" is SINGLE_DEPTH_NESTING, use the nested
                 * version to avoid the warning "possible recursive locking
 {
        struct vsock_sock *vsk = vsock_sk(sk);
 
-       vsk->transport->destruct(vsk);
+       if (vsk->transport)
+               vsk->transport->destruct(vsk);
 
        /* When clearing these addresses, there's no need to set the family and
         * possibly register the address family with the kernel.
                        mask |= EPOLLIN | EPOLLRDNORM;
 
                /* If there is something in the queue then we can read. */
-               if (transport->stream_is_active(vsk) &&
+               if (transport && transport->stream_is_active(vsk) &&
                    !(sk->sk_shutdown & RCV_SHUTDOWN)) {
                        bool data_ready_now = false;
                        int ret = transport->notify_poll_in(
        err = 0;
        sk = sock->sk;
        vsk = vsock_sk(sk);
-       transport = vsk->transport;
 
        lock_sock(sk);
 
                        goto out;
                }
 
+               /* Set the remote address that we are connecting to. */
+               memcpy(&vsk->remote_addr, remote_addr,
+                      sizeof(vsk->remote_addr));
+
+               err = vsock_assign_transport(vsk, NULL);
+               if (err)
+                       goto out;
+
+               transport = vsk->transport;
+
                /* The hypervisor and well-known contexts do not have socket
                 * endpoints.
                 */
-               if (!transport->stream_allow(remote_addr->svm_cid,
+               if (!transport ||
+                   !transport->stream_allow(remote_addr->svm_cid,
                                             remote_addr->svm_port)) {
                        err = -ENETUNREACH;
                        goto out;
                }
 
-               /* Set the remote address that we are connecting to. */
-               memcpy(&vsk->remote_addr, remote_addr,
-                      sizeof(vsk->remote_addr));
-
                err = vsock_auto_bind(vsk);
                if (err)
                        goto out;
                goto out;
        }
 
-       if (sk->sk_state != TCP_ESTABLISHED ||
+       if (!transport || sk->sk_state != TCP_ESTABLISHED ||
            !vsock_addr_bound(&vsk->local_addr)) {
                err = -ENOTCONN;
                goto out;
 
        lock_sock(sk);
 
-       if (sk->sk_state != TCP_ESTABLISHED) {
+       if (!transport || sk->sk_state != TCP_ESTABLISHED) {
                /* Recvmsg is supposed to return 0 if a peer performs an
                 * orderly shutdown. Differentiate between that case and when a
                 * peer has not connected or a local shutdown occured with the
 static int vsock_create(struct net *net, struct socket *sock,
                        int protocol, int kern)
 {
+       struct vsock_sock *vsk;
        struct sock *sk;
+       int ret;
 
        if (!sock)
                return -EINVAL;
        if (!sk)
                return -ENOMEM;
 
-       vsock_insert_unbound(vsock_sk(sk));
+       vsk = vsock_sk(sk);
+
+       if (sock->type == SOCK_DGRAM) {
+               ret = vsock_assign_transport(vsk, NULL);
+               if (ret < 0) {
+                       sock_put(sk);
+                       return ret;
+               }
+       }
+
+       vsock_insert_unbound(vsk);
 
        return 0;
 }
                               unsigned int cmd, void __user *ptr)
 {
        u32 __user *p = ptr;
+       u32 cid = VMADDR_CID_ANY;
        int retval = 0;
 
        switch (cmd) {
        case IOCTL_VM_SOCKETS_GET_LOCAL_CID:
-               if (put_user(transport_single->get_local_cid(), p) != 0)
+               /* To be compatible with the VMCI behavior, we prioritize the
+                * guest CID instead of well-know host CID (VMADDR_CID_HOST).
+                */
+               if (transport_g2h)
+                       cid = transport_g2h->get_local_cid();
+               else if (transport_h2g)
+                       cid = transport_h2g->get_local_cid();
+
+               if (put_user(cid, p) != 0)
                        retval = -EFAULT;
                break;
 
        .fops           = &vsock_device_ops,
 };
 
-int __vsock_core_init(const struct vsock_transport *t, struct module *owner)
+static int __init vsock_init(void)
 {
-       int err = mutex_lock_interruptible(&vsock_register_mutex);
+       int err = 0;
 
-       if (err)
-               return err;
-
-       if (transport_single) {
-               err = -EBUSY;
-               goto err_busy;
-       }
-
-       /* Transport must be the owner of the protocol so that it can't
-        * unload while there are open sockets.
-        */
-       vsock_proto.owner = owner;
-       transport_single = t;
+       vsock_init_tables();
 
+       vsock_proto.owner = THIS_MODULE;
        vsock_device.minor = MISC_DYNAMIC_MINOR;
        err = misc_register(&vsock_device);
        if (err) {
                goto err_unregister_proto;
        }
 
-       mutex_unlock(&vsock_register_mutex);
        return 0;
 
 err_unregister_proto:
 err_deregister_misc:
        misc_deregister(&vsock_device);
 err_reset_transport:
-       transport_single = NULL;
-err_busy:
-       mutex_unlock(&vsock_register_mutex);
        return err;
 }
-EXPORT_SYMBOL_GPL(__vsock_core_init);
 
-void vsock_core_exit(void)
+static void __exit vsock_exit(void)
 {
-       mutex_lock(&vsock_register_mutex);
-
        misc_deregister(&vsock_device);
        sock_unregister(AF_VSOCK);
        proto_unregister(&vsock_proto);
-
-       /* We do not want the assignment below re-ordered. */
-       mb();
-       transport_single = NULL;
-
-       mutex_unlock(&vsock_register_mutex);
 }
-EXPORT_SYMBOL_GPL(vsock_core_exit);
 
 const struct vsock_transport *vsock_core_get_transport(struct vsock_sock *vsk)
 {
 }
 EXPORT_SYMBOL_GPL(vsock_core_get_transport);
 
-static void __exit vsock_exit(void)
+int vsock_core_register(const struct vsock_transport *t, int features)
+{
+       const struct vsock_transport *t_h2g, *t_g2h, *t_dgram;
+       int err = mutex_lock_interruptible(&vsock_register_mutex);
+
+       if (err)
+               return err;
+
+       t_h2g = transport_h2g;
+       t_g2h = transport_g2h;
+       t_dgram = transport_dgram;
+
+       if (features & VSOCK_TRANSPORT_F_H2G) {
+               if (t_h2g) {
+                       err = -EBUSY;
+                       goto err_busy;
+               }
+               t_h2g = t;
+       }
+
+       if (features & VSOCK_TRANSPORT_F_G2H) {
+               if (t_g2h) {
+                       err = -EBUSY;
+                       goto err_busy;
+               }
+               t_g2h = t;
+       }
+
+       if (features & VSOCK_TRANSPORT_F_DGRAM) {
+               if (t_dgram) {
+                       err = -EBUSY;
+                       goto err_busy;
+               }
+               t_dgram = t;
+       }
+
+       transport_h2g = t_h2g;
+       transport_g2h = t_g2h;
+       transport_dgram = t_dgram;
+
+err_busy:
+       mutex_unlock(&vsock_register_mutex);
+       return err;
+}
+EXPORT_SYMBOL_GPL(vsock_core_register);
+
+void vsock_core_unregister(const struct vsock_transport *t)
 {
-       /* Do nothing.  This function makes this module removable. */
+       mutex_lock(&vsock_register_mutex);
+
+       if (transport_h2g == t)
+               transport_h2g = NULL;
+
+       if (transport_g2h == t)
+               transport_g2h = NULL;
+
+       if (transport_dgram == t)
+               transport_dgram = NULL;
+
+       mutex_unlock(&vsock_register_mutex);
 }
+EXPORT_SYMBOL_GPL(vsock_core_unregister);
 
-module_init(vsock_init_tables);
+module_init(vsock_init);
 module_exit(vsock_exit);
 
 MODULE_AUTHOR("VMware, Inc.");
 
 
        vsk->trans = vvs;
        vvs->vsk = vsk;
-       if (psk) {
+       if (psk && psk->trans) {
                struct virtio_vsock_sock *ptrans = psk->trans;
 
                vvs->peer_buf_alloc = ptrans->peer_buf_alloc;
        return virtio_transport_send_pkt_info(vsk, &info);
 }
 
+static bool virtio_transport_space_update(struct sock *sk,
+                                         struct virtio_vsock_pkt *pkt)
+{
+       struct vsock_sock *vsk = vsock_sk(sk);
+       struct virtio_vsock_sock *vvs = vsk->trans;
+       bool space_available;
+
+       /* Listener sockets are not associated with any transport, so we are
+        * not able to take the state to see if there is space available in the
+        * remote peer, but since they are only used to receive requests, we
+        * can assume that there is always space available in the other peer.
+        */
+       if (!vvs)
+               return true;
+
+       /* buf_alloc and fwd_cnt is always included in the hdr */
+       spin_lock_bh(&vvs->tx_lock);
+       vvs->peer_buf_alloc = le32_to_cpu(pkt->hdr.buf_alloc);
+       vvs->peer_fwd_cnt = le32_to_cpu(pkt->hdr.fwd_cnt);
+       space_available = virtio_transport_has_space(vsk);
+       spin_unlock_bh(&vvs->tx_lock);
+       return space_available;
+}
+
 /* Handle server socket */
 static int
-virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt)
+virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt,
+                            struct virtio_transport *t)
 {
        struct vsock_sock *vsk = vsock_sk(sk);
        struct vsock_sock *vchild;
        struct sock *child;
+       int ret;
 
        if (le16_to_cpu(pkt->hdr.op) != VIRTIO_VSOCK_OP_REQUEST) {
                virtio_transport_reset(vsk, pkt);
        vsock_addr_init(&vchild->remote_addr, le64_to_cpu(pkt->hdr.src_cid),
                        le32_to_cpu(pkt->hdr.src_port));
 
+       ret = vsock_assign_transport(vchild, vsk);
+       /* Transport assigned (looking at remote_addr) must be the same
+        * where we received the request.
+        */
+       if (ret || vchild->transport != &t->transport) {
+               release_sock(child);
+               virtio_transport_reset(vsk, pkt);
+               sock_put(child);
+               return ret;
+       }
+
+       if (virtio_transport_space_update(child, pkt))
+               child->sk_write_space(child);
+
        vsock_insert_connected(vchild);
        vsock_enqueue_accept(sk, child);
        virtio_transport_send_response(vchild, pkt);
        return 0;
 }
 
-static bool virtio_transport_space_update(struct sock *sk,
-                                         struct virtio_vsock_pkt *pkt)
-{
-       struct vsock_sock *vsk = vsock_sk(sk);
-       struct virtio_vsock_sock *vvs = vsk->trans;
-       bool space_available;
-
-       /* buf_alloc and fwd_cnt is always included in the hdr */
-       spin_lock_bh(&vvs->tx_lock);
-       vvs->peer_buf_alloc = le32_to_cpu(pkt->hdr.buf_alloc);
-       vvs->peer_fwd_cnt = le32_to_cpu(pkt->hdr.fwd_cnt);
-       space_available = virtio_transport_has_space(vsk);
-       spin_unlock_bh(&vvs->tx_lock);
-       return space_available;
-}
-
 /* We are under the virtio-vsock's vsock->rx_lock or vhost-vsock's vq->mutex
  * lock.
  */
 
        switch (sk->sk_state) {
        case TCP_LISTEN:
-               virtio_transport_recv_listen(sk, pkt);
+               virtio_transport_recv_listen(sk, pkt, t);
                virtio_transport_free_pkt(pkt);
                break;
        case TCP_SYN_SENT:
                virtio_transport_free_pkt(pkt);
                break;
        }
+
        release_sock(sk);
 
        /* Release refcnt obtained when we fetched this socket out of the
 
 static u16 vmci_transport_new_proto_supported_versions(void);
 static bool vmci_transport_proto_to_notify_struct(struct sock *sk, u16 *proto,
                                                  bool old_pkt_proto);
+static bool vmci_check_transport(struct vsock_sock *vsk);
 
 struct vmci_transport_recv_pkt_info {
        struct work_struct work;
        vsock_addr_init(&vpending->remote_addr, pkt->dg.src.context,
                        pkt->src_port);
 
+       err = vsock_assign_transport(vpending, vsock_sk(sk));
+       /* Transport assigned (looking at remote_addr) must be the same
+        * where we received the request.
+        */
+       if (err || !vmci_check_transport(vpending)) {
+               vmci_transport_send_reset(sk, pkt);
+               sock_put(pending);
+               return err;
+       }
+
        /* If the proposed size fits within our min/max, accept it. Otherwise
         * propose our own size.
         */
        return vmci_get_context_id();
 }
 
-static const struct vsock_transport vmci_transport = {
+static struct vsock_transport vmci_transport = {
        .init = vmci_transport_socket_init,
        .destruct = vmci_transport_destruct,
        .release = vmci_transport_release,
        .get_local_cid = vmci_transport_get_local_cid,
 };
 
+static bool vmci_check_transport(struct vsock_sock *vsk)
+{
+       return vsk->transport == &vmci_transport;
+}
+
 static int __init vmci_transport_init(void)
 {
+       int features = VSOCK_TRANSPORT_F_DGRAM | VSOCK_TRANSPORT_F_H2G;
+       int cid;
        int err;
 
+       cid = vmci_get_context_id();
+
+       if (cid == VMCI_INVALID_ID)
+               return -EINVAL;
+
+       if (cid != VMCI_HOST_CONTEXT_ID)
+               features |= VSOCK_TRANSPORT_F_G2H;
+
        /* Create the datagram handle that we will use to send and receive all
         * VSocket control messages for this context.
         */
                goto err_destroy_stream_handle;
        }
 
-       err = vsock_core_init(&vmci_transport);
+       err = vsock_core_register(&vmci_transport, features);
        if (err < 0)
                goto err_unsubscribe;
 
                vmci_transport_qp_resumed_sub_id = VMCI_INVALID_ID;
        }
 
-       vsock_core_exit();
+       vsock_core_unregister(&vmci_transport);
 }
 module_exit(vmci_transport_exit);