#define BOUNCE_POLL            (1<<3)
 static int bounces;
 
-#ifdef HUGETLB_TEST
+#define TEST_ANON      1
+#define TEST_HUGETLB   2
+#define TEST_SHMEM     3
+static int test_type;
+
 static int huge_fd;
 static char *huge_fd_off0;
-#endif
 static unsigned long long *count_verify;
 static int uffd, uffd_flags, finished, *pipefd;
 static char *area_src, *area_dst;
                                 ~(unsigned long)(sizeof(unsigned long long) \
                                                  -  1)))
 
-#if !defined(HUGETLB_TEST) && !defined(SHMEM_TEST)
-
-/* Anonymous memory */
-#define EXPECTED_IOCTLS                ((1 << _UFFDIO_WAKE) | \
-                                (1 << _UFFDIO_COPY) | \
-                                (1 << _UFFDIO_ZEROPAGE))
-
-static int release_pages(char *rel_area)
+static int anon_release_pages(char *rel_area)
 {
        int ret = 0;
 
        return ret;
 }
 
-static void allocate_area(void **alloc_area)
+static void anon_allocate_area(void **alloc_area)
 {
        if (posix_memalign(alloc_area, page_size, nr_pages * page_size)) {
                fprintf(stderr, "out of memory\n");
        }
 }
 
-#else /* HUGETLB_TEST or SHMEM_TEST */
-
-#define EXPECTED_IOCTLS                UFFD_API_RANGE_IOCTLS_BASIC
-
-#ifdef HUGETLB_TEST
 
 /* HugeTLB memory */
-static int release_pages(char *rel_area)
+static int hugetlb_release_pages(char *rel_area)
 {
        int ret = 0;
 
 }
 
 
-static void allocate_area(void **alloc_area)
+static void hugetlb_allocate_area(void **alloc_area)
 {
        *alloc_area = mmap(NULL, nr_pages * page_size, PROT_READ | PROT_WRITE,
                                MAP_PRIVATE | MAP_HUGETLB, huge_fd,
                huge_fd_off0 = *alloc_area;
 }
 
-#elif defined(SHMEM_TEST)
-
 /* Shared memory */
-static int release_pages(char *rel_area)
+static int shmem_release_pages(char *rel_area)
 {
        int ret = 0;
 
        return ret;
 }
 
-static void allocate_area(void **alloc_area)
+static void shmem_allocate_area(void **alloc_area)
 {
        *alloc_area = mmap(NULL, nr_pages * page_size, PROT_READ | PROT_WRITE,
                           MAP_ANONYMOUS | MAP_SHARED, -1, 0);
        }
 }
 
-#else /* SHMEM_TEST */
-#error "Undefined test type"
-#endif /* HUGETLB_TEST */
-
-#endif /* !defined(HUGETLB_TEST) && !defined(SHMEM_TEST) */
+struct uffd_test_ops {
+       unsigned long expected_ioctls;
+       void (*allocate_area)(void **alloc_area);
+       int (*release_pages)(char *rel_area);
+};
+
+#define ANON_EXPECTED_IOCTLS           ((1 << _UFFDIO_WAKE) | \
+                                        (1 << _UFFDIO_COPY) | \
+                                        (1 << _UFFDIO_ZEROPAGE))
+
+static struct uffd_test_ops anon_uffd_test_ops = {
+       .expected_ioctls = ANON_EXPECTED_IOCTLS,
+       .allocate_area  = anon_allocate_area,
+       .release_pages  = anon_release_pages,
+};
+
+static struct uffd_test_ops shmem_uffd_test_ops = {
+       .expected_ioctls = UFFD_API_RANGE_IOCTLS_BASIC,
+       .allocate_area  = shmem_allocate_area,
+       .release_pages  = shmem_release_pages,
+};
+
+static struct uffd_test_ops hugetlb_uffd_test_ops = {
+       .expected_ioctls = UFFD_API_RANGE_IOCTLS_BASIC,
+       .allocate_area  = hugetlb_allocate_area,
+       .release_pages  = hugetlb_release_pages,
+};
+
+static struct uffd_test_ops *uffd_test_ops;
 
 static int my_bcmp(char *str1, char *str2, size_t n)
 {
         * UFFDIO_COPY without writing zero pages into area_dst
         * because the background threads already completed).
         */
-       if (release_pages(area_src))
+       if (uffd_test_ops->release_pages(area_src))
                return 1;
 
        for (cpu = 0; cpu < nr_cpus; cpu++) {
 {
        unsigned long nr;
        unsigned long long count;
+       unsigned long split_nr_pages;
 
-#ifndef HUGETLB_TEST
-       unsigned long split_nr_pages = (nr_pages + 1) / 2;
-#else
-       unsigned long split_nr_pages = nr_pages;
-#endif
+       if (test_type != TEST_HUGETLB)
+               split_nr_pages = (nr_pages + 1) / 2;
+       else
+               split_nr_pages = nr_pages;
 
        for (nr = 0; nr < split_nr_pages; nr++) {
                count = *area_count(area_dst, nr);
                }
        }
 
-#ifndef HUGETLB_TEST
+       if (test_type == TEST_HUGETLB)
+               return 0;
+
        area_dst = mremap(area_dst, nr_pages * page_size,  nr_pages * page_size,
                          MREMAP_MAYMOVE | MREMAP_FIXED, area_src);
        if (area_dst == MAP_FAILED)
                }
        }
 
-       if (release_pages(area_dst))
+       if (uffd_test_ops->release_pages(area_dst))
                return 1;
 
        for (nr = 0; nr < nr_pages; nr++) {
                        fprintf(stderr, "nr %lu is not zero\n", nr), exit(1);
        }
 
-#endif /* HUGETLB_TEST */
-
        return 0;
 }
 
 {
        struct uffdio_zeropage uffdio_zeropage;
        int ret;
-       unsigned long has_zeropage = EXPECTED_IOCTLS & (1 << _UFFDIO_ZEROPAGE);
+       unsigned long has_zeropage;
+
+       has_zeropage = uffd_test_ops->expected_ioctls & (1 << _UFFDIO_ZEROPAGE);
 
        if (offset >= nr_pages * page_size)
                fprintf(stderr, "unexpected offset %lu\n",
        printf("testing UFFDIO_ZEROPAGE: ");
        fflush(stdout);
 
-       if (release_pages(area_dst))
+       if (uffd_test_ops->release_pages(area_dst))
                return 1;
 
        if (userfaultfd_open(0) < 0)
        if (ioctl(uffd, UFFDIO_REGISTER, &uffdio_register))
                fprintf(stderr, "register failure\n"), exit(1);
 
-       expected_ioctls = EXPECTED_IOCTLS;
+       expected_ioctls = uffd_test_ops->expected_ioctls;
        if ((uffdio_register.ioctls & expected_ioctls) !=
            expected_ioctls)
                fprintf(stderr,
        printf("testing events (fork, remap, remove): ");
        fflush(stdout);
 
-       if (release_pages(area_dst))
+       if (uffd_test_ops->release_pages(area_dst))
                return 1;
 
        features = UFFD_FEATURE_EVENT_FORK | UFFD_FEATURE_EVENT_REMAP |
        if (ioctl(uffd, UFFDIO_REGISTER, &uffdio_register))
                fprintf(stderr, "register failure\n"), exit(1);
 
-       expected_ioctls = EXPECTED_IOCTLS;
+       expected_ioctls = uffd_test_ops->expected_ioctls;
        if ((uffdio_register.ioctls & expected_ioctls) !=
            expected_ioctls)
                fprintf(stderr,
        int err;
        unsigned long userfaults[nr_cpus];
 
-       allocate_area((void **)&area_src);
+       uffd_test_ops->allocate_area((void **)&area_src);
        if (!area_src)
                return 1;
-       allocate_area((void **)&area_dst);
+       uffd_test_ops->allocate_area((void **)&area_dst);
        if (!area_dst)
                return 1;
 
                        fprintf(stderr, "register failure\n");
                        return 1;
                }
-               expected_ioctls = EXPECTED_IOCTLS;
+               expected_ioctls = uffd_test_ops->expected_ioctls;
                if ((uffdio_register.ioctls & expected_ioctls) !=
                    expected_ioctls) {
                        fprintf(stderr,
                 * MADV_DONTNEED only after the UFFDIO_REGISTER, so it's
                 * required to MADV_DONTNEED here.
                 */
-               if (release_pages(area_dst))
+               if (uffd_test_ops->release_pages(area_dst))
                        return 1;
 
                /* bounce pass */
        return userfaultfd_zeropage_test() || userfaultfd_events_test();
 }
 
-#ifndef HUGETLB_TEST
-
-int main(int argc, char **argv)
-{
-       if (argc < 3)
-               fprintf(stderr, "Usage: <MiB> <bounces>\n"), exit(1);
-       nr_cpus = sysconf(_SC_NPROCESSORS_ONLN);
-       page_size = sysconf(_SC_PAGE_SIZE);
-       if ((unsigned long) area_count(NULL, 0) + sizeof(unsigned long long) * 2
-           > page_size)
-               fprintf(stderr, "Impossible to run this test\n"), exit(2);
-       nr_pages_per_cpu = atol(argv[1]) * 1024*1024 / page_size /
-               nr_cpus;
-       if (!nr_pages_per_cpu) {
-               fprintf(stderr, "invalid MiB\n");
-               fprintf(stderr, "Usage: <MiB> <bounces>\n"), exit(1);
-       }
-       bounces = atoi(argv[2]);
-       if (bounces <= 0) {
-               fprintf(stderr, "invalid bounces\n");
-               fprintf(stderr, "Usage: <MiB> <bounces>\n"), exit(1);
-       }
-       nr_pages = nr_pages_per_cpu * nr_cpus;
-       printf("nr_pages: %lu, nr_pages_per_cpu: %lu\n",
-              nr_pages, nr_pages_per_cpu);
-       return userfaultfd_stress();
-}
-
-#else /* HUGETLB_TEST */
-
 /*
  * Copied from mlock2-tests.c
  */
        return hps;
 }
 
-int main(int argc, char **argv)
+static void set_test_type(const char *type)
 {
-       if (argc < 4)
-               fprintf(stderr, "Usage: <MiB> <bounces> <hugetlbfs_file>\n"),
-                               exit(1);
-       nr_cpus = sysconf(_SC_NPROCESSORS_ONLN);
-       page_size = default_huge_page_size();
+       if (!strcmp(type, "anon")) {
+               test_type = TEST_ANON;
+               uffd_test_ops = &anon_uffd_test_ops;
+       } else if (!strcmp(type, "hugetlb")) {
+               test_type = TEST_HUGETLB;
+               uffd_test_ops = &hugetlb_uffd_test_ops;
+       } else if (!strcmp(type, "shmem")) {
+               test_type = TEST_SHMEM;
+               uffd_test_ops = &shmem_uffd_test_ops;
+       } else {
+               fprintf(stderr, "Unknown test type: %s\n", type), exit(1);
+       }
+
+       if (test_type == TEST_HUGETLB)
+               page_size = default_huge_page_size();
+       else
+               page_size = sysconf(_SC_PAGE_SIZE);
+
        if (!page_size)
-               fprintf(stderr, "Unable to determine huge page size\n"),
+               fprintf(stderr, "Unable to determine page size\n"),
                                exit(2);
        if ((unsigned long) area_count(NULL, 0) + sizeof(unsigned long long) * 2
            > page_size)
                fprintf(stderr, "Impossible to run this test\n"), exit(2);
-       nr_pages_per_cpu = atol(argv[1]) * 1024*1024 / page_size /
+}
+
+int main(int argc, char **argv)
+{
+       if (argc < 4)
+               fprintf(stderr, "Usage: <test type> <MiB> <bounces> [hugetlbfs_file]\n"),
+                               exit(1);
+
+       set_test_type(argv[1]);
+
+       nr_cpus = sysconf(_SC_NPROCESSORS_ONLN);
+       nr_pages_per_cpu = atol(argv[2]) * 1024*1024 / page_size /
                nr_cpus;
        if (!nr_pages_per_cpu) {
                fprintf(stderr, "invalid MiB\n");
                fprintf(stderr, "Usage: <MiB> <bounces>\n"), exit(1);
        }
-       bounces = atoi(argv[2]);
+
+       bounces = atoi(argv[3]);
        if (bounces <= 0) {
                fprintf(stderr, "invalid bounces\n");
                fprintf(stderr, "Usage: <MiB> <bounces>\n"), exit(1);
        }
        nr_pages = nr_pages_per_cpu * nr_cpus;
-       huge_fd = open(argv[3], O_CREAT | O_RDWR, 0755);
-       if (huge_fd < 0) {
-               fprintf(stderr, "Open of %s failed", argv[3]);
-               perror("open");
-               exit(1);
-       }
-       if (ftruncate(huge_fd, 0)) {
-               fprintf(stderr, "ftruncate %s to size 0 failed", argv[3]);
-               perror("ftruncate");
-               exit(1);
+
+       if (test_type == TEST_HUGETLB) {
+               if (argc < 5)
+                       fprintf(stderr, "Usage: hugetlb <MiB> <bounces> <hugetlbfs_file>\n"),
+                               exit(1);
+               huge_fd = open(argv[4], O_CREAT | O_RDWR, 0755);
+               if (huge_fd < 0) {
+                       fprintf(stderr, "Open of %s failed", argv[3]);
+                       perror("open");
+                       exit(1);
+               }
+               if (ftruncate(huge_fd, 0)) {
+                       fprintf(stderr, "ftruncate %s to size 0 failed", argv[3]);
+                       perror("ftruncate");
+                       exit(1);
+               }
        }
        printf("nr_pages: %lu, nr_pages_per_cpu: %lu\n",
               nr_pages, nr_pages_per_cpu);
        return userfaultfd_stress();
 }
 
-#endif
 #else /* __NR_userfaultfd */
 
 #warning "missing __NR_userfaultfd definition"