vm_fault_t do_huge_pmd_numa_page(struct vm_fault *vmf)
 {
        struct vm_area_struct *vma = vmf->vma;
-       pmd_t oldpmd = vmf->orig_pmd;
-       pmd_t pmd;
        struct folio *folio;
        unsigned long haddr = vmf->address & HPAGE_PMD_MASK;
        int nid = NUMA_NO_NODE;
-       int target_nid, last_cpupid = (-1 & LAST_CPUPID_MASK);
+       int target_nid, last_cpupid;
+       pmd_t pmd, old_pmd;
        bool writable = false;
        int flags = 0;
 
        vmf->ptl = pmd_lock(vma->vm_mm, vmf->pmd);
-       if (unlikely(!pmd_same(oldpmd, *vmf->pmd))) {
+       old_pmd = pmdp_get(vmf->pmd);
+
+       if (unlikely(!pmd_same(old_pmd, vmf->orig_pmd))) {
                spin_unlock(vmf->ptl);
                return 0;
        }
 
-       pmd = pmd_modify(oldpmd, vma->vm_page_prot);
+       pmd = pmd_modify(old_pmd, vma->vm_page_prot);
 
        /*
         * Detect now whether the PMD could be writable; this information
        if (!folio)
                goto out_map;
 
-       /* See similar comment in do_numa_page for explanation */
-       if (!writable)
-               flags |= TNF_NO_GROUP;
-
        nid = folio_nid(folio);
-       /*
-        * For memory tiering mode, cpupid of slow memory page is used
-        * to record page access time.  So use default value.
-        */
-       if (!folio_use_access_time(folio))
-               last_cpupid = folio_last_cpupid(folio);
-       target_nid = numa_migrate_prep(folio, vmf, haddr, nid, &flags);
+
+       target_nid = numa_migrate_check(folio, vmf, haddr, &flags, writable,
+                                       &last_cpupid);
        if (target_nid == NUMA_NO_NODE)
                goto out_map;
        if (migrate_misplaced_folio_prepare(folio, vma, target_nid)) {
 
        flags |= TNF_MIGRATE_FAIL;
        vmf->ptl = pmd_lock(vma->vm_mm, vmf->pmd);
-       if (unlikely(!pmd_same(oldpmd, *vmf->pmd))) {
+       if (unlikely(!pmd_same(pmdp_get(vmf->pmd), vmf->orig_pmd))) {
                spin_unlock(vmf->ptl);
                return 0;
        }
 out_map:
        /* Restore the PMD */
-       pmd = pmd_modify(oldpmd, vma->vm_page_prot);
+       pmd = pmd_modify(pmdp_get(vmf->pmd), vma->vm_page_prot);
        pmd = pmd_mkyoung(pmd);
        if (writable)
                pmd = pmd_mkwrite(pmd, vma);
 
 
 void __vunmap_range_noflush(unsigned long start, unsigned long end);
 
-int numa_migrate_prep(struct folio *folio, struct vm_fault *vmf,
-                     unsigned long addr, int page_nid, int *flags);
+int numa_migrate_check(struct folio *folio, struct vm_fault *vmf,
+                     unsigned long addr, int *flags, bool writable,
+                     int *last_cpupid);
 
 void free_zone_device_folio(struct folio *folio);
 int migrate_device_coherent_page(struct page *page);
 
        return ret;
 }
 
-int numa_migrate_prep(struct folio *folio, struct vm_fault *vmf,
-                     unsigned long addr, int page_nid, int *flags)
+int numa_migrate_check(struct folio *folio, struct vm_fault *vmf,
+                     unsigned long addr, int *flags,
+                     bool writable, int *last_cpupid)
 {
        struct vm_area_struct *vma = vmf->vma;
 
+       /*
+        * Avoid grouping on RO pages in general. RO pages shouldn't hurt as
+        * much anyway since they can be in shared cache state. This misses
+        * the case where a mapping is writable but the process never writes
+        * to it but pte_write gets cleared during protection updates and
+        * pte_dirty has unpredictable behaviour between PTE scan updates,
+        * background writeback, dirty balancing and application behaviour.
+        */
+       if (!writable)
+               *flags |= TNF_NO_GROUP;
+
+       /*
+        * Flag if the folio is shared between multiple address spaces. This
+        * is later used when determining whether to group tasks together
+        */
+       if (folio_likely_mapped_shared(folio) && (vma->vm_flags & VM_SHARED))
+               *flags |= TNF_SHARED;
+       /*
+        * For memory tiering mode, cpupid of slow memory page is used
+        * to record page access time.  So use default value.
+        */
+       if (folio_use_access_time(folio))
+               *last_cpupid = (-1 & LAST_CPUPID_MASK);
+       else
+               *last_cpupid = folio_last_cpupid(folio);
+
        /* Record the current PID acceesing VMA */
        vma_set_access_pid_bit(vma);
 
        count_vm_numa_event(NUMA_HINT_FAULTS);
-       if (page_nid == numa_node_id()) {
+       if (folio_nid(folio) == numa_node_id()) {
                count_vm_numa_event(NUMA_HINT_FAULTS_LOCAL);
                *flags |= TNF_FAULT_LOCAL;
        }
        if (!folio || folio_is_zone_device(folio))
                goto out_map;
 
-       /*
-        * Avoid grouping on RO pages in general. RO pages shouldn't hurt as
-        * much anyway since they can be in shared cache state. This misses
-        * the case where a mapping is writable but the process never writes
-        * to it but pte_write gets cleared during protection updates and
-        * pte_dirty has unpredictable behaviour between PTE scan updates,
-        * background writeback, dirty balancing and application behaviour.
-        */
-       if (!writable)
-               flags |= TNF_NO_GROUP;
-
-       /*
-        * Flag if the folio is shared between multiple address spaces. This
-        * is later used when determining whether to group tasks together
-        */
-       if (folio_likely_mapped_shared(folio) && (vma->vm_flags & VM_SHARED))
-               flags |= TNF_SHARED;
-
        nid = folio_nid(folio);
        nr_pages = folio_nr_pages(folio);
-       /*
-        * For memory tiering mode, cpupid of slow memory page is used
-        * to record page access time.  So use default value.
-        */
-       if (folio_use_access_time(folio))
-               last_cpupid = (-1 & LAST_CPUPID_MASK);
-       else
-               last_cpupid = folio_last_cpupid(folio);
-       target_nid = numa_migrate_prep(folio, vmf, vmf->address, nid, &flags);
+
+       target_nid = numa_migrate_check(folio, vmf, vmf->address, &flags,
+                                       writable, &last_cpupid);
        if (target_nid == NUMA_NO_NODE)
                goto out_map;
        if (migrate_misplaced_folio_prepare(folio, vma, target_nid)) {