}
 
 #ifdef CONFIG_ARCH_HAS_PTE_SPECIAL
-static int gup_pte_range(pmd_t pmd, unsigned long addr, unsigned long end,
-                        unsigned int flags, struct page **pages, int *nr)
+/*
+ * Fast-gup relies on pte change detection to avoid concurrent pgtable
+ * operations.
+ *
+ * To pin the page, fast-gup needs to do below in order:
+ * (1) pin the page (by prefetching pte), then (2) check pte not changed.
+ *
+ * For the rest of pgtable operations where pgtable updates can be racy
+ * with fast-gup, we need to do (1) clear pte, then (2) check whether page
+ * is pinned.
+ *
+ * Above will work for all pte-level operations, including THP split.
+ *
+ * For THP collapse, it's a bit more complicated because fast-gup may be
+ * walking a pgtable page that is being freed (pte is still valid but pmd
+ * can be cleared already).  To avoid race in such condition, we need to
+ * also check pmd here to make sure pmd doesn't change (corresponds to
+ * pmdp_collapse_flush() in the THP collapse code path).
+ */
+static int gup_pte_range(pmd_t pmd, pmd_t *pmdp, unsigned long addr,
+                        unsigned long end, unsigned int flags,
+                        struct page **pages, int *nr)
 {
        struct dev_pagemap *pgmap = NULL;
        int nr_start = *nr, ret = 0;
                        goto pte_unmap;
                }
 
-               if (unlikely(pte_val(pte) != pte_val(*ptep))) {
+               if (unlikely(pmd_val(pmd) != pmd_val(*pmdp)) ||
+                   unlikely(pte_val(pte) != pte_val(*ptep))) {
                        gup_put_folio(folio, 1, flags);
                        goto pte_unmap;
                }
  * get_user_pages_fast_only implementation that can pin pages. Thus it's still
  * useful to have gup_huge_pmd even if we can't operate on ptes.
  */
-static int gup_pte_range(pmd_t pmd, unsigned long addr, unsigned long end,
-                        unsigned int flags, struct page **pages, int *nr)
+static int gup_pte_range(pmd_t pmd, pmd_t *pmdp, unsigned long addr,
+                        unsigned long end, unsigned int flags,
+                        struct page **pages, int *nr)
 {
        return 0;
 }
                        if (!gup_huge_pd(__hugepd(pmd_val(pmd)), addr,
                                         PMD_SHIFT, next, flags, pages, nr))
                                return 0;
-               } else if (!gup_pte_range(pmd, addr, next, flags, pages, nr))
+               } else if (!gup_pte_range(pmd, pmdp, addr, next, flags, pages, nr))
                        return 0;
        } while (pmdp++, addr = next, addr != end);
 
 
 
        pmd_ptl = pmd_lock(mm, pmd); /* probably unnecessary */
        /*
-        * After this gup_fast can't run anymore. This also removes
-        * any huge TLB entry from the CPU so we won't allow
-        * huge and small TLB entries for the same virtual address
-        * to avoid the risk of CPU bugs in that area.
+        * This removes any huge TLB entry from the CPU so we won't allow
+        * huge and small TLB entries for the same virtual address to
+        * avoid the risk of CPU bugs in that area.
+        *
+        * Parallel fast GUP is fine since fast GUP will back off when
+        * it detects PMD is changed.
         */
        _pmd = pmdp_collapse_flush(vma, address, pmd);
        spin_unlock(pmd_ptl);