VM_BUG_ON(waitqueue_active(&ctx->fault_wqh));
                VM_BUG_ON(spin_is_locked(&ctx->fd_wqh.lock));
                VM_BUG_ON(waitqueue_active(&ctx->fd_wqh));
-               mmput(ctx->mm);
+               mmdrop(ctx->mm);
                kmem_cache_free(userfaultfd_ctx_cachep, ctx);
        }
 }
 
        ACCESS_ONCE(ctx->released) = true;
 
+       if (!mmget_not_zero(mm))
+               goto wakeup;
+
        /*
         * Flush page faults out of all CPUs. NOTE: all page faults
         * must be retried without returning VM_FAULT_SIGBUS if
                vma->vm_userfaultfd_ctx = NULL_VM_UFFD_CTX;
        }
        up_write(&mm->mmap_sem);
-
+       mmput(mm);
+wakeup:
        /*
         * After no new page faults can wait on this fault_*wqh, flush
         * the last page faults that may have been already waiting on
        start = uffdio_register.range.start;
        end = start + uffdio_register.range.len;
 
+       ret = -ENOMEM;
+       if (!mmget_not_zero(mm))
+               goto out;
+
        down_write(&mm->mmap_sem);
        vma = find_vma_prev(mm, start, &prev);
-
-       ret = -ENOMEM;
        if (!vma)
                goto out_unlock;
 
        } while (vma && vma->vm_start < end);
 out_unlock:
        up_write(&mm->mmap_sem);
+       mmput(mm);
        if (!ret) {
                /*
                 * Now that we scanned all vmas we can already tell
        start = uffdio_unregister.start;
        end = start + uffdio_unregister.len;
 
+       ret = -ENOMEM;
+       if (!mmget_not_zero(mm))
+               goto out;
+
        down_write(&mm->mmap_sem);
        vma = find_vma_prev(mm, start, &prev);
-
-       ret = -ENOMEM;
        if (!vma)
                goto out_unlock;
 
        } while (vma && vma->vm_start < end);
 out_unlock:
        up_write(&mm->mmap_sem);
+       mmput(mm);
 out:
        return ret;
 }
                goto out;
        if (uffdio_copy.mode & ~UFFDIO_COPY_MODE_DONTWAKE)
                goto out;
-
-       ret = mcopy_atomic(ctx->mm, uffdio_copy.dst, uffdio_copy.src,
-                          uffdio_copy.len);
+       if (mmget_not_zero(ctx->mm)) {
+               ret = mcopy_atomic(ctx->mm, uffdio_copy.dst, uffdio_copy.src,
+                                  uffdio_copy.len);
+               mmput(ctx->mm);
+       }
        if (unlikely(put_user(ret, &user_uffdio_copy->copy)))
                return -EFAULT;
        if (ret < 0)
        if (uffdio_zeropage.mode & ~UFFDIO_ZEROPAGE_MODE_DONTWAKE)
                goto out;
 
-       ret = mfill_zeropage(ctx->mm, uffdio_zeropage.range.start,
-                            uffdio_zeropage.range.len);
+       if (mmget_not_zero(ctx->mm)) {
+               ret = mfill_zeropage(ctx->mm, uffdio_zeropage.range.start,
+                                    uffdio_zeropage.range.len);
+               mmput(ctx->mm);
+       }
        if (unlikely(put_user(ret, &user_uffdio_zeropage->zeropage)))
                return -EFAULT;
        if (ret < 0)
        ctx->released = false;
        ctx->mm = current->mm;
        /* prevent the mm struct to be freed */
-       atomic_inc(&ctx->mm->mm_users);
+       atomic_inc(&ctx->mm->mm_count);
 
        file = anon_inode_getfile("[userfaultfd]", &userfaultfd_fops, ctx,
                                  O_RDWR | (flags & UFFD_SHARED_FCNTL_FLAGS));
        if (IS_ERR(file)) {
-               mmput(ctx->mm);
+               mmdrop(ctx->mm);
                kmem_cache_free(userfaultfd_ctx_cachep, ctx);
        }
 out:
 
 
 /* mmdrop drops the mm and the page tables */
 extern void __mmdrop(struct mm_struct *);
-static inline void mmdrop(struct mm_struct * mm)
+static inline void mmdrop(struct mm_struct *mm)
 {
        if (unlikely(atomic_dec_and_test(&mm->mm_count)))
                __mmdrop(mm);
 }
 
+static inline bool mmget_not_zero(struct mm_struct *mm)
+{
+       return atomic_inc_not_zero(&mm->mm_users);
+}
+
 /* mmput gets rid of the mappings and all user-space */
 extern void mmput(struct mm_struct *);
 /* same as above but performs the slow path from the async kontext. Can