bool amd_iommu_pasid_supported(void);
 
-/* Device capabilities */
-int amd_iommu_pdev_enable_cap_pri(struct pci_dev *pdev);
-void amd_iommu_pdev_disable_cap_pri(struct pci_dev *pdev);
-
 /* GCR3 setup */
 int amd_iommu_set_gcr3(struct iommu_dev_data *dev_data,
                       ioasid_t pasid, unsigned long gcr3);
 
        }
 }
 
-int amd_iommu_pdev_enable_cap_pri(struct pci_dev *pdev)
+static inline int pdev_enable_cap_pri(struct pci_dev *pdev)
 {
        struct iommu_dev_data *dev_data = dev_iommu_priv_get(&pdev->dev);
        int ret = -EINVAL;
        if (dev_data->pri_enabled)
                return 0;
 
+       if (!dev_data->ats_enabled)
+               return 0;
+
        if (dev_data->flags & AMD_IOMMU_DEVICE_FLAG_PRI_SUP) {
                /*
                 * First reset the PRI state of the device.
        return ret;
 }
 
-void amd_iommu_pdev_disable_cap_pri(struct pci_dev *pdev)
+static inline void pdev_disable_cap_pri(struct pci_dev *pdev)
 {
        struct iommu_dev_data *dev_data = dev_iommu_priv_get(&pdev->dev);
 
 {
        pdev_enable_cap_ats(pdev);
        pdev_enable_cap_pasid(pdev);
-       amd_iommu_pdev_enable_cap_pri(pdev);
-
+       pdev_enable_cap_pri(pdev);
 }
 
 static void pdev_disable_caps(struct pci_dev *pdev)
 {
        pdev_disable_cap_ats(pdev);
        pdev_disable_cap_pasid(pdev);
-       amd_iommu_pdev_disable_cap_pri(pdev);
+       pdev_disable_cap_pri(pdev);
 }
 
 /*
                     struct protection_domain *domain)
 {
        struct amd_iommu *iommu = get_amd_iommu_from_dev_data(dev_data);
+       struct pci_dev *pdev;
        int ret = 0;
 
        /* Update data structures */
        domain->dev_iommu[iommu->index] += 1;
        domain->dev_cnt                 += 1;
 
+       pdev = dev_is_pci(dev_data->dev) ? to_pci_dev(dev_data->dev) : NULL;
        if (pdom_is_sva_capable(domain)) {
                ret = init_gcr3_table(dev_data, domain);
                if (ret)
                        return ret;
+
+               if (pdev)
+                       pdev_enable_caps(pdev);
+       } else if (pdev) {
+               pdev_enable_cap_ats(pdev);
        }
 
        /* Update device table */
                goto out;
        }
 
-       if (dev_is_pci(dev))
-               pdev_enable_caps(to_pci_dev(dev));
-
        ret = do_attach(dev_data, domain);
 
 out: