};
 
 struct arm_smmu_s2cr {
+       struct iommu_group              *group;
+       int                             count;
        enum arm_smmu_s2cr_type         type;
        enum arm_smmu_s2cr_privcfg      privcfg;
        u8                              cbndx;
        u16                             smr_mask_mask;
        struct arm_smmu_smr             *smrs;
        struct arm_smmu_s2cr            *s2crs;
+       struct mutex                    stream_map_mutex;
 
        unsigned long                   va_size;
        unsigned long                   ipa_size;
        kfree(smmu_domain);
 }
 
-static int arm_smmu_alloc_smr(struct arm_smmu_device *smmu)
-{
-       int i;
-
-       for (i = 0; i < smmu->num_mapping_groups; i++)
-               if (!cmpxchg(&smmu->smrs[i].valid, false, true))
-                       return i;
-
-       return INVALID_SMENDX;
-}
-
-static void arm_smmu_free_smr(struct arm_smmu_device *smmu, int idx)
-{
-       writel_relaxed(~SMR_VALID, ARM_SMMU_GR0(smmu) + ARM_SMMU_GR0_SMR(idx));
-       WRITE_ONCE(smmu->smrs[idx].valid, false);
-}
-
 static void arm_smmu_write_smr(struct arm_smmu_device *smmu, int idx)
 {
        struct arm_smmu_smr *smr = smmu->smrs + idx;
                arm_smmu_write_smr(smmu, idx);
 }
 
-static int arm_smmu_master_alloc_smes(struct arm_smmu_device *smmu,
-                                     struct arm_smmu_master_cfg *cfg)
+static int arm_smmu_find_sme(struct arm_smmu_device *smmu, u16 id, u16 mask)
 {
        struct arm_smmu_smr *smrs = smmu->smrs;
-       int i, idx;
-
-       /* Allocate the SMRs on the SMMU */
-       for_each_cfg_sme(cfg, i, idx) {
-               if (idx != INVALID_SMENDX)
-                       return -EEXIST;
+       int i, free_idx = -ENOSPC;
 
-               /* ...except on stream indexing hardware, of course */
-               if (!smrs) {
-                       cfg->smendx[i] = cfg->streamids[i];
+       /* Stream indexing is blissfully easy */
+       if (!smrs)
+               return id;
+
+       /* Validating SMRs is... less so */
+       for (i = 0; i < smmu->num_mapping_groups; ++i) {
+               if (!smrs[i].valid) {
+                       /*
+                        * Note the first free entry we come across, which
+                        * we'll claim in the end if nothing else matches.
+                        */
+                       if (free_idx < 0)
+                               free_idx = i;
                        continue;
                }
+               /*
+                * If the new entry is _entirely_ matched by an existing entry,
+                * then reuse that, with the guarantee that there also cannot
+                * be any subsequent conflicting entries. In normal use we'd
+                * expect simply identical entries for this case, but there's
+                * no harm in accommodating the generalisation.
+                */
+               if ((mask & smrs[i].mask) == mask &&
+                   !((id ^ smrs[i].id) & ~smrs[i].mask))
+                       return i;
+               /*
+                * If the new entry has any other overlap with an existing one,
+                * though, then there always exists at least one stream ID
+                * which would cause a conflict, and we can't allow that risk.
+                */
+               if (!((id ^ smrs[i].id) & ~(smrs[i].mask | mask)))
+                       return -EINVAL;
+       }
 
-               idx = arm_smmu_alloc_smr(smmu);
-               if (idx < 0) {
-                       dev_err(smmu->dev, "failed to allocate free SMR\n");
-                       goto err_free_smrs;
+       return free_idx;
+}
+
+static bool arm_smmu_free_sme(struct arm_smmu_device *smmu, int idx)
+{
+       if (--smmu->s2crs[idx].count)
+               return false;
+
+       smmu->s2crs[idx] = s2cr_init_val;
+       if (smmu->smrs)
+               smmu->smrs[idx].valid = false;
+
+       return true;
+}
+
+static int arm_smmu_master_alloc_smes(struct device *dev)
+{
+       struct arm_smmu_master_cfg *cfg = dev->archdata.iommu;
+       struct arm_smmu_device *smmu = cfg->smmu;
+       struct arm_smmu_smr *smrs = smmu->smrs;
+       struct iommu_group *group;
+       int i, idx, ret;
+
+       mutex_lock(&smmu->stream_map_mutex);
+       /* Figure out a viable stream map entry allocation */
+       for_each_cfg_sme(cfg, i, idx) {
+               if (idx != INVALID_SMENDX) {
+                       ret = -EEXIST;
+                       goto out_err;
                }
-               cfg->smendx[i] = idx;
 
-               smrs[idx].id = cfg->streamids[i];
-               smrs[idx].mask = 0; /* We don't currently share SMRs */
+               ret = arm_smmu_find_sme(smmu, cfg->streamids[i], 0);
+               if (ret < 0)
+                       goto out_err;
+
+               idx = ret;
+               if (smrs && smmu->s2crs[idx].count == 0) {
+                       smrs[idx].id = cfg->streamids[i];
+                       smrs[idx].mask = 0; /* We don't currently share SMRs */
+                       smrs[idx].valid = true;
+               }
+               smmu->s2crs[idx].count++;
+               cfg->smendx[i] = (s16)idx;
        }
 
-       if (!smrs)
-               return 0;
+       group = iommu_group_get_for_dev(dev);
+       if (!group)
+               group = ERR_PTR(-ENOMEM);
+       if (IS_ERR(group)) {
+               ret = PTR_ERR(group);
+               goto out_err;
+       }
+       iommu_group_put(group);
 
        /* It worked! Now, poke the actual hardware */
-       for_each_cfg_sme(cfg, i, idx)
-               arm_smmu_write_smr(smmu, idx);
+       for_each_cfg_sme(cfg, i, idx) {
+               arm_smmu_write_sme(smmu, idx);
+               smmu->s2crs[idx].group = group;
+       }
 
+       mutex_unlock(&smmu->stream_map_mutex);
        return 0;
 
-err_free_smrs:
+out_err:
        while (i--) {
-               arm_smmu_free_smr(smmu, cfg->smendx[i]);
+               arm_smmu_free_sme(smmu, cfg->smendx[i]);
                cfg->smendx[i] = INVALID_SMENDX;
        }
-       return -ENOSPC;
+       mutex_unlock(&smmu->stream_map_mutex);
+       return ret;
 }
 
 static void arm_smmu_master_free_smes(struct arm_smmu_master_cfg *cfg)
        struct arm_smmu_device *smmu = cfg->smmu;
        int i, idx;
 
-       /*
-        * We *must* clear the S2CR first, because freeing the SMR means
-        * that it can be re-allocated immediately.
-        */
+       mutex_lock(&smmu->stream_map_mutex);
        for_each_cfg_sme(cfg, i, idx) {
-               /* An IOMMU group is torn down by the first device to be removed */
-               if (idx == INVALID_SMENDX)
-                       return;
-
-               smmu->s2crs[idx] = s2cr_init_val;
-               arm_smmu_write_s2cr(smmu, idx);
-       }
-       /* Sync S2CR updates before touching anything else */
-       __iowmb();
-
-       /* Invalidate the SMRs before freeing back to the allocator */
-       for_each_cfg_sme(cfg, i, idx) {
-               if (smmu->smrs)
-                       arm_smmu_free_smr(smmu, idx);
-
+               if (arm_smmu_free_sme(smmu, idx))
+                       arm_smmu_write_sme(smmu, idx);
                cfg->smendx[i] = INVALID_SMENDX;
        }
+       mutex_unlock(&smmu->stream_map_mutex);
 }
 
 static int arm_smmu_domain_add_master(struct arm_smmu_domain *smmu_domain,
                                      struct arm_smmu_master_cfg *cfg)
 {
-       int i, idx, ret = 0;
        struct arm_smmu_device *smmu = smmu_domain->smmu;
        struct arm_smmu_s2cr *s2cr = smmu->s2crs;
        enum arm_smmu_s2cr_type type = S2CR_TYPE_TRANS;
        u8 cbndx = smmu_domain->cfg.cbndx;
-
-       if (cfg->smendx[0] == INVALID_SMENDX)
-               ret = arm_smmu_master_alloc_smes(smmu, cfg);
-       if (ret)
-               return ret;
+       int i, idx;
 
        /*
         * FIXME: This won't be needed once we have IOMMU-backed DMA ops
                type = S2CR_TYPE_BYPASS;
 
        for_each_cfg_sme(cfg, i, idx) {
-               /* Devices in an IOMMU group may already be configured */
                if (type == s2cr[idx].type && cbndx == s2cr[idx].cbndx)
-                       break;
+                       continue;
 
                s2cr[idx].type = type;
                s2cr[idx].privcfg = S2CR_PRIVCFG_UNPRIV;
 static int arm_smmu_add_device(struct device *dev)
 {
        struct arm_smmu_master_cfg *cfg;
-       struct iommu_group *group;
        int i, ret;
 
        ret = arm_smmu_register_legacy_master(dev);
                cfg->smendx[i] = INVALID_SMENDX;
        }
 
-       group = iommu_group_get_for_dev(dev);
-       if (IS_ERR(group)) {
-               ret = PTR_ERR(group);
-               goto out_free;
-       }
-       iommu_group_put(group);
-       return 0;
+       ret = arm_smmu_master_alloc_smes(dev);
+       if (!ret)
+               return ret;
 
 out_free:
        kfree(cfg);
 
 static struct iommu_group *arm_smmu_device_group(struct device *dev)
 {
-       struct iommu_group *group;
+       struct arm_smmu_master_cfg *cfg = dev->archdata.iommu;
+       struct arm_smmu_device *smmu = cfg->smmu;
+       struct iommu_group *group = NULL;
+       int i, idx;
+
+       for_each_cfg_sme(cfg, i, idx) {
+               if (group && smmu->s2crs[idx].group &&
+                   group != smmu->s2crs[idx].group)
+                       return ERR_PTR(-EINVAL);
+
+               group = smmu->s2crs[idx].group;
+       }
+
+       if (group)
+               return group;
 
        if (dev_is_pci(dev))
                group = pci_device_group(dev);
                smmu->s2crs[i] = s2cr_init_val;
 
        smmu->num_mapping_groups = size;
+       mutex_init(&smmu->stream_map_mutex);
 
        if (smmu->version < ARM_SMMU_V2 || !(id & ID0_PTFS_NO_AARCH32)) {
                smmu->features |= ARM_SMMU_FEAT_FMT_AARCH32_L;