}
 }
 
+struct vfio_pci_group_info;
 static bool vfio_pci_dev_set_try_reset(struct vfio_device_set *dev_set);
 static void vfio_pci_disable(struct vfio_pci_device *vdev);
-static int vfio_pci_try_zap_and_vma_lock_cb(struct pci_dev *pdev, void *data);
+static int vfio_pci_dev_set_hot_reset(struct vfio_device_set *dev_set,
+                                     struct vfio_pci_group_info *groups);
 
 /*
  * INTx masking requires the ability to disable INTx signaling via PCI_COMMAND
        return 0;
 }
 
-struct vfio_pci_group_entry {
-       struct vfio_group *group;
-       int id;
-};
-
 struct vfio_pci_group_info {
        int count;
-       struct vfio_pci_group_entry *groups;
+       struct vfio_group **groups;
 };
 
-static int vfio_pci_validate_devs(struct pci_dev *pdev, void *data)
-{
-       struct vfio_pci_group_info *info = data;
-       struct iommu_group *group;
-       int id, i;
-
-       group = iommu_group_get(&pdev->dev);
-       if (!group)
-               return -EPERM;
-
-       id = iommu_group_id(group);
-
-       for (i = 0; i < info->count; i++)
-               if (info->groups[i].id == id)
-                       break;
-
-       iommu_group_put(group);
-
-       return (i == info->count) ? -EINVAL : 0;
-}
-
 static bool vfio_pci_dev_below_slot(struct pci_dev *pdev, struct pci_slot *slot)
 {
        for (; pdev; pdev = pdev->bus->self)
        return 0;
 }
 
-struct vfio_devices {
-       struct vfio_pci_device **devices;
-       int cur_index;
-       int max_index;
-};
-
 static long vfio_pci_ioctl(struct vfio_device *core_vdev,
                           unsigned int cmd, unsigned long arg)
 {
        } else if (cmd == VFIO_DEVICE_PCI_HOT_RESET) {
                struct vfio_pci_hot_reset hdr;
                int32_t *group_fds;
-               struct vfio_pci_group_entry *groups;
+               struct vfio_group **groups;
                struct vfio_pci_group_info info;
-               struct vfio_devices devs = { .cur_index = 0 };
                bool slot = false;
-               int i, group_idx, mem_idx = 0, count = 0, ret = 0;
+               int group_idx, count = 0, ret = 0;
 
                minsz = offsetofend(struct vfio_pci_hot_reset, count);
 
                                break;
                        }
 
-                       groups[group_idx].group = group;
-                       groups[group_idx].id =
-                                       vfio_external_user_iommu_id(group);
+                       groups[group_idx] = group;
                }
 
                kfree(group_fds);
                info.count = hdr.count;
                info.groups = groups;
 
-               /*
-                * Test whether all the affected devices are contained
-                * by the set of groups provided by the user.
-                */
-               ret = vfio_pci_for_each_slot_or_bus(vdev->pdev,
-                                                   vfio_pci_validate_devs,
-                                                   &info, slot);
-               if (ret)
-                       goto hot_reset_release;
-
-               devs.max_index = count;
-               devs.devices = kcalloc(count, sizeof(struct vfio_device *),
-                                      GFP_KERNEL);
-               if (!devs.devices) {
-                       ret = -ENOMEM;
-                       goto hot_reset_release;
-               }
-
-               /*
-                * We need to get memory_lock for each device, but devices
-                * can share mmap_lock, therefore we need to zap and hold
-                * the vma_lock for each device, and only then get each
-                * memory_lock.
-                */
-               ret = vfio_pci_for_each_slot_or_bus(vdev->pdev,
-                                           vfio_pci_try_zap_and_vma_lock_cb,
-                                           &devs, slot);
-               if (ret)
-                       goto hot_reset_release;
-
-               for (; mem_idx < devs.cur_index; mem_idx++) {
-                       struct vfio_pci_device *tmp = devs.devices[mem_idx];
-
-                       ret = down_write_trylock(&tmp->memory_lock);
-                       if (!ret) {
-                               ret = -EBUSY;
-                               goto hot_reset_release;
-                       }
-                       mutex_unlock(&tmp->vma_lock);
-               }
-
-               /* User has access, do the reset */
-               ret = pci_reset_bus(vdev->pdev);
+               ret = vfio_pci_dev_set_hot_reset(vdev->vdev.dev_set, &info);
 
 hot_reset_release:
-               for (i = 0; i < devs.cur_index; i++) {
-                       struct vfio_pci_device *tmp = devs.devices[i];
-
-                       if (i < mem_idx)
-                               up_write(&tmp->memory_lock);
-                       else
-                               mutex_unlock(&tmp->vma_lock);
-                       vfio_device_put(&tmp->vdev);
-               }
-               kfree(devs.devices);
-
                for (group_idx--; group_idx >= 0; group_idx--)
-                       vfio_group_put_external_user(groups[group_idx].group);
+                       vfio_group_put_external_user(groups[group_idx]);
 
                kfree(groups);
                return ret;
        .err_handler            = &vfio_err_handlers,
 };
 
-static int vfio_pci_try_zap_and_vma_lock_cb(struct pci_dev *pdev, void *data)
+static bool vfio_dev_in_groups(struct vfio_pci_device *vdev,
+                              struct vfio_pci_group_info *groups)
 {
-       struct vfio_devices *devs = data;
-       struct vfio_device *device;
-       struct vfio_pci_device *vdev;
-
-       if (devs->cur_index == devs->max_index)
-               return -ENOSPC;
-
-       device = vfio_device_get_from_dev(&pdev->dev);
-       if (!device)
-               return -EINVAL;
-
-       if (pci_dev_driver(pdev) != &vfio_pci_driver) {
-               vfio_device_put(device);
-               return -EBUSY;
-       }
-
-       vdev = container_of(device, struct vfio_pci_device, vdev);
+       unsigned int i;
 
-       /*
-        * Locking multiple devices is prone to deadlock, runaway and
-        * unwind if we hit contention.
-        */
-       if (!vfio_pci_zap_and_vma_lock(vdev, true)) {
-               vfio_device_put(device);
-               return -EBUSY;
-       }
-
-       devs->devices[devs->cur_index++] = vdev;
-       return 0;
+       for (i = 0; i < groups->count; i++)
+               if (groups->groups[i] == vdev->vdev.group)
+                       return true;
+       return false;
 }
 
 static int vfio_pci_is_device_in_set(struct pci_dev *pdev, void *data)
        return pdev;
 }
 
+/*
+ * We need to get memory_lock for each device, but devices can share mmap_lock,
+ * therefore we need to zap and hold the vma_lock for each device, and only then
+ * get each memory_lock.
+ */
+static int vfio_pci_dev_set_hot_reset(struct vfio_device_set *dev_set,
+                                     struct vfio_pci_group_info *groups)
+{
+       struct vfio_pci_device *cur_mem;
+       struct vfio_pci_device *cur_vma;
+       struct vfio_pci_device *cur;
+       struct pci_dev *pdev;
+       bool is_mem = true;
+       int ret;
+
+       mutex_lock(&dev_set->lock);
+       cur_mem = list_first_entry(&dev_set->device_list,
+                                  struct vfio_pci_device, vdev.dev_set_list);
+
+       pdev = vfio_pci_dev_set_resettable(dev_set);
+       if (!pdev) {
+               ret = -EINVAL;
+               goto err_unlock;
+       }
+
+       list_for_each_entry(cur_vma, &dev_set->device_list, vdev.dev_set_list) {
+               /*
+                * Test whether all the affected devices are contained by the
+                * set of groups provided by the user.
+                */
+               if (!vfio_dev_in_groups(cur_vma, groups)) {
+                       ret = -EINVAL;
+                       goto err_undo;
+               }
+
+               /*
+                * Locking multiple devices is prone to deadlock, runaway and
+                * unwind if we hit contention.
+                */
+               if (!vfio_pci_zap_and_vma_lock(cur_vma, true)) {
+                       ret = -EBUSY;
+                       goto err_undo;
+               }
+       }
+       cur_vma = NULL;
+
+       list_for_each_entry(cur_mem, &dev_set->device_list, vdev.dev_set_list) {
+               if (!down_write_trylock(&cur_mem->memory_lock)) {
+                       ret = -EBUSY;
+                       goto err_undo;
+               }
+               mutex_unlock(&cur_mem->vma_lock);
+       }
+       cur_mem = NULL;
+
+       ret = pci_reset_bus(pdev);
+
+err_undo:
+       list_for_each_entry(cur, &dev_set->device_list, vdev.dev_set_list) {
+               if (cur == cur_mem)
+                       is_mem = false;
+               if (cur == cur_vma)
+                       break;
+               if (is_mem)
+                       up_write(&cur->memory_lock);
+               else
+                       mutex_unlock(&cur->vma_lock);
+       }
+err_unlock:
+       mutex_unlock(&dev_set->lock);
+       return ret;
+}
+
 static bool vfio_pci_dev_set_needs_reset(struct vfio_device_set *dev_set)
 {
        struct vfio_pci_device *cur;