]> www.infradead.org Git - users/jedix/linux-maple.git/commitdiff
selftests/bpf: Create established sockets in socket iterator tests
authorJordan Rife <jordan@jrife.io>
Mon, 14 Jul 2025 18:09:14 +0000 (11:09 -0700)
committerMartin KaFai Lau <martin.lau@kernel.org>
Mon, 14 Jul 2025 22:12:52 +0000 (15:12 -0700)
Prepare for bucket resume tests for established TCP sockets by creating
established sockets. Collect socket fds from connect() and accept()
sides and pass them to test cases.

Signed-off-by: Jordan Rife <jordan@jrife.io>
Signed-off-by: Martin KaFai Lau <martin.lau@kernel.org>
Acked-by: Stanislav Fomichev <sdf@fomichev.me>
tools/testing/selftests/bpf/prog_tests/sock_iter_batch.c

index 60b45685ef72a33a6bb710ccba3d821688973cfc..0ccc90d1d1ef00f83129dcdc1e40025871f6d7bc 100644 (file)
@@ -1,6 +1,7 @@
 // SPDX-License-Identifier: GPL-2.0
 // Copyright (c) 2024 Meta
 
+#include <poll.h>
 #include <test_progs.h>
 #include "network_helpers.h"
 #include "sock_iter_batch.skel.h"
@@ -153,8 +154,71 @@ static void check_n_were_seen_once(int *fds, int fds_len, int n,
        ASSERT_EQ(seen_once, n, "seen_once");
 }
 
+static int accept_from_one(struct pollfd *server_poll_fds,
+                          int server_poll_fds_len)
+{
+       static const int poll_timeout_ms = 5000; /* 5s */
+       int ret;
+       int i;
+
+       ret = poll(server_poll_fds, server_poll_fds_len, poll_timeout_ms);
+       if (!ASSERT_EQ(ret, 1, "poll"))
+               return -1;
+
+       for (i = 0; i < server_poll_fds_len; i++)
+               if (server_poll_fds[i].revents & POLLIN)
+                       return accept(server_poll_fds[i].fd, NULL, NULL);
+
+       return -1;
+}
+
+static int *connect_to_server(int family, int sock_type, const char *addr,
+                             __u16 port, int nr_connects, int *server_fds,
+                             int server_fds_len)
+{
+       struct pollfd *server_poll_fds = NULL;
+       int *established_socks = NULL;
+       int i;
+
+       server_poll_fds = calloc(server_fds_len, sizeof(*server_poll_fds));
+       if (!ASSERT_OK_PTR(server_poll_fds, "server_poll_fds"))
+               return NULL;
+
+       for (i = 0; i < server_fds_len; i++) {
+               server_poll_fds[i].fd = server_fds[i];
+               server_poll_fds[i].events = POLLIN;
+       }
+
+       i = 0;
+
+       established_socks = malloc(sizeof(*established_socks) * nr_connects*2);
+       if (!ASSERT_OK_PTR(established_socks, "established_socks"))
+               goto error;
+
+       while (nr_connects--) {
+               established_socks[i] = connect_to_addr_str(family, sock_type,
+                                                          addr, port, NULL);
+               if (!ASSERT_OK_FD(established_socks[i], "connect_to_addr_str"))
+                       goto error;
+               i++;
+               established_socks[i] = accept_from_one(server_poll_fds,
+                                                      server_fds_len);
+               if (!ASSERT_OK_FD(established_socks[i], "accept_from_one"))
+                       goto error;
+               i++;
+       }
+
+       free(server_poll_fds);
+       return established_socks;
+error:
+       free_fds(established_socks, i);
+       free(server_poll_fds);
+       return NULL;
+}
+
 static void remove_seen(int family, int sock_type, const char *addr, __u16 port,
-                       int *socks, int socks_len, struct sock_count *counts,
+                       int *socks, int socks_len, int *established_socks,
+                       int established_socks_len, struct sock_count *counts,
                        int counts_len, struct bpf_link *link, int iter_fd)
 {
        int close_idx;
@@ -185,6 +249,7 @@ static void remove_seen(int family, int sock_type, const char *addr, __u16 port,
 
 static void remove_unseen(int family, int sock_type, const char *addr,
                          __u16 port, int *socks, int socks_len,
+                         int *established_socks, int established_socks_len,
                          struct sock_count *counts, int counts_len,
                          struct bpf_link *link, int iter_fd)
 {
@@ -217,6 +282,7 @@ static void remove_unseen(int family, int sock_type, const char *addr,
 
 static void remove_all(int family, int sock_type, const char *addr,
                       __u16 port, int *socks, int socks_len,
+                      int *established_socks, int established_socks_len,
                       struct sock_count *counts, int counts_len,
                       struct bpf_link *link, int iter_fd)
 {
@@ -244,7 +310,8 @@ static void remove_all(int family, int sock_type, const char *addr,
 }
 
 static void add_some(int family, int sock_type, const char *addr, __u16 port,
-                    int *socks, int socks_len, struct sock_count *counts,
+                    int *socks, int socks_len, int *established_socks,
+                    int established_socks_len, struct sock_count *counts,
                     int counts_len, struct bpf_link *link, int iter_fd)
 {
        int *new_socks = NULL;
@@ -274,6 +341,7 @@ done:
 
 static void force_realloc(int family, int sock_type, const char *addr,
                          __u16 port, int *socks, int socks_len,
+                         int *established_socks, int established_socks_len,
                          struct sock_count *counts, int counts_len,
                          struct bpf_link *link, int iter_fd)
 {
@@ -302,10 +370,12 @@ done:
 
 struct test_case {
        void (*test)(int family, int sock_type, const char *addr, __u16 port,
-                    int *socks, int socks_len, struct sock_count *counts,
+                    int *socks, int socks_len, int *established_socks,
+                    int established_socks_len, struct sock_count *counts,
                     int counts_len, struct bpf_link *link, int iter_fd);
        const char *description;
        int ehash_buckets;
+       int connections;
        int init_socks;
        int max_socks;
        int sock_type;
@@ -416,6 +486,7 @@ static void do_resume_test(struct test_case *tc)
        static const __u16 port = 10001;
        struct nstoken *nstoken = NULL;
        struct bpf_link *link = NULL;
+       int *established_fds = NULL;
        int err, iter_fd = -1;
        const char *addr;
        int *fds = NULL;
@@ -444,6 +515,14 @@ static void do_resume_test(struct test_case *tc)
                                     tc->init_socks);
        if (!ASSERT_OK_PTR(fds, "start_reuseport_server"))
                goto done;
+       if (tc->connections) {
+               established_fds = connect_to_server(tc->family, tc->sock_type,
+                                                   addr, port,
+                                                   tc->connections, fds,
+                                                   tc->init_socks);
+               if (!ASSERT_OK_PTR(established_fds, "connect_to_server"))
+                       goto done;
+       }
        skel->rodata->ports[0] = 0;
        skel->rodata->ports[1] = 0;
        skel->rodata->sf = tc->family;
@@ -465,13 +544,15 @@ static void do_resume_test(struct test_case *tc)
                goto done;
 
        tc->test(tc->family, tc->sock_type, addr, port, fds, tc->init_socks,
-                counts, tc->max_socks, link, iter_fd);
+                established_fds, tc->connections*2, counts, tc->max_socks,
+                link, iter_fd);
 done:
        close_netns(nstoken);
        SYS_NOFAIL("ip netns del " TEST_CHILD_NS);
        SYS_NOFAIL("sysctl -w net.ipv4.tcp_child_ehash_entries=0");
        free(counts);
        free_fds(fds, tc->init_socks);
+       free_fds(established_fds, tc->connections*2);
        if (iter_fd >= 0)
                close(iter_fd);
        bpf_link__destroy(link);