*
  * Example of remapping huge page memory in a user application using the
  * mremap system call.  Code assumes a hugetlbfs filesystem is mounted
- * at './huge'.  The code will use 10MB worth of huge pages.
+ * at './huge'.  The amount of memory used by this test is decided by a command
+ * line argument in MBs. If missing, the default amount is 10MB.
+ *
+ * To make sure the test triggers pmd sharing and goes through the 'unshare'
+ * path in the mremap code use 1GB (1024) or more.
  */
 
 #define _GNU_SOURCE
 #include <linux/userfaultfd.h>
 #include <sys/ioctl.h>
 
-#define LENGTH (1UL * 1024 * 1024 * 1024)
+#define DEFAULT_LENGTH_MB 10UL
+#define MB_TO_BYTES(x) (x * 1024 * 1024)
 
+#define FILE_NAME "huge/hugepagefile"
 #define PROTECTION (PROT_READ | PROT_WRITE | PROT_EXEC)
 #define FLAGS (MAP_SHARED | MAP_ANONYMOUS)
 
        printf("First hex is %x\n", *((unsigned int *)addr));
 }
 
-static void write_bytes(char *addr)
+static void write_bytes(char *addr, size_t len)
 {
        unsigned long i;
 
-       for (i = 0; i < LENGTH; i++)
+       for (i = 0; i < len; i++)
                *(addr + i) = (char)i;
 }
 
-static int read_bytes(char *addr)
+static int read_bytes(char *addr, size_t len)
 {
        unsigned long i;
 
        check_bytes(addr);
-       for (i = 0; i < LENGTH; i++)
+       for (i = 0; i < len; i++)
                if (*(addr + i) != (char)i) {
                        printf("Mismatch at %lu\n", i);
                        return 1;
        }
 }
 
-int main(void)
+int main(int argc, char *argv[])
 {
+       /* Read memory length as the first arg if valid, otherwise fallback to
+        * the default length. Any additional args are ignored.
+        */
+       size_t length = argc > 1 ? (size_t)atoi(argv[1]) : 0UL;
+
+       length = length > 0 ? length : DEFAULT_LENGTH_MB;
+       length = MB_TO_BYTES(length);
+
        int ret = 0;
 
-       int fd = open("/huge/test", O_CREAT | O_RDWR, 0755);
+       int fd = open(FILE_NAME, O_CREAT | O_RDWR, 0755);
 
        if (fd < 0) {
                perror("Open failed");
 
        /* mmap to a PUD aligned address to hopefully trigger pmd sharing. */
        unsigned long suggested_addr = 0x7eaa40000000;
-       void *haddr = mmap((void *)suggested_addr, LENGTH, PROTECTION,
+       void *haddr = mmap((void *)suggested_addr, length, PROTECTION,
                           MAP_HUGETLB | MAP_SHARED | MAP_POPULATE, fd, 0);
        printf("Map haddr: Returned address is %p\n", haddr);
        if (haddr == MAP_FAILED) {
 
        /* mmap again to a dummy address to hopefully trigger pmd sharing. */
        suggested_addr = 0x7daa40000000;
-       void *daddr = mmap((void *)suggested_addr, LENGTH, PROTECTION,
+       void *daddr = mmap((void *)suggested_addr, length, PROTECTION,
                           MAP_HUGETLB | MAP_SHARED | MAP_POPULATE, fd, 0);
        printf("Map daddr: Returned address is %p\n", daddr);
        if (daddr == MAP_FAILED) {
 
        suggested_addr = 0x7faa40000000;
        void *vaddr =
-               mmap((void *)suggested_addr, LENGTH, PROTECTION, FLAGS, -1, 0);
+               mmap((void *)suggested_addr, length, PROTECTION, FLAGS, -1, 0);
        printf("Map vaddr: Returned address is %p\n", vaddr);
        if (vaddr == MAP_FAILED) {
                perror("mmap2");
                exit(1);
        }
 
-       register_region_with_uffd(haddr, LENGTH);
+       register_region_with_uffd(haddr, length);
 
-       void *addr = mremap(haddr, LENGTH, LENGTH,
+       void *addr = mremap(haddr, length, length,
                            MREMAP_MAYMOVE | MREMAP_FIXED, vaddr);
        if (addr == MAP_FAILED) {
                perror("mremap");
 
        printf("Mremap: Returned address is %p\n", addr);
        check_bytes(addr);
-       write_bytes(addr);
-       ret = read_bytes(addr);
+       write_bytes(addr, length);
+       ret = read_bytes(addr, length);
 
-       munmap(addr, LENGTH);
+       munmap(addr, length);
 
        return ret;
 }