From 2e9976edcf393471932fb222668a48afc107b810 Mon Sep 17 00:00:00 2001 From: "Liam R. Howlett" Date: Tue, 16 Mar 2021 15:58:18 -0400 Subject: [PATCH] userfaultfd rcu fix Signed-off-by: Liam R. Howlett --- fs/userfaultfd.c | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/fs/userfaultfd.c b/fs/userfaultfd.c index 1b2d1e3ed902..f22b6133f31e 100644 --- a/fs/userfaultfd.c +++ b/fs/userfaultfd.c @@ -609,12 +609,14 @@ static void userfaultfd_event_wait_completion(struct userfaultfd_ctx *ctx, /* the various vma->vm_userfaultfd_ctx still points to it */ mmap_write_lock(mm); + mas_lock(&mas); mas_for_each(&mas, vma, ULONG_MAX) { if (vma->vm_userfaultfd_ctx.ctx == release_new_ctx) { vma->vm_userfaultfd_ctx = NULL_VM_UFFD_CTX; vma->vm_flags &= ~(VM_UFFD_WP | VM_UFFD_MISSING); } } + mas_unlock(&mas); mmap_write_unlock(mm); userfaultfd_ctx_put(release_new_ctx); @@ -801,6 +803,7 @@ int userfaultfd_unmap_prep(struct vm_area_struct *vma, { MA_STATE(mas, &vma->vm_mm->mm_mt, vma->vm_start, vma->vm_start); + rcu_read_lock(); mas_for_each(&mas, vma, end) { struct userfaultfd_unmap_ctx *unmap_ctx; struct userfaultfd_ctx *ctx = vma->vm_userfaultfd_ctx.ctx; @@ -820,6 +823,7 @@ int userfaultfd_unmap_prep(struct vm_area_struct *vma, unmap_ctx->end = end; list_add_tail(&unmap_ctx->list, unmaps); } + rcu_read_unlock(); return 0; } @@ -867,9 +871,14 @@ static int userfaultfd_release(struct inode *inode, struct file *file) * taking the mmap_lock for writing. */ mmap_write_lock(mm); + mas_lock(&mas); prev = NULL; mas_for_each(&mas, vma, ULONG_MAX) { + mas_unlock(&mas); + mas_pause(&mas); cond_resched(); + mas_lock(&mas); + BUG_ON(!!vma->vm_userfaultfd_ctx.ctx ^ !!(vma->vm_flags & (VM_UFFD_MISSING | VM_UFFD_WP))); if (vma->vm_userfaultfd_ctx.ctx != ctx) { @@ -889,6 +898,7 @@ static int userfaultfd_release(struct inode *inode, struct file *file) vma->vm_flags = new_flags; vma->vm_userfaultfd_ctx = NULL_VM_UFFD_CTX; } + mas_unlock(&mas); mmap_write_unlock(mm); mmput(mm); wakeup: @@ -1316,6 +1326,7 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx, goto out; mmap_write_lock(mm); + rcu_read_lock(); vma = find_vma_prev(mm, start, &prev); if (!vma) goto out_unlock; @@ -1343,7 +1354,10 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx, basic_ioctls = false; mas_set(&mas, vma->vm_start); mas_for_each(&mas, cur, end) { + rcu_read_unlock(); + mas_pause(&mas); cond_resched(); + rcu_read_lock(); BUG_ON(!!cur->vm_userfaultfd_ctx.ctx ^ !!(cur->vm_flags & (VM_UFFD_MISSING | VM_UFFD_WP))); @@ -1461,6 +1475,7 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx, vma = vma_next(mm, vma); } while (vma && vma->vm_start < end); out_unlock: + rcu_read_unlock(); mmap_write_unlock(mm); mmput(mm); if (!ret) { @@ -1543,10 +1558,13 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx, */ found = false; ret = -EINVAL; + rcu_read_lock(); mas_set(&mas, vma->vm_start); - mas_for_each(&mas, cur, end) { + rcu_read_unlock(); + mas_pause(&mas); cond_resched(); + rcu_read_lock(); BUG_ON(!!cur->vm_userfaultfd_ctx.ctx ^ !!(cur->vm_flags & (VM_UFFD_MISSING | VM_UFFD_WP))); @@ -1563,6 +1581,7 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx, found = true; } + rcu_read_unlock(); BUG_ON(!found); if (vma->vm_start < start) -- 2.50.1