mask = GENMASK(ocelot->num_phys_ports - 1, 0);
                        mask &= ~cpu_fwd_mask;
                } else if (ocelot->bridge_fwd_mask & BIT(port)) {
-                       int lag;
+                       struct net_device *bond = ocelot_port->bond;
 
                        mask = ocelot->bridge_fwd_mask & ~BIT(port);
-
-                       for (lag = 0; lag < ocelot->num_phys_ports; lag++) {
-                               unsigned long bond_mask = ocelot->lags[lag];
-
-                               if (!bond_mask)
-                                       continue;
-
-                               if (bond_mask & BIT(port)) {
-                                       mask &= ~bond_mask;
-                                       break;
-                               }
-                       }
+                       if (bond)
+                               mask &= ~ocelot_get_bond_mask(ocelot, bond);
                } else {
                        /* Standalone ports forward only to DSA tag_8021q CPU
                         * ports (if those exist), or to the hardware CPU port
 
 static void ocelot_set_aggr_pgids(struct ocelot *ocelot)
 {
+       unsigned long visited = GENMASK(ocelot->num_phys_ports - 1, 0);
        int i, port, lag;
 
        /* Reset destination and aggregation PGIDS */
                ocelot_write_rix(ocelot, GENMASK(ocelot->num_phys_ports - 1, 0),
                                 ANA_PGID_PGID, i);
 
-       /* Now, set PGIDs for each LAG */
+       /* The visited ports bitmask holds the list of ports offloading any
+        * bonding interface. Initially we mark all these ports as unvisited,
+        * then every time we visit a port in this bitmask, we know that it is
+        * the lowest numbered port, i.e. the one whose logical ID == physical
+        * port ID == LAG ID. So we mark as visited all further ports in the
+        * bitmask that are offloading the same bonding interface. This way,
+        * we set up the aggregation PGIDs only once per bonding interface.
+        */
+       for (port = 0; port < ocelot->num_phys_ports; port++) {
+               struct ocelot_port *ocelot_port = ocelot->ports[port];
+
+               if (!ocelot_port || !ocelot_port->bond)
+                       continue;
+
+               visited &= ~BIT(port);
+       }
+
+       /* Now, set PGIDs for each active LAG */
        for (lag = 0; lag < ocelot->num_phys_ports; lag++) {
+               struct net_device *bond = ocelot->ports[lag]->bond;
                unsigned long bond_mask;
                int aggr_count = 0;
                u8 aggr_idx[16];
 
-               bond_mask = ocelot->lags[lag];
-               if (!bond_mask)
+               if (!bond || (visited & BIT(lag)))
                        continue;
 
+               bond_mask = ocelot_get_bond_mask(ocelot, bond);
+
                for_each_set_bit(port, &bond_mask, ocelot->num_phys_ports) {
                        // Destination mask
                        ocelot_write_rix(ocelot, bond_mask,
                        ac |= BIT(aggr_idx[i % aggr_count]);
                        ocelot_write_rix(ocelot, ac, ANA_PGID_PGID, i);
                }
+
+               /* Mark all ports in the same LAG as visited to avoid applying
+                * the same config again.
+                */
+               for (port = lag; port < ocelot->num_phys_ports; port++) {
+                       struct ocelot_port *ocelot_port = ocelot->ports[port];
+
+                       if (!ocelot_port)
+                               continue;
+
+                       if (ocelot_port->bond == bond)
+                               visited |= BIT(port);
+               }
        }
 }
 
                         struct net_device *bond,
                         struct netdev_lag_upper_info *info)
 {
-       u32 bond_mask = 0;
-       int lag;
-
        if (info->tx_type != NETDEV_LAG_TX_TYPE_HASH)
                return -EOPNOTSUPP;
 
        ocelot->ports[port]->bond = bond;
 
-       bond_mask = ocelot_get_bond_mask(ocelot, bond);
-
-       lag = __ffs(bond_mask);
-
-       /* If the new port is the lowest one, use it as the logical port from
-        * now on
-        */
-       if (port == lag) {
-               ocelot->lags[port] = bond_mask;
-               bond_mask &= ~BIT(port);
-               if (bond_mask)
-                       ocelot->lags[__ffs(bond_mask)] = 0;
-       } else {
-               ocelot->lags[lag] |= BIT(port);
-       }
-
        ocelot_setup_logical_port_ids(ocelot);
        ocelot_apply_bridge_fwd_mask(ocelot);
        ocelot_set_aggr_pgids(ocelot);
 void ocelot_port_lag_leave(struct ocelot *ocelot, int port,
                           struct net_device *bond)
 {
-       int i;
-
        ocelot->ports[port]->bond = NULL;
 
-       /* Remove port from any lag */
-       for (i = 0; i < ocelot->num_phys_ports; i++)
-               ocelot->lags[i] &= ~BIT(port);
-
-       /* if it was the logical port of the lag, move the lag config to the
-        * next port
-        */
-       if (ocelot->lags[port]) {
-               int n = __ffs(ocelot->lags[port]);
-
-               ocelot->lags[n] = ocelot->lags[port];
-               ocelot->lags[port] = 0;
-       }
-
        ocelot_setup_logical_port_ids(ocelot);
        ocelot_apply_bridge_fwd_mask(ocelot);
        ocelot_set_aggr_pgids(ocelot);
                }
        }
 
-       ocelot->lags = devm_kcalloc(ocelot->dev, ocelot->num_phys_ports,
-                                   sizeof(u32), GFP_KERNEL);
-       if (!ocelot->lags)
-               return -ENOMEM;
-
        ocelot->stats = devm_kcalloc(ocelot->dev,
                                     ocelot->num_phys_ports * ocelot->num_stats,
                                     sizeof(u64), GFP_KERNEL);