};
 
 struct dsa_switch;
-struct dsa_switch_tree;
 
 struct dsa_device_ops {
        struct sk_buff *(*xmit)(struct sk_buff *skb, struct net_device *dev);
        struct sk_buff *(*rcv)(struct sk_buff *skb, struct net_device *dev);
        void (*flow_dissect)(const struct sk_buff *skb, __be16 *proto,
                             int *offset);
-       int (*connect)(struct dsa_switch_tree *dst);
-       void (*disconnect)(struct dsa_switch_tree *dst);
+       int (*connect)(struct dsa_switch *ds);
+       void (*disconnect)(struct dsa_switch *ds);
        unsigned int needed_headroom;
        unsigned int needed_tailroom;
        const char *name;
 
 
 static void dsa_tree_free(struct dsa_switch_tree *dst)
 {
-       if (dst->tag_ops) {
-               if (dst->tag_ops->disconnect)
-                       dst->tag_ops->disconnect(dst);
-
+       if (dst->tag_ops)
                dsa_tag_driver_put(dst->tag_ops);
-       }
        list_del(&dst->list);
        kfree(dst);
 }
        }
 
 connect:
+       if (tag_ops->connect) {
+               err = tag_ops->connect(ds);
+               if (err)
+                       return err;
+       }
+
        if (ds->ops->connect_tag_protocol) {
                err = ds->ops->connect_tag_protocol(ds, tag_ops->proto);
                if (err) {
                        dev_err(ds->dev,
                                "Unable to connect to tag protocol \"%s\": %pe\n",
                                tag_ops->name, ERR_PTR(err));
-                       return err;
+                       goto disconnect;
                }
        }
 
        return 0;
+
+disconnect:
+       if (tag_ops->disconnect)
+               tag_ops->disconnect(ds);
+
+       return err;
 }
 
 static int dsa_switch_setup(struct dsa_switch *ds)
 
        dst->tag_ops = tag_ops;
 
-       /* Notify the new tagger about the connection to this tree */
-       if (tag_ops->connect) {
-               err = tag_ops->connect(dst);
-               if (err)
-                       goto out_revert;
-       }
-
        /* Notify the switches from this tree about the connection
         * to the new tagger
         */
                goto out_disconnect;
 
        /* Notify the old tagger about the disconnection from this tree */
-       if (old_tag_ops->disconnect)
-               old_tag_ops->disconnect(dst);
+       info.tag_ops = old_tag_ops;
+       dsa_tree_notify(dst, DSA_NOTIFIER_TAG_PROTO_DISCONNECT, &info);
 
        return 0;
 
 out_disconnect:
-       /* Revert the new tagger's connection to this tree */
-       if (tag_ops->disconnect)
-               tag_ops->disconnect(dst);
-out_revert:
+       info.tag_ops = tag_ops;
+       dsa_tree_notify(dst, DSA_NOTIFIER_TAG_PROTO_DISCONNECT, &info);
        dst->tag_ops = old_tag_ops;
 
        return err;
        struct dsa_switch_tree *dst = ds->dst;
        const struct dsa_device_ops *tag_ops;
        enum dsa_tag_protocol default_proto;
-       int err;
 
        /* Find out which protocol the switch would prefer. */
        default_proto = dsa_get_tag_protocol(dp, master);
                 */
                dsa_tag_driver_put(tag_ops);
        } else {
-               if (tag_ops->connect) {
-                       err = tag_ops->connect(dst);
-                       if (err)
-                               return err;
-               }
-
                dst->tag_ops = tag_ops;
        }
 
 
        DSA_NOTIFIER_MTU,
        DSA_NOTIFIER_TAG_PROTO,
        DSA_NOTIFIER_TAG_PROTO_CONNECT,
+       DSA_NOTIFIER_TAG_PROTO_DISCONNECT,
        DSA_NOTIFIER_MRP_ADD,
        DSA_NOTIFIER_MRP_DEL,
        DSA_NOTIFIER_MRP_ADD_RING_ROLE,
 
        return 0;
 }
 
-static int dsa_switch_connect_tag_proto(struct dsa_switch *ds,
-                                       struct dsa_notifier_tag_proto_info *info)
+/* We use the same cross-chip notifiers to inform both the tagger side, as well
+ * as the switch side, of connection and disconnection events.
+ * Since ds->tagger_data is owned by the tagger, it isn't a hard error if the
+ * switch side doesn't support connecting to this tagger, and therefore, the
+ * fact that we don't disconnect the tagger side doesn't constitute a memory
+ * leak: the tagger will still operate with persistent per-switch memory, just
+ * with the switch side unconnected to it. What does constitute a hard error is
+ * when the switch side supports connecting but fails.
+ */
+static int
+dsa_switch_connect_tag_proto(struct dsa_switch *ds,
+                            struct dsa_notifier_tag_proto_info *info)
 {
        const struct dsa_device_ops *tag_ops = info->tag_ops;
+       int err;
+
+       /* Notify the new tagger about the connection to this switch */
+       if (tag_ops->connect) {
+               err = tag_ops->connect(ds);
+               if (err)
+                       return err;
+       }
 
        if (!ds->ops->connect_tag_protocol)
                return -EOPNOTSUPP;
 
-       return ds->ops->connect_tag_protocol(ds, tag_ops->proto);
+       /* Notify the switch about the connection to the new tagger */
+       err = ds->ops->connect_tag_protocol(ds, tag_ops->proto);
+       if (err) {
+               /* Revert the new tagger's connection to this tree */
+               if (tag_ops->disconnect)
+                       tag_ops->disconnect(ds);
+               return err;
+       }
+
+       return 0;
+}
+
+static int
+dsa_switch_disconnect_tag_proto(struct dsa_switch *ds,
+                               struct dsa_notifier_tag_proto_info *info)
+{
+       const struct dsa_device_ops *tag_ops = info->tag_ops;
+
+       /* Notify the tagger about the disconnection from this switch */
+       if (tag_ops->disconnect && ds->tagger_data)
+               tag_ops->disconnect(ds);
+
+       /* No need to notify the switch, since it shouldn't have any
+        * resources to tear down
+        */
+       return 0;
 }
 
 static int dsa_switch_mrp_add(struct dsa_switch *ds,
        case DSA_NOTIFIER_TAG_PROTO_CONNECT:
                err = dsa_switch_connect_tag_proto(ds, info);
                break;
+       case DSA_NOTIFIER_TAG_PROTO_DISCONNECT:
+               err = dsa_switch_disconnect_tag_proto(ds, info);
+               break;
        case DSA_NOTIFIER_MRP_ADD:
                err = dsa_switch_mrp_add(ds, info);
                break;
 
        return skb;
 }
 
-static void ocelot_disconnect(struct dsa_switch_tree *dst)
+static void ocelot_disconnect(struct dsa_switch *ds)
 {
-       struct ocelot_8021q_tagger_private *priv;
-       struct dsa_port *dp;
-
-       list_for_each_entry(dp, &dst->ports, list) {
-               priv = dp->ds->tagger_data;
-
-               if (!priv)
-                       continue;
+       struct ocelot_8021q_tagger_private *priv = ds->tagger_data;
 
-               if (priv->xmit_worker)
-                       kthread_destroy_worker(priv->xmit_worker);
-
-               kfree(priv);
-               dp->ds->tagger_data = NULL;
-       }
+       kthread_destroy_worker(priv->xmit_worker);
+       kfree(priv);
+       ds->tagger_data = NULL;
 }
 
-static int ocelot_connect(struct dsa_switch_tree *dst)
+static int ocelot_connect(struct dsa_switch *ds)
 {
        struct ocelot_8021q_tagger_private *priv;
-       struct dsa_port *dp;
        int err;
 
-       list_for_each_entry(dp, &dst->ports, list) {
-               if (dp->ds->tagger_data)
-                       continue;
+       priv = kzalloc(sizeof(*priv), GFP_KERNEL);
+       if (!priv)
+               return -ENOMEM;
 
-               priv = kzalloc(sizeof(*priv), GFP_KERNEL);
-               if (!priv) {
-                       err = -ENOMEM;
-                       goto out;
-               }
-
-               priv->xmit_worker = kthread_create_worker(0, "felix_xmit");
-               if (IS_ERR(priv->xmit_worker)) {
-                       err = PTR_ERR(priv->xmit_worker);
-                       goto out;
-               }
-
-               dp->ds->tagger_data = priv;
+       priv->xmit_worker = kthread_create_worker(0, "felix_xmit");
+       if (IS_ERR(priv->xmit_worker)) {
+               err = PTR_ERR(priv->xmit_worker);
+               kfree(priv);
+               return err;
        }
 
-       return 0;
+       ds->tagger_data = priv;
 
-out:
-       ocelot_disconnect(dst);
-       return err;
+       return 0;
 }
 
 static const struct dsa_device_ops ocelot_8021q_netdev_ops = {
 
        *proto = ((__be16 *)skb->data)[(VLAN_HLEN / 2) - 1];
 }
 
-static void sja1105_disconnect(struct dsa_switch_tree *dst)
+static void sja1105_disconnect(struct dsa_switch *ds)
 {
-       struct sja1105_tagger_private *priv;
-       struct dsa_port *dp;
-
-       list_for_each_entry(dp, &dst->ports, list) {
-               priv = dp->ds->tagger_data;
-
-               if (!priv)
-                       continue;
+       struct sja1105_tagger_private *priv = ds->tagger_data;
 
-               if (priv->xmit_worker)
-                       kthread_destroy_worker(priv->xmit_worker);
-
-               kfree(priv);
-               dp->ds->tagger_data = NULL;
-       }
+       kthread_destroy_worker(priv->xmit_worker);
+       kfree(priv);
+       ds->tagger_data = NULL;
 }
 
-static int sja1105_connect(struct dsa_switch_tree *dst)
+static int sja1105_connect(struct dsa_switch *ds)
 {
        struct sja1105_tagger_data *tagger_data;
        struct sja1105_tagger_private *priv;
        struct kthread_worker *xmit_worker;
-       struct dsa_port *dp;
        int err;
 
-       list_for_each_entry(dp, &dst->ports, list) {
-               if (dp->ds->tagger_data)
-                       continue;
+       priv = kzalloc(sizeof(*priv), GFP_KERNEL);
+       if (!priv)
+               return -ENOMEM;
 
-               priv = kzalloc(sizeof(*priv), GFP_KERNEL);
-               if (!priv) {
-                       err = -ENOMEM;
-                       goto out;
-               }
-
-               spin_lock_init(&priv->meta_lock);
-
-               xmit_worker = kthread_create_worker(0, "dsa%d:%d_xmit",
-                                                   dst->index, dp->ds->index);
-               if (IS_ERR(xmit_worker)) {
-                       err = PTR_ERR(xmit_worker);
-                       goto out;
-               }
+       spin_lock_init(&priv->meta_lock);
 
-               priv->xmit_worker = xmit_worker;
-               /* Export functions for switch driver use */
-               tagger_data = &priv->data;
-               tagger_data->rxtstamp_get_state = sja1105_rxtstamp_get_state;
-               tagger_data->rxtstamp_set_state = sja1105_rxtstamp_set_state;
-               dp->ds->tagger_data = priv;
+       xmit_worker = kthread_create_worker(0, "dsa%d:%d_xmit",
+                                           ds->dst->index, ds->index);
+       if (IS_ERR(xmit_worker)) {
+               err = PTR_ERR(xmit_worker);
+               kfree(priv);
+               return err;
        }
 
-       return 0;
+       priv->xmit_worker = xmit_worker;
+       /* Export functions for switch driver use */
+       tagger_data = &priv->data;
+       tagger_data->rxtstamp_get_state = sja1105_rxtstamp_get_state;
+       tagger_data->rxtstamp_set_state = sja1105_rxtstamp_set_state;
+       ds->tagger_data = priv;
 
-out:
-       sja1105_disconnect(dst);
-       return err;
+       return 0;
 }
 
 static const struct dsa_device_ops sja1105_netdev_ops = {