if (IS_ERR(smmu_domain))
                return ERR_CAST(smmu_domain);
 
+       if (master->smmu->features & ARM_SMMU_FEAT_TRANS_S1)
+               smmu_domain->stage = ARM_SMMU_DOMAIN_S1;
+       else
+               smmu_domain->stage = ARM_SMMU_DOMAIN_S2;
+
        ret = arm_smmu_domain_finalise(smmu_domain, master->smmu, 0);
        if (ret) {
                kfree(smmu_domain);
                                 struct arm_smmu_domain *smmu_domain);
        bool enable_dirty = flags & IOMMU_HWPT_ALLOC_DIRTY_TRACKING;
 
-       /* Restrict the stage to what we can actually support */
-       if (!(smmu->features & ARM_SMMU_FEAT_TRANS_S1))
-               smmu_domain->stage = ARM_SMMU_DOMAIN_S2;
-       if (!(smmu->features & ARM_SMMU_FEAT_TRANS_S2))
-               smmu_domain->stage = ARM_SMMU_DOMAIN_S1;
-
        pgtbl_cfg = (struct io_pgtable_cfg) {
                .pgsize_bitmap  = smmu->pgsize_bitmap,
                .coherent_walk  = smmu->features & ARM_SMMU_FEAT_COHERENCY,
                                   const struct iommu_user_data *user_data)
 {
        struct arm_smmu_master *master = dev_iommu_priv_get(dev);
+       struct arm_smmu_device *smmu = master->smmu;
        const u32 PAGING_FLAGS = IOMMU_HWPT_ALLOC_DIRTY_TRACKING |
                                 IOMMU_HWPT_ALLOC_PASID |
                                 IOMMU_HWPT_ALLOC_NEST_PARENT;
        if (user_data)
                return ERR_PTR(-EOPNOTSUPP);
 
-       if (flags & IOMMU_HWPT_ALLOC_PASID)
-               return arm_smmu_domain_alloc_paging(dev);
-
        smmu_domain = arm_smmu_domain_alloc();
        if (IS_ERR(smmu_domain))
                return ERR_CAST(smmu_domain);
 
-       if (flags & IOMMU_HWPT_ALLOC_NEST_PARENT) {
-               if (!(master->smmu->features & ARM_SMMU_FEAT_NESTING)) {
+       switch (flags) {
+       case 0:
+               /* Prefer S1 if available */
+               if (smmu->features & ARM_SMMU_FEAT_TRANS_S1)
+                       smmu_domain->stage = ARM_SMMU_DOMAIN_S1;
+               else
+                       smmu_domain->stage = ARM_SMMU_DOMAIN_S2;
+               break;
+       case IOMMU_HWPT_ALLOC_NEST_PARENT:
+               if (!(smmu->features & ARM_SMMU_FEAT_NESTING)) {
                        ret = -EOPNOTSUPP;
                        goto err_free;
                }
                smmu_domain->stage = ARM_SMMU_DOMAIN_S2;
                smmu_domain->nest_parent = true;
+               break;
+       case IOMMU_HWPT_ALLOC_DIRTY_TRACKING:
+       case IOMMU_HWPT_ALLOC_DIRTY_TRACKING | IOMMU_HWPT_ALLOC_PASID:
+       case IOMMU_HWPT_ALLOC_PASID:
+               if (!(smmu->features & ARM_SMMU_FEAT_TRANS_S1)) {
+                       ret = -EOPNOTSUPP;
+                       goto err_free;
+               }
+               smmu_domain->stage = ARM_SMMU_DOMAIN_S1;
+               break;
+       default:
+               ret = -EOPNOTSUPP;
+               goto err_free;
        }
 
        smmu_domain->domain.type = IOMMU_DOMAIN_UNMANAGED;
        smmu_domain->domain.ops = arm_smmu_ops.default_domain_ops;
-       ret = arm_smmu_domain_finalise(smmu_domain, master->smmu, flags);
+       ret = arm_smmu_domain_finalise(smmu_domain, smmu, flags);
        if (ret)
                goto err_free;
        return &smmu_domain->domain;