}
 }
 
-int xdp_umem_query(struct net_device *dev, u16 queue_id)
+/* The umem is stored both in the _rx struct and the _tx struct as we do
+ * not know if the device has more tx queues than rx, or the opposite.
+ * This might also change during run time.
+ */
+static void xdp_reg_umem_at_qid(struct net_device *dev, struct xdp_umem *umem,
+                               u16 queue_id)
 {
-       struct netdev_bpf bpf;
+       if (queue_id < dev->real_num_rx_queues)
+               dev->_rx[queue_id].umem = umem;
+       if (queue_id < dev->real_num_tx_queues)
+               dev->_tx[queue_id].umem = umem;
+}
 
-       ASSERT_RTNL();
+static struct xdp_umem *xdp_get_umem_from_qid(struct net_device *dev,
+                                             u16 queue_id)
+{
+       if (queue_id < dev->real_num_rx_queues)
+               return dev->_rx[queue_id].umem;
+       if (queue_id < dev->real_num_tx_queues)
+               return dev->_tx[queue_id].umem;
 
-       memset(&bpf, 0, sizeof(bpf));
-       bpf.command = XDP_QUERY_XSK_UMEM;
-       bpf.xsk.queue_id = queue_id;
+       return NULL;
+}
 
-       if (!dev->netdev_ops->ndo_bpf)
-               return 0;
-       return dev->netdev_ops->ndo_bpf(dev, &bpf) ?: !!bpf.xsk.umem;
+static void xdp_clear_umem_at_qid(struct net_device *dev, u16 queue_id)
+{
+       /* Zero out the entry independent on how many queues are configured
+        * at this point in time, as it might be used in the future.
+        */
+       if (queue_id < dev->num_rx_queues)
+               dev->_rx[queue_id].umem = NULL;
+       if (queue_id < dev->num_tx_queues)
+               dev->_tx[queue_id].umem = NULL;
 }
 
 int xdp_umem_assign_dev(struct xdp_umem *umem, struct net_device *dev,
-                       u32 queue_id, u16 flags)
+                       u16 queue_id, u16 flags)
 {
        bool force_zc, force_copy;
        struct netdev_bpf bpf;
-       int err;
+       int err = 0;
 
        force_zc = flags & XDP_ZEROCOPY;
        force_copy = flags & XDP_COPY;
        if (force_zc && force_copy)
                return -EINVAL;
 
-       if (force_copy)
-               return 0;
+       rtnl_lock();
+       if (xdp_get_umem_from_qid(dev, queue_id)) {
+               err = -EBUSY;
+               goto out_rtnl_unlock;
+       }
 
-       if (!dev->netdev_ops->ndo_bpf || !dev->netdev_ops->ndo_xsk_async_xmit)
-               return force_zc ? -EOPNOTSUPP : 0; /* fail or fallback */
+       xdp_reg_umem_at_qid(dev, umem, queue_id);
+       umem->dev = dev;
+       umem->queue_id = queue_id;
+       if (force_copy)
+               /* For copy-mode, we are done. */
+               goto out_rtnl_unlock;
 
-       rtnl_lock();
-       err = xdp_umem_query(dev, queue_id);
-       if (err) {
-               err = err < 0 ? -EOPNOTSUPP : -EBUSY;
-               goto err_rtnl_unlock;
+       if (!dev->netdev_ops->ndo_bpf ||
+           !dev->netdev_ops->ndo_xsk_async_xmit) {
+               err = -EOPNOTSUPP;
+               goto err_unreg_umem;
        }
 
        bpf.command = XDP_SETUP_XSK_UMEM;
 
        err = dev->netdev_ops->ndo_bpf(dev, &bpf);
        if (err)
-               goto err_rtnl_unlock;
+               goto err_unreg_umem;
        rtnl_unlock();
 
        dev_hold(dev);
-       umem->dev = dev;
-       umem->queue_id = queue_id;
        umem->zc = true;
        return 0;
 
-err_rtnl_unlock:
+err_unreg_umem:
+       xdp_clear_umem_at_qid(dev, queue_id);
+       if (!force_zc)
+               err = 0; /* fallback to copy mode */
+out_rtnl_unlock:
        rtnl_unlock();
-       return force_zc ? err : 0; /* fail or fallback */
+       return err;
 }
 
 static void xdp_umem_clear_dev(struct xdp_umem *umem)
        struct netdev_bpf bpf;
        int err;
 
-       if (umem->dev) {
+       if (umem->zc) {
                bpf.command = XDP_SETUP_XSK_UMEM;
                bpf.xsk.umem = NULL;
                bpf.xsk.queue_id = umem->queue_id;
 
                if (err)
                        WARN(1, "failed to disable umem!\n");
+       }
+
+       if (umem->dev) {
+               rtnl_lock();
+               xdp_clear_umem_at_qid(umem->dev, umem->queue_id);
+               rtnl_unlock();
+       }
 
+       if (umem->zc) {
                dev_put(umem->dev);
-               umem->dev = NULL;
+               umem->zc = false;
        }
 }