}
 
 /*
- * Please note that this function, unlike __get_user_pages will not
- * return 0 for nr_pages > 0 without FOLL_NOWAIT
+ * Locking: (*locked == 1) means that the mmap_lock has already been acquired by
+ * the caller. This function may drop the mmap_lock. If it does so, then it will
+ * set (*locked = 0).
+ *
+ * (*locked == 0) means that the caller expects this function to acquire and
+ * drop the mmap_lock. Therefore, the value of *locked will still be zero when
+ * the function returns, even though it may have changed temporarily during
+ * function execution.
+ *
+ * Please note that this function, unlike __get_user_pages(), will not return 0
+ * for nr_pages > 0, unless FOLL_NOWAIT is used.
  */
 static __always_inline long __get_user_pages_locked(struct mm_struct *mm,
                                                unsigned long start,
                                                unsigned int flags)
 {
        long ret, pages_done;
-       bool lock_dropped;
+       bool must_unlock = false;
 
        if (locked) {
                /* if VM_FAULT_RETRY can be returned, vmas become invalid */
                BUG_ON(vmas);
-               /* check caller initialized locked */
-               BUG_ON(*locked != 1);
+       }
+
+       /*
+        * The internal caller expects GUP to manage the lock internally and the
+        * lock must be released when this returns.
+        */
+       if (locked && !*locked) {
+               if (mmap_read_lock_killable(mm))
+                       return -EAGAIN;
+               must_unlock = true;
+               *locked = 1;
        }
 
        if (flags & FOLL_PIN)
                flags |= FOLL_GET;
 
        pages_done = 0;
-       lock_dropped = false;
        for (;;) {
                ret = __get_user_pages(mm, start, nr_pages, flags, pages,
                                       vmas, locked);
                if (likely(pages))
                        pages += ret;
                start += ret << PAGE_SHIFT;
-               lock_dropped = true;
+
+               /* The lock was temporarily dropped, so we must unlock later */
+               must_unlock = true;
 
 retry:
                /*
                        pages++;
                start += PAGE_SIZE;
        }
-       if (lock_dropped && *locked) {
+       if (must_unlock && *locked) {
                /*
-                * We must let the caller know we temporarily dropped the lock
-                * and so the critical section protected by it was lost.
+                * We either temporarily dropped the lock, or the caller
+                * requested that we both acquire and drop the lock. Either way,
+                * we must now unlock, and notify the caller of that state.
                 */
                mmap_read_unlock(mm);
                *locked = 0;
                unsigned int foll_flags)
 {
        struct vm_area_struct *vma;
+       bool must_unlock = false;
        unsigned long vm_flags;
        long i;
 
+       if (!nr_pages)
+               return 0;
+
+       /*
+        * The internal caller expects GUP to manage the lock internally and the
+        * lock must be released when this returns.
+        */
+       if (locked && !*locked) {
+               if (mmap_read_lock_killable(mm))
+                       return -EAGAIN;
+               must_unlock = true;
+               *locked = 1;
+       }
+
        /* calculate required read or write permissions.
         * If FOLL_FORCE is set, we only require the "MAY" flags.
         */
        for (i = 0; i < nr_pages; i++) {
                vma = find_vma(mm, start);
                if (!vma)
-                       goto finish_or_fault;
+                       break;
 
                /* protect what we can, including chardevs */
                if ((vma->vm_flags & (VM_IO | VM_PFNMAP)) ||
                    !(vm_flags & vma->vm_flags))
-                       goto finish_or_fault;
+                       break;
 
                if (pages) {
                        pages[i] = virt_to_page((void *)start);
                start = (start + PAGE_SIZE) & PAGE_MASK;
        }
 
-       return i;
+       if (must_unlock && *locked) {
+               mmap_read_unlock(mm);
+               *locked = 0;
+       }
 
-finish_or_fault:
        return i ? : -EFAULT;
 }
 #endif /* !CONFIG_MMU */
 #ifdef CONFIG_ELF_CORE
 struct page *get_dump_page(unsigned long addr)
 {
-       struct mm_struct *mm = current->mm;
        struct page *page;
-       int locked = 1;
+       int locked = 0;
        int ret;
 
-       if (mmap_read_lock_killable(mm))
-               return NULL;
-       ret = __get_user_pages_locked(mm, addr, 1, &page, NULL, &locked,
+       ret = __get_user_pages_locked(current->mm, addr, 1, &page, NULL,
+                                     &locked,
                                      FOLL_FORCE | FOLL_DUMP | FOLL_GET);
-       if (locked)
-               mmap_read_unlock(mm);
        return (ret == 1) ? page : NULL;
 }
 #endif /* CONFIG_ELF_CORE */
                                  int *locked,
                                  unsigned int gup_flags)
 {
-       bool must_unlock = false;
        unsigned int flags;
        long rc, nr_pinned_pages;
 
-       if (locked && WARN_ON_ONCE(!*locked))
-               return -EINVAL;
-
        if (!(gup_flags & FOLL_LONGTERM))
                return __get_user_pages_locked(mm, start, nr_pages, pages, vmas,
                                               locked, gup_flags);
                return -EINVAL;
        flags = memalloc_pin_save();
        do {
-               if (locked && !*locked) {
-                       mmap_read_lock(mm);
-                       must_unlock = true;
-                       *locked = 1;
-               }
                nr_pinned_pages = __get_user_pages_locked(mm, start, nr_pages,
                                                          pages, vmas, locked,
                                                          gup_flags);
                rc = check_and_migrate_movable_pages(nr_pinned_pages, pages);
        } while (rc == -EAGAIN);
        memalloc_pin_restore(flags);
-
-       if (locked && *locked && must_unlock) {
-               mmap_read_unlock(mm);
-               *locked = 0;
-       }
        return rc ? rc : nr_pinned_pages;
 }
 
 long get_user_pages_unlocked(unsigned long start, unsigned long nr_pages,
                             struct page **pages, unsigned int gup_flags)
 {
-       struct mm_struct *mm = current->mm;
-       int locked = 1;
-       long ret;
+       int locked = 0;
 
-       mmap_read_lock(mm);
-       ret = __gup_longterm_locked(mm, start, nr_pages, pages, NULL, &locked,
-                                   gup_flags | FOLL_TOUCH);
-       if (locked)
-               mmap_read_unlock(mm);
-       return ret;
+       return __gup_longterm_locked(current->mm, start, nr_pages, pages, NULL,
+                                    &locked, gup_flags | FOLL_TOUCH);
 }
 EXPORT_SYMBOL(get_user_pages_unlocked);
 
 {
        unsigned long len, end;
        unsigned long nr_pinned;
+       int locked = 0;
        int ret;
 
        if (WARN_ON_ONCE(gup_flags & ~(FOLL_WRITE | FOLL_LONGTERM |
        /* Slow path: try to get the remaining pages with get_user_pages */
        start += nr_pinned << PAGE_SHIFT;
        pages += nr_pinned;
-       ret = get_user_pages_unlocked(start, nr_pages - nr_pinned, pages,
-                                     gup_flags);
+       ret = __gup_longterm_locked(current->mm, start, nr_pages - nr_pinned,
+                                   pages, NULL, &locked,
+                                   gup_flags | FOLL_TOUCH);
        if (ret < 0) {
                /*
                 * The caller has to unpin the pages we already pinned so
        /* FOLL_GET and FOLL_PIN are mutually exclusive. */
        if (WARN_ON_ONCE(gup_flags & FOLL_GET))
                return -EINVAL;
+       int locked = 0;
 
        if (WARN_ON_ONCE(!pages))
                return -EINVAL;
 
-       gup_flags |= FOLL_PIN;
-       return get_user_pages_unlocked(start, nr_pages, pages, gup_flags);
+       gup_flags |= FOLL_PIN | FOLL_TOUCH;
+       return __gup_longterm_locked(current->mm, start, nr_pages, pages, NULL,
+                                    &locked, gup_flags);
 }
 EXPORT_SYMBOL(pin_user_pages_unlocked);