return flags;
 }
 
+static inline int pdev_enable_cap_ats(struct pci_dev *pdev)
+{
+       struct iommu_dev_data *dev_data = dev_iommu_priv_get(&pdev->dev);
+       int ret = -EINVAL;
+
+       if (dev_data->ats_enabled)
+               return 0;
+
+       if (amd_iommu_iotlb_sup &&
+           (dev_data->flags & AMD_IOMMU_DEVICE_FLAG_ATS_SUP)) {
+               ret = pci_enable_ats(pdev, PAGE_SHIFT);
+               if (!ret) {
+                       dev_data->ats_enabled = 1;
+                       dev_data->ats_qdep    = pci_ats_queue_depth(pdev);
+               }
+       }
+
+       return ret;
+}
+
+static inline void pdev_disable_cap_ats(struct pci_dev *pdev)
+{
+       struct iommu_dev_data *dev_data = dev_iommu_priv_get(&pdev->dev);
+
+       if (dev_data->ats_enabled) {
+               pci_disable_ats(pdev);
+               dev_data->ats_enabled = 0;
+       }
+}
+
+int amd_iommu_pdev_enable_cap_pri(struct pci_dev *pdev)
+{
+       struct iommu_dev_data *dev_data = dev_iommu_priv_get(&pdev->dev);
+       int ret = -EINVAL;
+
+       if (dev_data->pri_enabled)
+               return 0;
+
+       if (dev_data->flags & AMD_IOMMU_DEVICE_FLAG_PRI_SUP) {
+               /*
+                * First reset the PRI state of the device.
+                * FIXME: Hardcode number of outstanding requests for now
+                */
+               if (!pci_reset_pri(pdev) && !pci_enable_pri(pdev, 32)) {
+                       dev_data->pri_enabled = 1;
+                       dev_data->pri_tlp     = pci_prg_resp_pasid_required(pdev);
+
+                       ret = 0;
+               }
+       }
+
+       return ret;
+}
+
+void amd_iommu_pdev_disable_cap_pri(struct pci_dev *pdev)
+{
+       struct iommu_dev_data *dev_data = dev_iommu_priv_get(&pdev->dev);
+
+       if (dev_data->pri_enabled) {
+               pci_disable_pri(pdev);
+               dev_data->pri_enabled = 0;
+       }
+}
+
+static inline int pdev_enable_cap_pasid(struct pci_dev *pdev)
+{
+       struct iommu_dev_data *dev_data = dev_iommu_priv_get(&pdev->dev);
+       int ret = -EINVAL;
+
+       if (dev_data->pasid_enabled)
+               return 0;
+
+       if (dev_data->flags & AMD_IOMMU_DEVICE_FLAG_PASID_SUP) {
+               /* Only allow access to user-accessible pages */
+               ret = pci_enable_pasid(pdev, 0);
+               if (!ret)
+                       dev_data->pasid_enabled = 1;
+       }
+
+       return ret;
+}
+
+static inline void pdev_disable_cap_pasid(struct pci_dev *pdev)
+{
+       struct iommu_dev_data *dev_data = dev_iommu_priv_get(&pdev->dev);
+
+       if (dev_data->pasid_enabled) {
+               pci_disable_pasid(pdev);
+               dev_data->pasid_enabled = 0;
+       }
+}
+
+static void pdev_enable_caps(struct pci_dev *pdev)
+{
+       pdev_enable_cap_ats(pdev);
+       pdev_enable_cap_pasid(pdev);
+       amd_iommu_pdev_enable_cap_pri(pdev);
+
+}
+
+static void pdev_disable_caps(struct pci_dev *pdev)
+{
+       pdev_disable_cap_ats(pdev);
+       pdev_disable_cap_pasid(pdev);
+       amd_iommu_pdev_disable_cap_pri(pdev);
+}
+
 /*
  * This function checks if the driver got a valid device from the caller to
  * avoid dereferencing invalid pointers.
        domain->dev_cnt                 -= 1;
 }
 
-static void pdev_iommuv2_disable(struct pci_dev *pdev)
-{
-       pci_disable_ats(pdev);
-       pci_disable_pri(pdev);
-       pci_disable_pasid(pdev);
-}
-
-static int pdev_pri_ats_enable(struct pci_dev *pdev)
-{
-       int ret;
-
-       /* Only allow access to user-accessible pages */
-       ret = pci_enable_pasid(pdev, 0);
-       if (ret)
-               return ret;
-
-       /* First reset the PRI state of the device */
-       ret = pci_reset_pri(pdev);
-       if (ret)
-               goto out_err_pasid;
-
-       /* Enable PRI */
-       /* FIXME: Hardcode number of outstanding requests for now */
-       ret = pci_enable_pri(pdev, 32);
-       if (ret)
-               goto out_err_pasid;
-
-       ret = pci_enable_ats(pdev, PAGE_SHIFT);
-       if (ret)
-               goto out_err_pri;
-
-       return 0;
-
-out_err_pri:
-       pci_disable_pri(pdev);
-
-out_err_pasid:
-       pci_disable_pasid(pdev);
-
-       return ret;
-}
-
 /*
  * If a device is not yet associated with a domain, this function makes the
  * device visible in the domain
                         struct protection_domain *domain)
 {
        struct iommu_dev_data *dev_data;
-       struct pci_dev *pdev;
        unsigned long flags;
-       int ret;
+       int ret = 0;
 
        spin_lock_irqsave(&domain->lock, flags);
 
 
        spin_lock(&dev_data->lock);
 
-       ret = -EBUSY;
-       if (dev_data->domain != NULL)
+       if (dev_data->domain != NULL) {
+               ret = -EBUSY;
                goto out;
-
-       if (!dev_is_pci(dev))
-               goto skip_ats_check;
-
-       pdev = to_pci_dev(dev);
-       if (domain->flags & PD_IOMMUV2_MASK) {
-               struct iommu_domain *def_domain = iommu_get_dma_domain(dev);
-
-               ret = -EINVAL;
-
-               /*
-                * In case of using AMD_IOMMU_V1 page table mode and the device
-                * is enabling for PPR/ATS support (using v2 table),
-                * we need to make sure that the domain type is identity map.
-                */
-               if ((amd_iommu_pgtable == AMD_IOMMU_V1) &&
-                   def_domain->type != IOMMU_DOMAIN_IDENTITY) {
-                       goto out;
-               }
-
-               if (pdev_pasid_supported(dev_data)) {
-                       if (pdev_pri_ats_enable(pdev) != 0)
-                               goto out;
-
-                       dev_data->ats_enabled = 1;
-                       dev_data->ats_qdep    = pci_ats_queue_depth(pdev);
-                       dev_data->pri_tlp     = pci_prg_resp_pasid_required(pdev);
-               }
-       } else if (amd_iommu_iotlb_sup &&
-                  pci_enable_ats(pdev, PAGE_SHIFT) == 0) {
-               dev_data->ats_enabled = 1;
-               dev_data->ats_qdep    = pci_ats_queue_depth(pdev);
        }
 
-skip_ats_check:
-       ret = 0;
+       if (dev_is_pci(dev))
+               pdev_enable_caps(to_pci_dev(dev));
 
        do_attach(dev_data, domain);
 
 
        do_detach(dev_data);
 
-       if (!dev_is_pci(dev))
-               goto out;
-
-       if (domain->flags & PD_IOMMUV2_MASK && pdev_pasid_supported(dev_data))
-               pdev_iommuv2_disable(to_pci_dev(dev));
-       else if (dev_data->ats_enabled)
-               pci_disable_ats(to_pci_dev(dev));
-
-       dev_data->ats_enabled = 0;
+       if (dev_is_pci(dev))
+               pdev_disable_caps(to_pci_dev(dev));
 
 out:
        spin_unlock(&dev_data->lock);