#include <sys/random.h>
 
 #include "../kselftest.h"
+#include "vm_util.h"
 
 #ifdef __NR_userfaultfd
 
-static unsigned long nr_cpus, nr_pages, nr_pages_per_cpu, page_size;
+static unsigned long nr_cpus, nr_pages, nr_pages_per_cpu, page_size, hpage_size;
 
 #define BOUNCE_RANDOM          (1<<0)
 #define BOUNCE_RACINGFAULTS    (1<<1)
 
 #define UFFD_FLAGS     (O_CLOEXEC | O_NONBLOCK | UFFD_USER_MODE_ONLY)
 
+#define BASE_PMD_ADDR ((void *)(1UL << 30))
+
 /* test using /dev/userfaultfd, instead of userfaultfd(2) */
 static bool test_dev_userfaultfd;
 
 static unsigned long long *count_verify;
 static int uffd = -1;
 static int uffd_flags, finished, *pipefd;
-static char *area_src, *area_src_alias, *area_dst, *area_dst_alias;
+static char *area_src, *area_src_alias, *area_dst, *area_dst_alias, *area_remap;
 static char *zeropage;
 pthread_attr_t attr;
+static bool test_collapse;
 
 /* Userfaultfd test statistics */
 struct uffd_stats {
 #define swap(a, b) \
        do { typeof(a) __tmp = (a); (a) = (b); (b) = __tmp; } while (0)
 
+#define factor_of_2(x) ((x) ^ ((x) & ((x) - 1)))
+
 const char *examples =
     "# Run anonymous memory test on 100MiB region with 99999 bounces:\n"
     "./userfaultfd anon 100 99999\n\n"
                "Supported mods:\n");
        fprintf(stderr, "\tsyscall - Use userfaultfd(2) (default)\n");
        fprintf(stderr, "\tdev - Use /dev/userfaultfd instead of userfaultfd(2)\n");
+       fprintf(stderr, "\tcollapse - Test MADV_COLLAPSE of UFFDIO_REGISTER_MODE_MINOR\n"
+               "memory\n");
        fprintf(stderr, "\nExample test mod usage:\n");
        fprintf(stderr, "# Run anonymous memory test with /dev/userfaultfd:\n");
        fprintf(stderr, "./userfaultfd anon:dev 100 99999\n\n");
                err("madvise(MADV_DONTNEED) failed");
 }
 
-static void anon_allocate_area(void **alloc_area)
+static void anon_allocate_area(void **alloc_area, bool is_src)
 {
        *alloc_area = mmap(NULL, nr_pages * page_size, PROT_READ | PROT_WRITE,
                           MAP_ANONYMOUS | MAP_PRIVATE, -1, 0);
-       if (*alloc_area == MAP_FAILED)
-               err("mmap of anonymous memory failed");
 }
 
 static void noop_alias_mapping(__u64 *start, size_t len, unsigned long offset)
        }
 }
 
-static void hugetlb_allocate_area(void **alloc_area)
+static void hugetlb_allocate_area(void **alloc_area, bool is_src)
 {
        void *area_alias = NULL;
        char **alloc_area_alias;
                        nr_pages * page_size,
                        PROT_READ | PROT_WRITE,
                        MAP_PRIVATE | MAP_ANONYMOUS | MAP_HUGETLB |
-                               (*alloc_area == area_src ? 0 : MAP_NORESERVE),
+                               (is_src ? 0 : MAP_NORESERVE),
                        -1,
                        0);
        else
                        nr_pages * page_size,
                        PROT_READ | PROT_WRITE,
                        MAP_SHARED |
-                               (*alloc_area == area_src ? 0 : MAP_NORESERVE),
+                               (is_src ? 0 : MAP_NORESERVE),
                        huge_fd,
-                       *alloc_area == area_src ? 0 : nr_pages * page_size);
+                       is_src ? 0 : nr_pages * page_size);
        if (*alloc_area == MAP_FAILED)
                err("mmap of hugetlbfs file failed");
 
                        PROT_READ | PROT_WRITE,
                        MAP_SHARED,
                        huge_fd,
-                       *alloc_area == area_src ? 0 : nr_pages * page_size);
+                       is_src ? 0 : nr_pages * page_size);
                if (area_alias == MAP_FAILED)
                        err("mmap of hugetlb file alias failed");
        }
 
-       if (*alloc_area == area_src) {
+       if (is_src) {
                alloc_area_alias = &area_src_alias;
        } else {
                alloc_area_alias = &area_dst_alias;
                err("madvise(MADV_REMOVE) failed");
 }
 
-static void shmem_allocate_area(void **alloc_area)
+static void shmem_allocate_area(void **alloc_area, bool is_src)
 {
        void *area_alias = NULL;
-       bool is_src = alloc_area == (void **)&area_src;
-       unsigned long offset = is_src ? 0 : nr_pages * page_size;
+       size_t bytes = nr_pages * page_size;
+       unsigned long offset = is_src ? 0 : bytes;
+       char *p = NULL, *p_alias = NULL;
+
+       if (test_collapse) {
+               p = BASE_PMD_ADDR;
+               if (!is_src)
+                       /* src map + alias + interleaved hpages */
+                       p += 2 * (bytes + hpage_size);
+               p_alias = p;
+               p_alias += bytes;
+               p_alias += hpage_size;  /* Prevent src/dst VMA merge */
+       }
 
-       *alloc_area = mmap(NULL, nr_pages * page_size, PROT_READ | PROT_WRITE,
-                          MAP_SHARED, shm_fd, offset);
+       *alloc_area = mmap(p, bytes, PROT_READ | PROT_WRITE, MAP_SHARED,
+                          shm_fd, offset);
        if (*alloc_area == MAP_FAILED)
                err("mmap of memfd failed");
+       if (test_collapse && *alloc_area != p)
+               err("mmap of memfd failed at %p", p);
 
-       area_alias = mmap(NULL, nr_pages * page_size, PROT_READ | PROT_WRITE,
-                         MAP_SHARED, shm_fd, offset);
+       area_alias = mmap(p_alias, bytes, PROT_READ | PROT_WRITE, MAP_SHARED,
+                         shm_fd, offset);
        if (area_alias == MAP_FAILED)
                err("mmap of memfd alias failed");
+       if (test_collapse && area_alias != p_alias)
+               err("mmap of anonymous memory failed at %p", p_alias);
 
        if (is_src)
                area_src_alias = area_alias;
        *start = (unsigned long)area_dst_alias + offset;
 }
 
+static void shmem_check_pmd_mapping(void *p, int expect_nr_hpages)
+{
+       if (!check_huge_shmem(area_dst_alias, expect_nr_hpages, hpage_size))
+               err("Did not find expected %d number of hugepages",
+                   expect_nr_hpages);
+}
+
 struct uffd_test_ops {
-       void (*allocate_area)(void **alloc_area);
+       void (*allocate_area)(void **alloc_area, bool is_src);
        void (*release_pages)(char *rel_area);
        void (*alias_mapping)(__u64 *start, size_t len, unsigned long offset);
+       void (*check_pmd_mapping)(void *p, int expect_nr_hpages);
 };
 
 static struct uffd_test_ops anon_uffd_test_ops = {
        .allocate_area  = anon_allocate_area,
        .release_pages  = anon_release_pages,
        .alias_mapping = noop_alias_mapping,
+       .check_pmd_mapping = NULL,
 };
 
 static struct uffd_test_ops shmem_uffd_test_ops = {
        .allocate_area  = shmem_allocate_area,
        .release_pages  = shmem_release_pages,
        .alias_mapping = shmem_alias_mapping,
+       .check_pmd_mapping = shmem_check_pmd_mapping,
 };
 
 static struct uffd_test_ops hugetlb_uffd_test_ops = {
        .allocate_area  = hugetlb_allocate_area,
        .release_pages  = hugetlb_release_pages,
        .alias_mapping = hugetlb_alias_mapping,
+       .check_pmd_mapping = NULL,
 };
 
 static struct uffd_test_ops *uffd_test_ops;
        munmap_area((void **)&area_src_alias);
        munmap_area((void **)&area_dst);
        munmap_area((void **)&area_dst_alias);
+       munmap_area((void **)&area_remap);
 }
 
 static void uffd_test_ctx_init(uint64_t features)
 
        uffd_test_ctx_clear();
 
-       uffd_test_ops->allocate_area((void **)&area_src);
-       uffd_test_ops->allocate_area((void **)&area_dst);
+       uffd_test_ops->allocate_area((void **)&area_src, true);
+       uffd_test_ops->allocate_area((void **)&area_dst, false);
 
        userfaultfd_open(&features);
 
                                err("remove failure");
                        break;
                case UFFD_EVENT_REMAP:
+                       area_remap = area_dst;  /* save for later unmap */
                        area_dst = (char *)(unsigned long)msg.arg.remap.to;
                        break;
                }
        return userfaults != 0;
 }
 
+void check_memory_contents(char *p)
+{
+       unsigned long i;
+       uint8_t expected_byte;
+       void *expected_page;
+
+       if (posix_memalign(&expected_page, page_size, page_size))
+               err("out of memory");
+
+       for (i = 0; i < nr_pages; ++i) {
+               expected_byte = ~((uint8_t)(i % ((uint8_t)-1)));
+               memset(expected_page, expected_byte, page_size);
+               if (my_bcmp(expected_page, p + (i * page_size), page_size))
+                       err("unexpected page contents after minor fault");
+       }
+
+       free(expected_page);
+}
+
 static int userfaultfd_minor_test(void)
 {
-       struct uffdio_register uffdio_register;
        unsigned long p;
+       struct uffdio_register uffdio_register;
        pthread_t uffd_mon;
-       uint8_t expected_byte;
-       void *expected_page;
        char c;
        struct uffd_stats stats = { 0 };
 
         * fault. uffd_poll_thread will resolve the fault by bit-flipping the
         * page's contents, and then issuing a CONTINUE ioctl.
         */
-
-       if (posix_memalign(&expected_page, page_size, page_size))
-               err("out of memory");
-
-       for (p = 0; p < nr_pages; ++p) {
-               expected_byte = ~((uint8_t)(p % ((uint8_t)-1)));
-               memset(expected_page, expected_byte, page_size);
-               if (my_bcmp(expected_page, area_dst_alias + (p * page_size),
-                           page_size))
-                       err("unexpected page contents after minor fault");
-       }
+       check_memory_contents(area_dst_alias);
 
        if (write(pipefd[1], &c, sizeof(c)) != sizeof(c))
                err("pipe write");
 
        uffd_stats_report(&stats, 1);
 
+       if (test_collapse) {
+               printf("testing collapse of uffd memory into PMD-mapped THPs:");
+               if (madvise(area_dst_alias, nr_pages * page_size,
+                           MADV_COLLAPSE))
+                       err("madvise(MADV_COLLAPSE)");
+
+               uffd_test_ops->check_pmd_mapping(area_dst,
+                                                nr_pages * page_size /
+                                                hpage_size);
+               /*
+                * This won't cause uffd-fault - it purely just makes sure there
+                * was no corruption.
+                */
+               check_memory_contents(area_dst_alias);
+               printf(" done.\n");
+       }
+
        return stats.missing_faults != 0 || stats.minor_faults != nr_pages;
 }
 
                        test_dev_userfaultfd = true;
                else if (!strcmp(token, "syscall"))
                        test_dev_userfaultfd = false;
+               else if (!strcmp(token, "collapse"))
+                       test_collapse = true;
                else
                        err("unrecognized test mod '%s'", token);
        }
        if (!test_type)
                err("failed to parse test type argument: '%s'", raw_type);
 
+       if (test_collapse && test_type != TEST_SHMEM)
+               err("Unsupported test: %s", raw_type);
+
        if (test_type == TEST_HUGETLB)
-               page_size = default_huge_page_size();
+               page_size = hpage_size;
        else
                page_size = sysconf(_SC_PAGE_SIZE);
 
 
 int main(int argc, char **argv)
 {
+       size_t bytes;
+
        if (argc < 4)
                usage();
 
                err("failed to arm SIGALRM");
        alarm(ALARM_INTERVAL_SECS);
 
+       hpage_size = default_huge_page_size();
        parse_test_type_arg(argv[1]);
+       bytes = atol(argv[2]) * 1024 * 1024;
+
+       if (test_collapse && bytes & (hpage_size - 1))
+               err("MiB must be multiple of %lu if :collapse mod set",
+                   hpage_size >> 20);
 
        nr_cpus = sysconf(_SC_NPROCESSORS_ONLN);
-       nr_pages_per_cpu = atol(argv[2]) * 1024*1024 / page_size /
-               nr_cpus;
+
+       if (test_collapse) {
+               /* nr_cpus must divide (bytes / page_size), otherwise,
+                * area allocations of (nr_pages * paze_size) won't be a
+                * multiple of hpage_size, even if bytes is a multiple of
+                * hpage_size.
+                *
+                * This means that nr_cpus must divide (N * (2 << (H-P))
+                * where:
+                *      bytes = hpage_size * N
+                *      hpage_size = 2 << H
+                *      page_size = 2 << P
+                *
+                * And we want to chose nr_cpus to be the largest value
+                * satisfying this constraint, not larger than the number
+                * of online CPUs. Unfortunately, prime factorization of
+                * N and nr_cpus may be arbitrary, so have to search for it.
+                * Instead, just use the highest power of 2 dividing both
+                * nr_cpus and (bytes / page_size).
+                */
+               int x = factor_of_2(nr_cpus);
+               int y = factor_of_2(bytes / page_size);
+
+               nr_cpus = x < y ? x : y;
+       }
+       nr_pages_per_cpu = bytes / page_size / nr_cpus;
        if (!nr_pages_per_cpu) {
                _err("invalid MiB");
                usage();