compound_group->pgsizes = pe->table_group.pgsizes;
        }
 
+       /*
+        * The gpu would have been added to the iommu group that's created
+        * for the PE. Pull it out now.
+        */
+       iommu_del_device(&gpdev->dev);
+
        /*
        * I'm not sure this is strictly required, but it's probably a good idea
        * since the table_group for the PE is going to be attached to the
        */
        iommu_group_put(pe->table_group.group);
 
+       /* now put the GPU into the compound group */
        pnv_comp_attach_table_group(npucomp, pe);
+       iommu_add_device(compound_group, &gpdev->dev);
 
        return compound_group;
 }
 
        WARN_ON(get_dma_ops(&pdev->dev) != &dma_iommu_ops);
        pdev->dev.archdata.dma_offset = pe->tce_bypass_base;
        set_iommu_table_base(&pdev->dev, pe->table_group.tables[0]);
-       /*
-        * Note: iommu_add_device() will fail here as
-        * for physical PE: the device is already added by now;
-        * for virtual PE: sysfs entries are not ready yet and
-        * tce_iommu_bus_notifier will add the device to a group later.
-        */
+
+       /* PEs with a DMA weight of zero won't have a group */
+       if (pe->table_group.group)
+               iommu_add_device(&pe->table_group, &pdev->dev);
 }
 
 /*
        struct pnv_ioda_pe *pe;
 
        /*
-        * There are 4 types of PEs:
-        * - PNV_IODA_PE_BUS: a downstream port with an adapter,
-        *   created from pnv_pci_setup_bridge();
-        * - PNV_IODA_PE_BUS_ALL: a PCI-PCIX bridge with devices behind it,
-        *   created from pnv_pci_setup_bridge();
-        * - PNV_IODA_PE_VF: a SRIOV virtual function,
-        *   created from pnv_pcibios_sriov_enable();
-        * - PNV_IODA_PE_DEV: an NPU or OCAPI device,
-        *   created from pnv_pci_ioda_fixup().
+        * For non-nvlink devices the IOMMU group is registered when the PE is
+        * configured and devices are added to the group when the per-device
+        * DMA setup is run. That's done in hose->ops.dma_dev_setup() which is
+        * only initialise for "normal" IODA PHBs.
         *
-        * Normally a PE is represented by an IOMMU group, however for
-        * devices with side channels the groups need to be more strict.
+        * For NVLink devices we need to ensure the NVLinks and the GPU end up
+        * in the same IOMMU group, so that's handled here.
         */
        list_for_each_entry(hose, &hose_list, list_node) {
                phb = hose->private_data;
 
-               if (phb->type == PNV_PHB_NPU_NVLINK ||
-                   phb->type == PNV_PHB_NPU_OCAPI)
-                       continue;
-
-               list_for_each_entry(pe, &phb->ioda.pe_list, list) {
-                       struct iommu_table_group *table_group;
-
-                       table_group = pnv_try_setup_npu_table_group(pe);
-                       if (!table_group) {
-                               if (!pnv_pci_ioda_pe_dma_weight(pe))
-                                       continue;
-
-                               table_group = &pe->table_group;
-                       }
-                       pnv_ioda_setup_bus_iommu_group(pe, table_group,
-                                       pe->pbus);
-               }
+               if (phb->type == PNV_PHB_IODA2)
+                       list_for_each_entry(pe, &phb->ioda.pe_list, list)
+                               pnv_try_setup_npu_table_group(pe);
        }
 
        /*
 
                unsigned long action, void *data)
 {
        struct device *dev = data;
-       struct pci_dev *pdev;
-       struct pci_dn *pdn;
-       struct pnv_ioda_pe *pe;
-       struct pci_controller *hose;
-       struct pnv_phb *phb;
 
        switch (action) {
-       case BUS_NOTIFY_ADD_DEVICE:
-               pdev = to_pci_dev(dev);
-               pdn = pci_get_pdn(pdev);
-               hose = pci_bus_to_host(pdev->bus);
-               phb = hose->private_data;
-
-               WARN_ON_ONCE(!phb);
-               if (!pdn || pdn->pe_number == IODA_INVALID_PE || !phb)
-                       return 0;
-
-               pe = &phb->ioda.pe_array[pdn->pe_number];
-               if (!pe->table_group.group)
-                       return 0;
-               iommu_add_device(&pe->table_group, dev);
-               return 0;
        case BUS_NOTIFY_DEL_DEVICE:
                iommu_del_device(dev);
                return 0;