}
 EXPORT_SYMBOL(b53_mdb_del);
 
-int b53_br_join(struct dsa_switch *ds, int port, struct net_device *br)
+int b53_br_join(struct dsa_switch *ds, int port, struct dsa_bridge bridge)
 {
        struct b53_device *dev = ds->priv;
        s8 cpu_port = dsa_to_port(ds, port)->cpu_dp->index;
        b53_read16(dev, B53_PVLAN_PAGE, B53_PVLAN_PORT_MASK(port), &pvlan);
 
        b53_for_each_port(dev, i) {
-               if (dsa_port_bridge_dev_get(dsa_to_port(ds, i)) != br)
+               if (!dsa_port_offloads_bridge(dsa_to_port(ds, i), &bridge))
                        continue;
 
                /* Add this local port to the remote port VLAN control
 }
 EXPORT_SYMBOL(b53_br_join);
 
-void b53_br_leave(struct dsa_switch *ds, int port, struct net_device *br)
+void b53_br_leave(struct dsa_switch *ds, int port, struct dsa_bridge bridge)
 {
        struct b53_device *dev = ds->priv;
        struct b53_vlan *vl = &dev->vlans[0];
 
        b53_for_each_port(dev, i) {
                /* Don't touch the remaining ports */
-               if (dsa_port_bridge_dev_get(dsa_to_port(ds, i)) != br)
+               if (!dsa_port_offloads_bridge(dsa_to_port(ds, i), &bridge))
                        continue;
 
                b53_read16(dev, B53_PVLAN_PAGE, B53_PVLAN_PORT_MASK(i), ®);
 
 void b53_get_ethtool_stats(struct dsa_switch *ds, int port, uint64_t *data);
 int b53_get_sset_count(struct dsa_switch *ds, int port, int sset);
 void b53_get_ethtool_phy_stats(struct dsa_switch *ds, int port, uint64_t *data);
-int b53_br_join(struct dsa_switch *ds, int port, struct net_device *bridge);
-void b53_br_leave(struct dsa_switch *ds, int port, struct net_device *bridge);
+int b53_br_join(struct dsa_switch *ds, int port, struct dsa_bridge bridge);
+void b53_br_leave(struct dsa_switch *ds, int port, struct dsa_bridge bridge);
 void b53_br_set_stp_state(struct dsa_switch *ds, int port, u8 state);
 void b53_br_fast_age(struct dsa_switch *ds, int port);
 int b53_br_flags_pre(struct dsa_switch *ds, int port,
 
 }
 
 static int dsa_loop_port_bridge_join(struct dsa_switch *ds, int port,
-                                    struct net_device *bridge)
+                                    struct dsa_bridge bridge)
 {
        dev_dbg(ds->dev, "%s: port: %d, bridge: %s\n",
-               __func__, port, bridge->name);
+               __func__, port, bridge.dev->name);
 
        return 0;
 }
 
 static void dsa_loop_port_bridge_leave(struct dsa_switch *ds, int port,
-                                      struct net_device *bridge)
+                                      struct dsa_bridge bridge)
 {
        dev_dbg(ds->dev, "%s: port: %d, bridge: %s\n",
-               __func__, port, bridge->name);
+               __func__, port, bridge.dev->name);
 }
 
 static void dsa_loop_port_stp_state_set(struct dsa_switch *ds, int port,
 
 }
 
 static int hellcreek_port_bridge_join(struct dsa_switch *ds, int port,
-                                     struct net_device *br)
+                                     struct dsa_bridge bridge)
 {
        struct hellcreek *hellcreek = ds->priv;
 
 }
 
 static void hellcreek_port_bridge_leave(struct dsa_switch *ds, int port,
-                                       struct net_device *br)
+                                       struct dsa_bridge bridge)
 {
        struct hellcreek *hellcreek = ds->priv;
 
 
 }
 
 static int lan9303_port_bridge_join(struct dsa_switch *ds, int port,
-                                   struct net_device *br)
+                                   struct dsa_bridge bridge)
 {
        struct lan9303 *chip = ds->priv;
 
 }
 
 static void lan9303_port_bridge_leave(struct dsa_switch *ds, int port,
-                                     struct net_device *br)
+                                     struct dsa_bridge bridge)
 {
        struct lan9303 *chip = ds->priv;
 
 
 }
 
 static int gswip_port_bridge_join(struct dsa_switch *ds, int port,
-                                 struct net_device *bridge)
+                                 struct dsa_bridge bridge)
 {
+       struct net_device *br = bridge.dev;
        struct gswip_priv *priv = ds->priv;
        int err;
 
        /* When the bridge uses VLAN filtering we have to configure VLAN
         * specific bridges. No bridge is configured here.
         */
-       if (!br_vlan_enabled(bridge)) {
-               err = gswip_vlan_add_unaware(priv, bridge, port);
+       if (!br_vlan_enabled(br)) {
+               err = gswip_vlan_add_unaware(priv, br, port);
                if (err)
                        return err;
                priv->port_vlan_filter &= ~BIT(port);
 }
 
 static void gswip_port_bridge_leave(struct dsa_switch *ds, int port,
-                                   struct net_device *bridge)
+                                   struct dsa_bridge bridge)
 {
+       struct net_device *br = bridge.dev;
        struct gswip_priv *priv = ds->priv;
 
        gswip_add_single_port_br(priv, port, true);
        /* When the bridge uses VLAN filtering we have to configure VLAN
         * specific bridges. No bridge is configured here.
         */
-       if (!br_vlan_enabled(bridge))
-               gswip_vlan_remove(priv, bridge, port, 0, true, false);
+       if (!br_vlan_enabled(br))
+               gswip_vlan_remove(priv, br, port, 0, true, false);
 }
 
 static int gswip_port_vlan_prepare(struct dsa_switch *ds, int port,
 
 EXPORT_SYMBOL_GPL(ksz_get_ethtool_stats);
 
 int ksz_port_bridge_join(struct dsa_switch *ds, int port,
-                        struct net_device *br)
+                        struct dsa_bridge bridge)
 {
        /* port_stp_state_set() will be called after to put the port in
         * appropriate state so there is no need to do anything.
 EXPORT_SYMBOL_GPL(ksz_port_bridge_join);
 
 void ksz_port_bridge_leave(struct dsa_switch *ds, int port,
-                          struct net_device *br)
+                          struct dsa_bridge bridge)
 {
        /* port_stp_state_set() will be called after to put the port in
         * forwarding state so there is no need to do anything.
 
 int ksz_sset_count(struct dsa_switch *ds, int port, int sset);
 void ksz_get_ethtool_stats(struct dsa_switch *ds, int port, uint64_t *buf);
 int ksz_port_bridge_join(struct dsa_switch *ds, int port,
-                        struct net_device *br);
+                        struct dsa_bridge bridge);
 void ksz_port_bridge_leave(struct dsa_switch *ds, int port,
-                          struct net_device *br);
+                          struct dsa_bridge bridge);
 void ksz_port_fast_age(struct dsa_switch *ds, int port);
 int ksz_port_fdb_dump(struct dsa_switch *ds, int port, dsa_fdb_dump_cb_t *cb,
                      void *data);
 
 
 static int
 mt7530_port_bridge_join(struct dsa_switch *ds, int port,
-                       struct net_device *bridge)
+                       struct dsa_bridge bridge)
 {
        struct dsa_port *dp = dsa_to_port(ds, port), *other_dp;
        u32 port_bitmap = BIT(MT7530_CPU_PORT);
                 * same bridge. If the port is disabled, port matrix is kept
                 * and not being setup until the port becomes enabled.
                 */
-               if (dsa_port_bridge_dev_get(other_dp) != bridge)
+               if (!dsa_port_offloads_bridge(other_dp, &bridge))
                        continue;
 
                if (priv->ports[other_port].enable)
 
 static void
 mt7530_port_bridge_leave(struct dsa_switch *ds, int port,
-                        struct net_device *bridge)
+                        struct dsa_bridge bridge)
 {
        struct dsa_port *dp = dsa_to_port(ds, port), *other_dp;
        struct mt7530_priv *priv = ds->priv;
                 * in the same bridge. If the port is disabled, port matrix
                 * is kept and not being setup until the port becomes enabled.
                 */
-               if (dsa_port_bridge_dev_get(other_dp) != bridge)
+               if (!dsa_port_offloads_bridge(other_dp, &bridge))
                        continue;
 
                if (priv->ports[other_port].enable)
 
 }
 
 static int mv88e6xxx_bridge_map(struct mv88e6xxx_chip *chip,
-                               struct net_device *br)
+                               struct dsa_bridge bridge)
 {
        struct dsa_switch *ds = chip->ds;
        struct dsa_switch_tree *dst = ds->dst;
        int err;
 
        list_for_each_entry(dp, &dst->ports, list) {
-               if (dsa_port_bridge_dev_get(dp) == br) {
+               if (dsa_port_offloads_bridge(dp, &bridge)) {
                        if (dp->ds == ds) {
                                /* This is a local bridge group member,
                                 * remap its Port VLAN Map.
 }
 
 static int mv88e6xxx_port_bridge_join(struct dsa_switch *ds, int port,
-                                     struct net_device *br)
+                                     struct dsa_bridge bridge)
 {
        struct mv88e6xxx_chip *chip = ds->priv;
        int err;
 
        mv88e6xxx_reg_lock(chip);
 
-       err = mv88e6xxx_bridge_map(chip, br);
+       err = mv88e6xxx_bridge_map(chip, bridge);
        if (err)
                goto unlock;
 
 }
 
 static void mv88e6xxx_port_bridge_leave(struct dsa_switch *ds, int port,
-                                       struct net_device *br)
+                                       struct dsa_bridge bridge)
 {
        struct mv88e6xxx_chip *chip = ds->priv;
        int err;
 
        mv88e6xxx_reg_lock(chip);
 
-       if (mv88e6xxx_bridge_map(chip, br) ||
+       if (mv88e6xxx_bridge_map(chip, bridge) ||
            mv88e6xxx_port_vlan_map(chip, port))
                dev_err(ds->dev, "failed to remap in-chip Port VLAN\n");
 
 
 static int mv88e6xxx_crosschip_bridge_join(struct dsa_switch *ds,
                                           int tree_index, int sw_index,
-                                          int port, struct net_device *br)
+                                          int port, struct dsa_bridge bridge)
 {
        struct mv88e6xxx_chip *chip = ds->priv;
        int err;
 
 static void mv88e6xxx_crosschip_bridge_leave(struct dsa_switch *ds,
                                             int tree_index, int sw_index,
-                                            int port, struct net_device *br)
+                                            int port, struct dsa_bridge bridge)
 {
        struct mv88e6xxx_chip *chip = ds->priv;
 
 }
 
 static int mv88e6xxx_bridge_tx_fwd_offload(struct dsa_switch *ds, int port,
-                                          struct net_device *br,
-                                          unsigned int bridge_num)
+                                          struct dsa_bridge bridge)
 {
-       return mv88e6xxx_map_virtual_bridge_to_pvt(ds, bridge_num);
+       return mv88e6xxx_map_virtual_bridge_to_pvt(ds, bridge.num);
 }
 
 static void mv88e6xxx_bridge_tx_fwd_unoffload(struct dsa_switch *ds, int port,
-                                             struct net_device *br,
-                                             unsigned int bridge_num)
+                                             struct dsa_bridge bridge)
 {
        int err;
 
-       err = mv88e6xxx_map_virtual_bridge_to_pvt(ds, bridge_num);
+       err = mv88e6xxx_map_virtual_bridge_to_pvt(ds, bridge.num);
        if (err) {
                dev_err(ds->dev, "failed to remap cross-chip Port VLAN: %pe\n",
                        ERR_PTR(err));
 
 }
 
 static int felix_bridge_join(struct dsa_switch *ds, int port,
-                            struct net_device *br)
+                            struct dsa_bridge bridge)
 {
        struct ocelot *ocelot = ds->priv;
 
-       ocelot_port_bridge_join(ocelot, port, br);
+       ocelot_port_bridge_join(ocelot, port, bridge.dev);
 
        return 0;
 }
 
 static void felix_bridge_leave(struct dsa_switch *ds, int port,
-                              struct net_device *br)
+                              struct dsa_bridge bridge)
 {
        struct ocelot *ocelot = ds->priv;
 
-       ocelot_port_bridge_leave(ocelot, port, br);
+       ocelot_port_bridge_leave(ocelot, port, bridge.dev);
 }
 
 static int felix_lag_join(struct dsa_switch *ds, int port,
 
                  QCA8K_PORT_LOOKUP_STATE_MASK, stp_state);
 }
 
-static int
-qca8k_port_bridge_join(struct dsa_switch *ds, int port, struct net_device *br)
+static int qca8k_port_bridge_join(struct dsa_switch *ds, int port,
+                                 struct dsa_bridge bridge)
 {
        struct qca8k_priv *priv = (struct qca8k_priv *)ds->priv;
        int port_mask, cpu_port;
        for (i = 0; i < QCA8K_NUM_PORTS; i++) {
                if (dsa_is_cpu_port(ds, i))
                        continue;
-               if (dsa_port_bridge_dev_get(dsa_to_port(ds, i)) != br)
+               if (!dsa_port_offloads_bridge(dsa_to_port(ds, i), &bridge))
                        continue;
                /* Add this port to the portvlan mask of the other ports
                 * in the bridge
        return ret;
 }
 
-static void
-qca8k_port_bridge_leave(struct dsa_switch *ds, int port, struct net_device *br)
+static void qca8k_port_bridge_leave(struct dsa_switch *ds, int port,
+                                   struct dsa_bridge bridge)
 {
        struct qca8k_priv *priv = (struct qca8k_priv *)ds->priv;
        int cpu_port, i;
        for (i = 0; i < QCA8K_NUM_PORTS; i++) {
                if (dsa_is_cpu_port(ds, i))
                        continue;
-               if (dsa_port_bridge_dev_get(dsa_to_port(ds, i)) != br)
+               if (!dsa_port_offloads_bridge(dsa_to_port(ds, i), &bridge))
                        continue;
                /* Remove this port to the portvlan mask of the other ports
                 * in the bridge
 
 
 static int
 rtl8366rb_port_bridge_join(struct dsa_switch *ds, int port,
-                          struct net_device *bridge)
+                          struct dsa_bridge bridge)
 {
        struct realtek_smi *smi = ds->priv;
        unsigned int port_bitmap = 0;
                if (i == port)
                        continue;
                /* Not on this bridge */
-               if (dsa_port_bridge_dev_get(dsa_to_port(ds, i)) != bridge)
+               if (!dsa_port_offloads_bridge(dsa_to_port(ds, i), &bridge))
                        continue;
                /* Join this port to each other port on the bridge */
                ret = regmap_update_bits(smi->map, RTL8366RB_PORT_ISO(i),
 
 static void
 rtl8366rb_port_bridge_leave(struct dsa_switch *ds, int port,
-                           struct net_device *bridge)
+                           struct dsa_bridge bridge)
 {
        struct realtek_smi *smi = ds->priv;
        unsigned int port_bitmap = 0;
                if (i == port)
                        continue;
                /* Not on this bridge */
-               if (dsa_port_bridge_dev_get(dsa_to_port(ds, i)) != bridge)
+               if (!dsa_port_offloads_bridge(dsa_to_port(ds, i), &bridge))
                        continue;
                /* Remove this port from any other port on the bridge */
                ret = regmap_update_bits(smi->map, RTL8366RB_PORT_ISO(i),
 
 }
 
 static int sja1105_bridge_member(struct dsa_switch *ds, int port,
-                                struct net_device *br, bool member)
+                                struct dsa_bridge bridge, bool member)
 {
        struct sja1105_l2_forwarding_entry *l2_fwd;
        struct sja1105_private *priv = ds->priv;
                 */
                if (i == port)
                        continue;
-               if (dsa_port_bridge_dev_get(dsa_to_port(ds, i)) != br)
+               if (!dsa_port_offloads_bridge(dsa_to_port(ds, i), &bridge))
                        continue;
                sja1105_port_allow_traffic(l2_fwd, i, port, member);
                sja1105_port_allow_traffic(l2_fwd, port, i, member);
 }
 
 static int sja1105_bridge_join(struct dsa_switch *ds, int port,
-                              struct net_device *br)
+                              struct dsa_bridge bridge)
 {
-       return sja1105_bridge_member(ds, port, br, true);
+       return sja1105_bridge_member(ds, port, bridge, true);
 }
 
 static void sja1105_bridge_leave(struct dsa_switch *ds, int port,
-                                struct net_device *br)
+                                struct dsa_bridge bridge)
 {
-       sja1105_bridge_member(ds, port, br, false);
+       sja1105_bridge_member(ds, port, bridge, false);
 }
 
 #define BYTES_PER_KBIT (1000LL / 8)
 
 }
 
 static int xrs700x_bridge_common(struct dsa_switch *ds, int port,
-                                struct net_device *bridge, bool join)
+                                struct dsa_bridge bridge, bool join)
 {
        unsigned int i, cpu_mask = 0, mask = 0;
        struct xrs700x *priv = ds->priv;
 
                cpu_mask |= BIT(i);
 
-               if (dsa_port_bridge_dev_get(dsa_to_port(ds, i)) == bridge)
+               if (dsa_port_offloads_bridge(dsa_to_port(ds, i), &bridge))
                        continue;
 
                mask |= BIT(i);
        }
 
        for (i = 0; i < ds->num_ports; i++) {
-               if (dsa_port_bridge_dev_get(dsa_to_port(ds, i)) != bridge)
+               if (!dsa_port_offloads_bridge(dsa_to_port(ds, i), &bridge))
                        continue;
 
                /* 1 = Disable forwarding to the port */
 }
 
 static int xrs700x_bridge_join(struct dsa_switch *ds, int port,
-                              struct net_device *bridge)
+                              struct dsa_bridge bridge)
 {
        return xrs700x_bridge_common(ds, port, bridge, true);
 }
 
 static void xrs700x_bridge_leave(struct dsa_switch *ds, int port,
-                                struct net_device *bridge)
+                                struct dsa_bridge bridge)
 {
        xrs700x_bridge_common(ds, port, bridge, false);
 }
 
 
 #include <linux/refcount.h>
 #include <linux/types.h>
+#include <net/dsa.h>
 
 struct dsa_switch;
 struct dsa_port;
 void dsa_8021q_rcv(struct sk_buff *skb, int *source_port, int *switch_id);
 
 int dsa_tag_8021q_bridge_tx_fwd_offload(struct dsa_switch *ds, int port,
-                                       struct net_device *br,
-                                       unsigned int bridge_num);
+                                       struct dsa_bridge bridge);
 
 void dsa_tag_8021q_bridge_tx_fwd_unoffload(struct dsa_switch *ds, int port,
-                                          struct net_device *br,
-                                          unsigned int bridge_num);
+                                          struct dsa_bridge bridge);
 
 u16 dsa_8021q_bridge_tx_fwd_offload_vid(unsigned int bridge_num);
 
 
        };
 };
 
+struct dsa_bridge {
+       struct net_device *dev;
+       unsigned int num;
+       refcount_t refcount;
+};
 
 struct dsa_port {
        /* A CPU port is physically connected to a master device.
        /* Managed by DSA on user ports and by drivers on CPU and DSA ports */
        bool                    learning;
        u8                      stp_state;
-       struct net_device       *bridge_dev;
-       unsigned int            bridge_num;
+       struct dsa_bridge       *bridge;
        struct devlink_port     devlink_port;
        bool                    devlink_port_setup;
        struct phylink          *pl;
 static inline
 struct net_device *dsa_port_to_bridge_port(const struct dsa_port *dp)
 {
-       if (!dp->bridge_dev)
+       if (!dp->bridge)
                return NULL;
 
        if (dp->lag_dev)
 static inline struct net_device *
 dsa_port_bridge_dev_get(const struct dsa_port *dp)
 {
-       return dp->bridge_dev;
+       return dp->bridge ? dp->bridge->dev : NULL;
 }
 
 static inline unsigned int dsa_port_bridge_num_get(struct dsa_port *dp)
 {
-       return dp->bridge_num;
+       return dp->bridge ? dp->bridge->num : 0;
 }
 
 static inline bool dsa_port_bridge_same(const struct dsa_port *a,
        return dsa_port_bridge_dev_get(dp) == bridge_dev;
 }
 
+static inline bool dsa_port_offloads_bridge(struct dsa_port *dp,
+                                           const struct dsa_bridge *bridge)
+{
+       return dsa_port_bridge_dev_get(dp) == bridge->dev;
+}
+
 /* Returns true if any port of this tree offloads the given net_device */
 static inline bool dsa_tree_offloads_bridge_port(struct dsa_switch_tree *dst,
                                                 const struct net_device *dev)
         */
        int     (*set_ageing_time)(struct dsa_switch *ds, unsigned int msecs);
        int     (*port_bridge_join)(struct dsa_switch *ds, int port,
-                                   struct net_device *bridge);
+                                   struct dsa_bridge bridge);
        void    (*port_bridge_leave)(struct dsa_switch *ds, int port,
-                                    struct net_device *bridge);
+                                    struct dsa_bridge bridge);
        /* Called right after .port_bridge_join() */
        int     (*port_bridge_tx_fwd_offload)(struct dsa_switch *ds, int port,
-                                             struct net_device *bridge,
-                                             unsigned int bridge_num);
+                                             struct dsa_bridge bridge);
        /* Called right before .port_bridge_leave() */
        void    (*port_bridge_tx_fwd_unoffload)(struct dsa_switch *ds, int port,
-                                               struct net_device *bridge,
-                                               unsigned int bridge_num);
+                                               struct dsa_bridge bridge);
        void    (*port_stp_state_set)(struct dsa_switch *ds, int port,
                                      u8 state);
        void    (*port_fast_age)(struct dsa_switch *ds, int port);
         */
        int     (*crosschip_bridge_join)(struct dsa_switch *ds, int tree_index,
                                         int sw_index, int port,
-                                        struct net_device *br);
+                                        struct dsa_bridge bridge);
        void    (*crosschip_bridge_leave)(struct dsa_switch *ds, int tree_index,
                                          int sw_index, int port,
-                                         struct net_device *br);
+                                         struct dsa_bridge bridge);
        int     (*crosschip_lag_change)(struct dsa_switch *ds, int sw_index,
                                        int port);
        int     (*crosschip_lag_join)(struct dsa_switch *ds, int sw_index,
 
        }
 }
 
+struct dsa_bridge *dsa_tree_bridge_find(struct dsa_switch_tree *dst,
+                                       const struct net_device *br)
+{
+       struct dsa_port *dp;
+
+       list_for_each_entry(dp, &dst->ports, list)
+               if (dsa_port_bridge_dev_get(dp) == br)
+                       return dp->bridge;
+
+       return NULL;
+}
+
 static int dsa_bridge_num_find(const struct net_device *bridge_dev)
 {
        struct dsa_switch_tree *dst;
-       struct dsa_port *dp;
 
-       /* When preparing the offload for a port, it will have a valid
-        * dp->bridge_dev pointer but a not yet valid dp->bridge_num.
-        * However there might be other ports having the same dp->bridge_dev
-        * and a valid dp->bridge_num, so just ignore this port.
-        */
-       list_for_each_entry(dst, &dsa_tree_list, list)
-               list_for_each_entry(dp, &dst->ports, list)
-                       if (dp->bridge_dev == bridge_dev && dp->bridge_num)
-                               return dp->bridge_num;
+       list_for_each_entry(dst, &dsa_tree_list, list) {
+               struct dsa_bridge *bridge;
+
+               bridge = dsa_tree_bridge_find(dst, bridge_dev);
+               if (bridge)
+                       return bridge->num;
+       }
 
        return 0;
 }
 {
        unsigned int bridge_num = dsa_bridge_num_find(bridge_dev);
 
+       /* Switches without FDB isolation support don't get unique
+        * bridge numbering
+        */
+       if (!max)
+               return 0;
+
        if (!bridge_num) {
                /* First port that requests FDB isolation or TX forwarding
                 * offload for this bridge
 void dsa_bridge_num_put(const struct net_device *bridge_dev,
                        unsigned int bridge_num)
 {
-       /* Check if the bridge is still in use, otherwise it is time
-        * to clean it up so we can reuse this bridge_num later.
+       /* Since we refcount bridges, we know that when we call this function
+        * it is no longer in use, so we can just go ahead and remove it from
+        * the bit mask.
         */
-       if (!dsa_bridge_num_find(bridge_dev))
-               clear_bit(bridge_num, &dsa_fwd_offloading_bridges);
+       clear_bit(bridge_num, &dsa_fwd_offloading_bridges);
 }
 
 struct dsa_switch *dsa_switch_find(int tree_index, int sw_index)
 
 
 /* DSA_NOTIFIER_BRIDGE_* */
 struct dsa_notifier_bridge_info {
-       struct net_device *br;
+       struct dsa_bridge bridge;
        int tree_index;
        int sw_index;
        int port;
                if (dp->type != DSA_PORT_TYPE_USER)
                        continue;
 
-               if (!dp->bridge_dev)
+               if (!dp->bridge)
                        continue;
 
                if (dp->stp_state != BR_STATE_LEARNING &&
 /* If the ingress port offloads the bridge, we mark the frame as autonomously
  * forwarded by hardware, so the software bridge doesn't forward in twice, back
  * to us, because we already did. However, if we're in fallback mode and we do
- * software bridging, we are not offloading it, therefore the dp->bridge_dev
+ * software bridging, we are not offloading it, therefore the dp->bridge
  * pointer is not populated, and flooding needs to be done by software (we are
  * effectively operating in standalone ports mode).
  */
 {
        struct dsa_port *dp = dsa_slave_to_port(skb->dev);
 
-       skb->offload_fwd_mark = !!(dp->bridge_dev);
+       skb->offload_fwd_mark = !!(dp->bridge);
 }
 
 /* Helper for removing DSA header tags from packets in the RX path.
 unsigned int dsa_bridge_num_get(const struct net_device *bridge_dev, int max);
 void dsa_bridge_num_put(const struct net_device *bridge_dev,
                        unsigned int bridge_num);
+struct dsa_bridge *dsa_tree_bridge_find(struct dsa_switch_tree *dst,
+                                       const struct net_device *br);
 
 /* tag_8021q.c */
 int dsa_tag_8021q_bridge_join(struct dsa_switch *ds,
 
                        return err;
        }
 
-       if (!dp->bridge_dev)
+       if (!dp->bridge)
                dsa_port_set_state_now(dp, BR_STATE_FORWARDING, false);
 
        if (dp->pl)
        if (dp->pl)
                phylink_stop(dp->pl);
 
-       if (!dp->bridge_dev)
+       if (!dp->bridge)
                dsa_port_set_state_now(dp, BR_STATE_DISABLED, false);
 
        if (ds->ops->port_disable)
 }
 
 static void dsa_port_bridge_tx_fwd_unoffload(struct dsa_port *dp,
-                                            struct net_device *bridge_dev,
-                                            unsigned int bridge_num)
+                                            struct dsa_bridge bridge)
 {
        struct dsa_switch *ds = dp->ds;
 
        /* No bridge TX forwarding offload => do nothing */
-       if (!ds->ops->port_bridge_tx_fwd_unoffload || !bridge_num)
+       if (!ds->ops->port_bridge_tx_fwd_unoffload || !bridge.num)
                return;
 
        /* Notify the chips only once the offload has been deactivated, so
         * that they can update their configuration accordingly.
         */
-       ds->ops->port_bridge_tx_fwd_unoffload(ds, dp->index, bridge_dev,
-                                             bridge_num);
+       ds->ops->port_bridge_tx_fwd_unoffload(ds, dp->index, bridge);
 }
 
 static bool dsa_port_bridge_tx_fwd_offload(struct dsa_port *dp,
-                                          struct net_device *bridge_dev,
-                                          unsigned int bridge_num)
+                                          struct dsa_bridge bridge)
 {
        struct dsa_switch *ds = dp->ds;
        int err;
 
        /* FDB isolation is required for TX forwarding offload */
-       if (!ds->ops->port_bridge_tx_fwd_offload || !bridge_num)
+       if (!ds->ops->port_bridge_tx_fwd_offload || !bridge.num)
                return false;
 
        /* Notify the driver */
-       err = ds->ops->port_bridge_tx_fwd_offload(ds, dp->index, bridge_dev,
-                                                 bridge_num);
+       err = ds->ops->port_bridge_tx_fwd_offload(ds, dp->index, bridge);
 
        return err ? false : true;
 }
                                  struct netlink_ext_ack *extack)
 {
        struct dsa_switch *ds = dp->ds;
-       unsigned int bridge_num;
+       struct dsa_bridge *bridge;
 
-       dp->bridge_dev = br;
-
-       if (!ds->max_num_bridges)
+       bridge = dsa_tree_bridge_find(ds->dst, br);
+       if (bridge) {
+               refcount_inc(&bridge->refcount);
+               dp->bridge = bridge;
                return 0;
+       }
+
+       bridge = kzalloc(sizeof(*bridge), GFP_KERNEL);
+       if (!bridge)
+               return -ENOMEM;
+
+       refcount_set(&bridge->refcount, 1);
+
+       bridge->dev = br;
 
-       bridge_num = dsa_bridge_num_get(br, ds->max_num_bridges);
-       if (!bridge_num) {
+       bridge->num = dsa_bridge_num_get(br, ds->max_num_bridges);
+       if (ds->max_num_bridges && !bridge->num) {
                NL_SET_ERR_MSG_MOD(extack,
                                   "Range of offloadable bridges exceeded");
+               kfree(bridge);
                return -EOPNOTSUPP;
        }
 
-       dp->bridge_num = bridge_num;
+       dp->bridge = bridge;
 
        return 0;
 }
 static void dsa_port_bridge_destroy(struct dsa_port *dp,
                                    const struct net_device *br)
 {
-       struct dsa_switch *ds = dp->ds;
+       struct dsa_bridge *bridge = dp->bridge;
+
+       dp->bridge = NULL;
 
-       dp->bridge_dev = NULL;
+       if (!refcount_dec_and_test(&bridge->refcount))
+               return;
 
-       if (ds->max_num_bridges) {
-               int bridge_num = dp->bridge_num;
+       if (bridge->num)
+               dsa_bridge_num_put(br, bridge->num);
 
-               dp->bridge_num = 0;
-               dsa_bridge_num_put(br, bridge_num);
-       }
+       kfree(bridge);
 }
 
 int dsa_port_bridge_join(struct dsa_port *dp, struct net_device *br,
                .tree_index = dp->ds->dst->index,
                .sw_index = dp->ds->index,
                .port = dp->index,
-               .br = br,
        };
        struct net_device *dev = dp->slave;
        struct net_device *brport_dev;
 
        brport_dev = dsa_port_to_bridge_port(dp);
 
+       info.bridge = *dp->bridge;
        err = dsa_broadcast(DSA_NOTIFIER_BRIDGE_JOIN, &info);
        if (err)
                goto out_rollback;
 
-       tx_fwd_offload = dsa_port_bridge_tx_fwd_offload(dp, br,
-                                                       dsa_port_bridge_num_get(dp));
+       tx_fwd_offload = dsa_port_bridge_tx_fwd_offload(dp, info.bridge);
 
        err = switchdev_bridge_port_offload(brport_dev, dev, dp,
                                            &dsa_slave_switchdev_notifier,
 
 void dsa_port_bridge_leave(struct dsa_port *dp, struct net_device *br)
 {
-       unsigned int bridge_num = dsa_port_bridge_num_get(dp);
        struct dsa_notifier_bridge_info info = {
                .tree_index = dp->ds->dst->index,
                .sw_index = dp->ds->index,
                .port = dp->index,
-               .br = br,
+               .bridge = *dp->bridge,
        };
        int err;
 
         */
        dsa_port_bridge_destroy(dp, br);
 
-       dsa_port_bridge_tx_fwd_unoffload(dp, br, bridge_num);
+       dsa_port_bridge_tx_fwd_unoffload(dp, info.bridge);
 
        err = dsa_broadcast(DSA_NOTIFIER_BRIDGE_LEAVE, &info);
        if (err)
 
        if (!dp->ds->mtu_enforcement_ingress)
                return;
 
-       if (!dp->bridge_dev)
+       if (!dp->bridge)
                return;
 
        INIT_LIST_HEAD(&hw_port_list);
 
                if (!ds->ops->port_bridge_join)
                        return -EOPNOTSUPP;
 
-               err = ds->ops->port_bridge_join(ds, info->port, info->br);
+               err = ds->ops->port_bridge_join(ds, info->port, info->bridge);
                if (err)
                        return err;
        }
            ds->ops->crosschip_bridge_join) {
                err = ds->ops->crosschip_bridge_join(ds, info->tree_index,
                                                     info->sw_index,
-                                                    info->port, info->br);
+                                                    info->port, info->bridge);
                if (err)
                        return err;
        }
 
        if (dst->index == info->tree_index && ds->index == info->sw_index &&
            ds->ops->port_bridge_leave)
-               ds->ops->port_bridge_leave(ds, info->port, info->br);
+               ds->ops->port_bridge_leave(ds, info->port, info->bridge);
 
        if ((dst->index != info->tree_index || ds->index != info->sw_index) &&
            ds->ops->crosschip_bridge_leave)
                ds->ops->crosschip_bridge_leave(ds, info->tree_index,
                                                info->sw_index, info->port,
-                                               info->br);
+                                               info->bridge);
 
-       if (ds->needs_standalone_vlan_filtering && !br_vlan_enabled(info->br)) {
+       if (ds->needs_standalone_vlan_filtering &&
+           !br_vlan_enabled(info->bridge.dev)) {
                change_vlan_filtering = true;
                vlan_filtering = true;
        } else if (!ds->needs_standalone_vlan_filtering &&
-                  br_vlan_enabled(info->br)) {
+                  br_vlan_enabled(info->bridge.dev)) {
                change_vlan_filtering = true;
                vlan_filtering = false;
        }
 
                return false;
 
        if (dsa_port_is_user(dp))
-               return dsa_port_bridge_dev_get(dp) == info->br;
+               return dsa_port_offloads_bridge(dp, &info->bridge);
 
        return false;
 }
 }
 
 int dsa_tag_8021q_bridge_tx_fwd_offload(struct dsa_switch *ds, int port,
-                                       struct net_device *br,
-                                       unsigned int bridge_num)
+                                       struct dsa_bridge bridge)
 {
-       u16 tx_vid = dsa_8021q_bridge_tx_fwd_offload_vid(bridge_num);
+       u16 tx_vid = dsa_8021q_bridge_tx_fwd_offload_vid(bridge.num);
 
        return dsa_port_tag_8021q_vlan_add(dsa_to_port(ds, port), tx_vid,
                                           true);
 EXPORT_SYMBOL_GPL(dsa_tag_8021q_bridge_tx_fwd_offload);
 
 void dsa_tag_8021q_bridge_tx_fwd_unoffload(struct dsa_switch *ds, int port,
-                                          struct net_device *br,
-                                          unsigned int bridge_num)
+                                          struct dsa_bridge bridge)
 {
-       u16 tx_vid = dsa_8021q_bridge_tx_fwd_offload_vid(bridge_num);
+       u16 tx_vid = dsa_8021q_bridge_tx_fwd_offload_vid(bridge.num);
 
        dsa_port_tag_8021q_vlan_del(dsa_to_port(ds, port), tx_vid, true);
 }