return 0;
 }
 
+/* Treat the software bridge as a virtual single-port switch behind the
+ * CPU and map in the PVT. First dst->last_switch elements are taken by
+ * physical switches, so start from beyond that range.
+ */
+static int mv88e6xxx_map_virtual_bridge_to_pvt(struct dsa_switch *ds,
+                                              unsigned int bridge_num)
+{
+       u8 dev = bridge_num + ds->dst->last_switch;
+       struct mv88e6xxx_chip *chip = ds->priv;
+
+       return mv88e6xxx_pvt_map(chip, dev, 0);
+}
+
 static int mv88e6xxx_port_bridge_join(struct dsa_switch *ds, int port,
                                      struct dsa_bridge bridge,
                                      bool *tx_fwd_offload)
        if (err)
                goto unlock;
 
+       if (mv88e6xxx_has_pvt(chip)) {
+               err = mv88e6xxx_map_virtual_bridge_to_pvt(ds, bridge.num);
+               if (err)
+                       goto unlock;
+
+               *tx_fwd_offload = true;
+       }
+
 unlock:
        mv88e6xxx_reg_unlock(chip);
 
 
        mv88e6xxx_reg_lock(chip);
 
+       if (bridge.tx_fwd_offload &&
+           mv88e6xxx_map_virtual_bridge_to_pvt(ds, bridge.num))
+               dev_err(ds->dev, "failed to remap cross-chip Port VLAN\n");
+
        if (mv88e6xxx_bridge_map(chip, bridge) ||
            mv88e6xxx_port_vlan_map(chip, port))
                dev_err(ds->dev, "failed to remap in-chip Port VLAN\n");
        mv88e6xxx_reg_unlock(chip);
 }
 
-/* Treat the software bridge as a virtual single-port switch behind the
- * CPU and map in the PVT. First dst->last_switch elements are taken by
- * physical switches, so start from beyond that range.
- */
-static int mv88e6xxx_map_virtual_bridge_to_pvt(struct dsa_switch *ds,
-                                              unsigned int bridge_num)
-{
-       u8 dev = bridge_num + ds->dst->last_switch;
-       struct mv88e6xxx_chip *chip = ds->priv;
-       int err;
-
-       mv88e6xxx_reg_lock(chip);
-       err = mv88e6xxx_pvt_map(chip, dev, 0);
-       mv88e6xxx_reg_unlock(chip);
-
-       return err;
-}
-
-static int mv88e6xxx_bridge_tx_fwd_offload(struct dsa_switch *ds, int port,
-                                          struct dsa_bridge bridge)
-{
-       return mv88e6xxx_map_virtual_bridge_to_pvt(ds, bridge.num);
-}
-
-static void mv88e6xxx_bridge_tx_fwd_unoffload(struct dsa_switch *ds, int port,
-                                             struct dsa_bridge bridge)
-{
-       int err;
-
-       err = mv88e6xxx_map_virtual_bridge_to_pvt(ds, bridge.num);
-       if (err) {
-               dev_err(ds->dev, "failed to remap cross-chip Port VLAN: %pe\n",
-                       ERR_PTR(err));
-       }
-}
-
 static int mv88e6xxx_software_reset(struct mv88e6xxx_chip *chip)
 {
        if (chip->info->ops->reset)
        .crosschip_lag_change   = mv88e6xxx_crosschip_lag_change,
        .crosschip_lag_join     = mv88e6xxx_crosschip_lag_join,
        .crosschip_lag_leave    = mv88e6xxx_crosschip_lag_leave,
-       .port_bridge_tx_fwd_offload = mv88e6xxx_bridge_tx_fwd_offload,
-       .port_bridge_tx_fwd_unoffload = mv88e6xxx_bridge_tx_fwd_unoffload,
 };
 
 static int mv88e6xxx_register_switch(struct mv88e6xxx_chip *chip)
 
                               struct dsa_bridge bridge,
                               bool *tx_fwd_offload)
 {
-       return sja1105_bridge_member(ds, port, bridge, true);
+       int rc;
+
+       rc = sja1105_bridge_member(ds, port, bridge, true);
+       if (rc)
+               return rc;
+
+       rc = dsa_tag_8021q_bridge_tx_fwd_offload(ds, port, bridge);
+       if (rc) {
+               sja1105_bridge_member(ds, port, bridge, false);
+               return rc;
+       }
+
+       *tx_fwd_offload = true;
+
+       return 0;
 }
 
 static void sja1105_bridge_leave(struct dsa_switch *ds, int port,
                                 struct dsa_bridge bridge)
 {
+       dsa_tag_8021q_bridge_tx_fwd_unoffload(ds, port, bridge);
        sja1105_bridge_member(ds, port, bridge, false);
 }
 
        .tag_8021q_vlan_add     = sja1105_dsa_8021q_vlan_add,
        .tag_8021q_vlan_del     = sja1105_dsa_8021q_vlan_del,
        .port_prechangeupper    = sja1105_prechangeupper,
-       .port_bridge_tx_fwd_offload = dsa_tag_8021q_bridge_tx_fwd_offload,
-       .port_bridge_tx_fwd_unoffload = dsa_tag_8021q_bridge_tx_fwd_unoffload,
 };
 
 static const struct of_device_id sja1105_dt_ids[];
 
 struct dsa_bridge {
        struct net_device *dev;
        unsigned int num;
+       bool tx_fwd_offload;
        refcount_t refcount;
 };
 
                                    bool *tx_fwd_offload);
        void    (*port_bridge_leave)(struct dsa_switch *ds, int port,
                                     struct dsa_bridge bridge);
-       /* Called right after .port_bridge_join() */
-       int     (*port_bridge_tx_fwd_offload)(struct dsa_switch *ds, int port,
-                                             struct dsa_bridge bridge);
-       /* Called right before .port_bridge_leave() */
-       void    (*port_bridge_tx_fwd_unoffload)(struct dsa_switch *ds, int port,
-                                               struct dsa_bridge bridge);
        void    (*port_stp_state_set)(struct dsa_switch *ds, int port,
                                      u8 state);
        void    (*port_fast_age)(struct dsa_switch *ds, int port);
 
         */
 }
 
-static void dsa_port_bridge_tx_fwd_unoffload(struct dsa_port *dp,
-                                            struct dsa_bridge bridge)
-{
-       struct dsa_switch *ds = dp->ds;
-
-       /* No bridge TX forwarding offload => do nothing */
-       if (!ds->ops->port_bridge_tx_fwd_unoffload || !bridge.num)
-               return;
-
-       /* Notify the chips only once the offload has been deactivated, so
-        * that they can update their configuration accordingly.
-        */
-       ds->ops->port_bridge_tx_fwd_unoffload(ds, dp->index, bridge);
-}
-
-static bool dsa_port_bridge_tx_fwd_offload(struct dsa_port *dp,
-                                          struct dsa_bridge bridge)
-{
-       struct dsa_switch *ds = dp->ds;
-       int err;
-
-       /* FDB isolation is required for TX forwarding offload */
-       if (!ds->ops->port_bridge_tx_fwd_offload || !bridge.num)
-               return false;
-
-       /* Notify the driver */
-       err = ds->ops->port_bridge_tx_fwd_offload(ds, dp->index, bridge);
-
-       return err ? false : true;
-}
-
 static int dsa_port_bridge_create(struct dsa_port *dp,
                                  struct net_device *br,
                                  struct netlink_ext_ack *extack)
        };
        struct net_device *dev = dp->slave;
        struct net_device *brport_dev;
-       bool tx_fwd_offload;
        int err;
 
        /* Here the interface is already bridged. Reflect the current
        if (err)
                goto out_rollback;
 
-       tx_fwd_offload = dsa_port_bridge_tx_fwd_offload(dp, info.bridge);
+       /* Drivers which support bridge TX forwarding should set this */
+       dp->bridge->tx_fwd_offload = info.tx_fwd_offload;
 
        err = switchdev_bridge_port_offload(brport_dev, dev, dp,
                                            &dsa_slave_switchdev_notifier,
                                            &dsa_slave_switchdev_blocking_notifier,
-                                           tx_fwd_offload, extack);
+                                           dp->bridge->tx_fwd_offload, extack);
        if (err)
                goto out_rollback_unbridge;
 
         */
        dsa_port_bridge_destroy(dp, br);
 
-       dsa_port_bridge_tx_fwd_unoffload(dp, info.bridge);
-
        err = dsa_broadcast(DSA_NOTIFIER_BRIDGE_LEAVE, &info);
        if (err)
                dev_err(dp->ds->dev,