#include <linux/kthread.h>
 #include <linux/wait.h>
 #include <linux/kernel.h>
+#include <linux/sched/signal.h>
 
 #include <asm/unaligned.h>             /* Used for ntoh_seq and hton_seq */
 
 /*
  *      Specifiy default interface for outgoing multicasts
  */
-static int set_mcast_if(struct sock *sk, char *ifname)
+static int set_mcast_if(struct sock *sk, struct net_device *dev)
 {
-       struct net_device *dev;
        struct inet_sock *inet = inet_sk(sk);
-       struct net *net = sock_net(sk);
-
-       dev = __dev_get_by_name(net, ifname);
-       if (!dev)
-               return -ENODEV;
 
        if (sk->sk_bound_dev_if && dev->ifindex != sk->sk_bound_dev_if)
                return -EINVAL;
  *      in the in_addr structure passed in as a parameter.
  */
 static int
-join_mcast_group(struct sock *sk, struct in_addr *addr, char *ifname)
+join_mcast_group(struct sock *sk, struct in_addr *addr, struct net_device *dev)
 {
-       struct net *net = sock_net(sk);
        struct ip_mreqn mreq;
-       struct net_device *dev;
        int ret;
 
        memset(&mreq, 0, sizeof(mreq));
        memcpy(&mreq.imr_multiaddr, addr, sizeof(struct in_addr));
 
-       dev = __dev_get_by_name(net, ifname);
-       if (!dev)
-               return -ENODEV;
        if (sk->sk_bound_dev_if && dev->ifindex != sk->sk_bound_dev_if)
                return -EINVAL;
 
 
 #ifdef CONFIG_IP_VS_IPV6
 static int join_mcast_group6(struct sock *sk, struct in6_addr *addr,
-                            char *ifname)
+                            struct net_device *dev)
 {
-       struct net *net = sock_net(sk);
-       struct net_device *dev;
        int ret;
 
-       dev = __dev_get_by_name(net, ifname);
-       if (!dev)
-               return -ENODEV;
        if (sk->sk_bound_dev_if && dev->ifindex != sk->sk_bound_dev_if)
                return -EINVAL;
 
 }
 #endif
 
-static int bind_mcastif_addr(struct socket *sock, char *ifname)
+static int bind_mcastif_addr(struct socket *sock, struct net_device *dev)
 {
-       struct net *net = sock_net(sock->sk);
-       struct net_device *dev;
        __be32 addr;
        struct sockaddr_in sin;
 
-       dev = __dev_get_by_name(net, ifname);
-       if (!dev)
-               return -ENODEV;
-
        addr = inet_select_addr(dev, 0, RT_SCOPE_UNIVERSE);
        if (!addr)
                pr_err("You probably need to specify IP address on "
                       "multicast interface.\n");
 
        IP_VS_DBG(7, "binding socket with (%s) %pI4\n",
-                 ifname, &addr);
+                 dev->name, &addr);
 
        /* Now bind the socket with the address of multicast interface */
        sin.sin_family       = AF_INET;
 /*
  *      Set up sending multicast socket over UDP
  */
-static struct socket *make_send_sock(struct netns_ipvs *ipvs, int id)
+static int make_send_sock(struct netns_ipvs *ipvs, int id,
+                         struct net_device *dev, struct socket **sock_ret)
 {
        /* multicast addr */
        union ipvs_sockaddr mcast_addr;
                                  IPPROTO_UDP, &sock);
        if (result < 0) {
                pr_err("Error during creation of socket; terminating\n");
-               return ERR_PTR(result);
+               goto error;
        }
-       result = set_mcast_if(sock->sk, ipvs->mcfg.mcast_ifn);
+       *sock_ret = sock;
+       result = set_mcast_if(sock->sk, dev);
        if (result < 0) {
                pr_err("Error setting outbound mcast interface\n");
                goto error;
                set_sock_size(sock->sk, 1, result);
 
        if (AF_INET == ipvs->mcfg.mcast_af)
-               result = bind_mcastif_addr(sock, ipvs->mcfg.mcast_ifn);
+               result = bind_mcastif_addr(sock, dev);
        else
                result = 0;
        if (result < 0) {
                goto error;
        }
 
-       return sock;
+       return 0;
 
 error:
-       sock_release(sock);
-       return ERR_PTR(result);
+       return result;
 }
 
 
 /*
  *      Set up receiving multicast socket over UDP
  */
-static struct socket *make_receive_sock(struct netns_ipvs *ipvs, int id,
-                                       int ifindex)
+static int make_receive_sock(struct netns_ipvs *ipvs, int id,
+                            struct net_device *dev, struct socket **sock_ret)
 {
        /* multicast addr */
        union ipvs_sockaddr mcast_addr;
                                  IPPROTO_UDP, &sock);
        if (result < 0) {
                pr_err("Error during creation of socket; terminating\n");
-               return ERR_PTR(result);
+               goto error;
        }
+       *sock_ret = sock;
        /* it is equivalent to the REUSEADDR option in user-space */
        sock->sk->sk_reuse = SK_CAN_REUSE;
        result = sysctl_sync_sock_size(ipvs);
                set_sock_size(sock->sk, 0, result);
 
        get_mcast_sockaddr(&mcast_addr, &salen, &ipvs->bcfg, id);
-       sock->sk->sk_bound_dev_if = ifindex;
+       sock->sk->sk_bound_dev_if = dev->ifindex;
        result = sock->ops->bind(sock, (struct sockaddr *)&mcast_addr, salen);
        if (result < 0) {
                pr_err("Error binding to the multicast addr\n");
 #ifdef CONFIG_IP_VS_IPV6
        if (ipvs->bcfg.mcast_af == AF_INET6)
                result = join_mcast_group6(sock->sk, &mcast_addr.in6.sin6_addr,
-                                          ipvs->bcfg.mcast_ifn);
+                                          dev);
        else
 #endif
                result = join_mcast_group(sock->sk, &mcast_addr.in.sin_addr,
-                                         ipvs->bcfg.mcast_ifn);
+                                         dev);
        if (result < 0) {
                pr_err("Error joining to the multicast group\n");
                goto error;
        }
 
-       return sock;
+       return 0;
 
 error:
-       sock_release(sock);
-       return ERR_PTR(result);
+       return result;
 }
 
 
 int start_sync_thread(struct netns_ipvs *ipvs, struct ipvs_sync_daemon_cfg *c,
                      int state)
 {
-       struct ip_vs_sync_thread_data *tinfo;
+       struct ip_vs_sync_thread_data *tinfo = NULL;
        struct task_struct **array = NULL, *task;
-       struct socket *sock;
        struct net_device *dev;
        char *name;
        int (*threadfn)(void *data);
-       int id, count, hlen;
+       int id = 0, count, hlen;
        int result = -ENOMEM;
        u16 mtu, min_mtu;
 
        IP_VS_DBG(7, "Each ip_vs_sync_conn entry needs %zd bytes\n",
                  sizeof(struct ip_vs_sync_conn_v0));
 
+       /* Do not hold one mutex and then to block on another */
+       for (;;) {
+               rtnl_lock();
+               if (mutex_trylock(&ipvs->sync_mutex))
+                       break;
+               rtnl_unlock();
+               mutex_lock(&ipvs->sync_mutex);
+               if (rtnl_trylock())
+                       break;
+               mutex_unlock(&ipvs->sync_mutex);
+       }
+
        if (!ipvs->sync_state) {
                count = clamp(sysctl_sync_ports(ipvs), 1, IPVS_SYNC_PORTS_MAX);
                ipvs->threads_mask = count - 1;
        dev = __dev_get_by_name(ipvs->net, c->mcast_ifn);
        if (!dev) {
                pr_err("Unknown mcast interface: %s\n", c->mcast_ifn);
-               return -ENODEV;
+               result = -ENODEV;
+               goto out_early;
        }
        hlen = (AF_INET6 == c->mcast_af) ?
               sizeof(struct ipv6hdr) + sizeof(struct udphdr) :
                c->sync_maxlen = mtu - hlen;
 
        if (state == IP_VS_STATE_MASTER) {
+               result = -EEXIST;
                if (ipvs->ms)
-                       return -EEXIST;
+                       goto out_early;
 
                ipvs->mcfg = *c;
                name = "ipvs-m:%d:%d";
                threadfn = sync_thread_master;
        } else if (state == IP_VS_STATE_BACKUP) {
+               result = -EEXIST;
                if (ipvs->backup_threads)
-                       return -EEXIST;
+                       goto out_early;
 
                ipvs->bcfg = *c;
                name = "ipvs-b:%d:%d";
                threadfn = sync_thread_backup;
        } else {
-               return -EINVAL;
+               result = -EINVAL;
+               goto out_early;
        }
 
        if (state == IP_VS_STATE_MASTER) {
                struct ipvs_master_sync_state *ms;
 
+               result = -ENOMEM;
                ipvs->ms = kcalloc(count, sizeof(ipvs->ms[0]), GFP_KERNEL);
                if (!ipvs->ms)
                        goto out;
        } else {
                array = kcalloc(count, sizeof(struct task_struct *),
                                GFP_KERNEL);
+               result = -ENOMEM;
                if (!array)
                        goto out;
        }
 
-       tinfo = NULL;
        for (id = 0; id < count; id++) {
-               if (state == IP_VS_STATE_MASTER)
-                       sock = make_send_sock(ipvs, id);
-               else
-                       sock = make_receive_sock(ipvs, id, dev->ifindex);
-               if (IS_ERR(sock)) {
-                       result = PTR_ERR(sock);
-                       goto outtinfo;
-               }
+               result = -ENOMEM;
                tinfo = kmalloc(sizeof(*tinfo), GFP_KERNEL);
                if (!tinfo)
-                       goto outsocket;
+                       goto out;
                tinfo->ipvs = ipvs;
-               tinfo->sock = sock;
+               tinfo->sock = NULL;
                if (state == IP_VS_STATE_BACKUP) {
                        tinfo->buf = kmalloc(ipvs->bcfg.sync_maxlen,
                                             GFP_KERNEL);
                        if (!tinfo->buf)
-                               goto outtinfo;
+                               goto out;
                } else {
                        tinfo->buf = NULL;
                }
                tinfo->id = id;
+               if (state == IP_VS_STATE_MASTER)
+                       result = make_send_sock(ipvs, id, dev, &tinfo->sock);
+               else
+                       result = make_receive_sock(ipvs, id, dev, &tinfo->sock);
+               if (result < 0)
+                       goto out;
 
                task = kthread_run(threadfn, tinfo, name, ipvs->gen, id);
                if (IS_ERR(task)) {
                        result = PTR_ERR(task);
-                       goto outtinfo;
+                       goto out;
                }
                tinfo = NULL;
                if (state == IP_VS_STATE_MASTER)
        ipvs->sync_state |= state;
        spin_unlock_bh(&ipvs->sync_buff_lock);
 
+       mutex_unlock(&ipvs->sync_mutex);
+       rtnl_unlock();
+
        /* increase the module use count */
        ip_vs_use_count_inc();
 
        return 0;
 
-outsocket:
-       sock_release(sock);
-
-outtinfo:
-       if (tinfo) {
-               sock_release(tinfo->sock);
-               kfree(tinfo->buf);
-               kfree(tinfo);
-       }
+out:
+       /* We do not need RTNL lock anymore, release it here so that
+        * sock_release below and in the kthreads can use rtnl_lock
+        * to leave the mcast group.
+        */
+       rtnl_unlock();
        count = id;
        while (count-- > 0) {
                if (state == IP_VS_STATE_MASTER)
                else
                        kthread_stop(array[count]);
        }
-       kfree(array);
-
-out:
        if (!(ipvs->sync_state & IP_VS_STATE_MASTER)) {
                kfree(ipvs->ms);
                ipvs->ms = NULL;
        }
+       mutex_unlock(&ipvs->sync_mutex);
+       if (tinfo) {
+               if (tinfo->sock)
+                       sock_release(tinfo->sock);
+               kfree(tinfo->buf);
+               kfree(tinfo);
+       }
+       kfree(array);
+       return result;
+
+out_early:
+       mutex_unlock(&ipvs->sync_mutex);
+       rtnl_unlock();
        return result;
 }