rtnl_doit_func          doit;
        rtnl_dumpit_func        dumpit;
        unsigned int            flags;
+       struct rcu_head         rcu;
 };
 
 static DEFINE_MUTEX(rtnl_mutex);
 EXPORT_SYMBOL(lockdep_rtnl_is_held);
 #endif /* #ifdef CONFIG_PROVE_LOCKING */
 
-static struct rtnl_link __rcu *rtnl_msg_handlers[RTNL_FAMILY_MAX + 1];
+static struct rtnl_link __rcu **rtnl_msg_handlers[RTNL_FAMILY_MAX + 1];
 static refcount_t rtnl_msg_handlers_ref[RTNL_FAMILY_MAX + 1];
 
 static inline int rtm_msgindex(int msgtype)
        return msgindex;
 }
 
+static struct rtnl_link *rtnl_get_link(int protocol, int msgtype)
+{
+       struct rtnl_link **tab;
+
+       if (protocol >= ARRAY_SIZE(rtnl_msg_handlers))
+               protocol = PF_UNSPEC;
+
+       tab = rcu_dereference_rtnl(rtnl_msg_handlers[protocol]);
+       if (!tab)
+               tab = rcu_dereference_rtnl(rtnl_msg_handlers[PF_UNSPEC]);
+
+       return tab[msgtype];
+}
+
 /**
  * __rtnl_register - Register a rtnetlink message type
  * @protocol: Protocol family or PF_UNSPEC
                    rtnl_doit_func doit, rtnl_dumpit_func dumpit,
                    unsigned int flags)
 {
-       struct rtnl_link *tab;
+       struct rtnl_link **tab, *link, *old;
        int msgindex;
+       int ret = -ENOBUFS;
 
        BUG_ON(protocol < 0 || protocol > RTNL_FAMILY_MAX);
        msgindex = rtm_msgindex(msgtype);
 
-       tab = rcu_dereference_raw(rtnl_msg_handlers[protocol]);
+       rtnl_lock();
+       tab = rtnl_msg_handlers[protocol];
        if (tab == NULL) {
-               tab = kcalloc(RTM_NR_MSGTYPES, sizeof(*tab), GFP_KERNEL);
-               if (tab == NULL)
-                       return -ENOBUFS;
+               tab = kcalloc(RTM_NR_MSGTYPES, sizeof(void *), GFP_KERNEL);
+               if (!tab)
+                       goto unlock;
 
+               /* ensures we see the 0 stores */
                rcu_assign_pointer(rtnl_msg_handlers[protocol], tab);
        }
 
+       old = rtnl_dereference(tab[msgindex]);
+       if (old) {
+               link = kmemdup(old, sizeof(*old), GFP_KERNEL);
+               if (!link)
+                       goto unlock;
+       } else {
+               link = kzalloc(sizeof(*link), GFP_KERNEL);
+               if (!link)
+                       goto unlock;
+       }
+
+       WARN_ON(doit && link->doit && link->doit != doit);
        if (doit)
-               tab[msgindex].doit = doit;
+               link->doit = doit;
+       WARN_ON(dumpit && link->dumpit && link->dumpit != dumpit);
        if (dumpit)
-               tab[msgindex].dumpit = dumpit;
-       tab[msgindex].flags |= flags;
+               link->dumpit = dumpit;
 
-       return 0;
+       link->flags |= flags;
+
+       /* publish protocol:msgtype */
+       rcu_assign_pointer(tab[msgindex], link);
+       ret = 0;
+       if (old)
+               kfree_rcu(old, rcu);
+unlock:
+       rtnl_unlock();
+       return ret;
 }
 EXPORT_SYMBOL_GPL(__rtnl_register);
 
  */
 int rtnl_unregister(int protocol, int msgtype)
 {
-       struct rtnl_link *handlers;
+       struct rtnl_link **tab, *link;
        int msgindex;
 
        BUG_ON(protocol < 0 || protocol > RTNL_FAMILY_MAX);
        msgindex = rtm_msgindex(msgtype);
 
        rtnl_lock();
-       handlers = rtnl_dereference(rtnl_msg_handlers[protocol]);
-       if (!handlers) {
+       tab = rtnl_dereference(rtnl_msg_handlers[protocol]);
+       if (!tab) {
                rtnl_unlock();
                return -ENOENT;
        }
 
-       handlers[msgindex].doit = NULL;
-       handlers[msgindex].dumpit = NULL;
-       handlers[msgindex].flags = 0;
+       link = tab[msgindex];
+       rcu_assign_pointer(tab[msgindex], NULL);
        rtnl_unlock();
 
+       kfree_rcu(link, rcu);
+
        return 0;
 }
 EXPORT_SYMBOL_GPL(rtnl_unregister);
  */
 void rtnl_unregister_all(int protocol)
 {
-       struct rtnl_link *handlers;
+       struct rtnl_link **tab, *link;
+       int msgindex;
 
        BUG_ON(protocol < 0 || protocol > RTNL_FAMILY_MAX);
 
        rtnl_lock();
-       handlers = rtnl_dereference(rtnl_msg_handlers[protocol]);
+       tab = rtnl_msg_handlers[protocol];
        RCU_INIT_POINTER(rtnl_msg_handlers[protocol], NULL);
+       for (msgindex = 0; msgindex < RTM_NR_MSGTYPES; msgindex++) {
+               link = tab[msgindex];
+               if (!link)
+                       continue;
+
+               rcu_assign_pointer(tab[msgindex], NULL);
+               kfree_rcu(link, rcu);
+       }
        rtnl_unlock();
 
        synchronize_net();
 
        while (refcount_read(&rtnl_msg_handlers_ref[protocol]) > 1)
                schedule();
-       kfree(handlers);
+       kfree(tab);
 }
 EXPORT_SYMBOL_GPL(rtnl_unregister_all);
 
                s_idx = 1;
 
        for (idx = 1; idx <= RTNL_FAMILY_MAX; idx++) {
+               struct rtnl_link **tab;
                int type = cb->nlh->nlmsg_type-RTM_BASE;
-               struct rtnl_link *handlers;
+               struct rtnl_link *link;
                rtnl_dumpit_func dumpit;
 
                if (idx < s_idx || idx == PF_PACKET)
                        continue;
 
-               handlers = rtnl_dereference(rtnl_msg_handlers[idx]);
-               if (!handlers)
+               if (type < 0 || type >= RTM_NR_MSGTYPES)
                        continue;
 
-               dumpit = READ_ONCE(handlers[type].dumpit);
+               tab = rcu_dereference_rtnl(rtnl_msg_handlers[idx]);
+               if (!tab)
+                       continue;
+
+               link = tab[type];
+               if (!link)
+                       continue;
+
+               dumpit = link->dumpit;
                if (!dumpit)
                        continue;
 
                             struct netlink_ext_ack *extack)
 {
        struct net *net = sock_net(skb->sk);
-       struct rtnl_link *handlers;
+       struct rtnl_link *link;
        int err = -EOPNOTSUPP;
        rtnl_doit_func doit;
        unsigned int flags;
        if (kind != 2 && !netlink_net_capable(skb, CAP_NET_ADMIN))
                return -EPERM;
 
-       if (family >= ARRAY_SIZE(rtnl_msg_handlers))
-               family = PF_UNSPEC;
-
        rcu_read_lock();
-       handlers = rcu_dereference(rtnl_msg_handlers[family]);
-       if (!handlers) {
-               family = PF_UNSPEC;
-               handlers = rcu_dereference(rtnl_msg_handlers[family]);
-       }
-
        if (kind == 2 && nlh->nlmsg_flags&NLM_F_DUMP) {
                struct sock *rtnl;
                rtnl_dumpit_func dumpit;
                u16 min_dump_alloc = 0;
 
-               dumpit = READ_ONCE(handlers[type].dumpit);
-               if (!dumpit) {
+               link = rtnl_get_link(family, type);
+               if (!link || !link->dumpit) {
                        family = PF_UNSPEC;
-                       handlers = rcu_dereference(rtnl_msg_handlers[PF_UNSPEC]);
-                       if (!handlers)
-                               goto err_unlock;
-
-                       dumpit = READ_ONCE(handlers[type].dumpit);
-                       if (!dumpit)
+                       link = rtnl_get_link(family, type);
+                       if (!link || !link->dumpit)
                                goto err_unlock;
                }
+               dumpit = link->dumpit;
 
                refcount_inc(&rtnl_msg_handlers_ref[family]);
 
                return err;
        }
 
-       doit = READ_ONCE(handlers[type].doit);
-       if (!doit) {
+       link = rtnl_get_link(family, type);
+       if (!link || !link->doit) {
                family = PF_UNSPEC;
-               handlers = rcu_dereference(rtnl_msg_handlers[family]);
+               link = rtnl_get_link(PF_UNSPEC, type);
+               if (!link || !link->doit)
+                       goto out_unlock;
        }
 
-       flags = READ_ONCE(handlers[type].flags);
+       flags = link->flags;
        if (flags & RTNL_FLAG_DOIT_UNLOCKED) {
                refcount_inc(&rtnl_msg_handlers_ref[family]);
-               doit = READ_ONCE(handlers[type].doit);
+               doit = link->doit;
                rcu_read_unlock();
                if (doit)
                        err = doit(skb, nlh, extack);
                refcount_dec(&rtnl_msg_handlers_ref[family]);
                return err;
        }
-
        rcu_read_unlock();
 
        rtnl_lock();
-       handlers = rtnl_dereference(rtnl_msg_handlers[family]);
-       if (handlers) {
-               doit = READ_ONCE(handlers[type].doit);
-               if (doit)
-                       err = doit(skb, nlh, extack);
-       }
+       link = rtnl_get_link(family, type);
+       if (link && link->doit)
+               err = link->doit(skb, nlh, extack);
        rtnl_unlock();
+
+       return err;
+
+out_unlock:
+       rcu_read_unlock();
        return err;
 
 err_unlock: