return 0;
 }
 
-struct rdma_cm_id *rdma_create_id(rdma_cm_event_handler event_handler,
+struct rdma_cm_id *rdma_create_id(struct net *net,
+                                 rdma_cm_event_handler event_handler,
                                  void *context, enum rdma_port_space ps,
                                  enum ib_qp_type qp_type)
 {
        INIT_LIST_HEAD(&id_priv->listen_list);
        INIT_LIST_HEAD(&id_priv->mc_list);
        get_random_bytes(&id_priv->seq_num, sizeof id_priv->seq_num);
-       id_priv->id.route.addr.dev_addr.net = &init_net;
+       id_priv->id.route.addr.dev_addr.net = get_net(net);
 
        return &id_priv->id;
 }
                       cma_protocol_roce(&id_priv->id);
 
        return !addr->dev_addr.bound_dev_if ||
-              (net_eq(dev_net(net_dev), &init_net) &&
+              (net_eq(dev_net(net_dev), addr->dev_addr.net) &&
                addr->dev_addr.bound_dev_if == net_dev->ifindex);
 }
 
                }
        }
 
-       bind_list = cma_ps_find(&init_net,
+       bind_list = cma_ps_find(*net_dev ? dev_net(*net_dev) : &init_net,
                                rdma_ps_from_service_id(req.service_id),
                                cma_port_from_service_id(req.service_id));
        id_priv = cma_find_listener(bind_list, cm_id, ib_event, &req, *net_dev);
 static void cma_release_port(struct rdma_id_private *id_priv)
 {
        struct rdma_bind_list *bind_list = id_priv->bind_list;
+       struct net *net = id_priv->id.route.addr.dev_addr.net;
 
        if (!bind_list)
                return;
        mutex_lock(&lock);
        hlist_del(&id_priv->node);
        if (hlist_empty(&bind_list->owners)) {
-               cma_ps_remove(&init_net, bind_list->ps, bind_list->port);
+               cma_ps_remove(net, bind_list->ps, bind_list->port);
                kfree(bind_list);
        }
        mutex_unlock(&lock);
                cma_deref_id(id_priv->id.context);
 
        kfree(id_priv->id.route.path_rec);
+       put_net(id_priv->id.route.addr.dev_addr.net);
        kfree(id_priv);
 }
 EXPORT_SYMBOL(rdma_destroy_id);
                      ib_event->param.req_rcvd.primary_path->service_id;
        int ret;
 
-       id = rdma_create_id(listen_id->event_handler, listen_id->context,
+       id = rdma_create_id(listen_id->route.addr.dev_addr.net,
+                           listen_id->event_handler, listen_id->context,
                            listen_id->ps, ib_event->param.req_rcvd.qp_type);
        if (IS_ERR(id))
                return NULL;
        struct rdma_id_private *id_priv;
        struct rdma_cm_id *id;
        const sa_family_t ss_family = listen_id->route.addr.src_addr.ss_family;
+       struct net *net = listen_id->route.addr.dev_addr.net;
        int ret;
 
-       id = rdma_create_id(listen_id->event_handler, listen_id->context,
+       id = rdma_create_id(net, listen_id->event_handler, listen_id->context,
                            listen_id->ps, IB_QPT_UD);
        if (IS_ERR(id))
                return NULL;
                return -ECONNABORTED;
 
        /* Create a new RDMA id for the new IW CM ID */
-       new_cm_id = rdma_create_id(listen_id->id.event_handler,
+       new_cm_id = rdma_create_id(listen_id->id.route.addr.dev_addr.net,
+                                  listen_id->id.event_handler,
                                   listen_id->id.context,
                                   RDMA_PS_TCP, IB_QPT_RC);
        if (IS_ERR(new_cm_id)) {
 {
        struct rdma_id_private *dev_id_priv;
        struct rdma_cm_id *id;
+       struct net *net = id_priv->id.route.addr.dev_addr.net;
        int ret;
 
        if (cma_family(id_priv) == AF_IB && !rdma_cap_ib_cm(cma_dev->device, 1))
                return;
 
-       id = rdma_create_id(cma_listen_handler, id_priv, id_priv->id.ps,
+       id = rdma_create_id(net, cma_listen_handler, id_priv, id_priv->id.ps,
                            id_priv->id.qp_type);
        if (IS_ERR(id))
                return;
        if (!bind_list)
                return -ENOMEM;
 
-       ret = cma_ps_alloc(&init_net, ps, bind_list, snum);
+       ret = cma_ps_alloc(id_priv->id.route.addr.dev_addr.net, ps, bind_list,
+                          snum);
        if (ret < 0)
                goto err;
 
        static unsigned int last_used_port;
        int low, high, remaining;
        unsigned int rover;
+       struct net *net = id_priv->id.route.addr.dev_addr.net;
 
-       inet_get_local_port_range(&init_net, &low, &high);
+       inet_get_local_port_range(net, &low, &high);
        remaining = (high - low) + 1;
        rover = prandom_u32() % remaining + low;
 retry:
        if (last_used_port != rover &&
-           !cma_ps_find(&init_net, ps, (unsigned short)rover)) {
+           !cma_ps_find(net, ps, (unsigned short)rover)) {
                int ret = cma_alloc_port(ps, id_priv, rover);
                /*
                 * Remember previously used port number in order to avoid
        if (snum < PROT_SOCK && !capable(CAP_NET_BIND_SERVICE))
                return -EACCES;
 
-       bind_list = cma_ps_find(&init_net, ps, snum);
+       bind_list = cma_ps_find(id_priv->id.route.addr.dev_addr.net, ps, snum);
        if (!bind_list) {
                ret = cma_alloc_port(ps, id_priv, snum);
        } else {
                if (addr->sa_family == AF_INET)
                        id_priv->afonly = 1;
 #if IS_ENABLED(CONFIG_IPV6)
-               else if (addr->sa_family == AF_INET6)
-                       id_priv->afonly = init_net.ipv6.sysctl.bindv6only;
+               else if (addr->sa_family == AF_INET6) {
+                       struct net *net = id_priv->id.route.addr.dev_addr.net;
+
+                       id_priv->afonly = net->ipv6.sysctl.bindv6only;
+               }
 #endif
        }
        ret = cma_get_port(id_priv);
        dev_addr = &id_priv->id.route.addr.dev_addr;
 
        if ((dev_addr->bound_dev_if == ndev->ifindex) &&
+           (net_eq(dev_net(ndev), dev_addr->net)) &&
            memcmp(dev_addr->src_dev_addr, ndev->dev_addr, ndev->addr_len)) {
                printk(KERN_INFO "RDMA CM addr change for ndev %s used by id %p\n",
                       ndev->name, &id_priv->id);
        struct rdma_id_private *id_priv;
        int ret = NOTIFY_DONE;
 
-       if (dev_net(ndev) != &init_net)
-               return NOTIFY_DONE;
-
        if (event != NETDEV_BONDING_FAILOVER)
                return NOTIFY_DONE;