*/
 int __mmu_notifier_register(struct mmu_notifier *mn, struct mm_struct *mm)
 {
-       struct mmu_notifier_mm *mmu_notifier_mm;
+       struct mmu_notifier_mm *mmu_notifier_mm = NULL;
        int ret;
 
        lockdep_assert_held_write(&mm->mmap_sem);
        BUG_ON(atomic_read(&mm->mm_users) <= 0);
 
-       mmu_notifier_mm = kmalloc(sizeof(struct mmu_notifier_mm), GFP_KERNEL);
-       if (unlikely(!mmu_notifier_mm))
-               return -ENOMEM;
+       if (!mm->mmu_notifier_mm) {
+               /*
+                * kmalloc cannot be called under mm_take_all_locks(), but we
+                * know that mm->mmu_notifier_mm can't change while we hold
+                * the write side of the mmap_sem.
+                */
+               mmu_notifier_mm =
+                       kmalloc(sizeof(struct mmu_notifier_mm), GFP_KERNEL);
+               if (!mmu_notifier_mm)
+                       return -ENOMEM;
+
+               INIT_HLIST_HEAD(&mmu_notifier_mm->list);
+               spin_lock_init(&mmu_notifier_mm->lock);
+       }
 
        ret = mm_take_all_locks(mm);
        if (unlikely(ret))
                goto out_clean;
 
-       if (!mm_has_notifiers(mm)) {
-               INIT_HLIST_HEAD(&mmu_notifier_mm->list);
-               spin_lock_init(&mmu_notifier_mm->lock);
-
-               mm->mmu_notifier_mm = mmu_notifier_mm;
-               mmu_notifier_mm = NULL;
-       }
+       /* Pairs with the mmdrop in mmu_notifier_unregister_* */
        mmgrab(mm);
 
        /*
         * We can't race against any other mmu notifier method either
         * thanks to mm_take_all_locks().
         */
+       if (mmu_notifier_mm)
+               mm->mmu_notifier_mm = mmu_notifier_mm;
+
        spin_lock(&mm->mmu_notifier_mm->lock);
        hlist_add_head_rcu(&mn->hlist, &mm->mmu_notifier_mm->list);
        spin_unlock(&mm->mmu_notifier_mm->lock);
 
        mm_drop_all_locks(mm);
+       BUG_ON(atomic_read(&mm->mm_users) <= 0);
+       return 0;
+
 out_clean:
        kfree(mmu_notifier_mm);
-       BUG_ON(atomic_read(&mm->mm_users) <= 0);
        return ret;
 }
 EXPORT_SYMBOL_GPL(__mmu_notifier_register);