atomic_sub(refs, compound_pincount_ptr(page));
 }
 
+/* Equivalent to calling put_page() @refs times. */
+static void put_page_refs(struct page *page, int refs)
+{
+#ifdef CONFIG_DEBUG_VM
+       if (VM_WARN_ON_ONCE_PAGE(page_ref_count(page) < refs, page))
+               return;
+#endif
+
+       /*
+        * Calling put_page() for each ref is unnecessarily slow. Only the last
+        * ref needs a put_page().
+        */
+       if (refs > 1)
+               page_ref_sub(page, refs - 1);
+       put_page(page);
+}
+
 /*
  * Return the compound head page with ref appropriately incremented,
  * or NULL if that failed.
                return NULL;
        if (unlikely(!page_cache_add_speculative(head, refs)))
                return NULL;
+
+       /*
+        * At this point we have a stable reference to the head page; but it
+        * could be that between the compound_head() lookup and the refcount
+        * increment, the compound page was split, in which case we'd end up
+        * holding a reference on a page that has nothing to do with the page
+        * we were given anymore.
+        * So now that the head page is stable, recheck that the pages still
+        * belong together.
+        */
+       if (unlikely(compound_head(page) != head)) {
+               put_page_refs(head, refs);
+               return NULL;
+       }
+
        return head;
 }
 
                             !is_pinnable_page(page)))
                        return NULL;
 
+               /*
+                * CAUTION: Don't use compound_head() on the page before this
+                * point, the result won't be stable.
+                */
+               page = try_get_compound_head(page, refs);
+               if (!page)
+                       return NULL;
+
                /*
                 * When pinning a compound page of order > 1 (which is what
                 * hpage_pincount_available() checks for), use an exact count to
                 * However, be sure to *also* increment the normal page refcount
                 * field at least once, so that the page really is pinned.
                 */
-               if (!hpage_pincount_available(page))
-                       refs *= GUP_PIN_COUNTING_BIAS;
-
-               page = try_get_compound_head(page, refs);
-               if (!page)
-                       return NULL;
-
                if (hpage_pincount_available(page))
                        hpage_pincount_add(page, refs);
+               else
+                       page_ref_add(page, refs * (GUP_PIN_COUNTING_BIAS - 1));
 
                mod_node_page_state(page_pgdat(page), NR_FOLL_PIN_ACQUIRED,
                                    orig_refs);
                        refs *= GUP_PIN_COUNTING_BIAS;
        }
 
-       VM_BUG_ON_PAGE(page_ref_count(page) < refs, page);
-       /*
-        * Calling put_page() for each ref is unnecessarily slow. Only the last
-        * ref needs a put_page().
-        */
-       if (refs > 1)
-               page_ref_sub(page, refs - 1);
-       put_page(page);
+       put_page_refs(page, refs);
 }
 
 /**