}
 EXPORT_SYMBOL(pci_scan_single_device);
 
+static unsigned next_ari_fn(struct pci_dev *dev, unsigned fn)
+{
+       u16 cap;
+       unsigned pos = pci_find_ext_capability(dev, PCI_EXT_CAP_ID_ARI);
+       if (!pos)
+               return 0;
+       pci_read_config_word(dev, pos + 4, &cap);
+       return cap >> 8;
+}
+
+static unsigned next_trad_fn(struct pci_dev *dev, unsigned fn)
+{
+       return (fn + 1) % 8;
+}
+
+static unsigned no_next_fn(struct pci_dev *dev, unsigned fn)
+{
+       return 0;
+}
+
+static int only_one_child(struct pci_bus *bus)
+{
+       struct pci_dev *parent = bus->self;
+       if (!parent || !pci_is_pcie(parent))
+               return 0;
+       if (parent->pcie_type == PCI_EXP_TYPE_ROOT_PORT ||
+           parent->pcie_type == PCI_EXP_TYPE_DOWNSTREAM)
+               return 1;
+       return 0;
+}
+
 /**
  * pci_scan_slot - scan a PCI slot on a bus for devices.
  * @bus: PCI bus to scan
  */
 int pci_scan_slot(struct pci_bus *bus, int devfn)
 {
-       int fn, nr = 0;
+       unsigned fn, nr = 0;
        struct pci_dev *dev;
+       unsigned (*next_fn)(struct pci_dev *, unsigned) = no_next_fn;
+
+       if (only_one_child(bus) && (devfn > 0))
+               return 0; /* Already scanned the entire slot */
 
        dev = pci_scan_single_device(bus, devfn);
        if (dev && !dev->is_added)      /* new device? */
                nr++;
 
-       if (dev && dev->multifunction) {
-               for (fn = 1; fn < 8; fn++) {
-                       dev = pci_scan_single_device(bus, devfn + fn);
-                       if (dev) {
-                               if (!dev->is_added)
-                                       nr++;
-                               dev->multifunction = 1;
-                       }
+       if (pci_ari_enabled(bus))
+               next_fn = next_ari_fn;
+       else if (dev && dev->multifunction)
+               next_fn = next_trad_fn;
+
+       for (fn = next_fn(dev, 0); fn > 0; fn = next_fn(dev, fn)) {
+               dev = pci_scan_single_device(bus, devfn + fn);
+               if (dev) {
+                       if (!dev->is_added)
+                               nr++;
+                       dev->multifunction = 1;
                }
        }