#include "ipvlan.h"
 
-static u32 ipvl_nf_hook_refcnt = 0;
+static unsigned int ipvlan_netid __read_mostly;
+
+struct ipvlan_netns {
+       unsigned int ipvl_nf_hook_refcnt;
+};
 
 static struct nf_hook_ops ipvl_nfops[] __read_mostly = {
        {
        ipvlan->dev->mtu = dev->mtu;
 }
 
-static int ipvlan_register_nf_hook(void)
+static int ipvlan_register_nf_hook(struct net *net)
 {
+       struct ipvlan_netns *vnet = net_generic(net, ipvlan_netid);
        int err = 0;
 
-       if (!ipvl_nf_hook_refcnt) {
-               err = _nf_register_hooks(ipvl_nfops, ARRAY_SIZE(ipvl_nfops));
+       if (!vnet->ipvl_nf_hook_refcnt) {
+               err = nf_register_net_hooks(net, ipvl_nfops,
+                                           ARRAY_SIZE(ipvl_nfops));
                if (!err)
-                       ipvl_nf_hook_refcnt = 1;
+                       vnet->ipvl_nf_hook_refcnt = 1;
        } else {
-               ipvl_nf_hook_refcnt++;
+               vnet->ipvl_nf_hook_refcnt++;
        }
 
        return err;
 }
 
-static void ipvlan_unregister_nf_hook(void)
+static void ipvlan_unregister_nf_hook(struct net *net)
 {
-       WARN_ON(!ipvl_nf_hook_refcnt);
+       struct ipvlan_netns *vnet = net_generic(net, ipvlan_netid);
+
+       if (WARN_ON(!vnet->ipvl_nf_hook_refcnt))
+               return;
 
-       ipvl_nf_hook_refcnt--;
-       if (!ipvl_nf_hook_refcnt)
-               _nf_unregister_hooks(ipvl_nfops, ARRAY_SIZE(ipvl_nfops));
+       vnet->ipvl_nf_hook_refcnt--;
+       if (!vnet->ipvl_nf_hook_refcnt)
+               nf_unregister_net_hooks(net, ipvl_nfops,
+                                       ARRAY_SIZE(ipvl_nfops));
 }
 
 static int ipvlan_set_port_mode(struct ipvl_port *port, u16 nval)
        if (port->mode != nval) {
                if (nval == IPVLAN_MODE_L3S) {
                        /* New mode is L3S */
-                       err = ipvlan_register_nf_hook();
+                       err = ipvlan_register_nf_hook(read_pnet(&port->pnet));
                        if (!err) {
                                mdev->l3mdev_ops = &ipvl_l3mdev_ops;
                                mdev->priv_flags |= IFF_L3MDEV_MASTER;
                } else if (port->mode == IPVLAN_MODE_L3S) {
                        /* Old mode was L3S */
                        mdev->priv_flags &= ~IFF_L3MDEV_MASTER;
-                       ipvlan_unregister_nf_hook();
+                       ipvlan_unregister_nf_hook(read_pnet(&port->pnet));
                        mdev->l3mdev_ops = NULL;
                }
                list_for_each_entry(ipvlan, &port->ipvlans, pnode) {
        if (!port)
                return -ENOMEM;
 
+       write_pnet(&port->pnet, dev_net(dev));
        port->dev = dev;
        port->mode = IPVLAN_MODE_L3;
        INIT_LIST_HEAD(&port->ipvlans);
        dev->priv_flags &= ~IFF_IPVLAN_MASTER;
        if (port->mode == IPVLAN_MODE_L3S) {
                dev->priv_flags &= ~IFF_L3MDEV_MASTER;
-               ipvlan_unregister_nf_hook();
+               ipvlan_unregister_nf_hook(dev_net(dev));
                dev->l3mdev_ops = NULL;
        }
        netdev_rx_handler_unregister(dev);
                                                         ipvlan->dev);
                break;
 
+       case NETDEV_REGISTER: {
+               struct net *oldnet, *newnet = dev_net(dev);
+               struct ipvlan_netns *old_vnet;
+
+               oldnet = read_pnet(&port->pnet);
+               if (net_eq(newnet, oldnet))
+                       break;
+
+               write_pnet(&port->pnet, newnet);
+
+               old_vnet = net_generic(oldnet, ipvlan_netid);
+               if (!old_vnet->ipvl_nf_hook_refcnt)
+                       break;
+
+               ipvlan_register_nf_hook(newnet);
+               ipvlan_unregister_nf_hook(oldnet);
+               break;
+       }
        case NETDEV_UNREGISTER:
                if (dev->reg_state != NETREG_UNREGISTERING)
                        break;
        .notifier_call = ipvlan_addr6_event,
 };
 
+static void ipvlan_ns_exit(struct net *net)
+{
+       struct ipvlan_netns *vnet = net_generic(net, ipvlan_netid);
+
+       if (WARN_ON_ONCE(vnet->ipvl_nf_hook_refcnt)) {
+               vnet->ipvl_nf_hook_refcnt = 0;
+               nf_unregister_net_hooks(net, ipvl_nfops,
+                                       ARRAY_SIZE(ipvl_nfops));
+       }
+}
+
+static struct pernet_operations ipvlan_net_ops = {
+       .id = &ipvlan_netid,
+       .size = sizeof(struct ipvlan_netns),
+       .exit = ipvlan_ns_exit,
+};
+
 static int __init ipvlan_init_module(void)
 {
        int err;
        register_inet6addr_notifier(&ipvlan_addr6_notifier_block);
        register_inetaddr_notifier(&ipvlan_addr4_notifier_block);
 
-       err = ipvlan_link_register(&ipvlan_link_ops);
+       err = register_pernet_subsys(&ipvlan_net_ops);
        if (err < 0)
                goto error;
 
+       err = ipvlan_link_register(&ipvlan_link_ops);
+       if (err < 0) {
+               unregister_pernet_subsys(&ipvlan_net_ops);
+               goto error;
+       }
+
        return 0;
 error:
        unregister_inetaddr_notifier(&ipvlan_addr4_notifier_block);
 static void __exit ipvlan_cleanup_module(void)
 {
        rtnl_link_unregister(&ipvlan_link_ops);
+       unregister_pernet_subsys(&ipvlan_net_ops);
        unregister_netdevice_notifier(&ipvlan_notifier_block);
        unregister_inetaddr_notifier(&ipvlan_addr4_notifier_block);
        unregister_inet6addr_notifier(&ipvlan_addr6_notifier_block);