#include <linux/bpf.h>
 #include <net/bpf_sk_storage.h>
 #include <net/sock.h>
+#include <uapi/linux/sock_diag.h>
 #include <uapi/linux/btf.h>
 
 static atomic_t cache_idx;
        kfree(map);
 }
 
+/* U16_MAX is much more than enough for sk local storage
+ * considering a tcp_sock is ~2k.
+ */
+#define MAX_VALUE_SIZE                                                 \
+       min_t(u32,                                                      \
+             (KMALLOC_MAX_SIZE - MAX_BPF_STACK - sizeof(struct bpf_sk_storage_elem)), \
+             (U16_MAX - sizeof(struct bpf_sk_storage_elem)))
+
 static int bpf_sk_storage_map_alloc_check(union bpf_attr *attr)
 {
        if (attr->map_flags & ~SK_STORAGE_CREATE_FLAG_MASK ||
        if (!capable(CAP_SYS_ADMIN))
                return -EPERM;
 
-       if (attr->value_size >= KMALLOC_MAX_SIZE -
-           MAX_BPF_STACK - sizeof(struct bpf_sk_storage_elem) ||
-           /* U16_MAX is much more than enough for sk local storage
-            * considering a tcp_sock is ~2k.
-            */
-           attr->value_size > U16_MAX - sizeof(struct bpf_sk_storage_elem))
+       if (attr->value_size > MAX_VALUE_SIZE)
                return -E2BIG;
 
        return 0;
        .arg1_type      = ARG_CONST_MAP_PTR,
        .arg2_type      = ARG_PTR_TO_SOCKET,
 };
+
+struct bpf_sk_storage_diag {
+       u32 nr_maps;
+       struct bpf_map *maps[];
+};
+
+/* The reply will be like:
+ * INET_DIAG_BPF_SK_STORAGES (nla_nest)
+ *     SK_DIAG_BPF_STORAGE (nla_nest)
+ *             SK_DIAG_BPF_STORAGE_MAP_ID (nla_put_u32)
+ *             SK_DIAG_BPF_STORAGE_MAP_VALUE (nla_reserve_64bit)
+ *     SK_DIAG_BPF_STORAGE (nla_nest)
+ *             SK_DIAG_BPF_STORAGE_MAP_ID (nla_put_u32)
+ *             SK_DIAG_BPF_STORAGE_MAP_VALUE (nla_reserve_64bit)
+ *     ....
+ */
+static int nla_value_size(u32 value_size)
+{
+       /* SK_DIAG_BPF_STORAGE (nla_nest)
+        *      SK_DIAG_BPF_STORAGE_MAP_ID (nla_put_u32)
+        *      SK_DIAG_BPF_STORAGE_MAP_VALUE (nla_reserve_64bit)
+        */
+       return nla_total_size(0) + nla_total_size(sizeof(u32)) +
+               nla_total_size_64bit(value_size);
+}
+
+void bpf_sk_storage_diag_free(struct bpf_sk_storage_diag *diag)
+{
+       u32 i;
+
+       if (!diag)
+               return;
+
+       for (i = 0; i < diag->nr_maps; i++)
+               bpf_map_put(diag->maps[i]);
+
+       kfree(diag);
+}
+EXPORT_SYMBOL_GPL(bpf_sk_storage_diag_free);
+
+static bool diag_check_dup(const struct bpf_sk_storage_diag *diag,
+                          const struct bpf_map *map)
+{
+       u32 i;
+
+       for (i = 0; i < diag->nr_maps; i++) {
+               if (diag->maps[i] == map)
+                       return true;
+       }
+
+       return false;
+}
+
+struct bpf_sk_storage_diag *
+bpf_sk_storage_diag_alloc(const struct nlattr *nla_stgs)
+{
+       struct bpf_sk_storage_diag *diag;
+       struct nlattr *nla;
+       u32 nr_maps = 0;
+       int rem, err;
+
+       /* bpf_sk_storage_map is currently limited to CAP_SYS_ADMIN as
+        * the map_alloc_check() side also does.
+        */
+       if (!capable(CAP_SYS_ADMIN))
+               return ERR_PTR(-EPERM);
+
+       nla_for_each_nested(nla, nla_stgs, rem) {
+               if (nla_type(nla) == SK_DIAG_BPF_STORAGE_REQ_MAP_FD)
+                       nr_maps++;
+       }
+
+       diag = kzalloc(sizeof(*diag) + sizeof(diag->maps[0]) * nr_maps,
+                      GFP_KERNEL);
+       if (!diag)
+               return ERR_PTR(-ENOMEM);
+
+       nla_for_each_nested(nla, nla_stgs, rem) {
+               struct bpf_map *map;
+               int map_fd;
+
+               if (nla_type(nla) != SK_DIAG_BPF_STORAGE_REQ_MAP_FD)
+                       continue;
+
+               map_fd = nla_get_u32(nla);
+               map = bpf_map_get(map_fd);
+               if (IS_ERR(map)) {
+                       err = PTR_ERR(map);
+                       goto err_free;
+               }
+               if (map->map_type != BPF_MAP_TYPE_SK_STORAGE) {
+                       bpf_map_put(map);
+                       err = -EINVAL;
+                       goto err_free;
+               }
+               if (diag_check_dup(diag, map)) {
+                       bpf_map_put(map);
+                       err = -EEXIST;
+                       goto err_free;
+               }
+               diag->maps[diag->nr_maps++] = map;
+       }
+
+       return diag;
+
+err_free:
+       bpf_sk_storage_diag_free(diag);
+       return ERR_PTR(err);
+}
+EXPORT_SYMBOL_GPL(bpf_sk_storage_diag_alloc);
+
+static int diag_get(struct bpf_sk_storage_data *sdata, struct sk_buff *skb)
+{
+       struct nlattr *nla_stg, *nla_value;
+       struct bpf_sk_storage_map *smap;
+
+       /* It cannot exceed max nlattr's payload */
+       BUILD_BUG_ON(U16_MAX - NLA_HDRLEN < MAX_VALUE_SIZE);
+
+       nla_stg = nla_nest_start(skb, SK_DIAG_BPF_STORAGE);
+       if (!nla_stg)
+               return -EMSGSIZE;
+
+       smap = rcu_dereference(sdata->smap);
+       if (nla_put_u32(skb, SK_DIAG_BPF_STORAGE_MAP_ID, smap->map.id))
+               goto errout;
+
+       nla_value = nla_reserve_64bit(skb, SK_DIAG_BPF_STORAGE_MAP_VALUE,
+                                     smap->map.value_size,
+                                     SK_DIAG_BPF_STORAGE_PAD);
+       if (!nla_value)
+               goto errout;
+
+       if (map_value_has_spin_lock(&smap->map))
+               copy_map_value_locked(&smap->map, nla_data(nla_value),
+                                     sdata->data, true);
+       else
+               copy_map_value(&smap->map, nla_data(nla_value), sdata->data);
+
+       nla_nest_end(skb, nla_stg);
+       return 0;
+
+errout:
+       nla_nest_cancel(skb, nla_stg);
+       return -EMSGSIZE;
+}
+
+static int bpf_sk_storage_diag_put_all(struct sock *sk, struct sk_buff *skb,
+                                      int stg_array_type,
+                                      unsigned int *res_diag_size)
+{
+       /* stg_array_type (e.g. INET_DIAG_BPF_SK_STORAGES) */
+       unsigned int diag_size = nla_total_size(0);
+       struct bpf_sk_storage *sk_storage;
+       struct bpf_sk_storage_elem *selem;
+       struct bpf_sk_storage_map *smap;
+       struct nlattr *nla_stgs;
+       unsigned int saved_len;
+       int err = 0;
+
+       rcu_read_lock();
+
+       sk_storage = rcu_dereference(sk->sk_bpf_storage);
+       if (!sk_storage || hlist_empty(&sk_storage->list)) {
+               rcu_read_unlock();
+               return 0;
+       }
+
+       nla_stgs = nla_nest_start(skb, stg_array_type);
+       if (!nla_stgs)
+               /* Continue to learn diag_size */
+               err = -EMSGSIZE;
+
+       saved_len = skb->len;
+       hlist_for_each_entry_rcu(selem, &sk_storage->list, snode) {
+               smap = rcu_dereference(SDATA(selem)->smap);
+               diag_size += nla_value_size(smap->map.value_size);
+
+               if (nla_stgs && diag_get(SDATA(selem), skb))
+                       /* Continue to learn diag_size */
+                       err = -EMSGSIZE;
+       }
+
+       rcu_read_unlock();
+
+       if (nla_stgs) {
+               if (saved_len == skb->len)
+                       nla_nest_cancel(skb, nla_stgs);
+               else
+                       nla_nest_end(skb, nla_stgs);
+       }
+
+       if (diag_size == nla_total_size(0)) {
+               *res_diag_size = 0;
+               return 0;
+       }
+
+       *res_diag_size = diag_size;
+       return err;
+}
+
+int bpf_sk_storage_diag_put(struct bpf_sk_storage_diag *diag,
+                           struct sock *sk, struct sk_buff *skb,
+                           int stg_array_type,
+                           unsigned int *res_diag_size)
+{
+       /* stg_array_type (e.g. INET_DIAG_BPF_SK_STORAGES) */
+       unsigned int diag_size = nla_total_size(0);
+       struct bpf_sk_storage *sk_storage;
+       struct bpf_sk_storage_data *sdata;
+       struct nlattr *nla_stgs;
+       unsigned int saved_len;
+       int err = 0;
+       u32 i;
+
+       *res_diag_size = 0;
+
+       /* No map has been specified.  Dump all. */
+       if (!diag->nr_maps)
+               return bpf_sk_storage_diag_put_all(sk, skb, stg_array_type,
+                                                  res_diag_size);
+
+       rcu_read_lock();
+       sk_storage = rcu_dereference(sk->sk_bpf_storage);
+       if (!sk_storage || hlist_empty(&sk_storage->list)) {
+               rcu_read_unlock();
+               return 0;
+       }
+
+       nla_stgs = nla_nest_start(skb, stg_array_type);
+       if (!nla_stgs)
+               /* Continue to learn diag_size */
+               err = -EMSGSIZE;
+
+       saved_len = skb->len;
+       for (i = 0; i < diag->nr_maps; i++) {
+               sdata = __sk_storage_lookup(sk_storage,
+                               (struct bpf_sk_storage_map *)diag->maps[i],
+                               false);
+
+               if (!sdata)
+                       continue;
+
+               diag_size += nla_value_size(diag->maps[i]->value_size);
+
+               if (nla_stgs && diag_get(sdata, skb))
+                       /* Continue to learn diag_size */
+                       err = -EMSGSIZE;
+       }
+       rcu_read_unlock();
+
+       if (nla_stgs) {
+               if (saved_len == skb->len)
+                       nla_nest_cancel(skb, nla_stgs);
+               else
+                       nla_nest_end(skb, nla_stgs);
+       }
+
+       if (diag_size == nla_total_size(0)) {
+               *res_diag_size = 0;
+               return 0;
+       }
+
+       *res_diag_size = diag_size;
+       return err;
+}
+EXPORT_SYMBOL_GPL(bpf_sk_storage_diag_put);