// SPDX-License-Identifier: GPL-2.0
 
 #include <sys/mman.h>
+#include <sys/prctl.h>
+#include <sys/wait.h>
 #include <stdbool.h>
 #include <time.h>
 #include <string.h>
 #define KSM_PROT_STR_DEFAULT "rw"
 #define KSM_USE_ZERO_PAGES_DEFAULT false
 #define KSM_MERGE_ACROSS_NODES_DEFAULT true
+#define KSM_MERGE_TYPE_DEFAULT 0
 #define MB (1ul << 20)
 
 struct ksm_sysfs {
        unsigned long use_zero_pages;
 };
 
+enum ksm_merge_type {
+       KSM_MERGE_MADVISE,
+       KSM_MERGE_PRCTL,
+       KSM_MERGE_LAST = KSM_MERGE_PRCTL
+};
+
 enum ksm_test_name {
        CHECK_KSM_MERGE,
        CHECK_KSM_UNMERGE,
+       CHECK_KSM_GET_MERGE_TYPE,
        CHECK_KSM_ZERO_PAGE_MERGE,
        CHECK_KSM_NUMA_MERGE,
        KSM_MERGE_TIME,
        KSM_COW_TIME
 };
 
+int debug;
+
 static int ksm_write_sysfs(const char *file_path, unsigned long val)
 {
        FILE *f = fopen(file_path, "w");
        return 0;
 }
 
+static void ksm_print_sysfs(void)
+{
+       unsigned long max_page_sharing, pages_sharing, pages_shared;
+       unsigned long full_scans, pages_unshared, pages_volatile;
+       unsigned long stable_node_chains, stable_node_dups;
+       long general_profit;
+
+       if (ksm_read_sysfs(KSM_FP("pages_shared"), &pages_shared) ||
+           ksm_read_sysfs(KSM_FP("pages_sharing"), &pages_sharing) ||
+           ksm_read_sysfs(KSM_FP("max_page_sharing"), &max_page_sharing) ||
+           ksm_read_sysfs(KSM_FP("full_scans"), &full_scans) ||
+           ksm_read_sysfs(KSM_FP("pages_unshared"), &pages_unshared) ||
+           ksm_read_sysfs(KSM_FP("pages_volatile"), &pages_volatile) ||
+           ksm_read_sysfs(KSM_FP("stable_node_chains"), &stable_node_chains) ||
+           ksm_read_sysfs(KSM_FP("stable_node_dups"), &stable_node_dups) ||
+           ksm_read_sysfs(KSM_FP("general_profit"), (unsigned long *)&general_profit))
+               return;
+
+       printf("pages_shared      : %lu\n", pages_shared);
+       printf("pages_sharing     : %lu\n", pages_sharing);
+       printf("max_page_sharing  : %lu\n", max_page_sharing);
+       printf("full_scans        : %lu\n", full_scans);
+       printf("pages_unshared    : %lu\n", pages_unshared);
+       printf("pages_volatile    : %lu\n", pages_volatile);
+       printf("stable_node_chains: %lu\n", stable_node_chains);
+       printf("stable_node_dups  : %lu\n", stable_node_dups);
+       printf("general_profit    : %ld\n", general_profit);
+}
+
+static void ksm_print_procfs(void)
+{
+       const char *file_name = "/proc/self/ksm_stat";
+       char buffer[512];
+       FILE *f = fopen(file_name, "r");
+
+       if (!f) {
+               fprintf(stderr, "f %s\n", file_name);
+               perror("fopen");
+               return;
+       }
+
+       while (fgets(buffer, sizeof(buffer), f))
+               printf("%s", buffer);
+
+       fclose(f);
+}
+
 static int str_to_prot(char *prot_str)
 {
        int prot = 0;
               "     Default: %d\n", KSM_USE_ZERO_PAGES_DEFAULT);
        printf(" -m: change merge_across_nodes tunable\n"
               "     Default: %d\n", KSM_MERGE_ACROSS_NODES_DEFAULT);
+       printf(" -d: turn debugging output on\n");
        printf(" -s: the size of duplicated memory area (in MiB)\n");
+       printf(" -t: KSM merge type\n"
+              "     Default: 0\n"
+              "     0: madvise merging\n"
+              "     1: prctl merging\n");
 
        exit(0);
 }
        return 0;
 }
 
-static int ksm_merge_pages(void *addr, size_t size, struct timespec start_time, int timeout)
+static int ksm_merge_pages(int merge_type, void *addr, size_t size,
+                       struct timespec start_time, int timeout)
 {
-       if (madvise(addr, size, MADV_MERGEABLE)) {
-               perror("madvise");
-               return 1;
+       if (merge_type == KSM_MERGE_MADVISE) {
+               if (madvise(addr, size, MADV_MERGEABLE)) {
+                       perror("madvise");
+                       return 1;
+               }
+       } else if (merge_type == KSM_MERGE_PRCTL) {
+               if (prctl(PR_SET_MEMORY_MERGE, 1, 0, 0, 0)) {
+                       perror("prctl");
+                       return 1;
+               }
        }
+
        if (ksm_write_sysfs(KSM_FP("run"), 1))
                return 1;
 
            ksm_read_sysfs(KSM_FP("max_page_sharing"), &max_page_sharing))
                return false;
 
+       if (debug) {
+               ksm_print_sysfs();
+               ksm_print_procfs();
+       }
+
        /*
         * Since there must be at least 2 pages for merging and 1 page can be
         * shared with the limited number of pages (max_page_sharing), sometimes
        return 0;
 }
 
-static int check_ksm_merge(int mapping, int prot, long page_count, int timeout, size_t page_size)
+static int check_ksm_merge(int merge_type, int mapping, int prot,
+                       long page_count, int timeout, size_t page_size)
 {
        void *map_ptr;
        struct timespec start_time;
        if (!map_ptr)
                return KSFT_FAIL;
 
-       if (ksm_merge_pages(map_ptr, page_size * page_count, start_time, timeout))
+       if (ksm_merge_pages(merge_type, map_ptr, page_size * page_count, start_time, timeout))
                goto err_out;
 
        /* verify that the right number of pages are merged */
        if (assert_ksm_pages_count(page_count)) {
                printf("OK\n");
                munmap(map_ptr, page_size * page_count);
+               if (merge_type == KSM_MERGE_PRCTL)
+                       prctl(PR_SET_MEMORY_MERGE, 0, 0, 0, 0);
                return KSFT_PASS;
        }
 
        return KSFT_FAIL;
 }
 
-static int check_ksm_unmerge(int mapping, int prot, int timeout, size_t page_size)
+static int check_ksm_unmerge(int merge_type, int mapping, int prot, int timeout, size_t page_size)
 {
        void *map_ptr;
        struct timespec start_time;
        if (!map_ptr)
                return KSFT_FAIL;
 
-       if (ksm_merge_pages(map_ptr, page_size * page_count, start_time, timeout))
+       if (ksm_merge_pages(merge_type, map_ptr, page_size * page_count, start_time, timeout))
                goto err_out;
 
        /* change 1 byte in each of the 2 pages -- KSM must automatically unmerge them */
        return KSFT_FAIL;
 }
 
-static int check_ksm_zero_page_merge(int mapping, int prot, long page_count, int timeout,
-                                    bool use_zero_pages, size_t page_size)
+static int check_ksm_zero_page_merge(int merge_type, int mapping, int prot, long page_count,
+                               int timeout, bool use_zero_pages, size_t page_size)
 {
        void *map_ptr;
        struct timespec start_time;
        if (!map_ptr)
                return KSFT_FAIL;
 
-       if (ksm_merge_pages(map_ptr, page_size * page_count, start_time, timeout))
+       if (ksm_merge_pages(merge_type, map_ptr, page_size * page_count, start_time, timeout))
                goto err_out;
 
        /*
        return get_next_mem_node(numa_max_node());
 }
 
-static int check_ksm_numa_merge(int mapping, int prot, int timeout, bool merge_across_nodes,
-                               size_t page_size)
+static int check_ksm_numa_merge(int merge_type, int mapping, int prot, int timeout,
+                               bool merge_across_nodes, size_t page_size)
 {
        void *numa1_map_ptr, *numa2_map_ptr;
        struct timespec start_time;
        memset(numa2_map_ptr, '*', page_size);
 
        /* try to merge the pages */
-       if (ksm_merge_pages(numa1_map_ptr, page_size, start_time, timeout) ||
-           ksm_merge_pages(numa2_map_ptr, page_size, start_time, timeout))
+       if (ksm_merge_pages(merge_type, numa1_map_ptr, page_size, start_time, timeout) ||
+           ksm_merge_pages(merge_type, numa2_map_ptr, page_size, start_time, timeout))
                goto err_out;
 
        /*
        return KSFT_FAIL;
 }
 
-static int ksm_merge_hugepages_time(int mapping, int prot, int timeout, size_t map_size)
+static int ksm_merge_hugepages_time(int merge_type, int mapping, int prot,
+                               int timeout, size_t map_size)
 {
        void *map_ptr, *map_ptr_orig;
        struct timespec start_time, end_time;
                perror("clock_gettime");
                goto err_out;
        }
-       if (ksm_merge_pages(map_ptr, map_size, start_time, timeout))
+       if (ksm_merge_pages(merge_type, map_ptr, map_size, start_time, timeout))
                goto err_out;
        if (clock_gettime(CLOCK_MONOTONIC_RAW, &end_time)) {
                perror("clock_gettime");
        return KSFT_FAIL;
 }
 
-static int ksm_merge_time(int mapping, int prot, int timeout, size_t map_size)
+static int ksm_merge_time(int merge_type, int mapping, int prot, int timeout, size_t map_size)
 {
        void *map_ptr;
        struct timespec start_time, end_time;
                perror("clock_gettime");
                goto err_out;
        }
-       if (ksm_merge_pages(map_ptr, map_size, start_time, timeout))
+       if (ksm_merge_pages(merge_type, map_ptr, map_size, start_time, timeout))
                goto err_out;
        if (clock_gettime(CLOCK_MONOTONIC_RAW, &end_time)) {
                perror("clock_gettime");
        return KSFT_FAIL;
 }
 
-static int ksm_unmerge_time(int mapping, int prot, int timeout, size_t map_size)
+static int ksm_unmerge_time(int merge_type, int mapping, int prot, int timeout, size_t map_size)
 {
        void *map_ptr;
        struct timespec start_time, end_time;
                perror("clock_gettime");
                goto err_out;
        }
-       if (ksm_merge_pages(map_ptr, map_size, start_time, timeout))
+       if (ksm_merge_pages(merge_type, map_ptr, map_size, start_time, timeout))
                goto err_out;
 
        if (clock_gettime(CLOCK_MONOTONIC_RAW, &start_time)) {
        return KSFT_FAIL;
 }
 
-static int ksm_cow_time(int mapping, int prot, int timeout, size_t page_size)
+static int ksm_cow_time(int merge_type, int mapping, int prot, int timeout, size_t page_size)
 {
        void *map_ptr;
        struct timespec start_time, end_time;
                memset(map_ptr + page_size * i, '+', i / 2 + 1);
                memset(map_ptr + page_size * (i + 1), '+', i / 2 + 1);
        }
-       if (ksm_merge_pages(map_ptr, page_size * page_count, start_time, timeout))
+       if (ksm_merge_pages(merge_type, map_ptr, page_size * page_count, start_time, timeout))
                goto err_out;
 
        if (clock_gettime(CLOCK_MONOTONIC_RAW, &start_time)) {
        int ret, opt;
        int prot = 0;
        int ksm_scan_limit_sec = KSM_SCAN_LIMIT_SEC_DEFAULT;
+       int merge_type = KSM_MERGE_TYPE_DEFAULT;
        long page_count = KSM_PAGE_COUNT_DEFAULT;
        size_t page_size = sysconf(_SC_PAGESIZE);
        struct ksm_sysfs ksm_sysfs_old;
        bool merge_across_nodes = KSM_MERGE_ACROSS_NODES_DEFAULT;
        long size_MB = 0;
 
-       while ((opt = getopt(argc, argv, "ha:p:l:z:m:s:MUZNPCHD")) != -1) {
+       while ((opt = getopt(argc, argv, "dha:p:l:z:m:s:t:MUZNPCHD")) != -1) {
                switch (opt) {
                case 'a':
                        prot = str_to_prot(optarg);
                        else
                                merge_across_nodes = 1;
                        break;
+               case 'd':
+                       debug = 1;
+                       break;
                case 's':
                        size_MB = atoi(optarg);
                        if (size_MB <= 0) {
                                printf("Size must be greater than 0\n");
                                return KSFT_FAIL;
                        }
+               case 't':
+                       {
+                               int tmp = atoi(optarg);
+
+                               if (tmp < 0 || tmp > KSM_MERGE_LAST) {
+                                       printf("Invalid merge type\n");
+                                       return KSFT_FAIL;
+                               }
+                               merge_type = tmp;
+                       }
+                       break;
                case 'M':
                        break;
                case 'U':
 
        switch (test_name) {
        case CHECK_KSM_MERGE:
-               ret = check_ksm_merge(MAP_PRIVATE | MAP_ANONYMOUS, prot, page_count,
+               ret = check_ksm_merge(merge_type, MAP_PRIVATE | MAP_ANONYMOUS, prot, page_count,
                                      ksm_scan_limit_sec, page_size);
                break;
        case CHECK_KSM_UNMERGE:
-               ret = check_ksm_unmerge(MAP_PRIVATE | MAP_ANONYMOUS, prot, ksm_scan_limit_sec,
-                                       page_size);
+               ret = check_ksm_unmerge(merge_type, MAP_PRIVATE | MAP_ANONYMOUS, prot,
+                                       ksm_scan_limit_sec, page_size);
                break;
        case CHECK_KSM_ZERO_PAGE_MERGE:
-               ret = check_ksm_zero_page_merge(MAP_PRIVATE | MAP_ANONYMOUS, prot, page_count,
-                                               ksm_scan_limit_sec, use_zero_pages, page_size);
+               ret = check_ksm_zero_page_merge(merge_type, MAP_PRIVATE | MAP_ANONYMOUS, prot,
+                                               page_count, ksm_scan_limit_sec, use_zero_pages,
+                                               page_size);
                break;
        case CHECK_KSM_NUMA_MERGE:
-               ret = check_ksm_numa_merge(MAP_PRIVATE | MAP_ANONYMOUS, prot, ksm_scan_limit_sec,
-                                          merge_across_nodes, page_size);
+               ret = check_ksm_numa_merge(merge_type, MAP_PRIVATE | MAP_ANONYMOUS, prot,
+                                       ksm_scan_limit_sec, merge_across_nodes, page_size);
                break;
        case KSM_MERGE_TIME:
                if (size_MB == 0) {
                        printf("Option '-s' is required.\n");
                        return KSFT_FAIL;
                }
-               ret = ksm_merge_time(MAP_PRIVATE | MAP_ANONYMOUS, prot, ksm_scan_limit_sec,
-                                    size_MB);
+               ret = ksm_merge_time(merge_type, MAP_PRIVATE | MAP_ANONYMOUS, prot,
+                               ksm_scan_limit_sec, size_MB);
                break;
        case KSM_MERGE_TIME_HUGE_PAGES:
                if (size_MB == 0) {
                        printf("Option '-s' is required.\n");
                        return KSFT_FAIL;
                }
-               ret = ksm_merge_hugepages_time(MAP_PRIVATE | MAP_ANONYMOUS, prot,
+               ret = ksm_merge_hugepages_time(merge_type, MAP_PRIVATE | MAP_ANONYMOUS, prot,
                                ksm_scan_limit_sec, size_MB);
                break;
        case KSM_UNMERGE_TIME:
                        printf("Option '-s' is required.\n");
                        return KSFT_FAIL;
                }
-               ret = ksm_unmerge_time(MAP_PRIVATE | MAP_ANONYMOUS, prot,
+               ret = ksm_unmerge_time(merge_type, MAP_PRIVATE | MAP_ANONYMOUS, prot,
                                       ksm_scan_limit_sec, size_MB);
                break;
        case KSM_COW_TIME:
-               ret = ksm_cow_time(MAP_PRIVATE | MAP_ANONYMOUS, prot, ksm_scan_limit_sec,
-                                  page_size);
+               ret = ksm_cow_time(merge_type, MAP_PRIVATE | MAP_ANONYMOUS, prot,
+                               ksm_scan_limit_sec, page_size);
                break;
        }