#include <net/inet_hashtables.h>
 #include <net/inet_timewait_sock.h>
 #include <net/inet6_hashtables.h>
+#include <net/bpf_sk_storage.h>
 #include <net/netlink.h>
 
 #include <linux/inet.h>
 }
 EXPORT_SYMBOL_GPL(inet_diag_msg_attrs_fill);
 
+#define MAX_DUMP_ALLOC_SIZE (KMALLOC_MAX_SIZE - SKB_DATA_ALIGN(sizeof(struct skb_shared_info)))
+
 int inet_sk_diag_fill(struct sock *sk, struct inet_connection_sock *icsk,
                      struct sk_buff *skb, struct netlink_callback *cb,
                      const struct inet_diag_req_v2 *req,
 {
        const struct tcp_congestion_ops *ca_ops;
        const struct inet_diag_handler *handler;
+       struct inet_diag_dump_data *cb_data;
        int ext = req->idiag_ext;
        struct inet_diag_msg *r;
        struct nlmsghdr  *nlh;
        struct nlattr *attr;
        void *info = NULL;
 
+       cb_data = cb->data;
        handler = inet_diag_table[req->sdiag_protocol];
        BUG_ON(!handler);
 
                        goto errout;
        }
 
+       /* Keep it at the end for potential retry with a larger skb,
+        * or else do best-effort fitting, which is only done for the
+        * first_nlmsg.
+        */
+       if (cb_data->bpf_stg_diag) {
+               bool first_nlmsg = ((unsigned char *)nlh == skb->data);
+               unsigned int prev_min_dump_alloc;
+               unsigned int total_nla_size = 0;
+               unsigned int msg_len;
+               int err;
+
+               msg_len = skb_tail_pointer(skb) - (unsigned char *)nlh;
+               err = bpf_sk_storage_diag_put(cb_data->bpf_stg_diag, sk, skb,
+                                             INET_DIAG_SK_BPF_STORAGES,
+                                             &total_nla_size);
+
+               if (!err)
+                       goto out;
+
+               total_nla_size += msg_len;
+               prev_min_dump_alloc = cb->min_dump_alloc;
+               if (total_nla_size > prev_min_dump_alloc)
+                       cb->min_dump_alloc = min_t(u32, total_nla_size,
+                                                  MAX_DUMP_ALLOC_SIZE);
+
+               if (!first_nlmsg)
+                       goto errout;
+
+               if (cb->min_dump_alloc > prev_min_dump_alloc)
+                       /* Retry with pskb_expand_head() with
+                        * __GFP_DIRECT_RECLAIM
+                        */
+                       goto errout;
+
+               WARN_ON_ONCE(total_nla_size <= prev_min_dump_alloc);
+
+               /* Send what we have for this sk
+                * and move on to the next sk in the following
+                * dump()
+                */
+       }
+
 out:
        nlmsg_end(skb, nlh);
        return 0;
                            const struct inet_diag_req_v2 *r)
 {
        const struct inet_diag_handler *handler;
+       u32 prev_min_dump_alloc;
        int err = 0;
 
+again:
+       prev_min_dump_alloc = cb->min_dump_alloc;
        handler = inet_diag_lock_handler(r->sdiag_protocol);
        if (!IS_ERR(handler))
                handler->dump(skb, cb, r);
                err = PTR_ERR(handler);
        inet_diag_unlock_handler(handler);
 
+       /* The skb is not large enough to fit one sk info and
+        * inet_sk_diag_fill() has requested for a larger skb.
+        */
+       if (!skb->len && cb->min_dump_alloc > prev_min_dump_alloc) {
+               err = pskb_expand_head(skb, 0, cb->min_dump_alloc, GFP_KERNEL);
+               if (!err)
+                       goto again;
+       }
+
        return err ? : skb->len;
 }
 
                }
        }
 
+       nla = cb_data->inet_diag_nla_bpf_stgs;
+       if (nla) {
+               struct bpf_sk_storage_diag *bpf_stg_diag;
+
+               bpf_stg_diag = bpf_sk_storage_diag_alloc(nla);
+               if (IS_ERR(bpf_stg_diag)) {
+                       kfree(cb_data);
+                       return PTR_ERR(bpf_stg_diag);
+               }
+               cb_data->bpf_stg_diag = bpf_stg_diag;
+       }
+
        cb->data = cb_data;
        return 0;
 }
 
 static int inet_diag_dump_done(struct netlink_callback *cb)
 {
+       struct inet_diag_dump_data *cb_data = cb->data;
+
+       bpf_sk_storage_diag_free(cb_data->bpf_stg_diag);
        kfree(cb->data);
 
        return 0;