!xa_is_err(entry);                                                \
             (index)++, entry = xan_find_marked(xa, &(index), filter))
 
+static void free_netdevs(struct ib_device *ib_dev);
 static int ib_security_change(struct notifier_block *nb, unsigned long event,
                              void *lsm_data);
 static void ib_policy_change_task(struct work_struct *work);
 {
        struct ib_device *dev = container_of(device, struct ib_device, dev);
 
+       free_netdevs(dev);
        WARN_ON(refcount_read(&dev->refcount));
        ib_cache_release_one(dev);
        ib_security_release_port_pkey_list(dev);
  */
 void ib_dealloc_device(struct ib_device *device)
 {
+       /* Expedite releasing netdev references */
+       free_netdevs(device);
+
        WARN_ON(!xa_empty(&device->client_data));
        WARN_ON(refcount_read(&device->refcount));
        rdma_restrack_clean(device);
        up_read(&device->client_data_rwsem);
 }
 
-static int verify_immutable(const struct ib_device *dev, u8 port)
-{
-       return WARN_ON(!rdma_cap_ib_mad(dev, port) &&
-                           rdma_max_mad_size(dev, port) != 0);
-}
-
-static int setup_port_data(struct ib_device *device)
+static int alloc_port_data(struct ib_device *device)
 {
        unsigned int port;
-       int ret;
+
+       if (device->port_data)
+               return 0;
+
+       /* This can only be called once the physical port range is defined */
+       if (WARN_ON(!device->phys_port_cnt))
+               return -EINVAL;
 
        /*
         * device->port_data is indexed directly by the port number to make
 
                spin_lock_init(&pdata->pkey_list_lock);
                INIT_LIST_HEAD(&pdata->pkey_list);
+               spin_lock_init(&pdata->netdev_lock);
+       }
+       return 0;
+}
+
+static int verify_immutable(const struct ib_device *dev, u8 port)
+{
+       return WARN_ON(!rdma_cap_ib_mad(dev, port) &&
+                           rdma_max_mad_size(dev, port) != 0);
+}
+
+static int setup_port_data(struct ib_device *device)
+{
+       unsigned int port;
+       int ret;
+
+       ret = alloc_port_data(device);
+       if (ret)
+               return ret;
+
+       rdma_for_each_port (device, port) {
+               struct ib_port_data *pdata = &device->port_data[port];
 
                ret = device->ops.get_port_immutable(device, port,
                                                     &pdata->immutable);
        /* Pairs with refcount_set in enable_device */
        ib_device_put(device);
        wait_for_completion(&device->unreg_completion);
+
+       /* Expedite removing unregistered pointers from the hash table */
+       free_netdevs(device);
 }
 
 /*
 }
 EXPORT_SYMBOL(ib_query_port);
 
+/**
+ * ib_device_set_netdev - Associate the ib_dev with an underlying net_device
+ * @ib_dev: Device to modify
+ * @ndev: net_device to affiliate, may be NULL
+ * @port: IB port the net_device is connected to
+ *
+ * Drivers should use this to link the ib_device to a netdev so the netdev
+ * shows up in interfaces like ib_enum_roce_netdev. Only one netdev may be
+ * affiliated with any port.
+ *
+ * The caller must ensure that the given ndev is not unregistered or
+ * unregistering, and that either the ib_device is unregistered or
+ * ib_device_set_netdev() is called with NULL when the ndev sends a
+ * NETDEV_UNREGISTER event.
+ */
+int ib_device_set_netdev(struct ib_device *ib_dev, struct net_device *ndev,
+                        unsigned int port)
+{
+       struct net_device *old_ndev;
+       struct ib_port_data *pdata;
+       unsigned long flags;
+       int ret;
+
+       /*
+        * Drivers wish to call this before ib_register_driver, so we have to
+        * setup the port data early.
+        */
+       ret = alloc_port_data(ib_dev);
+       if (ret)
+               return ret;
+
+       if (!rdma_is_port_valid(ib_dev, port))
+               return -EINVAL;
+
+       pdata = &ib_dev->port_data[port];
+       spin_lock_irqsave(&pdata->netdev_lock, flags);
+       if (pdata->netdev == ndev) {
+               spin_unlock_irqrestore(&pdata->netdev_lock, flags);
+               return 0;
+       }
+       old_ndev = pdata->netdev;
+
+       if (ndev)
+               dev_hold(ndev);
+       pdata->netdev = ndev;
+       spin_unlock_irqrestore(&pdata->netdev_lock, flags);
+
+       if (old_ndev)
+               dev_put(old_ndev);
+
+       return 0;
+}
+EXPORT_SYMBOL(ib_device_set_netdev);
+
+static void free_netdevs(struct ib_device *ib_dev)
+{
+       unsigned long flags;
+       unsigned int port;
+
+       rdma_for_each_port (ib_dev, port) {
+               struct ib_port_data *pdata = &ib_dev->port_data[port];
+
+               spin_lock_irqsave(&pdata->netdev_lock, flags);
+               if (pdata->netdev) {
+                       dev_put(pdata->netdev);
+                       pdata->netdev = NULL;
+               }
+               spin_unlock_irqrestore(&pdata->netdev_lock, flags);
+       }
+}
+
+struct net_device *ib_device_get_netdev(struct ib_device *ib_dev,
+                                       unsigned int port)
+{
+       struct ib_port_data *pdata;
+       struct net_device *res;
+
+       if (!rdma_is_port_valid(ib_dev, port))
+               return NULL;
+
+       pdata = &ib_dev->port_data[port];
+
+       /*
+        * New drivers should use ib_device_set_netdev() not the legacy
+        * get_netdev().
+        */
+       if (ib_dev->ops.get_netdev)
+               res = ib_dev->ops.get_netdev(ib_dev, port);
+       else {
+               spin_lock(&pdata->netdev_lock);
+               res = pdata->netdev;
+               if (res)
+                       dev_hold(res);
+               spin_unlock(&pdata->netdev_lock);
+       }
+
+       /*
+        * If we are starting to unregister expedite things by preventing
+        * propagation of an unregistering netdev.
+        */
+       if (res && res->reg_state != NETREG_REGISTERED) {
+               dev_put(res);
+               return NULL;
+       }
+
+       return res;
+}
+
 /**
  * ib_enum_roce_netdev - enumerate all RoCE ports
  * @ib_dev : IB device we want to query
 
        rdma_for_each_port (ib_dev, port)
                if (rdma_protocol_roce(ib_dev, port)) {
-                       struct net_device *idev = NULL;
-
-                       if (ib_dev->ops.get_netdev)
-                               idev = ib_dev->ops.get_netdev(ib_dev, port);
-
-                       if (idev &&
-                           idev->reg_state >= NETREG_UNREGISTERED) {
-                               dev_put(idev);
-                               idev = NULL;
-                       }
+                       struct net_device *idev =
+                               ib_device_get_netdev(ib_dev, port);
 
                        if (filter(ib_dev, port, idev, filter_cookie))
                                cb(ib_dev, port, idev, cookie);