#include "mlx5_core.h"
 #include "lib/fs_ttc.h"
 
-#define MLX5_TTC_NUM_GROUPS    3
-#define MLX5_TTC_GROUP1_SIZE   (BIT(3) + MLX5_NUM_TUNNEL_TT)
-#define MLX5_TTC_GROUP2_SIZE    BIT(1)
-#define MLX5_TTC_GROUP3_SIZE    BIT(0)
-#define MLX5_TTC_TABLE_SIZE    (MLX5_TTC_GROUP1_SIZE +\
-                                MLX5_TTC_GROUP2_SIZE +\
-                                MLX5_TTC_GROUP3_SIZE)
-
-#define MLX5_INNER_TTC_NUM_GROUPS      3
-#define MLX5_INNER_TTC_GROUP1_SIZE     BIT(3)
-#define MLX5_INNER_TTC_GROUP2_SIZE     BIT(1)
-#define MLX5_INNER_TTC_GROUP3_SIZE     BIT(0)
-#define MLX5_INNER_TTC_TABLE_SIZE      (MLX5_INNER_TTC_GROUP1_SIZE +\
-                                        MLX5_INNER_TTC_GROUP2_SIZE +\
-                                        MLX5_INNER_TTC_GROUP3_SIZE)
+#define MLX5_TTC_MAX_NUM_GROUPS                4
+#define MLX5_TTC_GROUP_TCPUDP_SIZE     (MLX5_TT_IPV6_UDP + 1)
+
+struct mlx5_fs_ttc_groups {
+       bool use_l4_type;
+       int num_groups;
+       int group_size[MLX5_TTC_MAX_NUM_GROUPS];
+};
+
+static int mlx5_fs_ttc_table_size(const struct mlx5_fs_ttc_groups *groups)
+{
+       int i, sz = 0;
+
+       for (i = 0; i < groups->num_groups; i++)
+               sz += groups->group_size[i];
+
+       return sz;
+}
 
 /* L3/L4 traffic type classifier */
 struct mlx5_ttc_table {
 
 };
 
+enum TTC_GROUP_TYPE {
+       TTC_GROUPS_DEFAULT = 0,
+       TTC_GROUPS_USE_L4_TYPE = 1,
+};
+
+static const struct mlx5_fs_ttc_groups ttc_groups[] = {
+       [TTC_GROUPS_DEFAULT] = {
+               .num_groups = 3,
+               .group_size = {
+                       BIT(3) + MLX5_NUM_TUNNEL_TT,
+                       BIT(1),
+                       BIT(0),
+               },
+       },
+       [TTC_GROUPS_USE_L4_TYPE] = {
+               .use_l4_type = true,
+               .num_groups = 4,
+               .group_size = {
+                       MLX5_TTC_GROUP_TCPUDP_SIZE,
+                       BIT(3) + MLX5_NUM_TUNNEL_TT - MLX5_TTC_GROUP_TCPUDP_SIZE,
+                       BIT(1),
+                       BIT(0),
+               },
+       },
+};
+
+static const struct mlx5_fs_ttc_groups inner_ttc_groups[] = {
+       [TTC_GROUPS_DEFAULT] = {
+               .num_groups = 3,
+               .group_size = {
+                       BIT(3),
+                       BIT(1),
+                       BIT(0),
+               },
+       },
+       [TTC_GROUPS_USE_L4_TYPE] = {
+               .use_l4_type = true,
+               .num_groups = 4,
+               .group_size = {
+                       MLX5_TTC_GROUP_TCPUDP_SIZE,
+                       BIT(3) - MLX5_TTC_GROUP_TCPUDP_SIZE,
+                       BIT(1),
+                       BIT(0),
+               },
+       },
+};
+
 u8 mlx5_get_proto_by_tunnel_type(enum mlx5_tunnel_types tt)
 {
        return ttc_tunnel_rules[tt].proto;
        return 0;
 }
 
+static void mlx5_fs_ttc_set_match_proto(void *headers_c, void *headers_v,
+                                       u8 proto, bool use_l4_type)
+{
+       int l4_type;
+
+       if (use_l4_type && (proto == IPPROTO_TCP || proto == IPPROTO_UDP)) {
+               if (proto == IPPROTO_TCP)
+                       l4_type = MLX5_PACKET_L4_TYPE_TCP;
+               else
+                       l4_type = MLX5_PACKET_L4_TYPE_UDP;
+
+               MLX5_SET_TO_ONES(fte_match_set_lyr_2_4, headers_c, l4_type);
+               MLX5_SET(fte_match_set_lyr_2_4, headers_v, l4_type, l4_type);
+       } else {
+               MLX5_SET_TO_ONES(fte_match_set_lyr_2_4, headers_c, ip_protocol);
+               MLX5_SET(fte_match_set_lyr_2_4, headers_v, ip_protocol, proto);
+       }
+}
+
 static struct mlx5_flow_handle *
 mlx5_generate_ttc_rule(struct mlx5_core_dev *dev, struct mlx5_flow_table *ft,
-                      struct mlx5_flow_destination *dest, u16 etype, u8 proto)
+                      struct mlx5_flow_destination *dest, u16 etype, u8 proto,
+                      bool use_l4_type)
 {
        int match_ipv_outer =
                MLX5_CAP_FLOWTABLE_NIC_RX(dev,
 
        if (proto) {
                spec->match_criteria_enable = MLX5_MATCH_OUTER_HEADERS;
-               MLX5_SET_TO_ONES(fte_match_param, spec->match_criteria, outer_headers.ip_protocol);
-               MLX5_SET(fte_match_param, spec->match_value, outer_headers.ip_protocol, proto);
+               mlx5_fs_ttc_set_match_proto(MLX5_ADDR_OF(fte_match_param,
+                                                        spec->match_criteria,
+                                                        outer_headers),
+                                           MLX5_ADDR_OF(fte_match_param,
+                                                        spec->match_value,
+                                                        outer_headers),
+                                           proto, use_l4_type);
        }
 
        ipv = mlx5_etype_to_ipv(etype);
 
 static int mlx5_generate_ttc_table_rules(struct mlx5_core_dev *dev,
                                         struct ttc_params *params,
-                                        struct mlx5_ttc_table *ttc)
+                                        struct mlx5_ttc_table *ttc,
+                                        bool use_l4_type)
 {
        struct mlx5_flow_handle **trules;
        struct mlx5_ttc_rule *rules;
                        continue;
                rule->rule = mlx5_generate_ttc_rule(dev, ft, ¶ms->dests[tt],
                                                    ttc_rules[tt].etype,
-                                                   ttc_rules[tt].proto);
+                                                   ttc_rules[tt].proto,
+                                                   use_l4_type);
                if (IS_ERR(rule->rule)) {
                        err = PTR_ERR(rule->rule);
                        rule->rule = NULL;
                trules[tt] = mlx5_generate_ttc_rule(dev, ft,
                                                    ¶ms->tunnel_dests[tt],
                                                    ttc_tunnel_rules[tt].etype,
-                                                   ttc_tunnel_rules[tt].proto);
+                                                   ttc_tunnel_rules[tt].proto,
+                                                   use_l4_type);
                if (IS_ERR(trules[tt])) {
                        err = PTR_ERR(trules[tt]);
                        trules[tt] = NULL;
 }
 
 static int mlx5_create_ttc_table_groups(struct mlx5_ttc_table *ttc,
-                                       bool use_ipv)
+                                       bool use_ipv,
+                                       const struct mlx5_fs_ttc_groups *groups)
 {
        int inlen = MLX5_ST_SZ_BYTES(create_flow_group_in);
        int ix = 0;
        int err;
        u8 *mc;
 
-       ttc->g = kcalloc(MLX5_TTC_NUM_GROUPS, sizeof(*ttc->g), GFP_KERNEL);
+       ttc->g = kcalloc(groups->num_groups, sizeof(*ttc->g), GFP_KERNEL);
        if (!ttc->g)
                return -ENOMEM;
        in = kvzalloc(inlen, GFP_KERNEL);
                return -ENOMEM;
        }
 
-       /* L4 Group */
        mc = MLX5_ADDR_OF(create_flow_group_in, in, match_criteria);
-       MLX5_SET_TO_ONES(fte_match_param, mc, outer_headers.ip_protocol);
        if (use_ipv)
                MLX5_SET_TO_ONES(fte_match_param, mc, outer_headers.ip_version);
        else
                MLX5_SET_TO_ONES(fte_match_param, mc, outer_headers.ethertype);
        MLX5_SET_CFG(in, match_criteria_enable, MLX5_MATCH_OUTER_HEADERS);
+
+       /* TCP UDP group */
+       if (groups->use_l4_type) {
+               MLX5_SET_TO_ONES(fte_match_param, mc, outer_headers.l4_type);
+               MLX5_SET_CFG(in, start_flow_index, ix);
+               ix += groups->group_size[ttc->num_groups];
+               MLX5_SET_CFG(in, end_flow_index, ix - 1);
+               ttc->g[ttc->num_groups] = mlx5_create_flow_group(ttc->t, in);
+               if (IS_ERR(ttc->g[ttc->num_groups]))
+                       goto err;
+               ttc->num_groups++;
+
+               MLX5_SET(fte_match_param, mc, outer_headers.l4_type, 0);
+       }
+
+       /* L4 Group */
+       MLX5_SET_TO_ONES(fte_match_param, mc, outer_headers.ip_protocol);
        MLX5_SET_CFG(in, start_flow_index, ix);
-       ix += MLX5_TTC_GROUP1_SIZE;
+       ix += groups->group_size[ttc->num_groups];
        MLX5_SET_CFG(in, end_flow_index, ix - 1);
        ttc->g[ttc->num_groups] = mlx5_create_flow_group(ttc->t, in);
        if (IS_ERR(ttc->g[ttc->num_groups]))
        /* L3 Group */
        MLX5_SET(fte_match_param, mc, outer_headers.ip_protocol, 0);
        MLX5_SET_CFG(in, start_flow_index, ix);
-       ix += MLX5_TTC_GROUP2_SIZE;
+       ix += groups->group_size[ttc->num_groups];
        MLX5_SET_CFG(in, end_flow_index, ix - 1);
        ttc->g[ttc->num_groups] = mlx5_create_flow_group(ttc->t, in);
        if (IS_ERR(ttc->g[ttc->num_groups]))
        /* Any Group */
        memset(in, 0, inlen);
        MLX5_SET_CFG(in, start_flow_index, ix);
-       ix += MLX5_TTC_GROUP3_SIZE;
+       ix += groups->group_size[ttc->num_groups];
        MLX5_SET_CFG(in, end_flow_index, ix - 1);
        ttc->g[ttc->num_groups] = mlx5_create_flow_group(ttc->t, in);
        if (IS_ERR(ttc->g[ttc->num_groups]))
 mlx5_generate_inner_ttc_rule(struct mlx5_core_dev *dev,
                             struct mlx5_flow_table *ft,
                             struct mlx5_flow_destination *dest,
-                            u16 etype, u8 proto)
+                            u16 etype, u8 proto, bool use_l4_type)
 {
        MLX5_DECLARE_FLOW_ACT(flow_act);
        struct mlx5_flow_handle *rule;
 
        if (proto) {
                spec->match_criteria_enable = MLX5_MATCH_INNER_HEADERS;
-               MLX5_SET_TO_ONES(fte_match_param, spec->match_criteria, inner_headers.ip_protocol);
-               MLX5_SET(fte_match_param, spec->match_value, inner_headers.ip_protocol, proto);
+               mlx5_fs_ttc_set_match_proto(MLX5_ADDR_OF(fte_match_param,
+                                                        spec->match_criteria,
+                                                        inner_headers),
+                                           MLX5_ADDR_OF(fte_match_param,
+                                                        spec->match_value,
+                                                        inner_headers),
+                                           proto, use_l4_type);
        }
 
        rule = mlx5_add_flow_rules(ft, spec, &flow_act, dest, 1);
 
 static int mlx5_generate_inner_ttc_table_rules(struct mlx5_core_dev *dev,
                                               struct ttc_params *params,
-                                              struct mlx5_ttc_table *ttc)
+                                              struct mlx5_ttc_table *ttc,
+                                              bool use_l4_type)
 {
        struct mlx5_ttc_rule *rules;
        struct mlx5_flow_table *ft;
                rule->rule = mlx5_generate_inner_ttc_rule(dev, ft,
                                                          ¶ms->dests[tt],
                                                          ttc_rules[tt].etype,
-                                                         ttc_rules[tt].proto);
+                                                         ttc_rules[tt].proto,
+                                                         use_l4_type);
                if (IS_ERR(rule->rule)) {
                        err = PTR_ERR(rule->rule);
                        rule->rule = NULL;
        return err;
 }
 
-static int mlx5_create_inner_ttc_table_groups(struct mlx5_ttc_table *ttc)
+static int mlx5_create_inner_ttc_table_groups(struct mlx5_ttc_table *ttc,
+                                             const struct mlx5_fs_ttc_groups *groups)
 {
        int inlen = MLX5_ST_SZ_BYTES(create_flow_group_in);
        int ix = 0;
        int err;
        u8 *mc;
 
-       ttc->g = kcalloc(MLX5_INNER_TTC_NUM_GROUPS, sizeof(*ttc->g),
-                        GFP_KERNEL);
+       ttc->g = kcalloc(groups->num_groups, sizeof(*ttc->g), GFP_KERNEL);
        if (!ttc->g)
                return -ENOMEM;
        in = kvzalloc(inlen, GFP_KERNEL);
                return -ENOMEM;
        }
 
-       /* L4 Group */
        mc = MLX5_ADDR_OF(create_flow_group_in, in, match_criteria);
-       MLX5_SET_TO_ONES(fte_match_param, mc, inner_headers.ip_protocol);
        MLX5_SET_TO_ONES(fte_match_param, mc, inner_headers.ip_version);
        MLX5_SET_CFG(in, match_criteria_enable, MLX5_MATCH_INNER_HEADERS);
+
+       /* TCP UDP group */
+       if (groups->use_l4_type) {
+               MLX5_SET_TO_ONES(fte_match_param, mc, inner_headers.l4_type);
+               MLX5_SET_CFG(in, start_flow_index, ix);
+               ix += groups->group_size[ttc->num_groups];
+               MLX5_SET_CFG(in, end_flow_index, ix - 1);
+               ttc->g[ttc->num_groups] = mlx5_create_flow_group(ttc->t, in);
+               if (IS_ERR(ttc->g[ttc->num_groups]))
+                       goto err;
+               ttc->num_groups++;
+
+               MLX5_SET(fte_match_param, mc, inner_headers.l4_type, 0);
+       }
+
+       /* L4 Group */
+       MLX5_SET_TO_ONES(fte_match_param, mc, inner_headers.ip_protocol);
        MLX5_SET_CFG(in, start_flow_index, ix);
-       ix += MLX5_INNER_TTC_GROUP1_SIZE;
+       ix += groups->group_size[ttc->num_groups];
        MLX5_SET_CFG(in, end_flow_index, ix - 1);
        ttc->g[ttc->num_groups] = mlx5_create_flow_group(ttc->t, in);
        if (IS_ERR(ttc->g[ttc->num_groups]))
        /* L3 Group */
        MLX5_SET(fte_match_param, mc, inner_headers.ip_protocol, 0);
        MLX5_SET_CFG(in, start_flow_index, ix);
-       ix += MLX5_INNER_TTC_GROUP2_SIZE;
+       ix += groups->group_size[ttc->num_groups];
        MLX5_SET_CFG(in, end_flow_index, ix - 1);
        ttc->g[ttc->num_groups] = mlx5_create_flow_group(ttc->t, in);
        if (IS_ERR(ttc->g[ttc->num_groups]))
        /* Any Group */
        memset(in, 0, inlen);
        MLX5_SET_CFG(in, start_flow_index, ix);
-       ix += MLX5_INNER_TTC_GROUP3_SIZE;
+       ix += groups->group_size[ttc->num_groups];
        MLX5_SET_CFG(in, end_flow_index, ix - 1);
        ttc->g[ttc->num_groups] = mlx5_create_flow_group(ttc->t, in);
        if (IS_ERR(ttc->g[ttc->num_groups]))
 struct mlx5_ttc_table *mlx5_create_inner_ttc_table(struct mlx5_core_dev *dev,
                                                   struct ttc_params *params)
 {
+       const struct mlx5_fs_ttc_groups *groups;
+       struct mlx5_flow_namespace *ns;
        struct mlx5_ttc_table *ttc;
+       bool use_l4_type;
        int err;
 
        ttc = kvzalloc(sizeof(*ttc), GFP_KERNEL);
        if (!ttc)
                return ERR_PTR(-ENOMEM);
 
+       switch (params->ns_type) {
+       case MLX5_FLOW_NAMESPACE_PORT_SEL:
+               use_l4_type = MLX5_CAP_GEN_2(dev, pcc_ifa2) &&
+                       MLX5_CAP_PORT_SELECTION_FT_FIELD_SUPPORT_2(dev, inner_l4_type);
+               break;
+       case MLX5_FLOW_NAMESPACE_KERNEL:
+               use_l4_type = MLX5_CAP_GEN_2(dev, pcc_ifa2) &&
+                       MLX5_CAP_NIC_RX_FT_FIELD_SUPPORT_2(dev, inner_l4_type);
+               break;
+       default:
+               return ERR_PTR(-EINVAL);
+       }
+
+       ns = mlx5_get_flow_namespace(dev, params->ns_type);
+       groups = use_l4_type ? &inner_ttc_groups[TTC_GROUPS_USE_L4_TYPE] :
+                              &inner_ttc_groups[TTC_GROUPS_DEFAULT];
+
        WARN_ON_ONCE(params->ft_attr.max_fte);
-       params->ft_attr.max_fte = MLX5_INNER_TTC_TABLE_SIZE;
-       ttc->t = mlx5_create_flow_table(params->ns, ¶ms->ft_attr);
+       params->ft_attr.max_fte = mlx5_fs_ttc_table_size(groups);
+       ttc->t = mlx5_create_flow_table(ns, ¶ms->ft_attr);
        if (IS_ERR(ttc->t)) {
                err = PTR_ERR(ttc->t);
                kvfree(ttc);
                return ERR_PTR(err);
        }
 
-       err = mlx5_create_inner_ttc_table_groups(ttc);
+       err = mlx5_create_inner_ttc_table_groups(ttc, groups);
        if (err)
                goto destroy_ft;
 
-       err = mlx5_generate_inner_ttc_table_rules(dev, params, ttc);
+       err = mlx5_generate_inner_ttc_table_rules(dev, params, ttc, use_l4_type);
        if (err)
                goto destroy_ft;
 
        bool match_ipv_outer =
                MLX5_CAP_FLOWTABLE_NIC_RX(dev,
                                          ft_field_support.outer_ip_version);
+       const struct mlx5_fs_ttc_groups *groups;
+       struct mlx5_flow_namespace *ns;
        struct mlx5_ttc_table *ttc;
+       bool use_l4_type;
        int err;
 
        ttc = kvzalloc(sizeof(*ttc), GFP_KERNEL);
        if (!ttc)
                return ERR_PTR(-ENOMEM);
 
+       switch (params->ns_type) {
+       case MLX5_FLOW_NAMESPACE_PORT_SEL:
+               use_l4_type = MLX5_CAP_GEN_2(dev, pcc_ifa2) &&
+                       MLX5_CAP_PORT_SELECTION_FT_FIELD_SUPPORT_2(dev, outer_l4_type);
+               break;
+       case MLX5_FLOW_NAMESPACE_KERNEL:
+               use_l4_type = MLX5_CAP_GEN_2(dev, pcc_ifa2) &&
+                       MLX5_CAP_NIC_RX_FT_FIELD_SUPPORT_2(dev, outer_l4_type);
+               break;
+       default:
+               return ERR_PTR(-EINVAL);
+       }
+
+       ns = mlx5_get_flow_namespace(dev, params->ns_type);
+       groups = use_l4_type ? &ttc_groups[TTC_GROUPS_USE_L4_TYPE] :
+                              &ttc_groups[TTC_GROUPS_DEFAULT];
+
        WARN_ON_ONCE(params->ft_attr.max_fte);
-       params->ft_attr.max_fte = MLX5_TTC_TABLE_SIZE;
-       ttc->t = mlx5_create_flow_table(params->ns, ¶ms->ft_attr);
+       params->ft_attr.max_fte = mlx5_fs_ttc_table_size(groups);
+       ttc->t = mlx5_create_flow_table(ns, ¶ms->ft_attr);
        if (IS_ERR(ttc->t)) {
                err = PTR_ERR(ttc->t);
                kvfree(ttc);
                return ERR_PTR(err);
        }
 
-       err = mlx5_create_ttc_table_groups(ttc, match_ipv_outer);
+       err = mlx5_create_ttc_table_groups(ttc, match_ipv_outer, groups);
        if (err)
                goto destroy_ft;
 
-       err = mlx5_generate_ttc_table_rules(dev, params, ttc);
+       err = mlx5_generate_ttc_table_rules(dev, params, ttc, use_l4_type);
        if (err)
                goto destroy_ft;