// SPDX-License-Identifier: GPL-2.0 OR Linux-OpenIB
 /* Copyright (c) 2021 Mellanox Technologies. */
 
+#include <linux/skbuff.h>
+#include <net/psample.h>
 #include "esw/sample.h"
 #include "eswitch.h"
 #include "en_tc.h"
        struct mlx5_flow_handle *termtbl_rule;
        DECLARE_HASHTABLE(hashtbl, 8);
        struct mutex ht_lock; /* protect hashtbl */
+       DECLARE_HASHTABLE(restore_hashtbl, 8);
+       struct mutex restore_lock; /* protect restore_hashtbl */
 };
 
 struct mlx5_sampler {
 
 struct mlx5_sample_flow {
        struct mlx5_sampler *sampler;
+       struct mlx5_sample_restore *restore;
+};
+
+struct mlx5_sample_restore {
+       struct hlist_node hlist;
+       struct mlx5_modify_hdr *modify_hdr;
+       struct mlx5_flow_handle *rule;
+       u32 obj_id;
+       int count;
 };
 
 static int
        mutex_unlock(&esw_psample->ht_lock);
 }
 
+static struct mlx5_modify_hdr *
+sample_metadata_rule_get(struct mlx5_core_dev *mdev, u32 obj_id)
+{
+       struct mlx5e_tc_mod_hdr_acts mod_acts = {};
+       struct mlx5_modify_hdr *modify_hdr;
+       int err;
+
+       err = mlx5e_tc_match_to_reg_set(mdev, &mod_acts, MLX5_FLOW_NAMESPACE_FDB,
+                                       CHAIN_TO_REG, obj_id);
+       if (err)
+               goto err_set_regc0;
+
+       modify_hdr = mlx5_modify_header_alloc(mdev, MLX5_FLOW_NAMESPACE_FDB,
+                                             mod_acts.num_actions,
+                                             mod_acts.actions);
+       if (IS_ERR(modify_hdr)) {
+               err = PTR_ERR(modify_hdr);
+               goto err_modify_hdr;
+       }
+
+       dealloc_mod_hdr_actions(&mod_acts);
+       return modify_hdr;
+
+err_modify_hdr:
+       dealloc_mod_hdr_actions(&mod_acts);
+err_set_regc0:
+       return ERR_PTR(err);
+}
+
+static struct mlx5_sample_restore *
+sample_restore_get(struct mlx5_esw_psample *esw_psample, u32 obj_id)
+{
+       struct mlx5_core_dev *mdev = esw_psample->priv->mdev;
+       struct mlx5_eswitch *esw = mdev->priv.eswitch;
+       struct mlx5_sample_restore *restore;
+       struct mlx5_modify_hdr *modify_hdr;
+       int err;
+
+       mutex_lock(&esw_psample->restore_lock);
+       hash_for_each_possible(esw_psample->restore_hashtbl, restore, hlist, obj_id)
+               if (restore->obj_id == obj_id)
+                       goto add_ref;
+
+       restore = kzalloc(sizeof(*restore), GFP_KERNEL);
+       if (!restore) {
+               err = -ENOMEM;
+               goto err_alloc;
+       }
+       restore->obj_id = obj_id;
+
+       modify_hdr = sample_metadata_rule_get(mdev, obj_id);
+       if (IS_ERR(modify_hdr)) {
+               err = PTR_ERR(modify_hdr);
+               goto err_modify_hdr;
+       }
+       restore->modify_hdr = modify_hdr;
+
+       restore->rule = esw_add_restore_rule(esw, obj_id);
+       if (IS_ERR(restore->rule)) {
+               err = PTR_ERR(restore->rule);
+               goto err_restore;
+       }
+
+       hash_add(esw_psample->restore_hashtbl, &restore->hlist, obj_id);
+add_ref:
+       restore->count++;
+       mutex_unlock(&esw_psample->restore_lock);
+       return restore;
+
+err_restore:
+       mlx5_modify_header_dealloc(mdev, restore->modify_hdr);
+err_modify_hdr:
+       kfree(restore);
+err_alloc:
+       mutex_unlock(&esw_psample->restore_lock);
+       return ERR_PTR(err);
+}
+
+static void
+sample_restore_put(struct mlx5_esw_psample *esw_psample, struct mlx5_sample_restore *restore)
+{
+       mutex_lock(&esw_psample->restore_lock);
+       if (--restore->count == 0)
+               hash_del(&restore->hlist);
+       mutex_unlock(&esw_psample->restore_lock);
+
+       if (!restore->count) {
+               mlx5_del_flow_rules(restore->rule);
+               mlx5_modify_header_dealloc(esw_psample->priv->mdev, restore->modify_hdr);
+               kfree(restore);
+       }
+}
+
 struct mlx5_esw_psample *
 mlx5_esw_sample_init(struct mlx5e_priv *priv)
 {
                goto err_termtbl;
 
        mutex_init(&esw_psample->ht_lock);
+       mutex_init(&esw_psample->restore_lock);
 
        return esw_psample;
 
        if (IS_ERR_OR_NULL(esw_psample))
                return;
 
+       mutex_destroy(&esw_psample->restore_lock);
        mutex_destroy(&esw_psample->ht_lock);
        sampler_termtbl_destroy(esw_psample);
        kfree(esw_psample);