return ctx->features & UFFD_FEATURE_WP_UNPOPULATED;
 }
 
-static void userfaultfd_set_vm_flags(struct vm_area_struct *vma,
-                                    vm_flags_t flags)
-{
-       const bool uffd_wp_changed = (vma->vm_flags ^ flags) & VM_UFFD_WP;
-
-       vm_flags_reset(vma, flags);
-       /*
-        * For shared mappings, we want to enable writenotify while
-        * userfaultfd-wp is enabled (see vma_wants_writenotify()). We'll simply
-        * recalculate vma->vm_page_prot whenever userfaultfd-wp changes.
-        */
-       if ((vma->vm_flags & VM_SHARED) && uffd_wp_changed)
-               vma_set_page_prot(vma);
-}
-
 static int userfaultfd_wake_function(wait_queue_entry_t *wq, unsigned mode,
                                     int wake_flags, void *key)
 {
        spin_unlock_irq(&ctx->event_wqh.lock);
 
        if (release_new_ctx) {
-               struct vm_area_struct *vma;
-               struct mm_struct *mm = release_new_ctx->mm;
-               VMA_ITERATOR(vmi, mm, 0);
-
-               /* the various vma->vm_userfaultfd_ctx still points to it */
-               mmap_write_lock(mm);
-               for_each_vma(vmi, vma) {
-                       if (vma->vm_userfaultfd_ctx.ctx == release_new_ctx) {
-                               vma_start_write(vma);
-                               vma->vm_userfaultfd_ctx = NULL_VM_UFFD_CTX;
-                               userfaultfd_set_vm_flags(vma,
-                                                        vma->vm_flags & ~__VM_UFFD_FLAGS);
-                       }
-               }
-               mmap_write_unlock(mm);
-
+               userfaultfd_release_new(release_new_ctx);
                userfaultfd_ctx_put(release_new_ctx);
        }
 
                return 0;
 
        if (!(octx->features & UFFD_FEATURE_EVENT_FORK)) {
-               vma_start_write(vma);
-               vma->vm_userfaultfd_ctx = NULL_VM_UFFD_CTX;
-               userfaultfd_set_vm_flags(vma, vma->vm_flags & ~__VM_UFFD_FLAGS);
+               userfaultfd_reset_ctx(vma);
                return 0;
        }
 
                up_write(&ctx->map_changing_lock);
        } else {
                /* Drop uffd context if remap feature not enabled */
-               vma_start_write(vma);
-               vma->vm_userfaultfd_ctx = NULL_VM_UFFD_CTX;
-               userfaultfd_set_vm_flags(vma, vma->vm_flags & ~__VM_UFFD_FLAGS);
+               userfaultfd_reset_ctx(vma);
        }
 }
 
 {
        struct userfaultfd_ctx *ctx = file->private_data;
        struct mm_struct *mm = ctx->mm;
-       struct vm_area_struct *vma, *prev;
        /* len == 0 means wake all */
        struct userfaultfd_wake_range range = { .len = 0, };
-       unsigned long new_flags;
-       VMA_ITERATOR(vmi, mm, 0);
 
        WRITE_ONCE(ctx->released, true);
 
-       if (!mmget_not_zero(mm))
-               goto wakeup;
-
-       /*
-        * Flush page faults out of all CPUs. NOTE: all page faults
-        * must be retried without returning VM_FAULT_SIGBUS if
-        * userfaultfd_ctx_get() succeeds but vma->vma_userfault_ctx
-        * changes while handle_userfault released the mmap_lock. So
-        * it's critical that released is set to true (above), before
-        * taking the mmap_lock for writing.
-        */
-       mmap_write_lock(mm);
-       prev = NULL;
-       for_each_vma(vmi, vma) {
-               cond_resched();
-               BUG_ON(!!vma->vm_userfaultfd_ctx.ctx ^
-                      !!(vma->vm_flags & __VM_UFFD_FLAGS));
-               if (vma->vm_userfaultfd_ctx.ctx != ctx) {
-                       prev = vma;
-                       continue;
-               }
-               /* Reset ptes for the whole vma range if wr-protected */
-               if (userfaultfd_wp(vma))
-                       uffd_wp_range(vma, vma->vm_start,
-                                     vma->vm_end - vma->vm_start, false);
-               new_flags = vma->vm_flags & ~__VM_UFFD_FLAGS;
-               vma = vma_modify_flags_uffd(&vmi, prev, vma, vma->vm_start,
-                                           vma->vm_end, new_flags,
-                                           NULL_VM_UFFD_CTX);
-
-               vma_start_write(vma);
-               userfaultfd_set_vm_flags(vma, new_flags);
-               vma->vm_userfaultfd_ctx = NULL_VM_UFFD_CTX;
+       userfaultfd_release_all(mm, ctx);
 
-               prev = vma;
-       }
-       mmap_write_unlock(mm);
-       mmput(mm);
-wakeup:
        /*
         * After no new page faults can wait on this fault_*wqh, flush
         * the last page faults that may have been already waiting on
                                unsigned long arg)
 {
        struct mm_struct *mm = ctx->mm;
-       struct vm_area_struct *vma, *prev, *cur;
+       struct vm_area_struct *vma, *cur;
        int ret;
        struct uffdio_register uffdio_register;
        struct uffdio_register __user *user_uffdio_register;
-       unsigned long vm_flags, new_flags;
+       unsigned long vm_flags;
        bool found;
        bool basic_ioctls;
-       unsigned long start, end, vma_end;
+       unsigned long start, end;
        struct vma_iterator vmi;
        bool wp_async = userfaultfd_wp_async_ctx(ctx);
 
        } for_each_vma_range(vmi, cur, end);
        BUG_ON(!found);
 
-       vma_iter_set(&vmi, start);
-       prev = vma_prev(&vmi);
-       if (vma->vm_start < start)
-               prev = vma;
-
-       ret = 0;
-       for_each_vma_range(vmi, vma, end) {
-               cond_resched();
-
-               BUG_ON(!vma_can_userfault(vma, vm_flags, wp_async));
-               BUG_ON(vma->vm_userfaultfd_ctx.ctx &&
-                      vma->vm_userfaultfd_ctx.ctx != ctx);
-               WARN_ON(!(vma->vm_flags & VM_MAYWRITE));
-
-               /*
-                * Nothing to do: this vma is already registered into this
-                * userfaultfd and with the right tracking mode too.
-                */
-               if (vma->vm_userfaultfd_ctx.ctx == ctx &&
-                   (vma->vm_flags & vm_flags) == vm_flags)
-                       goto skip;
-
-               if (vma->vm_start > start)
-                       start = vma->vm_start;
-               vma_end = min(end, vma->vm_end);
-
-               new_flags = (vma->vm_flags & ~__VM_UFFD_FLAGS) | vm_flags;
-               vma = vma_modify_flags_uffd(&vmi, prev, vma, start, vma_end,
-                                           new_flags,
-                                           (struct vm_userfaultfd_ctx){ctx});
-               if (IS_ERR(vma)) {
-                       ret = PTR_ERR(vma);
-                       break;
-               }
-
-               /*
-                * In the vma_merge() successful mprotect-like case 8:
-                * the next vma was merged into the current one and
-                * the current one has not been updated yet.
-                */
-               vma_start_write(vma);
-               userfaultfd_set_vm_flags(vma, new_flags);
-               vma->vm_userfaultfd_ctx.ctx = ctx;
-
-               if (is_vm_hugetlb_page(vma) && uffd_disable_huge_pmd_share(vma))
-                       hugetlb_unshare_all_pmds(vma);
-
-       skip:
-               prev = vma;
-               start = vma->vm_end;
-       }
+       ret = userfaultfd_register_range(ctx, vma, vm_flags, start, end,
+                                        wp_async);
 
 out_unlock:
        mmap_write_unlock(mm);
        struct vm_area_struct *vma, *prev, *cur;
        int ret;
        struct uffdio_range uffdio_unregister;
-       unsigned long new_flags;
        bool found;
        unsigned long start, end, vma_end;
        const void __user *buf = (void __user *)arg;
                        wake_userfault(vma->vm_userfaultfd_ctx.ctx, &range);
                }
 
-               /* Reset ptes for the whole vma range if wr-protected */
-               if (userfaultfd_wp(vma))
-                       uffd_wp_range(vma, start, vma_end - start, false);
-
-               new_flags = vma->vm_flags & ~__VM_UFFD_FLAGS;
-               vma = vma_modify_flags_uffd(&vmi, prev, vma, start, vma_end,
-                                           new_flags, NULL_VM_UFFD_CTX);
+               vma = userfaultfd_clear_vma(&vmi, prev, vma,
+                                           start, vma_end);
                if (IS_ERR(vma)) {
                        ret = PTR_ERR(vma);
                        break;
                }
 
-               /*
-                * In the vma_merge() successful mprotect-like case 8:
-                * the next vma was merged into the current one and
-                * the current one has not been updated yet.
-                */
-               vma_start_write(vma);
-               userfaultfd_set_vm_flags(vma, new_flags);
-               vma->vm_userfaultfd_ctx = NULL_VM_UFFD_CTX;
-
        skip:
                prev = vma;
                start = vma->vm_end;
 
        VM_WARN_ON(!moved && !err);
        return moved ? moved : err;
 }
+
+static void userfaultfd_set_vm_flags(struct vm_area_struct *vma,
+                                    vm_flags_t flags)
+{
+       const bool uffd_wp_changed = (vma->vm_flags ^ flags) & VM_UFFD_WP;
+
+       vm_flags_reset(vma, flags);
+       /*
+        * For shared mappings, we want to enable writenotify while
+        * userfaultfd-wp is enabled (see vma_wants_writenotify()). We'll simply
+        * recalculate vma->vm_page_prot whenever userfaultfd-wp changes.
+        */
+       if ((vma->vm_flags & VM_SHARED) && uffd_wp_changed)
+               vma_set_page_prot(vma);
+}
+
+static void userfaultfd_set_ctx(struct vm_area_struct *vma,
+                               struct userfaultfd_ctx *ctx,
+                               unsigned long flags)
+{
+       vma_start_write(vma);
+       vma->vm_userfaultfd_ctx = (struct vm_userfaultfd_ctx){ctx};
+       userfaultfd_set_vm_flags(vma,
+                                (vma->vm_flags & ~__VM_UFFD_FLAGS) | flags);
+}
+
+void userfaultfd_reset_ctx(struct vm_area_struct *vma)
+{
+       userfaultfd_set_ctx(vma, NULL, 0);
+}
+
+struct vm_area_struct *userfaultfd_clear_vma(struct vma_iterator *vmi,
+                                            struct vm_area_struct *prev,
+                                            struct vm_area_struct *vma,
+                                            unsigned long start,
+                                            unsigned long end)
+{
+       struct vm_area_struct *ret;
+
+       /* Reset ptes for the whole vma range if wr-protected */
+       if (userfaultfd_wp(vma))
+               uffd_wp_range(vma, start, end - start, false);
+
+       ret = vma_modify_flags_uffd(vmi, prev, vma, start, end,
+                                   vma->vm_flags & ~__VM_UFFD_FLAGS,
+                                   NULL_VM_UFFD_CTX);
+
+       /*
+        * In the vma_merge() successful mprotect-like case 8:
+        * the next vma was merged into the current one and
+        * the current one has not been updated yet.
+        */
+       if (!IS_ERR(ret))
+               userfaultfd_reset_ctx(vma);
+
+       return ret;
+}
+
+/* Assumes mmap write lock taken, and mm_struct pinned. */
+int userfaultfd_register_range(struct userfaultfd_ctx *ctx,
+                              struct vm_area_struct *vma,
+                              unsigned long vm_flags,
+                              unsigned long start, unsigned long end,
+                              bool wp_async)
+{
+       VMA_ITERATOR(vmi, ctx->mm, start);
+       struct vm_area_struct *prev = vma_prev(&vmi);
+       unsigned long vma_end;
+       unsigned long new_flags;
+
+       if (vma->vm_start < start)
+               prev = vma;
+
+       for_each_vma_range(vmi, vma, end) {
+               cond_resched();
+
+               BUG_ON(!vma_can_userfault(vma, vm_flags, wp_async));
+               BUG_ON(vma->vm_userfaultfd_ctx.ctx &&
+                      vma->vm_userfaultfd_ctx.ctx != ctx);
+               WARN_ON(!(vma->vm_flags & VM_MAYWRITE));
+
+               /*
+                * Nothing to do: this vma is already registered into this
+                * userfaultfd and with the right tracking mode too.
+                */
+               if (vma->vm_userfaultfd_ctx.ctx == ctx &&
+                   (vma->vm_flags & vm_flags) == vm_flags)
+                       goto skip;
+
+               if (vma->vm_start > start)
+                       start = vma->vm_start;
+               vma_end = min(end, vma->vm_end);
+
+               new_flags = (vma->vm_flags & ~__VM_UFFD_FLAGS) | vm_flags;
+               vma = vma_modify_flags_uffd(&vmi, prev, vma, start, vma_end,
+                                           new_flags,
+                                           (struct vm_userfaultfd_ctx){ctx});
+               if (IS_ERR(vma))
+                       return PTR_ERR(vma);
+
+               /*
+                * In the vma_merge() successful mprotect-like case 8:
+                * the next vma was merged into the current one and
+                * the current one has not been updated yet.
+                */
+               userfaultfd_set_ctx(vma, ctx, vm_flags);
+
+               if (is_vm_hugetlb_page(vma) && uffd_disable_huge_pmd_share(vma))
+                       hugetlb_unshare_all_pmds(vma);
+
+skip:
+               prev = vma;
+               start = vma->vm_end;
+       }
+
+       return 0;
+}
+
+void userfaultfd_release_new(struct userfaultfd_ctx *ctx)
+{
+       struct mm_struct *mm = ctx->mm;
+       struct vm_area_struct *vma;
+       VMA_ITERATOR(vmi, mm, 0);
+
+       /* the various vma->vm_userfaultfd_ctx still points to it */
+       mmap_write_lock(mm);
+       for_each_vma(vmi, vma) {
+               if (vma->vm_userfaultfd_ctx.ctx == ctx)
+                       userfaultfd_reset_ctx(vma);
+       }
+       mmap_write_unlock(mm);
+}
+
+void userfaultfd_release_all(struct mm_struct *mm,
+                            struct userfaultfd_ctx *ctx)
+{
+       struct vm_area_struct *vma, *prev;
+       VMA_ITERATOR(vmi, mm, 0);
+
+       if (!mmget_not_zero(mm))
+               return;
+
+       /*
+        * Flush page faults out of all CPUs. NOTE: all page faults
+        * must be retried without returning VM_FAULT_SIGBUS if
+        * userfaultfd_ctx_get() succeeds but vma->vma_userfault_ctx
+        * changes while handle_userfault released the mmap_lock. So
+        * it's critical that released is set to true (above), before
+        * taking the mmap_lock for writing.
+        */
+       mmap_write_lock(mm);
+       prev = NULL;
+       for_each_vma(vmi, vma) {
+               cond_resched();
+               BUG_ON(!!vma->vm_userfaultfd_ctx.ctx ^
+                      !!(vma->vm_flags & __VM_UFFD_FLAGS));
+               if (vma->vm_userfaultfd_ctx.ctx != ctx) {
+                       prev = vma;
+                       continue;
+               }
+
+               vma = userfaultfd_clear_vma(&vmi, prev, vma,
+                                           vma->vm_start, vma->vm_end);
+               prev = vma;
+       }
+       mmap_write_unlock(mm);
+       mmput(mm);
+}