--- /dev/null
+// SPDX-License-Identifier: GPL-2.0
+/*
+ * Copyright (c) 2023 Oracle and/or its affiliates.
+ *
+ * KUnit test of the handshake upcall mechanism.
+ */
+
+#include <kunit/test.h>
+#include <kunit/visibility.h>
+
+#include <linux/kernel.h>
+
+#include <net/sock.h>
+#include <net/genetlink.h>
+#include <net/netns/generic.h>
+
+#include <uapi/linux/handshake.h>
+#include "handshake.h"
+
+MODULE_IMPORT_NS(EXPORTED_FOR_KUNIT_TESTING);
+
+static int test_accept_func(struct handshake_req *req, struct genl_info *info,
+                           int fd)
+{
+       return 0;
+}
+
+static void test_done_func(struct handshake_req *req, unsigned int status,
+                          struct genl_info *info)
+{
+}
+
+struct handshake_req_alloc_test_param {
+       const char                      *desc;
+       struct handshake_proto          *proto;
+       gfp_t                           gfp;
+       bool                            expect_success;
+};
+
+static struct handshake_proto handshake_req_alloc_proto_2 = {
+       .hp_handler_class       = HANDSHAKE_HANDLER_CLASS_NONE,
+};
+
+static struct handshake_proto handshake_req_alloc_proto_3 = {
+       .hp_handler_class       = HANDSHAKE_HANDLER_CLASS_MAX,
+};
+
+static struct handshake_proto handshake_req_alloc_proto_4 = {
+       .hp_handler_class       = HANDSHAKE_HANDLER_CLASS_TLSHD,
+};
+
+static struct handshake_proto handshake_req_alloc_proto_5 = {
+       .hp_handler_class       = HANDSHAKE_HANDLER_CLASS_TLSHD,
+       .hp_accept              = test_accept_func,
+};
+
+static struct handshake_proto handshake_req_alloc_proto_6 = {
+       .hp_handler_class       = HANDSHAKE_HANDLER_CLASS_TLSHD,
+       .hp_privsize            = UINT_MAX,
+       .hp_accept              = test_accept_func,
+       .hp_done                = test_done_func,
+};
+
+static struct handshake_proto handshake_req_alloc_proto_good = {
+       .hp_handler_class       = HANDSHAKE_HANDLER_CLASS_TLSHD,
+       .hp_accept              = test_accept_func,
+       .hp_done                = test_done_func,
+};
+
+static const
+struct handshake_req_alloc_test_param handshake_req_alloc_params[] = {
+       {
+               .desc                   = "handshake_req_alloc NULL proto",
+               .proto                  = NULL,
+               .gfp                    = GFP_KERNEL,
+               .expect_success         = false,
+       },
+       {
+               .desc                   = "handshake_req_alloc CLASS_NONE",
+               .proto                  = &handshake_req_alloc_proto_2,
+               .gfp                    = GFP_KERNEL,
+               .expect_success         = false,
+       },
+       {
+               .desc                   = "handshake_req_alloc CLASS_MAX",
+               .proto                  = &handshake_req_alloc_proto_3,
+               .gfp                    = GFP_KERNEL,
+               .expect_success         = false,
+       },
+       {
+               .desc                   = "handshake_req_alloc no callbacks",
+               .proto                  = &handshake_req_alloc_proto_4,
+               .gfp                    = GFP_KERNEL,
+               .expect_success         = false,
+       },
+       {
+               .desc                   = "handshake_req_alloc no done callback",
+               .proto                  = &handshake_req_alloc_proto_5,
+               .gfp                    = GFP_KERNEL,
+               .expect_success         = false,
+       },
+       {
+               .desc                   = "handshake_req_alloc excessive privsize",
+               .proto                  = &handshake_req_alloc_proto_6,
+               .gfp                    = GFP_KERNEL,
+               .expect_success         = false,
+       },
+       {
+               .desc                   = "handshake_req_alloc all good",
+               .proto                  = &handshake_req_alloc_proto_good,
+               .gfp                    = GFP_KERNEL,
+               .expect_success         = true,
+       },
+};
+
+static void
+handshake_req_alloc_get_desc(const struct handshake_req_alloc_test_param *param,
+                            char *desc)
+{
+       strscpy(desc, param->desc, KUNIT_PARAM_DESC_SIZE);
+}
+
+/* Creates the function handshake_req_alloc_gen_params */
+KUNIT_ARRAY_PARAM(handshake_req_alloc, handshake_req_alloc_params,
+                 handshake_req_alloc_get_desc);
+
+static void handshake_req_alloc_case(struct kunit *test)
+{
+       const struct handshake_req_alloc_test_param *param = test->param_value;
+       struct handshake_req *result;
+
+       /* Arrange */
+
+       /* Act */
+       result = handshake_req_alloc(param->proto, param->gfp);
+
+       /* Assert */
+       if (param->expect_success)
+               KUNIT_EXPECT_NOT_NULL(test, result);
+       else
+               KUNIT_EXPECT_NULL(test, result);
+
+       kfree(result);
+}
+
+static void handshake_req_submit_test1(struct kunit *test)
+{
+       struct socket *sock;
+       int err, result;
+
+       /* Arrange */
+       err = __sock_create(&init_net, PF_INET, SOCK_STREAM, IPPROTO_TCP,
+                           &sock, 1);
+       KUNIT_ASSERT_EQ(test, err, 0);
+
+       /* Act */
+       result = handshake_req_submit(sock, NULL, GFP_KERNEL);
+
+       /* Assert */
+       KUNIT_EXPECT_EQ(test, result, -EINVAL);
+
+       sock_release(sock);
+}
+
+static void handshake_req_submit_test2(struct kunit *test)
+{
+       struct handshake_req *req;
+       int result;
+
+       /* Arrange */
+       req = handshake_req_alloc(&handshake_req_alloc_proto_good, GFP_KERNEL);
+       KUNIT_ASSERT_NOT_NULL(test, req);
+
+       /* Act */
+       result = handshake_req_submit(NULL, req, GFP_KERNEL);
+
+       /* Assert */
+       KUNIT_EXPECT_EQ(test, result, -EINVAL);
+
+       /* handshake_req_submit() destroys @req on error */
+}
+
+static void handshake_req_submit_test3(struct kunit *test)
+{
+       struct handshake_req *req;
+       struct socket *sock;
+       int err, result;
+
+       /* Arrange */
+       req = handshake_req_alloc(&handshake_req_alloc_proto_good, GFP_KERNEL);
+       KUNIT_ASSERT_NOT_NULL(test, req);
+
+       err = __sock_create(&init_net, PF_INET, SOCK_STREAM, IPPROTO_TCP,
+                           &sock, 1);
+       KUNIT_ASSERT_EQ(test, err, 0);
+       sock->file = NULL;
+
+       /* Act */
+       result = handshake_req_submit(sock, req, GFP_KERNEL);
+
+       /* Assert */
+       KUNIT_EXPECT_EQ(test, result, -EINVAL);
+
+       /* handshake_req_submit() destroys @req on error */
+       sock_release(sock);
+}
+
+static void handshake_req_submit_test4(struct kunit *test)
+{
+       struct handshake_req *req, *result;
+       struct socket *sock;
+       int err;
+
+       /* Arrange */
+       req = handshake_req_alloc(&handshake_req_alloc_proto_good, GFP_KERNEL);
+       KUNIT_ASSERT_NOT_NULL(test, req);
+
+       err = __sock_create(&init_net, PF_INET, SOCK_STREAM, IPPROTO_TCP,
+                           &sock, 1);
+       KUNIT_ASSERT_EQ(test, err, 0);
+       sock->file = sock_alloc_file(sock, O_NONBLOCK, NULL);
+       KUNIT_ASSERT_NOT_ERR_OR_NULL(test, sock->file);
+       KUNIT_ASSERT_NOT_NULL(test, sock->sk);
+
+       err = handshake_req_submit(sock, req, GFP_KERNEL);
+       KUNIT_ASSERT_EQ(test, err, 0);
+
+       /* Act */
+       result = handshake_req_hash_lookup(sock->sk);
+
+       /* Assert */
+       KUNIT_EXPECT_NOT_NULL(test, result);
+       KUNIT_EXPECT_PTR_EQ(test, req, result);
+
+       handshake_req_cancel(sock->sk);
+       sock_release(sock);
+}
+
+static void handshake_req_submit_test5(struct kunit *test)
+{
+       struct handshake_req *req;
+       struct handshake_net *hn;
+       struct socket *sock;
+       struct net *net;
+       int saved, err;
+
+       /* Arrange */
+       req = handshake_req_alloc(&handshake_req_alloc_proto_good, GFP_KERNEL);
+       KUNIT_ASSERT_NOT_NULL(test, req);
+
+       err = __sock_create(&init_net, PF_INET, SOCK_STREAM, IPPROTO_TCP,
+                           &sock, 1);
+       KUNIT_ASSERT_EQ(test, err, 0);
+       sock->file = sock_alloc_file(sock, O_NONBLOCK, NULL);
+       KUNIT_ASSERT_NOT_ERR_OR_NULL(test, sock->file);
+       KUNIT_ASSERT_NOT_NULL(test, sock->sk);
+
+       net = sock_net(sock->sk);
+       hn = handshake_pernet(net);
+       KUNIT_ASSERT_NOT_NULL(test, hn);
+
+       saved = hn->hn_pending;
+       hn->hn_pending = hn->hn_pending_max + 1;
+
+       /* Act */
+       err = handshake_req_submit(sock, req, GFP_KERNEL);
+
+       /* Assert */
+       KUNIT_EXPECT_EQ(test, err, -EAGAIN);
+
+       sock_release(sock);
+       hn->hn_pending = saved;
+}
+
+static void handshake_req_submit_test6(struct kunit *test)
+{
+       struct handshake_req *req1, *req2;
+       struct socket *sock;
+       int err;
+
+       /* Arrange */
+       req1 = handshake_req_alloc(&handshake_req_alloc_proto_good, GFP_KERNEL);
+       KUNIT_ASSERT_NOT_NULL(test, req1);
+       req2 = handshake_req_alloc(&handshake_req_alloc_proto_good, GFP_KERNEL);
+       KUNIT_ASSERT_NOT_NULL(test, req2);
+
+       err = __sock_create(&init_net, PF_INET, SOCK_STREAM, IPPROTO_TCP,
+                           &sock, 1);
+       KUNIT_ASSERT_EQ(test, err, 0);
+       sock->file = sock_alloc_file(sock, O_NONBLOCK, NULL);
+       KUNIT_ASSERT_NOT_ERR_OR_NULL(test, sock->file);
+       KUNIT_ASSERT_NOT_NULL(test, sock->sk);
+
+       /* Act */
+       err = handshake_req_submit(sock, req1, GFP_KERNEL);
+       KUNIT_ASSERT_EQ(test, err, 0);
+       err = handshake_req_submit(sock, req2, GFP_KERNEL);
+
+       /* Assert */
+       KUNIT_EXPECT_EQ(test, err, -EBUSY);
+
+       handshake_req_cancel(sock->sk);
+       sock_release(sock);
+}
+
+static void handshake_req_cancel_test1(struct kunit *test)
+{
+       struct handshake_req *req;
+       struct socket *sock;
+       bool result;
+       int err;
+
+       /* Arrange */
+       req = handshake_req_alloc(&handshake_req_alloc_proto_good, GFP_KERNEL);
+       KUNIT_ASSERT_NOT_NULL(test, req);
+
+       err = __sock_create(&init_net, PF_INET, SOCK_STREAM, IPPROTO_TCP,
+                           &sock, 1);
+       KUNIT_ASSERT_EQ(test, err, 0);
+
+       sock->file = sock_alloc_file(sock, O_NONBLOCK, NULL);
+       KUNIT_ASSERT_NOT_ERR_OR_NULL(test, sock->file);
+
+       err = handshake_req_submit(sock, req, GFP_KERNEL);
+       KUNIT_ASSERT_EQ(test, err, 0);
+
+       /* NB: handshake_req hasn't been accepted */
+
+       /* Act */
+       result = handshake_req_cancel(sock->sk);
+
+       /* Assert */
+       KUNIT_EXPECT_TRUE(test, result);
+
+       sock_release(sock);
+}
+
+static void handshake_req_cancel_test2(struct kunit *test)
+{
+       struct handshake_req *req, *next;
+       struct handshake_net *hn;
+       struct socket *sock;
+       struct net *net;
+       bool result;
+       int err;
+
+       /* Arrange */
+       req = handshake_req_alloc(&handshake_req_alloc_proto_good, GFP_KERNEL);
+       KUNIT_ASSERT_NOT_NULL(test, req);
+
+       err = __sock_create(&init_net, PF_INET, SOCK_STREAM, IPPROTO_TCP,
+                           &sock, 1);
+       KUNIT_ASSERT_EQ(test, err, 0);
+
+       sock->file = sock_alloc_file(sock, O_NONBLOCK, NULL);
+       KUNIT_ASSERT_NOT_ERR_OR_NULL(test, sock->file);
+
+       err = handshake_req_submit(sock, req, GFP_KERNEL);
+       KUNIT_ASSERT_EQ(test, err, 0);
+
+       net = sock_net(sock->sk);
+       hn = handshake_pernet(net);
+       KUNIT_ASSERT_NOT_NULL(test, hn);
+
+       /* Pretend to accept this request */
+       next = handshake_req_next(hn, HANDSHAKE_HANDLER_CLASS_TLSHD);
+       KUNIT_ASSERT_PTR_EQ(test, req, next);
+
+       /* Act */
+       result = handshake_req_cancel(sock->sk);
+
+       /* Assert */
+       KUNIT_EXPECT_TRUE(test, result);
+
+       sock_release(sock);
+}
+
+static void handshake_req_cancel_test3(struct kunit *test)
+{
+       struct handshake_req *req, *next;
+       struct handshake_net *hn;
+       struct socket *sock;
+       struct net *net;
+       bool result;
+       int err;
+
+       /* Arrange */
+       req = handshake_req_alloc(&handshake_req_alloc_proto_good, GFP_KERNEL);
+       KUNIT_ASSERT_NOT_NULL(test, req);
+
+       err = __sock_create(&init_net, PF_INET, SOCK_STREAM, IPPROTO_TCP,
+                           &sock, 1);
+       KUNIT_ASSERT_EQ(test, err, 0);
+
+       sock->file = sock_alloc_file(sock, O_NONBLOCK, NULL);
+       KUNIT_ASSERT_NOT_ERR_OR_NULL(test, sock->file);
+
+       err = handshake_req_submit(sock, req, GFP_KERNEL);
+       KUNIT_ASSERT_EQ(test, err, 0);
+
+       net = sock_net(sock->sk);
+       hn = handshake_pernet(net);
+       KUNIT_ASSERT_NOT_NULL(test, hn);
+
+       /* Pretend to accept this request */
+       next = handshake_req_next(hn, HANDSHAKE_HANDLER_CLASS_TLSHD);
+       KUNIT_ASSERT_PTR_EQ(test, req, next);
+
+       /* Pretend to complete this request */
+       handshake_complete(next, -ETIMEDOUT, NULL);
+
+       /* Act */
+       result = handshake_req_cancel(sock->sk);
+
+       /* Assert */
+       KUNIT_EXPECT_FALSE(test, result);
+
+       sock_release(sock);
+}
+
+static struct handshake_req *handshake_req_destroy_test;
+
+static void test_destroy_func(struct handshake_req *req)
+{
+       handshake_req_destroy_test = req;
+}
+
+static struct handshake_proto handshake_req_alloc_proto_destroy = {
+       .hp_handler_class       = HANDSHAKE_HANDLER_CLASS_TLSHD,
+       .hp_accept              = test_accept_func,
+       .hp_done                = test_done_func,
+       .hp_destroy             = test_destroy_func,
+};
+
+static void handshake_req_destroy_test1(struct kunit *test)
+{
+       struct handshake_req *req;
+       struct socket *sock;
+       int err;
+
+       /* Arrange */
+       handshake_req_destroy_test = NULL;
+
+       req = handshake_req_alloc(&handshake_req_alloc_proto_destroy, GFP_KERNEL);
+       KUNIT_ASSERT_NOT_NULL(test, req);
+
+       err = __sock_create(&init_net, PF_INET, SOCK_STREAM, IPPROTO_TCP,
+                           &sock, 1);
+       KUNIT_ASSERT_EQ(test, err, 0);
+
+       sock->file = sock_alloc_file(sock, O_NONBLOCK, NULL);
+       KUNIT_ASSERT_NOT_ERR_OR_NULL(test, sock->file);
+
+       err = handshake_req_submit(sock, req, GFP_KERNEL);
+       KUNIT_ASSERT_EQ(test, err, 0);
+
+       handshake_req_cancel(sock->sk);
+
+       /* Act */
+       sock_release(sock);
+
+       /* Assert */
+       KUNIT_EXPECT_PTR_EQ(test, handshake_req_destroy_test, req);
+}
+
+static struct kunit_case handshake_api_test_cases[] = {
+       {
+               .name                   = "req_alloc API fuzzing",
+               .run_case               = handshake_req_alloc_case,
+               .generate_params        = handshake_req_alloc_gen_params,
+       },
+       {
+               .name                   = "req_submit NULL req arg",
+               .run_case               = handshake_req_submit_test1,
+       },
+       {
+               .name                   = "req_submit NULL sock arg",
+               .run_case               = handshake_req_submit_test2,
+       },
+       {
+               .name                   = "req_submit NULL sock->file",
+               .run_case               = handshake_req_submit_test3,
+       },
+       {
+               .name                   = "req_lookup works",
+               .run_case               = handshake_req_submit_test4,
+       },
+       {
+               .name                   = "req_submit max pending",
+               .run_case               = handshake_req_submit_test5,
+       },
+       {
+               .name                   = "req_submit multiple",
+               .run_case               = handshake_req_submit_test6,
+       },
+       {
+               .name                   = "req_cancel before accept",
+               .run_case               = handshake_req_cancel_test1,
+       },
+       {
+               .name                   = "req_cancel after accept",
+               .run_case               = handshake_req_cancel_test2,
+       },
+       {
+               .name                   = "req_cancel after done",
+               .run_case               = handshake_req_cancel_test3,
+       },
+       {
+               .name                   = "req_destroy works",
+               .run_case               = handshake_req_destroy_test1,
+       },
+       {}
+};
+
+static struct kunit_suite handshake_api_suite = {
+       .name                   = "Handshake API tests",
+       .test_cases             = handshake_api_test_cases,
+};
+
+kunit_test_suites(&handshake_api_suite);
+
+MODULE_DESCRIPTION("Test handshake upcall API functions");
+MODULE_LICENSE("GPL");