dev->iommu->iommu_dev = iommu_dev;
        dev->iommu->max_pasids = dev_iommu_get_max_pasids(dev);
+       if (ops->is_attach_deferred)
+               dev->iommu->attach_deferred = ops->is_attach_deferred(dev);
 
        group = iommu_group_get_for_dev(dev);
        if (IS_ERR(group)) {
        return ret;
 }
 
-static bool iommu_is_attach_deferred(struct device *dev)
-{
-       const struct iommu_ops *ops = dev_iommu_ops(dev);
-
-       if (ops->is_attach_deferred)
-               return ops->is_attach_deferred(dev);
-
-       return false;
-}
-
 static int iommu_group_do_dma_first_attach(struct device *dev, void *data)
 {
        struct iommu_domain *domain = data;
 
        lockdep_assert_held(&dev->iommu_group->mutex);
 
-       if (iommu_is_attach_deferred(dev)) {
-               dev->iommu->attach_deferred = 1;
+       if (dev->iommu->attach_deferred)
                return 0;
-       }
-
        return __iommu_attach_device(domain, dev);
 }
 
 
 }
 
-static int __iommu_group_dma_first_attach(struct iommu_group *group)
-{
-       return __iommu_group_for_each_dev(group, group->default_domain,
-                                         iommu_group_do_dma_first_attach);
-}
-
 static int iommu_group_do_probe_finalize(struct device *dev, void *data)
 {
        const struct iommu_ops *ops = dev_iommu_ops(dev);
 
                iommu_group_create_direct_mappings(group);
 
-               ret = __iommu_group_dma_first_attach(group);
+               group->domain = NULL;
+               ret = __iommu_group_set_domain(group, group->default_domain);
 
                mutex_unlock(&group->mutex);
 
 {
        int ret;
 
+       if (dev->iommu->attach_deferred) {
+               if (new_domain == group->default_domain)
+                       return 0;
+               dev->iommu->attach_deferred = 0;
+       }
+
        ret = __iommu_attach_device(new_domain, dev);
        if (ret) {
                /*