cond_resched();
                stable_node = rb_entry(*new, struct stable_node, node);
                tree_page = get_ksm_page(stable_node, false);
-               if (!tree_page)
-                       return NULL;
+               if (!tree_page) {
+                       /*
+                        * If we walked over a stale stable_node,
+                        * get_ksm_page() will call rb_erase() and it
+                        * may rebalance the tree from under us. So
+                        * restart the search from scratch. Returning
+                        * NULL would be safe too, but we'd generate
+                        * false negative insertions just because some
+                        * stable_node was stale.
+                        */
+                       goto again;
+               }
 
                ret = memcmp_pages(page, tree_page);
                put_page(tree_page);
        unsigned long kpfn;
        struct rb_root *root;
        struct rb_node **new;
-       struct rb_node *parent = NULL;
+       struct rb_node *parent;
        struct stable_node *stable_node;
 
        kpfn = page_to_pfn(kpage);
        nid = get_kpfn_nid(kpfn);
        root = root_stable_tree + nid;
+again:
+       parent = NULL;
        new = &root->rb_node;
 
        while (*new) {
                cond_resched();
                stable_node = rb_entry(*new, struct stable_node, node);
                tree_page = get_ksm_page(stable_node, false);
-               if (!tree_page)
-                       return NULL;
+               if (!tree_page) {
+                       /*
+                        * If we walked over a stale stable_node,
+                        * get_ksm_page() will call rb_erase() and it
+                        * may rebalance the tree from under us. So
+                        * restart the search from scratch. Returning
+                        * NULL would be safe too, but we'd generate
+                        * false negative insertions just because some
+                        * stable_node was stale.
+                        */
+                       goto again;
+               }
 
                ret = memcmp_pages(kpage, tree_page);
                put_page(tree_page);