#include <linux/sched.h>
 #include <linux/vmstat.h>
 
+struct mmu_gather;
+
 #ifdef CONFIG_KSM
 int ksm_madvise(struct vm_area_struct *vma, unsigned long start,
                unsigned long end, int advice, unsigned long *vm_flags);
 int __ksm_enter(struct mm_struct *mm);
-void __ksm_exit(struct mm_struct *mm);
+void __ksm_exit(struct mm_struct *mm,
+               struct mmu_gather **tlbp, unsigned long end);
 
 static inline int ksm_fork(struct mm_struct *mm, struct mm_struct *oldmm)
 {
        return 0;
 }
 
-static inline void ksm_exit(struct mm_struct *mm)
+/*
+ * For KSM to handle OOM without deadlock when it's breaking COW in a
+ * likely victim of the OOM killer, exit_mmap() has to serialize with
+ * ksm_exit() after freeing mm's pages but before freeing its page tables.
+ * That leaves a window in which KSM might refault pages which have just
+ * been finally unmapped: guard against that with ksm_test_exit(), and
+ * use it after getting mmap_sem in ksm.c, to check if mm is exiting.
+ */
+static inline bool ksm_test_exit(struct mm_struct *mm)
+{
+       return atomic_read(&mm->mm_users) == 0;
+}
+
+static inline void ksm_exit(struct mm_struct *mm,
+                           struct mmu_gather **tlbp, unsigned long end)
 {
        if (test_bit(MMF_VM_MERGEABLE, &mm->flags))
-               __ksm_exit(mm);
+               __ksm_exit(mm, tlbp, end);
 }
 
 /*
        return 0;
 }
 
-static inline void ksm_exit(struct mm_struct *mm)
+static inline bool ksm_test_exit(struct mm_struct *mm)
+{
+       return 0;
+}
+
+static inline void ksm_exit(struct mm_struct *mm,
+                           struct mmu_gather **tlbp, unsigned long end)
 {
 }
 
 
 #include <linux/mmu_notifier.h>
 #include <linux/ksm.h>
 
+#include <asm/tlb.h>
 #include <asm/tlbflush.h>
 
 /*
        struct vm_area_struct *vma;
 
        down_read(&mm->mmap_sem);
+       if (ksm_test_exit(mm))
+               goto out;
        vma = find_vma(mm, addr);
        if (!vma || vma->vm_start > addr)
                goto out;
        struct page *page;
 
        down_read(&mm->mmap_sem);
+       if (ksm_test_exit(mm))
+               goto out;
        vma = find_vma(mm, addr);
        if (!vma || vma->vm_start > addr)
                goto out;
        } else if (rmap_item->address & NODE_FLAG) {
                unsigned char age;
                /*
-                * ksm_thread can and must skip the rb_erase, because
+                * Usually ksmd can and must skip the rb_erase, because
                 * root_unstable_tree was already reset to RB_ROOT.
-                * But __ksm_exit has to be careful: do the rb_erase
-                * if it's interrupting a scan, and this rmap_item was
-                * inserted by this scan rather than left from before.
+                * But be careful when an mm is exiting: do the rb_erase
+                * if this rmap_item was inserted by this scan, rather
+                * than left over from before.
                 */
                age = (unsigned char)(ksm_scan.seqnr - rmap_item->address);
                BUG_ON(age > 1);
        int err = 0;
 
        for (addr = start; addr < end && !err; addr += PAGE_SIZE) {
+               if (ksm_test_exit(vma->vm_mm))
+                       break;
                if (signal_pending(current))
                        err = -ERESTARTSYS;
                else
        int err = 0;
 
        spin_lock(&ksm_mmlist_lock);
-       mm_slot = list_entry(ksm_mm_head.mm_list.next,
+       ksm_scan.mm_slot = list_entry(ksm_mm_head.mm_list.next,
                                                struct mm_slot, mm_list);
        spin_unlock(&ksm_mmlist_lock);
 
-       while (mm_slot != &ksm_mm_head) {
+       for (mm_slot = ksm_scan.mm_slot;
+                       mm_slot != &ksm_mm_head; mm_slot = ksm_scan.mm_slot) {
                mm = mm_slot->mm;
                down_read(&mm->mmap_sem);
                for (vma = mm->mmap; vma; vma = vma->vm_next) {
+                       if (ksm_test_exit(mm))
+                               break;
                        if (!(vma->vm_flags & VM_MERGEABLE) || !vma->anon_vma)
                                continue;
                        err = unmerge_ksm_pages(vma,
                                                vma->vm_start, vma->vm_end);
-                       if (err) {
-                               up_read(&mm->mmap_sem);
-                               goto out;
-                       }
+                       if (err)
+                               goto error;
                }
+
                remove_trailing_rmap_items(mm_slot, mm_slot->rmap_list.next);
-               up_read(&mm->mmap_sem);
 
                spin_lock(&ksm_mmlist_lock);
-               mm_slot = list_entry(mm_slot->mm_list.next,
+               ksm_scan.mm_slot = list_entry(mm_slot->mm_list.next,
                                                struct mm_slot, mm_list);
-               spin_unlock(&ksm_mmlist_lock);
+               if (ksm_test_exit(mm)) {
+                       hlist_del(&mm_slot->link);
+                       list_del(&mm_slot->mm_list);
+                       spin_unlock(&ksm_mmlist_lock);
+
+                       free_mm_slot(mm_slot);
+                       clear_bit(MMF_VM_MERGEABLE, &mm->flags);
+                       up_read(&mm->mmap_sem);
+                       mmdrop(mm);
+               } else {
+                       spin_unlock(&ksm_mmlist_lock);
+                       up_read(&mm->mmap_sem);
+               }
        }
 
        ksm_scan.seqnr = 0;
-out:
+       return 0;
+
+error:
+       up_read(&mm->mmap_sem);
        spin_lock(&ksm_mmlist_lock);
        ksm_scan.mm_slot = &ksm_mm_head;
        spin_unlock(&ksm_mmlist_lock);
        int err = -EFAULT;
 
        down_read(&mm1->mmap_sem);
+       if (ksm_test_exit(mm1))
+               goto out;
+
        vma = find_vma(mm1, addr1);
        if (!vma || vma->vm_start > addr1)
                goto out;
                return err;
 
        down_read(&mm1->mmap_sem);
+       if (ksm_test_exit(mm1)) {
+               up_read(&mm1->mmap_sem);
+               goto out;
+       }
        vma = find_vma(mm1, addr1);
        if (!vma || vma->vm_start > addr1) {
                up_read(&mm1->mmap_sem);
 
        mm = slot->mm;
        down_read(&mm->mmap_sem);
-       for (vma = find_vma(mm, ksm_scan.address); vma; vma = vma->vm_next) {
+       if (ksm_test_exit(mm))
+               vma = NULL;
+       else
+               vma = find_vma(mm, ksm_scan.address);
+
+       for (; vma; vma = vma->vm_next) {
                if (!(vma->vm_flags & VM_MERGEABLE))
                        continue;
                if (ksm_scan.address < vma->vm_start)
                        ksm_scan.address = vma->vm_end;
 
                while (ksm_scan.address < vma->vm_end) {
+                       if (ksm_test_exit(mm))
+                               break;
                        *page = follow_page(vma, ksm_scan.address, FOLL_GET);
                        if (*page && PageAnon(*page)) {
                                flush_anon_page(vma, *page, ksm_scan.address);
                }
        }
 
+       if (ksm_test_exit(mm)) {
+               ksm_scan.address = 0;
+               ksm_scan.rmap_item = list_entry(&slot->rmap_list,
+                                               struct rmap_item, link);
+       }
        /*
         * Nuke all the rmap_items that are above this current rmap:
         * because there were no VM_MERGEABLE vmas with such addresses.
                 * We've completed a full scan of all vmas, holding mmap_sem
                 * throughout, and found no VM_MERGEABLE: so do the same as
                 * __ksm_exit does to remove this mm from all our lists now.
+                * This applies either when cleaning up after __ksm_exit
+                * (but beware: we can reach here even before __ksm_exit),
+                * or when all VM_MERGEABLE areas have been unmapped (and
+                * mmap_sem then protects against race with MADV_MERGEABLE).
                 */
                hlist_del(&slot->link);
                list_del(&slot->mm_list);
+               spin_unlock(&ksm_mmlist_lock);
+
                free_mm_slot(slot);
                clear_bit(MMF_VM_MERGEABLE, &mm->flags);
+               up_read(&mm->mmap_sem);
+               mmdrop(mm);
+       } else {
+               spin_unlock(&ksm_mmlist_lock);
+               up_read(&mm->mmap_sem);
        }
-       spin_unlock(&ksm_mmlist_lock);
-       up_read(&mm->mmap_sem);
 
        /* Repeat until we've completed scanning the whole list */
        slot = ksm_scan.mm_slot;
        if (slot != &ksm_mm_head)
                goto next_mm;
 
-       /*
-        * Bump seqnr here rather than at top, so that __ksm_exit
-        * can skip rb_erase on unstable tree until we run again.
-        */
        ksm_scan.seqnr++;
        return NULL;
 }
        spin_unlock(&ksm_mmlist_lock);
 
        set_bit(MMF_VM_MERGEABLE, &mm->flags);
+       atomic_inc(&mm->mm_count);
 
        if (needs_wakeup)
                wake_up_interruptible(&ksm_thread_wait);
        return 0;
 }
 
-void __ksm_exit(struct mm_struct *mm)
+void __ksm_exit(struct mm_struct *mm,
+               struct mmu_gather **tlbp, unsigned long end)
 {
        struct mm_slot *mm_slot;
+       int easy_to_free = 0;
 
        /*
-        * This process is exiting: doesn't hold and doesn't need mmap_sem;
-        * but we do need to exclude ksmd and other exiters while we modify
-        * the various lists and trees.
+        * This process is exiting: if it's straightforward (as is the
+        * case when ksmd was never running), free mm_slot immediately.
+        * But if it's at the cursor or has rmap_items linked to it, use
+        * mmap_sem to synchronize with any break_cows before pagetables
+        * are freed, and leave the mm_slot on the list for ksmd to free.
+        * Beware: ksm may already have noticed it exiting and freed the slot.
         */
-       mutex_lock(&ksm_thread_mutex);
+
        spin_lock(&ksm_mmlist_lock);
        mm_slot = get_mm_slot(mm);
-       if (!list_empty(&mm_slot->rmap_list)) {
-               spin_unlock(&ksm_mmlist_lock);
-               remove_trailing_rmap_items(mm_slot, mm_slot->rmap_list.next);
-               spin_lock(&ksm_mmlist_lock);
-       }
-
-       if (ksm_scan.mm_slot == mm_slot) {
-               ksm_scan.mm_slot = list_entry(
-                       mm_slot->mm_list.next, struct mm_slot, mm_list);
-               ksm_scan.address = 0;
-               ksm_scan.rmap_item = list_entry(
-                       &ksm_scan.mm_slot->rmap_list, struct rmap_item, link);
-               if (ksm_scan.mm_slot == &ksm_mm_head)
-                       ksm_scan.seqnr++;
+       if (mm_slot && ksm_scan.mm_slot != mm_slot) {
+               if (list_empty(&mm_slot->rmap_list)) {
+                       hlist_del(&mm_slot->link);
+                       list_del(&mm_slot->mm_list);
+                       easy_to_free = 1;
+               } else {
+                       list_move(&mm_slot->mm_list,
+                                 &ksm_scan.mm_slot->mm_list);
+               }
        }
-
-       hlist_del(&mm_slot->link);
-       list_del(&mm_slot->mm_list);
        spin_unlock(&ksm_mmlist_lock);
 
-       free_mm_slot(mm_slot);
-       clear_bit(MMF_VM_MERGEABLE, &mm->flags);
-       mutex_unlock(&ksm_thread_mutex);
+       if (easy_to_free) {
+               free_mm_slot(mm_slot);
+               clear_bit(MMF_VM_MERGEABLE, &mm->flags);
+               mmdrop(mm);
+       } else if (mm_slot) {
+               tlb_finish_mmu(*tlbp, 0, end);
+               down_write(&mm->mmap_sem);
+               up_write(&mm->mmap_sem);
+               *tlbp = tlb_gather_mmu(mm, 1);
+       }
 }
 
 #define KSM_ATTR_RO(_name) \