// SPDX-License-Identifier: GPL-2.0 OR Linux-OpenIB
 /* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. */
 
+#include "lib/devcom.h"
 #include "bridge.h"
 #include "eswitch.h"
 #include "bridge_priv.h"
 
+static int mlx5_esw_bridge_port_mcast_fts_init(struct mlx5_esw_bridge_port *port,
+                                              struct mlx5_esw_bridge *bridge)
+{
+       struct mlx5_eswitch *esw = bridge->br_offloads->esw;
+       struct mlx5_flow_table *mcast_ft;
+
+       mcast_ft = mlx5_esw_bridge_table_create(MLX5_ESW_BRIDGE_MCAST_TABLE_SIZE,
+                                               MLX5_ESW_BRIDGE_LEVEL_MCAST_TABLE,
+                                               esw);
+       if (IS_ERR(mcast_ft))
+               return PTR_ERR(mcast_ft);
+
+       port->mcast.ft = mcast_ft;
+       return 0;
+}
+
+static void mlx5_esw_bridge_port_mcast_fts_cleanup(struct mlx5_esw_bridge_port *port)
+{
+       if (port->mcast.ft)
+               mlx5_destroy_flow_table(port->mcast.ft);
+       port->mcast.ft = NULL;
+}
+
+static struct mlx5_flow_group *
+mlx5_esw_bridge_mcast_filter_fg_create(struct mlx5_eswitch *esw,
+                                      struct mlx5_flow_table *mcast_ft)
+{
+       int inlen = MLX5_ST_SZ_BYTES(create_flow_group_in);
+       struct mlx5_flow_group *fg;
+       u32 *in, *match;
+
+       in = kvzalloc(inlen, GFP_KERNEL);
+       if (!in)
+               return ERR_PTR(-ENOMEM);
+
+       MLX5_SET(create_flow_group_in, in, match_criteria_enable, MLX5_MATCH_MISC_PARAMETERS_2);
+       match = MLX5_ADDR_OF(create_flow_group_in, in, match_criteria);
+
+       MLX5_SET(fte_match_param, match, misc_parameters_2.metadata_reg_c_0,
+                mlx5_eswitch_get_vport_metadata_mask());
+
+       MLX5_SET(create_flow_group_in, in, start_flow_index,
+                MLX5_ESW_BRIDGE_MCAST_TABLE_FILTER_GRP_IDX_FROM);
+       MLX5_SET(create_flow_group_in, in, end_flow_index,
+                MLX5_ESW_BRIDGE_MCAST_TABLE_FILTER_GRP_IDX_TO);
+
+       fg = mlx5_create_flow_group(mcast_ft, in);
+       kvfree(in);
+       if (IS_ERR(fg))
+               esw_warn(esw->dev,
+                        "Failed to create filter flow group for bridge mcast table (err=%pe)\n",
+                        fg);
+
+       return fg;
+}
+
+static struct mlx5_flow_group *
+mlx5_esw_bridge_mcast_fwd_fg_create(struct mlx5_eswitch *esw,
+                                   struct mlx5_flow_table *mcast_ft)
+{
+       int inlen = MLX5_ST_SZ_BYTES(create_flow_group_in);
+       struct mlx5_flow_group *fg;
+       u32 *in;
+
+       in = kvzalloc(inlen, GFP_KERNEL);
+       if (!in)
+               return ERR_PTR(-ENOMEM);
+
+       MLX5_SET(create_flow_group_in, in, start_flow_index,
+                MLX5_ESW_BRIDGE_MCAST_TABLE_FWD_GRP_IDX_FROM);
+       MLX5_SET(create_flow_group_in, in, end_flow_index,
+                MLX5_ESW_BRIDGE_MCAST_TABLE_FWD_GRP_IDX_TO);
+
+       fg = mlx5_create_flow_group(mcast_ft, in);
+       kvfree(in);
+       if (IS_ERR(fg))
+               esw_warn(esw->dev,
+                        "Failed to create forward flow group for bridge mcast table (err=%pe)\n",
+                        fg);
+
+       return fg;
+}
+
+static int mlx5_esw_bridge_port_mcast_fgs_init(struct mlx5_esw_bridge_port *port)
+{
+       struct mlx5_eswitch *esw = port->bridge->br_offloads->esw;
+       struct mlx5_flow_table *mcast_ft = port->mcast.ft;
+       struct mlx5_flow_group *fwd_fg, *filter_fg;
+       int err;
+
+       filter_fg = mlx5_esw_bridge_mcast_filter_fg_create(esw, mcast_ft);
+       if (IS_ERR(filter_fg))
+               return PTR_ERR(filter_fg);
+
+       fwd_fg = mlx5_esw_bridge_mcast_fwd_fg_create(esw, mcast_ft);
+       if (IS_ERR(fwd_fg)) {
+               err = PTR_ERR(fwd_fg);
+               goto err_fwd_fg;
+       }
+
+       port->mcast.filter_fg = filter_fg;
+       port->mcast.fwd_fg = fwd_fg;
+
+       return 0;
+
+err_fwd_fg:
+       mlx5_destroy_flow_group(filter_fg);
+       return err;
+}
+
+static void mlx5_esw_bridge_port_mcast_fgs_cleanup(struct mlx5_esw_bridge_port *port)
+{
+       if (port->mcast.fwd_fg)
+               mlx5_destroy_flow_group(port->mcast.fwd_fg);
+       port->mcast.fwd_fg = NULL;
+       if (port->mcast.filter_fg)
+               mlx5_destroy_flow_group(port->mcast.filter_fg);
+       port->mcast.filter_fg = NULL;
+}
+
+static struct mlx5_flow_handle *
+mlx5_esw_bridge_mcast_flow_with_esw_create(struct mlx5_esw_bridge_port *port,
+                                          struct mlx5_eswitch *esw)
+{
+       struct mlx5_flow_act flow_act = {
+               .action = MLX5_FLOW_CONTEXT_ACTION_DROP,
+               .flags = FLOW_ACT_NO_APPEND,
+       };
+       struct mlx5_flow_spec *rule_spec;
+       struct mlx5_flow_handle *handle;
+
+       rule_spec = kvzalloc(sizeof(*rule_spec), GFP_KERNEL);
+       if (!rule_spec)
+               return ERR_PTR(-ENOMEM);
+
+       rule_spec->match_criteria_enable = MLX5_MATCH_MISC_PARAMETERS_2;
+
+       MLX5_SET(fte_match_param, rule_spec->match_criteria,
+                misc_parameters_2.metadata_reg_c_0, mlx5_eswitch_get_vport_metadata_mask());
+       MLX5_SET(fte_match_param, rule_spec->match_value, misc_parameters_2.metadata_reg_c_0,
+                mlx5_eswitch_get_vport_metadata_for_match(esw, port->vport_num));
+
+       handle = mlx5_add_flow_rules(port->mcast.ft, rule_spec, &flow_act, NULL, 0);
+
+       kvfree(rule_spec);
+       return handle;
+}
+
+static struct mlx5_flow_handle *
+mlx5_esw_bridge_mcast_filter_flow_create(struct mlx5_esw_bridge_port *port)
+{
+       return mlx5_esw_bridge_mcast_flow_with_esw_create(port, port->bridge->br_offloads->esw);
+}
+
+static struct mlx5_flow_handle *
+mlx5_esw_bridge_mcast_filter_flow_peer_create(struct mlx5_esw_bridge_port *port)
+{
+       struct mlx5_devcom *devcom = port->bridge->br_offloads->esw->dev->priv.devcom;
+       static struct mlx5_flow_handle *handle;
+       struct mlx5_eswitch *peer_esw;
+
+       peer_esw = mlx5_devcom_get_peer_data(devcom, MLX5_DEVCOM_ESW_OFFLOADS);
+       if (!peer_esw)
+               return ERR_PTR(-ENODEV);
+
+       handle = mlx5_esw_bridge_mcast_flow_with_esw_create(port, peer_esw);
+
+       mlx5_devcom_release_peer_data(devcom, MLX5_DEVCOM_ESW_OFFLOADS);
+       return handle;
+}
+
+static struct mlx5_flow_handle *
+mlx5_esw_bridge_mcast_fwd_flow_create(struct mlx5_esw_bridge_port *port)
+{
+       struct mlx5_flow_act flow_act = {
+               .action = MLX5_FLOW_CONTEXT_ACTION_FWD_DEST,
+               .flags = FLOW_ACT_NO_APPEND,
+       };
+       struct mlx5_flow_destination dest = {
+               .type = MLX5_FLOW_DESTINATION_TYPE_VPORT,
+               .vport.num = port->vport_num,
+       };
+       struct mlx5_esw_bridge *bridge = port->bridge;
+       struct mlx5_flow_spec *rule_spec;
+       struct mlx5_flow_handle *handle;
+
+       rule_spec = kvzalloc(sizeof(*rule_spec), GFP_KERNEL);
+       if (!rule_spec)
+               return ERR_PTR(-ENOMEM);
+
+       if (MLX5_CAP_ESW_FLOWTABLE(bridge->br_offloads->esw->dev, flow_source) &&
+           port->vport_num == MLX5_VPORT_UPLINK)
+               rule_spec->flow_context.flow_source =
+                       MLX5_FLOW_CONTEXT_FLOW_SOURCE_LOCAL_VPORT;
+
+       if (MLX5_CAP_ESW(bridge->br_offloads->esw->dev, merged_eswitch)) {
+               dest.vport.flags = MLX5_FLOW_DEST_VPORT_VHCA_ID;
+               dest.vport.vhca_id = port->esw_owner_vhca_id;
+       }
+       handle = mlx5_add_flow_rules(port->mcast.ft, rule_spec, &flow_act, &dest, 1);
+
+       kvfree(rule_spec);
+       return handle;
+}
+
+static int mlx5_esw_bridge_port_mcast_fhs_init(struct mlx5_esw_bridge_port *port)
+{
+       struct mlx5_flow_handle *filter_handle, *fwd_handle;
+
+       filter_handle = (port->flags & MLX5_ESW_BRIDGE_PORT_FLAG_PEER) ?
+               mlx5_esw_bridge_mcast_filter_flow_peer_create(port) :
+               mlx5_esw_bridge_mcast_filter_flow_create(port);
+       if (IS_ERR(filter_handle))
+               return PTR_ERR(filter_handle);
+
+       fwd_handle = mlx5_esw_bridge_mcast_fwd_flow_create(port);
+       if (IS_ERR(fwd_handle)) {
+               mlx5_del_flow_rules(filter_handle);
+               return PTR_ERR(fwd_handle);
+       }
+
+       port->mcast.filter_handle = filter_handle;
+       port->mcast.fwd_handle = fwd_handle;
+
+       return 0;
+}
+
+static void mlx5_esw_bridge_port_mcast_fhs_cleanup(struct mlx5_esw_bridge_port *port)
+{
+       if (port->mcast.fwd_handle)
+               mlx5_del_flow_rules(port->mcast.fwd_handle);
+       port->mcast.fwd_handle = NULL;
+       if (port->mcast.filter_handle)
+               mlx5_del_flow_rules(port->mcast.filter_handle);
+       port->mcast.filter_handle = NULL;
+}
+
+int mlx5_esw_bridge_port_mcast_init(struct mlx5_esw_bridge_port *port)
+{
+       struct mlx5_esw_bridge *bridge = port->bridge;
+       int err;
+
+       if (!(bridge->flags & MLX5_ESW_BRIDGE_MCAST_FLAG))
+               return 0;
+
+       err = mlx5_esw_bridge_port_mcast_fts_init(port, bridge);
+       if (err)
+               return err;
+
+       err = mlx5_esw_bridge_port_mcast_fgs_init(port);
+       if (err)
+               goto err_fgs;
+
+       err = mlx5_esw_bridge_port_mcast_fhs_init(port);
+       if (err)
+               goto err_fhs;
+       return err;
+
+err_fhs:
+       mlx5_esw_bridge_port_mcast_fgs_cleanup(port);
+err_fgs:
+       mlx5_esw_bridge_port_mcast_fts_cleanup(port);
+       return err;
+}
+
+void mlx5_esw_bridge_port_mcast_cleanup(struct mlx5_esw_bridge_port *port)
+{
+       mlx5_esw_bridge_port_mcast_fhs_cleanup(port);
+       mlx5_esw_bridge_port_mcast_fgs_cleanup(port);
+       mlx5_esw_bridge_port_mcast_fts_cleanup(port);
+}
+
 static struct mlx5_flow_group *
 mlx5_esw_bridge_ingress_igmp_fg_create(struct mlx5_eswitch *esw,
                                       struct mlx5_flow_table *ingress_ft)
        br_offloads->igmp_handle = NULL;
 }
 
+static int mlx5_esw_brige_mcast_init(struct mlx5_esw_bridge *bridge)
+{
+       struct mlx5_esw_bridge_offloads *br_offloads = bridge->br_offloads;
+       struct mlx5_esw_bridge_port *port, *failed;
+       unsigned long i;
+       int err;
+
+       xa_for_each(&br_offloads->ports, i, port) {
+               if (port->bridge != bridge)
+                       continue;
+
+               err = mlx5_esw_bridge_port_mcast_init(port);
+               if (err) {
+                       failed = port;
+                       goto err_port;
+               }
+       }
+       return 0;
+
+err_port:
+       xa_for_each(&br_offloads->ports, i, port) {
+               if (port == failed)
+                       break;
+               if (port->bridge != bridge)
+                       continue;
+
+               mlx5_esw_bridge_port_mcast_cleanup(port);
+       }
+       return err;
+}
+
+static void mlx5_esw_brige_mcast_cleanup(struct mlx5_esw_bridge *bridge)
+{
+       struct mlx5_esw_bridge_offloads *br_offloads = bridge->br_offloads;
+       struct mlx5_esw_bridge_port *port;
+       unsigned long i;
+
+       xa_for_each(&br_offloads->ports, i, port) {
+               if (port->bridge != bridge)
+                       continue;
+
+               mlx5_esw_bridge_port_mcast_cleanup(port);
+       }
+}
+
 static int mlx5_esw_brige_mcast_global_enable(struct mlx5_esw_bridge_offloads *br_offloads)
 {
        int err;
                return err;
 
        bridge->flags |= MLX5_ESW_BRIDGE_MCAST_FLAG;
-       return 0;
+
+       err = mlx5_esw_brige_mcast_init(bridge);
+       if (err) {
+               esw_warn(bridge->br_offloads->esw->dev, "Failed to enable multicast (err=%d)\n",
+                        err);
+               bridge->flags &= ~MLX5_ESW_BRIDGE_MCAST_FLAG;
+               mlx5_esw_brige_mcast_global_disable(bridge->br_offloads);
+       }
+       return err;
 }
 
 void mlx5_esw_bridge_mcast_disable(struct mlx5_esw_bridge *bridge)
 {
+       mlx5_esw_brige_mcast_cleanup(bridge);
        bridge->flags &= ~MLX5_ESW_BRIDGE_MCAST_FLAG;
        mlx5_esw_brige_mcast_global_disable(bridge->br_offloads);
 }