userfaultfd_event_wait_completion(ctx, &ewq);
 }
 
-void userfaultfd_remove(struct vm_area_struct *vma,
-                       struct vm_area_struct **prev,
+bool userfaultfd_remove(struct vm_area_struct *vma,
                        unsigned long start, unsigned long end)
 {
        struct mm_struct *mm = vma->vm_mm;
 
        ctx = vma->vm_userfaultfd_ctx.ctx;
        if (!ctx || !(ctx->features & UFFD_FEATURE_EVENT_REMOVE))
-               return;
+               return true;
 
        userfaultfd_ctx_get(ctx);
        up_read(&mm->mmap_sem);
 
-       *prev = NULL; /* We wait for ACK w/o the mmap semaphore */
-
        msg_init(&ewq.msg);
 
        ewq.msg.event = UFFD_EVENT_REMOVE;
 
        userfaultfd_event_wait_completion(ctx, &ewq);
 
-       down_read(&mm->mmap_sem);
+       return false;
 }
 
 static bool has_unmap_ctx(struct userfaultfd_ctx *ctx, struct list_head *unmaps,
 
                                        unsigned long from, unsigned long to,
                                        unsigned long len);
 
-extern void userfaultfd_remove(struct vm_area_struct *vma,
-                              struct vm_area_struct **prev,
+extern bool userfaultfd_remove(struct vm_area_struct *vma,
                               unsigned long start,
                               unsigned long end);
 
 {
 }
 
-static inline void userfaultfd_remove(struct vm_area_struct *vma,
-                                     struct vm_area_struct **prev,
+static inline bool userfaultfd_remove(struct vm_area_struct *vma,
                                      unsigned long start,
                                      unsigned long end)
 {
+       return true;
 }
 
 static inline int userfaultfd_unmap_prep(struct vm_area_struct *vma,
 
        if (!can_madv_dontneed_vma(vma))
                return -EINVAL;
 
-       userfaultfd_remove(vma, prev, start, end);
+       if (!userfaultfd_remove(vma, start, end)) {
+               *prev = NULL; /* mmap_sem has been dropped, prev is stale */
+
+               down_read(¤t->mm->mmap_sem);
+               vma = find_vma(current->mm, start);
+               if (!vma)
+                       return -ENOMEM;
+               if (start < vma->vm_start) {
+                       /*
+                        * This "vma" under revalidation is the one
+                        * with the lowest vma->vm_start where start
+                        * is also < vma->vm_end. If start <
+                        * vma->vm_start it means an hole materialized
+                        * in the user address space within the
+                        * virtual range passed to MADV_DONTNEED.
+                        */
+                       return -ENOMEM;
+               }
+               if (!can_madv_dontneed_vma(vma))
+                       return -EINVAL;
+               if (end > vma->vm_end) {
+                       /*
+                        * Don't fail if end > vma->vm_end. If the old
+                        * vma was splitted while the mmap_sem was
+                        * released the effect of the concurrent
+                        * operation may not cause MADV_DONTNEED to
+                        * have an undefined result. There may be an
+                        * adjacent next vma that we'll walk
+                        * next. userfaultfd_remove() will generate an
+                        * UFFD_EVENT_REMOVE repetition on the
+                        * end-vma->vm_end range, but the manager can
+                        * handle a repetition fine.
+                        */
+                       end = vma->vm_end;
+               }
+               VM_WARN_ON(start >= end);
+       }
        zap_page_range(vma, start, end - start);
        return 0;
 }
         * mmap_sem.
         */
        get_file(f);
-       userfaultfd_remove(vma, prev, start, end);
-       up_read(¤t->mm->mmap_sem);
+       if (userfaultfd_remove(vma, start, end)) {
+               /* mmap_sem was not released by userfaultfd_remove() */
+               up_read(¤t->mm->mmap_sem);
+       }
        error = vfs_fallocate(f,
                                FALLOC_FL_PUNCH_HOLE | FALLOC_FL_KEEP_SIZE,
                                offset, end - start);