#define mlxsw_sp_prefix_usage_for_each(prefix, prefix_usage) \
        for_each_set_bit(prefix, (prefix_usage)->b, MLXSW_SP_PREFIX_COUNT)
 
-static bool
-mlxsw_sp_prefix_usage_subset(struct mlxsw_sp_prefix_usage *prefix_usage1,
-                            struct mlxsw_sp_prefix_usage *prefix_usage2)
-{
-       unsigned char prefix;
-
-       mlxsw_sp_prefix_usage_for_each(prefix, prefix_usage1) {
-               if (!test_bit(prefix, prefix_usage2->b))
-                       return false;
-       }
-       return true;
-}
-
 static bool
 mlxsw_sp_prefix_usage_eq(struct mlxsw_sp_prefix_usage *prefix_usage1,
                         struct mlxsw_sp_prefix_usage *prefix_usage2)
                    lpm_tree->proto == proto &&
                    mlxsw_sp_prefix_usage_eq(&lpm_tree->prefix_usage,
                                             prefix_usage))
-                       goto inc_ref_count;
+                       return lpm_tree;
        }
-       lpm_tree = mlxsw_sp_lpm_tree_create(mlxsw_sp, prefix_usage,
-                                           proto);
-       if (IS_ERR(lpm_tree))
-               return lpm_tree;
+       return mlxsw_sp_lpm_tree_create(mlxsw_sp, prefix_usage, proto);
+}
 
-inc_ref_count:
+static void mlxsw_sp_lpm_tree_hold(struct mlxsw_sp_lpm_tree *lpm_tree)
+{
        lpm_tree->ref_count++;
-       return lpm_tree;
 }
 
 static void mlxsw_sp_lpm_tree_put(struct mlxsw_sp *mlxsw_sp,
        vr->fib4 = NULL;
 }
 
-static int
-mlxsw_sp_vr_lpm_tree_check(struct mlxsw_sp *mlxsw_sp, struct mlxsw_sp_fib *fib,
-                          struct mlxsw_sp_prefix_usage *req_prefix_usage)
-{
-       struct mlxsw_sp_lpm_tree *lpm_tree = fib->lpm_tree;
-       struct mlxsw_sp_lpm_tree *new_tree;
-       int err;
-
-       if (mlxsw_sp_prefix_usage_eq(req_prefix_usage, &lpm_tree->prefix_usage))
-               return 0;
-
-       new_tree = mlxsw_sp_lpm_tree_get(mlxsw_sp, req_prefix_usage,
-                                        fib->proto);
-       if (IS_ERR(new_tree)) {
-               /* We failed to get a tree according to the required
-                * prefix usage. However, the current tree might be still good
-                * for us if our requirement is subset of the prefixes used
-                * in the tree.
-                */
-               if (mlxsw_sp_prefix_usage_subset(req_prefix_usage,
-                                                &lpm_tree->prefix_usage))
-                       return 0;
-               return PTR_ERR(new_tree);
-       }
-
-       /* Prevent packet loss by overwriting existing binding */
-       fib->lpm_tree = new_tree;
-       err = mlxsw_sp_vr_lpm_tree_bind(mlxsw_sp, fib, new_tree->id);
-       if (err)
-               goto err_tree_bind;
-       mlxsw_sp_lpm_tree_put(mlxsw_sp, lpm_tree);
-
-       return 0;
-
-err_tree_bind:
-       fib->lpm_tree = lpm_tree;
-       mlxsw_sp_lpm_tree_put(mlxsw_sp, new_tree);
-       return err;
-}
-
 static struct mlxsw_sp_vr *mlxsw_sp_vr_get(struct mlxsw_sp *mlxsw_sp, u32 tb_id)
 {
        struct mlxsw_sp_vr *vr;
                mlxsw_sp_vr_destroy(vr);
 }
 
+static bool
+mlxsw_sp_vr_lpm_tree_should_replace(struct mlxsw_sp_vr *vr,
+                                   enum mlxsw_sp_l3proto proto, u8 tree_id)
+{
+       struct mlxsw_sp_fib *fib = mlxsw_sp_vr_fib(vr, proto);
+
+       if (!mlxsw_sp_vr_is_used(vr))
+               return false;
+       if (fib->lpm_tree && fib->lpm_tree->id == tree_id)
+               return true;
+       return false;
+}
+
+static int mlxsw_sp_vr_lpm_tree_replace(struct mlxsw_sp *mlxsw_sp,
+                                       struct mlxsw_sp_fib *fib,
+                                       struct mlxsw_sp_lpm_tree *new_tree)
+{
+       struct mlxsw_sp_lpm_tree *old_tree = fib->lpm_tree;
+       int err;
+
+       err = mlxsw_sp_vr_lpm_tree_bind(mlxsw_sp, fib, new_tree->id);
+       if (err)
+               return err;
+       fib->lpm_tree = new_tree;
+       mlxsw_sp_lpm_tree_hold(new_tree);
+       mlxsw_sp_lpm_tree_put(mlxsw_sp, old_tree);
+       return 0;
+}
+
+static int mlxsw_sp_vrs_lpm_tree_replace(struct mlxsw_sp *mlxsw_sp,
+                                        struct mlxsw_sp_fib *fib,
+                                        struct mlxsw_sp_lpm_tree *new_tree)
+{
+       struct mlxsw_sp_lpm_tree *old_tree = fib->lpm_tree;
+       enum mlxsw_sp_l3proto proto = fib->proto;
+       u8 old_id, new_id = new_tree->id;
+       struct mlxsw_sp_vr *vr;
+       int i, err;
+
+       if (!old_tree)
+               goto no_replace;
+       old_id = old_tree->id;
+
+       for (i = 0; i < MLXSW_CORE_RES_GET(mlxsw_sp->core, MAX_VRS); i++) {
+               vr = &mlxsw_sp->router->vrs[i];
+               if (!mlxsw_sp_vr_lpm_tree_should_replace(vr, proto, old_id))
+                       continue;
+               err = mlxsw_sp_vr_lpm_tree_replace(mlxsw_sp,
+                                                  mlxsw_sp_vr_fib(vr, proto),
+                                                  new_tree);
+               if (err)
+                       goto err_tree_replace;
+       }
+
+       return 0;
+
+err_tree_replace:
+       for (i--; i >= 0; i--) {
+               if (!mlxsw_sp_vr_lpm_tree_should_replace(vr, proto, new_id))
+                       continue;
+               mlxsw_sp_vr_lpm_tree_replace(mlxsw_sp,
+                                            mlxsw_sp_vr_fib(vr, proto),
+                                            old_tree);
+       }
+       return err;
+
+no_replace:
+       err = mlxsw_sp_vr_lpm_tree_bind(mlxsw_sp, fib, new_tree->id);
+       if (err)
+               return err;
+       fib->lpm_tree = new_tree;
+       mlxsw_sp_lpm_tree_hold(new_tree);
+       return 0;
+}
+
+static void
+mlxsw_sp_vrs_prefixes(struct mlxsw_sp *mlxsw_sp,
+                     enum mlxsw_sp_l3proto proto,
+                     struct mlxsw_sp_prefix_usage *req_prefix_usage)
+{
+       int i;
+
+       for (i = 0; i < MLXSW_CORE_RES_GET(mlxsw_sp->core, MAX_VRS); i++) {
+               struct mlxsw_sp_vr *vr = &mlxsw_sp->router->vrs[i];
+               struct mlxsw_sp_fib *fib = mlxsw_sp_vr_fib(vr, proto);
+               unsigned char prefix;
+
+               if (!mlxsw_sp_vr_is_used(vr))
+                       continue;
+               mlxsw_sp_prefix_usage_for_each(prefix, &fib->prefix_usage)
+                       mlxsw_sp_prefix_usage_set(req_prefix_usage, prefix);
+       }
+}
+
 static int mlxsw_sp_vrs_init(struct mlxsw_sp *mlxsw_sp)
 {
        struct mlxsw_sp_vr *vr;
                                struct mlxsw_sp_fib_entry, list) == fib_entry;
 }
 
+static int mlxsw_sp_fib_lpm_tree_link(struct mlxsw_sp *mlxsw_sp,
+                                     struct mlxsw_sp_fib *fib,
+                                     struct mlxsw_sp_fib_node *fib_node)
+{
+       struct mlxsw_sp_prefix_usage req_prefix_usage = {{ 0 } };
+       struct mlxsw_sp_lpm_tree *lpm_tree;
+       int err;
+
+       /* Since the tree is shared between all virtual routers we must
+        * make sure it contains all the required prefix lengths. This
+        * can be computed by either adding the new prefix length to the
+        * existing prefix usage of a bound tree, or by aggregating the
+        * prefix lengths across all virtual routers and adding the new
+        * one as well.
+        */
+       if (fib->lpm_tree)
+               mlxsw_sp_prefix_usage_cpy(&req_prefix_usage,
+                                         &fib->lpm_tree->prefix_usage);
+       else
+               mlxsw_sp_vrs_prefixes(mlxsw_sp, fib->proto, &req_prefix_usage);
+       mlxsw_sp_prefix_usage_set(&req_prefix_usage, fib_node->key.prefix_len);
+
+       lpm_tree = mlxsw_sp_lpm_tree_get(mlxsw_sp, &req_prefix_usage,
+                                        fib->proto);
+       if (IS_ERR(lpm_tree))
+               return PTR_ERR(lpm_tree);
+
+       if (fib->lpm_tree && fib->lpm_tree->id == lpm_tree->id)
+               return 0;
+
+       err = mlxsw_sp_vrs_lpm_tree_replace(mlxsw_sp, fib, lpm_tree);
+       if (err)
+               return err;
+
+       return 0;
+}
+
+static void mlxsw_sp_fib_lpm_tree_unlink(struct mlxsw_sp *mlxsw_sp,
+                                        struct mlxsw_sp_fib *fib)
+{
+       struct mlxsw_sp_prefix_usage req_prefix_usage = {{ 0 } };
+       struct mlxsw_sp_lpm_tree *lpm_tree;
+
+       /* Aggregate prefix lengths across all virtual routers to make
+        * sure we only have used prefix lengths in the LPM tree.
+        */
+       mlxsw_sp_vrs_prefixes(mlxsw_sp, fib->proto, &req_prefix_usage);
+       lpm_tree = mlxsw_sp_lpm_tree_get(mlxsw_sp, &req_prefix_usage,
+                                        fib->proto);
+       if (IS_ERR(lpm_tree))
+               goto err_tree_get;
+       mlxsw_sp_vrs_lpm_tree_replace(mlxsw_sp, fib, lpm_tree);
+
+err_tree_get:
+       if (!mlxsw_sp_prefix_usage_none(&fib->prefix_usage))
+               return;
+       mlxsw_sp_vr_lpm_tree_unbind(mlxsw_sp, fib);
+       mlxsw_sp_lpm_tree_put(mlxsw_sp, fib->lpm_tree);
+       fib->lpm_tree = NULL;
+}
+
 static void mlxsw_sp_fib_node_prefix_inc(struct mlxsw_sp_fib_node *fib_node)
 {
        unsigned char prefix_len = fib_node->key.prefix_len;
                                  struct mlxsw_sp_fib_node *fib_node,
                                  struct mlxsw_sp_fib *fib)
 {
-       struct mlxsw_sp_prefix_usage req_prefix_usage;
-       struct mlxsw_sp_lpm_tree *lpm_tree;
        int err;
 
        err = mlxsw_sp_fib_node_insert(fib, fib_node);
                return err;
        fib_node->fib = fib;
 
-       mlxsw_sp_prefix_usage_cpy(&req_prefix_usage, &fib->prefix_usage);
-       mlxsw_sp_prefix_usage_set(&req_prefix_usage, fib_node->key.prefix_len);
-
-       if (!mlxsw_sp_prefix_usage_none(&fib->prefix_usage)) {
-               err = mlxsw_sp_vr_lpm_tree_check(mlxsw_sp, fib,
-                                                &req_prefix_usage);
-               if (err)
-                       goto err_tree_check;
-       } else {
-               lpm_tree = mlxsw_sp_lpm_tree_get(mlxsw_sp, &req_prefix_usage,
-                                                fib->proto);
-               if (IS_ERR(lpm_tree))
-                       return PTR_ERR(lpm_tree);
-               fib->lpm_tree = lpm_tree;
-               err = mlxsw_sp_vr_lpm_tree_bind(mlxsw_sp, fib, lpm_tree->id);
-               if (err)
-                       goto err_tree_bind;
-       }
+       err = mlxsw_sp_fib_lpm_tree_link(mlxsw_sp, fib, fib_node);
+       if (err)
+               goto err_fib_lpm_tree_link;
 
        mlxsw_sp_fib_node_prefix_inc(fib_node);
 
        return 0;
 
-err_tree_bind:
-       fib->lpm_tree = NULL;
-       mlxsw_sp_lpm_tree_put(mlxsw_sp, lpm_tree);
-err_tree_check:
+err_fib_lpm_tree_link:
        fib_node->fib = NULL;
        mlxsw_sp_fib_node_remove(fib, fib_node);
        return err;
 static void mlxsw_sp_fib_node_fini(struct mlxsw_sp *mlxsw_sp,
                                   struct mlxsw_sp_fib_node *fib_node)
 {
-       struct mlxsw_sp_lpm_tree *lpm_tree = fib_node->fib->lpm_tree;
        struct mlxsw_sp_fib *fib = fib_node->fib;
 
        mlxsw_sp_fib_node_prefix_dec(fib_node);
-
-       if (mlxsw_sp_prefix_usage_none(&fib->prefix_usage)) {
-               mlxsw_sp_vr_lpm_tree_unbind(mlxsw_sp, fib);
-               fib->lpm_tree = NULL;
-               mlxsw_sp_lpm_tree_put(mlxsw_sp, lpm_tree);
-       } else {
-               mlxsw_sp_vr_lpm_tree_check(mlxsw_sp, fib, &fib->prefix_usage);
-       }
-
+       mlxsw_sp_fib_lpm_tree_unlink(mlxsw_sp, fib);
        fib_node->fib = NULL;
        mlxsw_sp_fib_node_remove(fib, fib_node);
 }