dev_put(mlxsw_sp_port->dev);
 }
 
+static void
+mlxsw_sp_port_lag_uppers_cleanup(struct mlxsw_sp_port *mlxsw_sp_port,
+                                struct net_device *lag_dev)
+{
+       struct net_device *br_dev = netdev_master_upper_dev_get(lag_dev);
+       struct net_device *upper_dev;
+       struct list_head *iter;
+
+       if (netif_is_bridge_port(lag_dev))
+               mlxsw_sp_port_bridge_leave(mlxsw_sp_port, lag_dev, br_dev);
+
+       netdev_for_each_upper_dev_rcu(lag_dev, upper_dev, iter) {
+               if (!netif_is_bridge_port(upper_dev))
+                       continue;
+               br_dev = netdev_master_upper_dev_get(upper_dev);
+               mlxsw_sp_port_bridge_leave(mlxsw_sp_port, upper_dev, br_dev);
+       }
+}
+
 static int mlxsw_sp_lag_create(struct mlxsw_sp *mlxsw_sp, u16 lag_id)
 {
        char sldr_pl[MLXSW_REG_SLDR_LEN];
 
        /* Any VLANs configured on the port are no longer valid */
        mlxsw_sp_port_vlan_flush(mlxsw_sp_port);
+       /* Make the LAG and its directly linked uppers leave bridges they
+        * are memeber in
+        */
+       mlxsw_sp_port_lag_uppers_cleanup(mlxsw_sp_port, lag_dev);
 
        if (lag->ref_count == 1)
                mlxsw_sp_lag_destroy(mlxsw_sp, lag_id);
 
        kfree(bridge_port);
 }
 
-static bool
-mlxsw_sp_bridge_port_should_destroy(const struct mlxsw_sp_bridge_port *
-                                   bridge_port)
-{
-       struct net_device *dev = bridge_port->dev;
-       struct mlxsw_sp *mlxsw_sp;
-
-       if (is_vlan_dev(dev))
-               mlxsw_sp = mlxsw_sp_lower_get(vlan_dev_real_dev(dev));
-       else
-               mlxsw_sp = mlxsw_sp_lower_get(dev);
-
-       /* In case ports were pulled from out of a bridged LAG, then
-        * it's possible the reference count isn't zero, yet the bridge
-        * port should be destroyed, as it's no longer an upper of ours.
-        */
-       if (!mlxsw_sp && list_empty(&bridge_port->vlans_list))
-               return true;
-       else if (bridge_port->ref_count == 0)
-               return true;
-       else
-               return false;
-}
-
 static struct mlxsw_sp_bridge_port *
 mlxsw_sp_bridge_port_get(struct mlxsw_sp_bridge *bridge,
                         struct net_device *brport_dev)
 {
        struct mlxsw_sp_bridge_device *bridge_device;
 
-       bridge_port->ref_count--;
-       if (!mlxsw_sp_bridge_port_should_destroy(bridge_port))
+       if (--bridge_port->ref_count != 0)
                return;
        bridge_device = bridge_port->bridge_device;
        mlxsw_sp_bridge_port_destroy(bridge_port);