static int felix_tag_8021q_setup(struct dsa_switch *ds)
 {
        struct ocelot *ocelot = ds->priv;
-       struct dsa_port *dp, *cpu_dp;
+       struct dsa_port *dp;
        int err;
 
        err = dsa_tag_8021q_register(ds, htons(ETH_P_8021AD));
        if (err)
                return err;
 
-       dsa_switch_for_each_cpu_port(cpu_dp, ds) {
-               ocelot_port_set_dsa_8021q_cpu(ocelot, cpu_dp->index);
-
-               /* TODO we could support multiple CPU ports in tag_8021q mode */
-               break;
-       }
+       dsa_switch_for_each_user_port(dp, ds)
+               ocelot_port_assign_dsa_8021q_cpu(ocelot, dp->index,
+                                                dp->cpu_dp->index);
 
-       dsa_switch_for_each_available_port(dp, ds) {
+       dsa_switch_for_each_available_port(dp, ds)
                /* This overwrites ocelot_init():
                 * Do not forward BPDU frames to the CPU port module,
                 * for 2 reasons:
                ocelot_write_gix(ocelot,
                                 ANA_PORT_CPU_FWD_BPDU_CFG_BPDU_REDIR_ENA(0),
                                 ANA_PORT_CPU_FWD_BPDU_CFG, dp->index);
-       }
 
        /* The ownership of the CPU port module's queues might have just been
         * transferred to the tag_8021q tagger from the NPI-based tagger.
 static void felix_tag_8021q_teardown(struct dsa_switch *ds)
 {
        struct ocelot *ocelot = ds->priv;
-       struct dsa_port *dp, *cpu_dp;
+       struct dsa_port *dp;
 
-       dsa_switch_for_each_available_port(dp, ds) {
+       dsa_switch_for_each_available_port(dp, ds)
                /* Restore the logic from ocelot_init:
                 * do not forward BPDU frames to the front ports.
                 */
                                 ANA_PORT_CPU_FWD_BPDU_CFG_BPDU_REDIR_ENA(0xffff),
                                 ANA_PORT_CPU_FWD_BPDU_CFG,
                                 dp->index);
-       }
 
-       dsa_switch_for_each_cpu_port(cpu_dp, ds) {
-               ocelot_port_unset_dsa_8021q_cpu(ocelot, cpu_dp->index);
-
-               /* TODO we could support multiple CPU ports in tag_8021q mode */
-               break;
-       }
+       dsa_switch_for_each_user_port(dp, ds)
+               ocelot_port_unassign_dsa_8021q_cpu(ocelot, dp->index);
 
        dsa_tag_8021q_unregister(ds);
 }
 
        return __ffs(bond_mask);
 }
 
-u32 ocelot_get_bridge_fwd_mask(struct ocelot *ocelot, int src_port)
+static u32 ocelot_dsa_8021q_cpu_assigned_ports(struct ocelot *ocelot,
+                                              struct ocelot_port *cpu)
 {
-       struct ocelot_port *ocelot_port = ocelot->ports[src_port];
-       const struct net_device *bridge;
        u32 mask = 0;
        int port;
 
-       if (!ocelot_port || ocelot_port->stp_state != BR_STATE_FORWARDING)
-               return 0;
-
-       bridge = ocelot_port->bridge;
-       if (!bridge)
-               return 0;
-
        for (port = 0; port < ocelot->num_phys_ports; port++) {
-               ocelot_port = ocelot->ports[port];
+               struct ocelot_port *ocelot_port = ocelot->ports[port];
 
                if (!ocelot_port)
                        continue;
 
-               if (ocelot_port->stp_state == BR_STATE_FORWARDING &&
-                   ocelot_port->bridge == bridge)
+               if (ocelot_port->dsa_8021q_cpu == cpu)
                        mask |= BIT(port);
        }
 
        return mask;
 }
-EXPORT_SYMBOL_GPL(ocelot_get_bridge_fwd_mask);
 
-u32 ocelot_get_dsa_8021q_cpu_mask(struct ocelot *ocelot)
+u32 ocelot_port_assigned_dsa_8021q_cpu_mask(struct ocelot *ocelot, int port)
 {
+       struct ocelot_port *ocelot_port = ocelot->ports[port];
+       struct ocelot_port *cpu_port = ocelot_port->dsa_8021q_cpu;
+
+       if (!cpu_port)
+               return 0;
+
+       return BIT(cpu_port->index);
+}
+EXPORT_SYMBOL_GPL(ocelot_port_assigned_dsa_8021q_cpu_mask);
+
+u32 ocelot_get_bridge_fwd_mask(struct ocelot *ocelot, int src_port)
+{
+       struct ocelot_port *ocelot_port = ocelot->ports[src_port];
+       const struct net_device *bridge;
        u32 mask = 0;
        int port;
 
+       if (!ocelot_port || ocelot_port->stp_state != BR_STATE_FORWARDING)
+               return 0;
+
+       bridge = ocelot_port->bridge;
+       if (!bridge)
+               return 0;
+
        for (port = 0; port < ocelot->num_phys_ports; port++) {
-               struct ocelot_port *ocelot_port = ocelot->ports[port];
+               ocelot_port = ocelot->ports[port];
 
                if (!ocelot_port)
                        continue;
 
-               if (ocelot_port->is_dsa_8021q_cpu)
+               if (ocelot_port->stp_state == BR_STATE_FORWARDING &&
+                   ocelot_port->bridge == bridge)
                        mask |= BIT(port);
        }
 
        return mask;
 }
-EXPORT_SYMBOL_GPL(ocelot_get_dsa_8021q_cpu_mask);
+EXPORT_SYMBOL_GPL(ocelot_get_bridge_fwd_mask);
 
 static void ocelot_apply_bridge_fwd_mask(struct ocelot *ocelot, bool joining)
 {
-       unsigned long cpu_fwd_mask;
        int port;
 
        lockdep_assert_held(&ocelot->fwd_domain_lock);
        if (joining && ocelot->ops->cut_through_fwd)
                ocelot->ops->cut_through_fwd(ocelot);
 
-       /* If a DSA tag_8021q CPU exists, it needs to be included in the
-        * regular forwarding path of the front ports regardless of whether
-        * those are bridged or standalone.
-        * If DSA tag_8021q is not used, this returns 0, which is fine because
-        * the hardware-based CPU port module can be a destination for packets
-        * even if it isn't part of PGID_SRC.
-        */
-       cpu_fwd_mask = ocelot_get_dsa_8021q_cpu_mask(ocelot);
-
        /* Apply FWD mask. The loop is needed to add/remove the current port as
         * a source for the other ports.
         */
                        mask = 0;
                } else if (ocelot_port->is_dsa_8021q_cpu) {
                        /* The DSA tag_8021q CPU ports need to be able to
-                        * forward packets to all other ports except for
-                        * themselves
+                        * forward packets to all ports assigned to them.
                         */
-                       mask = GENMASK(ocelot->num_phys_ports - 1, 0);
-                       mask &= ~cpu_fwd_mask;
+                       mask = ocelot_dsa_8021q_cpu_assigned_ports(ocelot,
+                                                                  ocelot_port);
                } else if (ocelot_port->bridge) {
                        struct net_device *bond = ocelot_port->bond;
 
                        mask = ocelot_get_bridge_fwd_mask(ocelot, port);
-                       mask |= cpu_fwd_mask;
                        mask &= ~BIT(port);
+
+                       mask |= ocelot_port_assigned_dsa_8021q_cpu_mask(ocelot,
+                                                                       port);
+
                        if (bond)
                                mask &= ~ocelot_get_bond_mask(ocelot, bond);
                } else {
                         * ports (if those exist), or to the hardware CPU port
                         * module otherwise.
                         */
-                       mask = cpu_fwd_mask;
+                       mask = ocelot_port_assigned_dsa_8021q_cpu_mask(ocelot,
+                                                                      port);
                }
 
                ocelot_write_rix(ocelot, mask, ANA_PGID_PGID, PGID_SRC + port);
        ocelot_write_rix(ocelot, pgid_cpu, ANA_PGID_PGID, PGID_CPU);
 }
 
-void ocelot_port_set_dsa_8021q_cpu(struct ocelot *ocelot, int port)
+void ocelot_port_assign_dsa_8021q_cpu(struct ocelot *ocelot, int port,
+                                     int cpu)
 {
+       struct ocelot_port *cpu_port = ocelot->ports[cpu];
        u16 vid;
 
        mutex_lock(&ocelot->fwd_domain_lock);
 
-       ocelot->ports[port]->is_dsa_8021q_cpu = true;
+       ocelot->ports[port]->dsa_8021q_cpu = cpu_port;
+
+       if (!cpu_port->is_dsa_8021q_cpu) {
+               cpu_port->is_dsa_8021q_cpu = true;
 
-       for (vid = OCELOT_RSV_VLAN_RANGE_START; vid < VLAN_N_VID; vid++)
-               ocelot_vlan_member_add(ocelot, port, vid, true);
+               for (vid = OCELOT_RSV_VLAN_RANGE_START; vid < VLAN_N_VID; vid++)
+                       ocelot_vlan_member_add(ocelot, cpu, vid, true);
 
-       ocelot_update_pgid_cpu(ocelot);
+               ocelot_update_pgid_cpu(ocelot);
+       }
 
        ocelot_apply_bridge_fwd_mask(ocelot, true);
 
        mutex_unlock(&ocelot->fwd_domain_lock);
 }
-EXPORT_SYMBOL_GPL(ocelot_port_set_dsa_8021q_cpu);
+EXPORT_SYMBOL_GPL(ocelot_port_assign_dsa_8021q_cpu);
 
-void ocelot_port_unset_dsa_8021q_cpu(struct ocelot *ocelot, int port)
+void ocelot_port_unassign_dsa_8021q_cpu(struct ocelot *ocelot, int port)
 {
+       struct ocelot_port *cpu_port = ocelot->ports[port]->dsa_8021q_cpu;
+       bool keep = false;
        u16 vid;
+       int p;
 
        mutex_lock(&ocelot->fwd_domain_lock);
 
-       ocelot->ports[port]->is_dsa_8021q_cpu = false;
+       ocelot->ports[port]->dsa_8021q_cpu = NULL;
+
+       for (p = 0; p < ocelot->num_phys_ports; p++) {
+               if (!ocelot->ports[p])
+                       continue;
+
+               if (ocelot->ports[p]->dsa_8021q_cpu == cpu_port) {
+                       keep = true;
+                       break;
+               }
+       }
+
+       if (!keep) {
+               cpu_port->is_dsa_8021q_cpu = false;
 
-       for (vid = OCELOT_RSV_VLAN_RANGE_START; vid < VLAN_N_VID; vid++)
-               ocelot_vlan_member_del(ocelot, port, vid);
+               for (vid = OCELOT_RSV_VLAN_RANGE_START; vid < VLAN_N_VID; vid++)
+                       ocelot_vlan_member_del(ocelot, cpu_port->index, vid);
 
-       ocelot_update_pgid_cpu(ocelot);
+               ocelot_update_pgid_cpu(ocelot);
+       }
 
        ocelot_apply_bridge_fwd_mask(ocelot, true);
 
        mutex_unlock(&ocelot->fwd_domain_lock);
 }
-EXPORT_SYMBOL_GPL(ocelot_port_unset_dsa_8021q_cpu);
+EXPORT_SYMBOL_GPL(ocelot_port_unassign_dsa_8021q_cpu);
 
 void ocelot_bridge_stp_state_set(struct ocelot *ocelot, int port, u8 state)
 {
 
        int to;
 };
 
+struct ocelot_port;
+
 struct ocelot_port {
        struct ocelot                   *ocelot;
 
        struct net_device               *bond;
        struct net_device               *bridge;
 
+       struct ocelot_port              *dsa_8021q_cpu;
+
        /* VLAN that untagged frames are classified to, on ingress */
        const struct ocelot_bridge_vlan *pvid_vlan;
 
 void ocelot_init_port(struct ocelot *ocelot, int port);
 void ocelot_deinit_port(struct ocelot *ocelot, int port);
 
-void ocelot_port_set_dsa_8021q_cpu(struct ocelot *ocelot, int port);
-void ocelot_port_unset_dsa_8021q_cpu(struct ocelot *ocelot, int port);
+void ocelot_port_assign_dsa_8021q_cpu(struct ocelot *ocelot, int port, int cpu);
+void ocelot_port_unassign_dsa_8021q_cpu(struct ocelot *ocelot, int port);
+u32 ocelot_port_assigned_dsa_8021q_cpu_mask(struct ocelot *ocelot, int port);
 
 /* DSA callbacks */
 void ocelot_get_strings(struct ocelot *ocelot, int port, u32 sset, u8 *data);
 int ocelot_port_vlan_filtering(struct ocelot *ocelot, int port, bool enabled,
                               struct netlink_ext_ack *extack);
 void ocelot_bridge_stp_state_set(struct ocelot *ocelot, int port, u8 state);
-u32 ocelot_get_dsa_8021q_cpu_mask(struct ocelot *ocelot);
 u32 ocelot_get_bridge_fwd_mask(struct ocelot *ocelot, int src_port);
 int ocelot_port_pre_bridge_flags(struct ocelot *ocelot, int port,
                                 struct switchdev_brport_flags val);