struct mlx5e_tc_update_priv *tc_priv)
 {
 #if IS_ENABLED(CONFIG_NET_TC_SKB_EXT)
-       u32 chain = 0, reg_c0, reg_c1, tunnel_id, zone_restore_id;
+       u32 reg_c0, reg_c1, tunnel_id, zone_restore_id;
        struct mlx5_rep_uplink_priv *uplink_priv;
        struct mlx5e_rep_priv *uplink_rpriv;
+       struct mlx5_mapped_obj mapped_obj;
        struct tc_skb_ext *tc_skb_ext;
        struct mlx5_eswitch *esw;
        struct mlx5e_priv *priv;
        priv = netdev_priv(skb->dev);
        esw = priv->mdev->priv.eswitch;
 
-       err = mlx5_get_chain_for_tag(esw_chains(esw), reg_c0, &chain);
+       err = mlx5_get_mapped_object(esw_chains(esw), reg_c0, &mapped_obj);
        if (err) {
                netdev_dbg(priv->netdev,
-                          "Couldn't find chain for chain tag: %d, err: %d\n",
+                          "Couldn't find mapped object for reg_c0: %d, err: %d\n",
                           reg_c0, err);
                return false;
        }
 
-       if (chain) {
-               tc_skb_ext = skb_ext_add(skb, TC_SKB_EXT);
-               if (!tc_skb_ext) {
-                       WARN_ON(1);
-                       return false;
-               }
+       if (mapped_obj.type == MLX5_MAPPED_OBJ_CHAIN) {
+               if (mapped_obj.chain) {
+                       tc_skb_ext = skb_ext_add(skb, TC_SKB_EXT);
+                       if (!tc_skb_ext) {
+                               WARN_ON(1);
+                               return false;
+                       }
 
-               tc_skb_ext->chain = chain;
+                       tc_skb_ext->chain = mapped_obj.chain;
 
-               zone_restore_id = reg_c1 & ESW_ZONE_ID_MASK;
+                       zone_restore_id = reg_c1 & ESW_ZONE_ID_MASK;
 
-               uplink_rpriv = mlx5_eswitch_get_uplink_priv(esw, REP_ETH);
-               uplink_priv = &uplink_rpriv->uplink_priv;
-               if (!mlx5e_tc_ct_restore_flow(uplink_priv->ct_priv, skb,
-                                             zone_restore_id))
-                       return false;
+                       uplink_rpriv = mlx5_eswitch_get_uplink_priv(esw, REP_ETH);
+                       uplink_priv = &uplink_rpriv->uplink_priv;
+                       if (!mlx5e_tc_ct_restore_flow(uplink_priv->ct_priv, skb,
+                                                     zone_restore_id))
+                               return false;
+               }
+       } else {
+               netdev_dbg(priv->netdev, "Invalid mapped object type: %d\n", mapped_obj.type);
+               return false;
        }
 
        tunnel_id = reg_c1 >> ESW_TUN_OFFSET;
 
        u32 chain = 0, chain_tag, reg_b, zone_restore_id;
        struct mlx5e_priv *priv = netdev_priv(skb->dev);
        struct mlx5e_tc_table *tc = &priv->fs.tc;
+       struct mlx5_mapped_obj mapped_obj;
        struct tc_skb_ext *tc_skb_ext;
        int err;
 
 
        chain_tag = reg_b & MLX5E_TC_TABLE_CHAIN_TAG_MASK;
 
-       err = mlx5_get_chain_for_tag(nic_chains(priv), chain_tag, &chain);
+       err = mlx5_get_mapped_object(nic_chains(priv), chain_tag, &mapped_obj);
        if (err) {
                netdev_dbg(priv->netdev,
                           "Couldn't find chain for chain tag: %d, err: %d\n",
                return false;
        }
 
-       if (chain) {
+       if (mapped_obj.type == MLX5_MAPPED_OBJ_CHAIN) {
+               chain = mapped_obj.chain;
                tc_skb_ext = skb_ext_add(skb, TC_SKB_EXT);
                if (WARN_ON(!tc_skb_ext))
                        return false;
                if (!mlx5e_tc_ct_restore_flow(tc->ct, skb,
                                              zone_restore_id))
                        return false;
+       } else {
+               netdev_dbg(priv->netdev, "Invalid mapped object type: %d\n", mapped_obj.type);
+               return false;
        }
 #endif /* CONFIG_NET_TC_SKB_EXT */
 
 
 #include "sf/sf.h"
 #include "en/tc_ct.h"
 
+enum mlx5_mapped_obj_type {
+       MLX5_MAPPED_OBJ_CHAIN,
+};
+
+struct mlx5_mapped_obj {
+       enum mlx5_mapped_obj_type type;
+       union {
+               u32 chain;
+       };
+};
+
 #ifdef CONFIG_MLX5_ESWITCH
 
 #define ESW_OFFLOADS_DEFAULT_NUM_GROUPS 15
 
 struct mlx5_flow_handle *
 esw_add_restore_rule(struct mlx5_eswitch *esw, u32 tag);
-u32
-esw_get_max_restore_tag(struct mlx5_eswitch *esw);
 
 int esw_offloads_load_rep(struct mlx5_eswitch *esw, u16 vport_num);
 void esw_offloads_unload_rep(struct mlx5_eswitch *esw, u16 vport_num);
 
        misc = MLX5_ADDR_OF(fte_match_param, spec->match_criteria,
                            misc_parameters_2);
        MLX5_SET(fte_match_set_misc2, misc, metadata_reg_c_0,
-                ESW_CHAIN_TAG_METADATA_MASK);
+                ESW_REG_C0_USER_DATA_METADATA_MASK);
        misc = MLX5_ADDR_OF(fte_match_param, spec->match_value,
                            misc_parameters_2);
        MLX5_SET(fte_match_set_misc2, misc, metadata_reg_c_0, tag);
        return flow_rule;
 }
 
-u32
-esw_get_max_restore_tag(struct mlx5_eswitch *esw)
-{
-       return ESW_CHAIN_TAG_METADATA_MASK;
-}
-
 #define MAX_PF_SQ 256
 #define MAX_SQ_NVPORTS 32
 
        attr.max_ft_sz = fdb_max;
        attr.max_grp_num = esw->params.large_group_num;
        attr.default_ft = miss_fdb;
-       attr.max_restore_tag = esw_get_max_restore_tag(esw);
+       attr.max_restore_tag = ESW_REG_C0_USER_DATA_METADATA_MASK;
 
        chains = mlx5_chains_create(dev, &attr);
        if (IS_ERR(chains)) {
                goto out_free;
        }
 
-       ft_attr.max_fte = 1 << ESW_CHAIN_TAG_METADATA_BITS;
+       ft_attr.max_fte = 1 << ESW_REG_C0_USER_DATA_METADATA_BITS;
        ft = mlx5_create_flow_table(ns, &ft_attr);
        if (IS_ERR(ft)) {
                err = PTR_ERR(ft);
                            misc_parameters_2);
 
        MLX5_SET(fte_match_set_misc2, misc, metadata_reg_c_0,
-                ESW_CHAIN_TAG_METADATA_MASK);
+                ESW_REG_C0_USER_DATA_METADATA_MASK);
        MLX5_SET(create_flow_group_in, flow_group_in, start_flow_index, 0);
        MLX5_SET(create_flow_group_in, flow_group_in, end_flow_index,
                 ft_attr.max_fte - 1);
 
        if (err)
                goto init_prios_ht_err;
 
-       mapping = mapping_create(sizeof(u32), attr->max_restore_tag,
-                                true);
+       mapping = mapping_create(sizeof(struct mlx5_mapped_obj), attr->max_restore_tag, true);
        if (IS_ERR(mapping)) {
                err = PTR_ERR(mapping);
                goto mapping_err;
 mlx5_chains_get_chain_mapping(struct mlx5_fs_chains *chains, u32 chain,
                              u32 *chain_mapping)
 {
-       return mapping_add(chains_mapping(chains), &chain, chain_mapping);
+       struct mapping_ctx *ctx = chains->chains_mapping;
+       struct mlx5_mapped_obj mapped_obj = {};
+
+       mapped_obj.type = MLX5_MAPPED_OBJ_CHAIN;
+       mapped_obj.chain = chain;
+       return mapping_add(ctx, &mapped_obj, chain_mapping);
 }
 
 int
 mlx5_chains_put_chain_mapping(struct mlx5_fs_chains *chains, u32 chain_mapping)
 {
-       return mapping_remove(chains_mapping(chains), chain_mapping);
+       struct mapping_ctx *ctx = chains->chains_mapping;
+
+       return mapping_remove(ctx, chain_mapping);
 }
 
-int mlx5_get_chain_for_tag(struct mlx5_fs_chains *chains, u32 tag,
-                          u32 *chain)
+int
+mlx5_get_mapped_object(struct mlx5_fs_chains *chains, u32 tag, struct mlx5_mapped_obj *obj)
 {
        int err;
 
-       err = mapping_find(chains_mapping(chains), tag, chain);
+       err = mapping_find(chains->chains_mapping, tag, obj);
        if (err) {
                mlx5_core_warn(chains->dev, "Can't find chain for tag: %d\n", tag);
                return -ENOENT;
 
 #include <linux/mlx5/fs.h>
 
 struct mlx5_fs_chains;
+struct mlx5_mapped_obj;
 
 enum mlx5_chains_flags {
        MLX5_CHAINS_AND_PRIOS_SUPPORTED = BIT(0),
 void mlx5_chains_destroy(struct mlx5_fs_chains *chains);
 
 int
-mlx5_get_chain_for_tag(struct mlx5_fs_chains *chains, u32 tag, u32 *chain);
+mlx5_get_mapped_object(struct mlx5_fs_chains *chains, u32 tag, struct mlx5_mapped_obj *obj);
 
 void
 mlx5_chains_set_end_ft(struct mlx5_fs_chains *chains,
 
 bool mlx5_eswitch_vport_match_metadata_enabled(const struct mlx5_eswitch *esw);
 
 /* Reg C0 usage:
- * Reg C0 = < ESW_PFNUM_BITS(4) | ESW_VPORT BITS(12) | ESW_CHAIN_TAG(16) >
+ * Reg C0 = < ESW_PFNUM_BITS(4) | ESW_VPORT BITS(12) | ESW_REG_C0_OBJ(16) >
  *
  * Highest 4 bits of the reg c0 is the PF_NUM (range 0-15), 12 bits of
  * unique non-zero vport id (range 1-4095). The rest (lowest 16 bits) is left
- * for tc chain tag restoration.
+ * for user data objects managed by a common mapping context.
  * PFNUM + VPORT comprise the SOURCE_PORT matching.
  */
 #define ESW_VPORT_BITS 12
 #define ESW_PFNUM_BITS 4
 #define ESW_SOURCE_PORT_METADATA_BITS (ESW_PFNUM_BITS + ESW_VPORT_BITS)
 #define ESW_SOURCE_PORT_METADATA_OFFSET (32 - ESW_SOURCE_PORT_METADATA_BITS)
-#define ESW_CHAIN_TAG_METADATA_BITS (32 - ESW_SOURCE_PORT_METADATA_BITS)
-#define ESW_CHAIN_TAG_METADATA_MASK GENMASK(ESW_CHAIN_TAG_METADATA_BITS - 1,\
-                                           0)
+#define ESW_REG_C0_USER_DATA_METADATA_BITS (32 - ESW_SOURCE_PORT_METADATA_BITS)
+#define ESW_REG_C0_USER_DATA_METADATA_MASK GENMASK(ESW_REG_C0_USER_DATA_METADATA_BITS - 1, 0)
 
 static inline u32 mlx5_eswitch_get_vport_metadata_mask(void)
 {