#include <linux/hashtable.h>
 #include <linux/freezer.h>
 #include <linux/oom.h>
+#include <linux/numa.h>
 
 #include <asm/tlbflush.h>
 #include "internal.h"
        struct mm_struct *mm;
        unsigned long address;          /* + low bits used for flags below */
        unsigned int oldchecksum;       /* when unstable */
+#ifdef CONFIG_NUMA
+       unsigned int nid;
+#endif
        union {
                struct rb_node node;    /* when node of unstable tree */
                struct {                /* when listed from stable tree */
 #define STABLE_FLAG    0x200   /* is listed from the stable tree */
 
 /* The stable and unstable tree heads */
-static struct rb_root root_stable_tree = RB_ROOT;
-static struct rb_root root_unstable_tree = RB_ROOT;
+static struct rb_root root_unstable_tree[MAX_NUMNODES];
+static struct rb_root root_stable_tree[MAX_NUMNODES];
 
 #define MM_SLOTS_HASH_BITS 10
 static DEFINE_HASHTABLE(mm_slots_hash, MM_SLOTS_HASH_BITS);
 /* Milliseconds ksmd should sleep between batches */
 static unsigned int ksm_thread_sleep_millisecs = 20;
 
+/* Zeroed when merging across nodes is not allowed */
+static unsigned int ksm_merge_across_nodes = 1;
+
 #define KSM_RUN_STOP   0
 #define KSM_RUN_MERGE  1
 #define KSM_RUN_UNMERGE        2
        return page;
 }
 
+/*
+ * This helper is used for getting right index into array of tree roots.
+ * When merge_across_nodes knob is set to 1, there are only two rb-trees for
+ * stable and unstable pages from all nodes with roots in index 0. Otherwise,
+ * every node has its own stable and unstable tree.
+ */
+static inline int get_kpfn_nid(unsigned long kpfn)
+{
+       if (ksm_merge_across_nodes)
+               return 0;
+       else
+               return pfn_to_nid(kpfn);
+}
+
 static void remove_node_from_stable_tree(struct stable_node *stable_node)
 {
        struct rmap_item *rmap_item;
        struct hlist_node *hlist;
+       int nid;
 
        hlist_for_each_entry(rmap_item, hlist, &stable_node->hlist, hlist) {
                if (rmap_item->hlist.next)
                cond_resched();
        }
 
-       rb_erase(&stable_node->node, &root_stable_tree);
+       nid = get_kpfn_nid(stable_node->kpfn);
+
+       rb_erase(&stable_node->node, &root_stable_tree[nid]);
        free_stable_node(stable_node);
 }
 
                age = (unsigned char)(ksm_scan.seqnr - rmap_item->address);
                BUG_ON(age > 1);
                if (!age)
-                       rb_erase(&rmap_item->node, &root_unstable_tree);
+#ifdef CONFIG_NUMA
+                       rb_erase(&rmap_item->node,
+                                       &root_unstable_tree[rmap_item->nid]);
+#else
+                       rb_erase(&rmap_item->node, &root_unstable_tree[0]);
+#endif
 
                ksm_pages_unshared--;
                rmap_item->address &= PAGE_MASK;
  */
 static struct page *stable_tree_search(struct page *page)
 {
-       struct rb_node *node = root_stable_tree.rb_node;
+       struct rb_node *node;
        struct stable_node *stable_node;
+       int nid;
 
        stable_node = page_stable_node(page);
        if (stable_node) {                      /* ksm page forked */
                return page;
        }
 
+       nid = get_kpfn_nid(page_to_pfn(page));
+       node = root_stable_tree[nid].rb_node;
+
        while (node) {
                struct page *tree_page;
                int ret;
  */
 static struct stable_node *stable_tree_insert(struct page *kpage)
 {
-       struct rb_node **new = &root_stable_tree.rb_node;
+       int nid;
+       unsigned long kpfn;
+       struct rb_node **new;
        struct rb_node *parent = NULL;
        struct stable_node *stable_node;
 
+       kpfn = page_to_pfn(kpage);
+       nid = get_kpfn_nid(kpfn);
+       new = &root_stable_tree[nid].rb_node;
+
        while (*new) {
                struct page *tree_page;
                int ret;
                return NULL;
 
        rb_link_node(&stable_node->node, parent, new);
-       rb_insert_color(&stable_node->node, &root_stable_tree);
+       rb_insert_color(&stable_node->node, &root_stable_tree[nid]);
 
        INIT_HLIST_HEAD(&stable_node->hlist);
 
-       stable_node->kpfn = page_to_pfn(kpage);
+       stable_node->kpfn = kpfn;
        set_page_stable_node(kpage, stable_node);
 
        return stable_node;
 struct rmap_item *unstable_tree_search_insert(struct rmap_item *rmap_item,
                                              struct page *page,
                                              struct page **tree_pagep)
-
 {
-       struct rb_node **new = &root_unstable_tree.rb_node;
+       struct rb_node **new;
+       struct rb_root *root;
        struct rb_node *parent = NULL;
+       int nid;
+
+       nid = get_kpfn_nid(page_to_pfn(page));
+       root = &root_unstable_tree[nid];
+       new = &root->rb_node;
 
        while (*new) {
                struct rmap_item *tree_rmap_item;
                        return NULL;
                }
 
+               /*
+                * If tree_page has been migrated to another NUMA node, it
+                * will be flushed out and put into the right unstable tree
+                * next time: only merge with it if merge_across_nodes.
+                * Just notice, we don't have similar problem for PageKsm
+                * because their migration is disabled now. (62b61f611e)
+                */
+               if (!ksm_merge_across_nodes && page_to_nid(tree_page) != nid) {
+                       put_page(tree_page);
+                       return NULL;
+               }
+
                ret = memcmp_pages(page, tree_page);
 
                parent = *new;
 
        rmap_item->address |= UNSTABLE_FLAG;
        rmap_item->address |= (ksm_scan.seqnr & SEQNR_MASK);
+#ifdef CONFIG_NUMA
+       rmap_item->nid = nid;
+#endif
        rb_link_node(&rmap_item->node, parent, new);
-       rb_insert_color(&rmap_item->node, &root_unstable_tree);
+       rb_insert_color(&rmap_item->node, root);
 
        ksm_pages_unshared++;
        return NULL;
 static void stable_tree_append(struct rmap_item *rmap_item,
                               struct stable_node *stable_node)
 {
+#ifdef CONFIG_NUMA
+       /*
+        * Usually rmap_item->nid is already set correctly,
+        * but it may be wrong after switching merge_across_nodes.
+        */
+       rmap_item->nid = get_kpfn_nid(stable_node->kpfn);
+#endif
        rmap_item->head = stable_node;
        rmap_item->address |= STABLE_FLAG;
        hlist_add_head(&rmap_item->hlist, &stable_node->hlist);
        struct mm_slot *slot;
        struct vm_area_struct *vma;
        struct rmap_item *rmap_item;
+       int nid;
 
        if (list_empty(&ksm_mm_head.mm_list))
                return NULL;
                 */
                lru_add_drain_all();
 
-               root_unstable_tree = RB_ROOT;
+               for (nid = 0; nid < nr_node_ids; nid++)
+                       root_unstable_tree[nid] = RB_ROOT;
 
                spin_lock(&ksm_mmlist_lock);
                slot = list_entry(slot->mm_list.next, struct mm_slot, mm_list);
                                                 unsigned long end_pfn)
 {
        struct rb_node *node;
+       int nid;
 
-       for (node = rb_first(&root_stable_tree); node; node = rb_next(node)) {
-               struct stable_node *stable_node;
+       for (nid = 0; nid < nr_node_ids; nid++)
+               for (node = rb_first(&root_stable_tree[nid]); node;
+                               node = rb_next(node)) {
+                       struct stable_node *stable_node;
+
+                       stable_node = rb_entry(node, struct stable_node, node);
+                       if (stable_node->kpfn >= start_pfn &&
+                           stable_node->kpfn < end_pfn)
+                               return stable_node;
+               }
 
-               stable_node = rb_entry(node, struct stable_node, node);
-               if (stable_node->kpfn >= start_pfn &&
-                   stable_node->kpfn < end_pfn)
-                       return stable_node;
-       }
        return NULL;
 }
 
 }
 KSM_ATTR(run);
 
+#ifdef CONFIG_NUMA
+static ssize_t merge_across_nodes_show(struct kobject *kobj,
+                               struct kobj_attribute *attr, char *buf)
+{
+       return sprintf(buf, "%u\n", ksm_merge_across_nodes);
+}
+
+static ssize_t merge_across_nodes_store(struct kobject *kobj,
+                                  struct kobj_attribute *attr,
+                                  const char *buf, size_t count)
+{
+       int err;
+       unsigned long knob;
+
+       err = kstrtoul(buf, 10, &knob);
+       if (err)
+               return err;
+       if (knob > 1)
+               return -EINVAL;
+
+       mutex_lock(&ksm_thread_mutex);
+       if (ksm_merge_across_nodes != knob) {
+               if (ksm_pages_shared)
+                       err = -EBUSY;
+               else
+                       ksm_merge_across_nodes = knob;
+       }
+       mutex_unlock(&ksm_thread_mutex);
+
+       return err ? err : count;
+}
+KSM_ATTR(merge_across_nodes);
+#endif
+
 static ssize_t pages_shared_show(struct kobject *kobj,
                                 struct kobj_attribute *attr, char *buf)
 {
        &pages_unshared_attr.attr,
        &pages_volatile_attr.attr,
        &full_scans_attr.attr,
+#ifdef CONFIG_NUMA
+       &merge_across_nodes_attr.attr,
+#endif
        NULL,
 };
 
 {
        struct task_struct *ksm_thread;
        int err;
+       int nid;
 
        err = ksm_slab_init();
        if (err)
                goto out;
 
+       for (nid = 0; nid < nr_node_ids; nid++)
+               root_stable_tree[nid] = RB_ROOT;
+
        ksm_thread = kthread_run(ksm_scan_thread, NULL, "ksmd");
        if (IS_ERR(ksm_thread)) {
                printk(KERN_ERR "ksm: creating kthread failed\n");