int sock_wake_async(struct socket_wq *sk_wq, int how, int band);
 int sock_register(const struct net_proto_family *fam);
 void sock_unregister(int family);
+bool sock_is_registered(int family);
 int __sock_create(struct net *net, int family, int type, int proto,
                  struct socket **res, int kern);
 int sock_create(int family, int type, int proto, struct socket **res);
 
 
 int proto_register(struct proto *prot, int alloc_slab);
 void proto_unregister(struct proto *prot);
+int sock_load_diag_module(int family, int protocol);
 
 #ifdef SOCK_REFCNT_DEBUG
 static inline void sk_refcnt_debug_inc(struct sock *sk)
 
 }
 EXPORT_SYMBOL(proto_unregister);
 
+int sock_load_diag_module(int family, int protocol)
+{
+       if (!protocol) {
+               if (!sock_is_registered(family))
+                       return -ENOENT;
+
+               return request_module("net-pf-%d-proto-%d-type-%d", PF_NETLINK,
+                                     NETLINK_SOCK_DIAG, family);
+       }
+
+#ifdef CONFIG_INET
+       if (family == AF_INET &&
+           !rcu_access_pointer(inet_protos[protocol]))
+               return -ENOENT;
+#endif
+
+       return request_module("net-pf-%d-proto-%d-type-%d-%d", PF_NETLINK,
+                             NETLINK_SOCK_DIAG, family, protocol);
+}
+EXPORT_SYMBOL(sock_load_diag_module);
+
 #ifdef CONFIG_PROC_FS
 static void *proto_seq_start(struct seq_file *seq, loff_t *pos)
        __acquires(proto_list_mutex)
 
                return -EINVAL;
 
        if (sock_diag_handlers[req->sdiag_family] == NULL)
-               request_module("net-pf-%d-proto-%d-type-%d", PF_NETLINK,
-                               NETLINK_SOCK_DIAG, req->sdiag_family);
+               sock_load_diag_module(req->sdiag_family, 0);
 
        mutex_lock(&sock_diag_table_mutex);
        hndl = sock_diag_handlers[req->sdiag_family];
        case TCPDIAG_GETSOCK:
        case DCCPDIAG_GETSOCK:
                if (inet_rcv_compat == NULL)
-                       request_module("net-pf-%d-proto-%d-type-%d", PF_NETLINK,
-                                       NETLINK_SOCK_DIAG, AF_INET);
+                       sock_load_diag_module(AF_INET, 0);
 
                mutex_lock(&sock_diag_table_mutex);
                if (inet_rcv_compat != NULL)
        case SKNLGRP_INET_TCP_DESTROY:
        case SKNLGRP_INET_UDP_DESTROY:
                if (!sock_diag_handlers[AF_INET])
-                       request_module("net-pf-%d-proto-%d-type-%d", PF_NETLINK,
-                                      NETLINK_SOCK_DIAG, AF_INET);
+                       sock_load_diag_module(AF_INET, 0);
                break;
        case SKNLGRP_INET6_TCP_DESTROY:
        case SKNLGRP_INET6_UDP_DESTROY:
                if (!sock_diag_handlers[AF_INET6])
-                       request_module("net-pf-%d-proto-%d-type-%d", PF_NETLINK,
-                                      NETLINK_SOCK_DIAG, AF_INET6);
+                       sock_load_diag_module(AF_INET6, 0);
                break;
        }
        return 0;
 
 static const struct inet_diag_handler *inet_diag_lock_handler(int proto)
 {
        if (!inet_diag_table[proto])
-               request_module("net-pf-%d-proto-%d-type-%d-%d", PF_NETLINK,
-                              NETLINK_SOCK_DIAG, AF_INET, proto);
+               sock_load_diag_module(AF_INET, proto);
 
        mutex_lock(&inet_diag_table_mutex);
        if (!inet_diag_table[proto])
 
 }
 EXPORT_SYMBOL(sock_unregister);
 
+bool sock_is_registered(int family)
+{
+       return family < NPROTO && rcu_access_pointer(net_families[family]);
+}
+
 static int __init sock_init(void)
 {
        int err;