#include "../kselftest.h"
 #include "vm_util.h"
 
+#ifndef MADV_COLLAPSE
+#define MADV_COLLAPSE 25
+#endif
+
 static size_t pagesize;
 static int pagemap_fd;
 static size_t thpsize;
        return tests;
 }
 
+enum anon_thp_collapse_test {
+       ANON_THP_COLLAPSE_UNSHARED,
+       ANON_THP_COLLAPSE_FULLY_SHARED,
+       ANON_THP_COLLAPSE_LOWER_SHARED,
+       ANON_THP_COLLAPSE_UPPER_SHARED,
+};
+
+static void do_test_anon_thp_collapse(char *mem, size_t size,
+                                     enum anon_thp_collapse_test test)
+{
+       struct comm_pipes comm_pipes;
+       char buf;
+       int ret;
+
+       ret = setup_comm_pipes(&comm_pipes);
+       if (ret) {
+               ksft_test_result_fail("pipe() failed\n");
+               return;
+       }
+
+       /*
+        * Trigger PTE-mapping the THP by temporarily mapping a single subpage
+        * R/O, such that we can try collapsing it later.
+        */
+       ret = mprotect(mem + pagesize, pagesize, PROT_READ);
+       if (ret) {
+               ksft_test_result_fail("mprotect() failed\n");
+               goto close_comm_pipes;
+       }
+       ret = mprotect(mem + pagesize, pagesize, PROT_READ | PROT_WRITE);
+       if (ret) {
+               ksft_test_result_fail("mprotect() failed\n");
+               goto close_comm_pipes;
+       }
+
+       switch (test) {
+       case ANON_THP_COLLAPSE_UNSHARED:
+               /* Collapse before actually COW-sharing the page. */
+               ret = madvise(mem, size, MADV_COLLAPSE);
+               if (ret) {
+                       ksft_test_result_skip("MADV_COLLAPSE failed: %s\n",
+                                             strerror(errno));
+                       goto close_comm_pipes;
+               }
+               break;
+       case ANON_THP_COLLAPSE_FULLY_SHARED:
+               /* COW-share the full PTE-mapped THP. */
+               break;
+       case ANON_THP_COLLAPSE_LOWER_SHARED:
+               /* Don't COW-share the upper part of the THP. */
+               ret = madvise(mem + size / 2, size / 2, MADV_DONTFORK);
+               if (ret) {
+                       ksft_test_result_fail("MADV_DONTFORK failed\n");
+                       goto close_comm_pipes;
+               }
+               break;
+       case ANON_THP_COLLAPSE_UPPER_SHARED:
+               /* Don't COW-share the lower part of the THP. */
+               ret = madvise(mem, size / 2, MADV_DONTFORK);
+               if (ret) {
+                       ksft_test_result_fail("MADV_DONTFORK failed\n");
+                       goto close_comm_pipes;
+               }
+               break;
+       default:
+               assert(false);
+       }
+
+       ret = fork();
+       if (ret < 0) {
+               ksft_test_result_fail("fork() failed\n");
+               goto close_comm_pipes;
+       } else if (!ret) {
+               switch (test) {
+               case ANON_THP_COLLAPSE_UNSHARED:
+               case ANON_THP_COLLAPSE_FULLY_SHARED:
+                       exit(child_memcmp_fn(mem, size, &comm_pipes));
+                       break;
+               case ANON_THP_COLLAPSE_LOWER_SHARED:
+                       exit(child_memcmp_fn(mem, size / 2, &comm_pipes));
+                       break;
+               case ANON_THP_COLLAPSE_UPPER_SHARED:
+                       exit(child_memcmp_fn(mem + size / 2, size / 2,
+                                            &comm_pipes));
+                       break;
+               default:
+                       assert(false);
+               }
+       }
+
+       while (read(comm_pipes.child_ready[0], &buf, 1) != 1)
+               ;
+
+       switch (test) {
+       case ANON_THP_COLLAPSE_UNSHARED:
+               break;
+       case ANON_THP_COLLAPSE_UPPER_SHARED:
+       case ANON_THP_COLLAPSE_LOWER_SHARED:
+               /*
+                * Revert MADV_DONTFORK such that we merge the VMAs and are
+                * able to actually collapse.
+                */
+               ret = madvise(mem, size, MADV_DOFORK);
+               if (ret) {
+                       ksft_test_result_fail("MADV_DOFORK failed\n");
+                       write(comm_pipes.parent_ready[1], "0", 1);
+                       wait(&ret);
+                       goto close_comm_pipes;
+               }
+               /* FALLTHROUGH */
+       case ANON_THP_COLLAPSE_FULLY_SHARED:
+               /* Collapse before anyone modified the COW-shared page. */
+               ret = madvise(mem, size, MADV_COLLAPSE);
+               if (ret) {
+                       ksft_test_result_skip("MADV_COLLAPSE failed: %s\n",
+                                             strerror(errno));
+                       write(comm_pipes.parent_ready[1], "0", 1);
+                       wait(&ret);
+                       goto close_comm_pipes;
+               }
+               break;
+       default:
+               assert(false);
+       }
+
+       /* Modify the page. */
+       memset(mem, 0xff, size);
+       write(comm_pipes.parent_ready[1], "0", 1);
+
+       wait(&ret);
+       if (WIFEXITED(ret))
+               ret = WEXITSTATUS(ret);
+       else
+               ret = -EINVAL;
+
+       ksft_test_result(!ret, "No leak from parent into child\n");
+close_comm_pipes:
+       close_comm_pipes(&comm_pipes);
+}
+
+static void test_anon_thp_collapse_unshared(char *mem, size_t size)
+{
+       do_test_anon_thp_collapse(mem, size, ANON_THP_COLLAPSE_UNSHARED);
+}
+
+static void test_anon_thp_collapse_fully_shared(char *mem, size_t size)
+{
+       do_test_anon_thp_collapse(mem, size, ANON_THP_COLLAPSE_FULLY_SHARED);
+}
+
+static void test_anon_thp_collapse_lower_shared(char *mem, size_t size)
+{
+       do_test_anon_thp_collapse(mem, size, ANON_THP_COLLAPSE_LOWER_SHARED);
+}
+
+static void test_anon_thp_collapse_upper_shared(char *mem, size_t size)
+{
+       do_test_anon_thp_collapse(mem, size, ANON_THP_COLLAPSE_UPPER_SHARED);
+}
+
+/*
+ * Test cases that are specific to anonymous THP: pages in private mappings
+ * that may get shared via COW during fork().
+ */
+static const struct test_case anon_thp_test_cases[] = {
+       /*
+        * Basic COW test for fork() without any GUP when collapsing a THP
+        * before fork().
+        *
+        * Re-mapping a PTE-mapped anon THP using a single PMD ("in-place
+        * collapse") might easily get COW handling wrong when not collapsing
+        * exclusivity information properly.
+        */
+       {
+               "Basic COW after fork() when collapsing before fork()",
+               test_anon_thp_collapse_unshared,
+       },
+       /* Basic COW test, but collapse after COW-sharing a full THP. */
+       {
+               "Basic COW after fork() when collapsing after fork() (fully shared)",
+               test_anon_thp_collapse_fully_shared,
+       },
+       /*
+        * Basic COW test, but collapse after COW-sharing the lower half of a
+        * THP.
+        */
+       {
+               "Basic COW after fork() when collapsing after fork() (lower shared)",
+               test_anon_thp_collapse_lower_shared,
+       },
+       /*
+        * Basic COW test, but collapse after COW-sharing the upper half of a
+        * THP.
+        */
+       {
+               "Basic COW after fork() when collapsing after fork() (upper shared)",
+               test_anon_thp_collapse_upper_shared,
+       },
+};
+
+static void run_anon_thp_test_cases(void)
+{
+       int i;
+
+       if (!thpsize)
+               return;
+
+       ksft_print_msg("[INFO] Anonymous THP tests\n");
+
+       for (i = 0; i < ARRAY_SIZE(anon_thp_test_cases); i++) {
+               struct test_case const *test_case = &anon_thp_test_cases[i];
+
+               ksft_print_msg("[RUN] %s\n", test_case->desc);
+               do_run_with_thp(test_case->fn, THP_RUN_PMD);
+       }
+}
+
+static int tests_per_anon_thp_test_case(void)
+{
+       return thpsize ? 1 : 0;
+}
+
 typedef void (*non_anon_test_fn)(char *mem, const char *smem, size_t size);
 
 static void test_cow(char *mem, const char *smem, size_t size)
 
        ksft_print_header();
        ksft_set_plan(ARRAY_SIZE(anon_test_cases) * tests_per_anon_test_case() +
+                     ARRAY_SIZE(anon_thp_test_cases) * tests_per_anon_thp_test_case() +
                      ARRAY_SIZE(non_anon_test_cases) * tests_per_non_anon_test_case());
 
        gup_fd = open("/sys/kernel/debug/gup_test", O_RDWR);
                ksft_exit_fail_msg("opening pagemap failed\n");
 
        run_anon_test_cases();
+       run_anon_thp_test_cases();
        run_non_anon_test_cases();
 
        err = ksft_get_fail_cnt();