VISIBLE_IF_KUNIT
 void arm_smmu_make_cdtable_ste(struct arm_smmu_ste *target,
-                              struct arm_smmu_master *master)
+                              struct arm_smmu_master *master, bool ats_enabled)
 {
        struct arm_smmu_ctx_desc_cfg *cd_table = &master->cd_table;
        struct arm_smmu_device *smmu = master->smmu;
                         STRTAB_STE_1_S1STALLD :
                         0) |
                FIELD_PREP(STRTAB_STE_1_EATS,
-                          master->ats_enabled ? STRTAB_STE_1_EATS_TRANS : 0));
+                          ats_enabled ? STRTAB_STE_1_EATS_TRANS : 0));
 
        if (smmu->features & ARM_SMMU_FEAT_E2H) {
                /*
 VISIBLE_IF_KUNIT
 void arm_smmu_make_s2_domain_ste(struct arm_smmu_ste *target,
                                 struct arm_smmu_master *master,
-                                struct arm_smmu_domain *smmu_domain)
+                                struct arm_smmu_domain *smmu_domain,
+                                bool ats_enabled)
 {
        struct arm_smmu_s2_cfg *s2_cfg = &smmu_domain->s2_cfg;
        const struct io_pgtable_cfg *pgtbl_cfg =
 
        target->data[1] = cpu_to_le64(
                FIELD_PREP(STRTAB_STE_1_EATS,
-                          master->ats_enabled ? STRTAB_STE_1_EATS_TRANS : 0));
+                          ats_enabled ? STRTAB_STE_1_EATS_TRANS : 0));
 
        if (smmu->features & ARM_SMMU_FEAT_ATTR_TYPES_OVR)
                target->data[1] |= cpu_to_le64(FIELD_PREP(STRTAB_STE_1_SHCFG,
        return dev_is_pci(dev) && pci_ats_supported(to_pci_dev(dev));
 }
 
-static void arm_smmu_enable_ats(struct arm_smmu_master *master,
-                               struct arm_smmu_domain *smmu_domain)
+static void arm_smmu_enable_ats(struct arm_smmu_master *master)
 {
        size_t stu;
        struct pci_dev *pdev;
        struct arm_smmu_device *smmu = master->smmu;
 
-       /* Don't enable ATS at the endpoint if it's not enabled in the STE */
-       if (!master->ats_enabled)
-               return;
-
        /* Smallest Translation Unit: log2 of the smallest supported granule */
        stu = __ffs(smmu->pgsize_bitmap);
        pdev = to_pci_dev(master->dev);
 
-       atomic_inc(&smmu_domain->nr_ats_masters);
        /*
         * ATC invalidation of PASID 0 causes the entire ATC to be flushed.
         */
                dev_err(master->dev, "Failed to enable ATS (STU %zu)\n", stu);
 }
 
-static void arm_smmu_disable_ats(struct arm_smmu_master *master,
-                                struct arm_smmu_domain *smmu_domain)
-{
-       if (!master->ats_enabled)
-               return;
-
-       pci_disable_ats(to_pci_dev(master->dev));
-       /*
-        * Ensure ATS is disabled at the endpoint before we issue the
-        * ATC invalidation via the SMMU.
-        */
-       wmb();
-       arm_smmu_atc_inv_master(master);
-       atomic_dec(&smmu_domain->nr_ats_masters);
-}
-
 static int arm_smmu_enable_pasid(struct arm_smmu_master *master)
 {
        int ret;
        return NULL;
 }
 
-static void arm_smmu_detach_dev(struct arm_smmu_master *master)
+/*
+ * If the domain uses the smmu_domain->devices list return the arm_smmu_domain
+ * structure, otherwise NULL. These domains track attached devices so they can
+ * issue invalidations.
+ */
+static struct arm_smmu_domain *
+to_smmu_domain_devices(struct iommu_domain *domain)
+{
+       /* The domain can be NULL only when processing the first attach */
+       if (!domain)
+               return NULL;
+       if (domain->type & __IOMMU_DOMAIN_PAGING)
+               return to_smmu_domain(domain);
+       return NULL;
+}
+
+static void arm_smmu_remove_master_domain(struct arm_smmu_master *master,
+                                         struct iommu_domain *domain)
 {
-       struct iommu_domain *domain = iommu_get_domain_for_dev(master->dev);
+       struct arm_smmu_domain *smmu_domain = to_smmu_domain_devices(domain);
        struct arm_smmu_master_domain *master_domain;
-       struct arm_smmu_domain *smmu_domain;
        unsigned long flags;
 
-       if (!domain || !(domain->type & __IOMMU_DOMAIN_PAGING))
+       if (!smmu_domain)
                return;
 
-       smmu_domain = to_smmu_domain(domain);
-       arm_smmu_disable_ats(master, smmu_domain);
-
        spin_lock_irqsave(&smmu_domain->devices_lock, flags);
        master_domain = arm_smmu_find_master_domain(smmu_domain, master);
        if (master_domain) {
                list_del(&master_domain->devices_elm);
                kfree(master_domain);
+               if (master->ats_enabled)
+                       atomic_dec(&smmu_domain->nr_ats_masters);
        }
        spin_unlock_irqrestore(&smmu_domain->devices_lock, flags);
+}
+
+struct arm_smmu_attach_state {
+       /* Inputs */
+       struct iommu_domain *old_domain;
+       struct arm_smmu_master *master;
+       /* Resulting state */
+       bool ats_enabled;
+};
+
+/*
+ * Start the sequence to attach a domain to a master. The sequence contains three
+ * steps:
+ *  arm_smmu_attach_prepare()
+ *  arm_smmu_install_ste_for_dev()
+ *  arm_smmu_attach_commit()
+ *
+ * If prepare succeeds then the sequence must be completed. The STE installed
+ * must set the STE.EATS field according to state.ats_enabled.
+ *
+ * If the device supports ATS then this determines if EATS should be enabled
+ * in the STE, and starts sequencing EATS disable if required.
+ *
+ * The change of the EATS in the STE and the PCI ATS config space is managed by
+ * this sequence to be in the right order so that if PCI ATS is enabled then
+ * STE.ETAS is enabled.
+ *
+ * new_domain can be a non-paging domain. In this case ATS will not be enabled,
+ * and invalidations won't be tracked.
+ */
+static int arm_smmu_attach_prepare(struct arm_smmu_attach_state *state,
+                                  struct iommu_domain *new_domain)
+{
+       struct arm_smmu_master *master = state->master;
+       struct arm_smmu_master_domain *master_domain;
+       struct arm_smmu_domain *smmu_domain =
+               to_smmu_domain_devices(new_domain);
+       unsigned long flags;
+
+       /*
+        * arm_smmu_share_asid() must not see two domains pointing to the same
+        * arm_smmu_master_domain contents otherwise it could randomly write one
+        * or the other to the CD.
+        */
+       lockdep_assert_held(&arm_smmu_asid_lock);
+
+       if (smmu_domain) {
+               /*
+                * The SMMU does not support enabling ATS with bypass/abort.
+                * When the STE is in bypass (STE.Config[2:0] == 0b100), ATS
+                * Translation Requests and Translated transactions are denied
+                * as though ATS is disabled for the stream (STE.EATS == 0b00),
+                * causing F_BAD_ATS_TREQ and F_TRANSL_FORBIDDEN events
+                * (IHI0070Ea 5.2 Stream Table Entry). Thus ATS can only be
+                * enabled if we have arm_smmu_domain, those always have page
+                * tables.
+                */
+               state->ats_enabled = arm_smmu_ats_supported(master);
+
+               master_domain = kzalloc(sizeof(*master_domain), GFP_KERNEL);
+               if (!master_domain)
+                       return -ENOMEM;
+               master_domain->master = master;
 
-       master->ats_enabled = false;
+               /*
+                * During prepare we want the current smmu_domain and new
+                * smmu_domain to be in the devices list before we change any
+                * HW. This ensures that both domains will send ATS
+                * invalidations to the master until we are done.
+                *
+                * It is tempting to make this list only track masters that are
+                * using ATS, but arm_smmu_share_asid() also uses this to change
+                * the ASID of a domain, unrelated to ATS.
+                *
+                * Notice if we are re-attaching the same domain then the list
+                * will have two identical entries and commit will remove only
+                * one of them.
+                */
+               spin_lock_irqsave(&smmu_domain->devices_lock, flags);
+               if (state->ats_enabled)
+                       atomic_inc(&smmu_domain->nr_ats_masters);
+               list_add(&master_domain->devices_elm, &smmu_domain->devices);
+               spin_unlock_irqrestore(&smmu_domain->devices_lock, flags);
+       }
+
+       if (!state->ats_enabled && master->ats_enabled) {
+               pci_disable_ats(to_pci_dev(master->dev));
+               /*
+                * This is probably overkill, but the config write for disabling
+                * ATS should complete before the STE is configured to generate
+                * UR to avoid AER noise.
+                */
+               wmb();
+       }
+       return 0;
+}
+
+/*
+ * Commit is done after the STE/CD are configured with the EATS setting. It
+ * completes synchronizing the PCI device's ATC and finishes manipulating the
+ * smmu_domain->devices list.
+ */
+static void arm_smmu_attach_commit(struct arm_smmu_attach_state *state)
+{
+       struct arm_smmu_master *master = state->master;
+
+       lockdep_assert_held(&arm_smmu_asid_lock);
+
+       if (state->ats_enabled && !master->ats_enabled) {
+               arm_smmu_enable_ats(master);
+       } else if (master->ats_enabled) {
+               /*
+                * The translation has changed, flush the ATC. At this point the
+                * SMMU is translating for the new domain and both the old&new
+                * domain will issue invalidations.
+                */
+               arm_smmu_atc_inv_master(master);
+       }
+       master->ats_enabled = state->ats_enabled;
+
+       arm_smmu_remove_master_domain(master, state->old_domain);
 }
 
 static int arm_smmu_attach_dev(struct iommu_domain *domain, struct device *dev)
 {
        int ret = 0;
-       unsigned long flags;
        struct arm_smmu_ste target;
        struct iommu_fwspec *fwspec = dev_iommu_fwspec_get(dev);
        struct arm_smmu_device *smmu;
        struct arm_smmu_domain *smmu_domain = to_smmu_domain(domain);
-       struct arm_smmu_master_domain *master_domain;
+       struct arm_smmu_attach_state state = {
+               .old_domain = iommu_get_domain_for_dev(dev),
+       };
        struct arm_smmu_master *master;
        struct arm_smmu_cd *cdptr;
 
        if (!fwspec)
                return -ENOENT;
 
-       master = dev_iommu_priv_get(dev);
+       state.master = master = dev_iommu_priv_get(dev);
        smmu = master->smmu;
 
        /*
                        return -ENOMEM;
        }
 
-       master_domain = kzalloc(sizeof(*master_domain), GFP_KERNEL);
-       if (!master_domain)
-               return -ENOMEM;
-       master_domain->master = master;
-
        /*
         * Prevent arm_smmu_share_asid() from trying to change the ASID
         * of either the old or new domain while we are working on it.
         */
        mutex_lock(&arm_smmu_asid_lock);
 
-       arm_smmu_detach_dev(master);
-
-       master->ats_enabled = arm_smmu_ats_supported(master);
-
-       spin_lock_irqsave(&smmu_domain->devices_lock, flags);
-       list_add(&master_domain->devices_elm, &smmu_domain->devices);
-       spin_unlock_irqrestore(&smmu_domain->devices_lock, flags);
+       ret = arm_smmu_attach_prepare(&state, domain);
+       if (ret) {
+               mutex_unlock(&arm_smmu_asid_lock);
+               return ret;
+       }
 
        switch (smmu_domain->stage) {
        case ARM_SMMU_DOMAIN_S1: {
                arm_smmu_make_s1_cd(&target_cd, master, smmu_domain);
                arm_smmu_write_cd_entry(master, IOMMU_NO_PASID, cdptr,
                                        &target_cd);
-               arm_smmu_make_cdtable_ste(&target, master);
+               arm_smmu_make_cdtable_ste(&target, master, state.ats_enabled);
                arm_smmu_install_ste_for_dev(master, &target);
                break;
        }
        case ARM_SMMU_DOMAIN_S2:
-               arm_smmu_make_s2_domain_ste(&target, master, smmu_domain);
+               arm_smmu_make_s2_domain_ste(&target, master, smmu_domain,
+                                           state.ats_enabled);
                arm_smmu_install_ste_for_dev(master, &target);
                arm_smmu_clear_cd(master, IOMMU_NO_PASID);
                break;
        }
 
-       arm_smmu_enable_ats(master, smmu_domain);
+       arm_smmu_attach_commit(&state);
        mutex_unlock(&arm_smmu_asid_lock);
        return 0;
 }
        arm_smmu_clear_cd(master, pasid);
 }
 
-static int arm_smmu_attach_dev_ste(struct device *dev,
-                                  struct arm_smmu_ste *ste)
+static int arm_smmu_attach_dev_ste(struct iommu_domain *domain,
+                                  struct device *dev, struct arm_smmu_ste *ste)
 {
        struct arm_smmu_master *master = dev_iommu_priv_get(dev);
+       struct arm_smmu_attach_state state = {
+               .master = master,
+               .old_domain = iommu_get_domain_for_dev(dev),
+       };
 
        if (arm_smmu_master_sva_enabled(master))
                return -EBUSY;
         */
        mutex_lock(&arm_smmu_asid_lock);
 
-       /*
-        * The SMMU does not support enabling ATS with bypass/abort. When the
-        * STE is in bypass (STE.Config[2:0] == 0b100), ATS Translation Requests
-        * and Translated transactions are denied as though ATS is disabled for
-        * the stream (STE.EATS == 0b00), causing F_BAD_ATS_TREQ and
-        * F_TRANSL_FORBIDDEN events (IHI0070Ea 5.2 Stream Table Entry).
-        */
-       arm_smmu_detach_dev(master);
-
+       arm_smmu_attach_prepare(&state, domain);
        arm_smmu_install_ste_for_dev(master, ste);
+       arm_smmu_attach_commit(&state);
        mutex_unlock(&arm_smmu_asid_lock);
 
        /*
        struct arm_smmu_master *master = dev_iommu_priv_get(dev);
 
        arm_smmu_make_bypass_ste(master->smmu, &ste);
-       return arm_smmu_attach_dev_ste(dev, &ste);
+       return arm_smmu_attach_dev_ste(domain, dev, &ste);
 }
 
 static const struct iommu_domain_ops arm_smmu_identity_ops = {
        struct arm_smmu_ste ste;
 
        arm_smmu_make_abort_ste(&ste);
-       return arm_smmu_attach_dev_ste(dev, &ste);
+       return arm_smmu_attach_dev_ste(domain, dev, &ste);
 }
 
 static const struct iommu_domain_ops arm_smmu_blocked_ops = {