#include <linux/khugepaged.h>
 #include <linux/rculist_nulls.h>
 #include <linux/random.h>
+#include <linux/mmu_notifier.h>
 
 #include <asm/tlbflush.h>
 #include <asm/div64.h>
        return false;
 }
 
-static unsigned long get_pte_pfn(pte_t pte, struct vm_area_struct *vma, unsigned long addr)
+static unsigned long get_pte_pfn(pte_t pte, struct vm_area_struct *vma, unsigned long addr,
+                                struct pglist_data *pgdat)
 {
        unsigned long pfn = pte_pfn(pte);
 
        if (WARN_ON_ONCE(pte_devmap(pte) || pte_special(pte)))
                return -1;
 
+       if (!pte_young(pte) && !mm_has_notifiers(vma->vm_mm))
+               return -1;
+
        if (WARN_ON_ONCE(!pfn_valid(pfn)))
                return -1;
 
+       if (pfn < pgdat->node_start_pfn || pfn >= pgdat_end_pfn(pgdat))
+               return -1;
+
        return pfn;
 }
 
-static unsigned long get_pmd_pfn(pmd_t pmd, struct vm_area_struct *vma, unsigned long addr)
+static unsigned long get_pmd_pfn(pmd_t pmd, struct vm_area_struct *vma, unsigned long addr,
+                                struct pglist_data *pgdat)
 {
        unsigned long pfn = pmd_pfn(pmd);
 
        if (WARN_ON_ONCE(pmd_devmap(pmd)))
                return -1;
 
+       if (!pmd_young(pmd) && !mm_has_notifiers(vma->vm_mm))
+               return -1;
+
        if (WARN_ON_ONCE(!pfn_valid(pfn)))
                return -1;
 
+       if (pfn < pgdat->node_start_pfn || pfn >= pgdat_end_pfn(pgdat))
+               return -1;
+
        return pfn;
 }
 
 {
        struct folio *folio;
 
-       /* try to avoid unnecessary memory loads */
-       if (pfn < pgdat->node_start_pfn || pfn >= pgdat_end_pfn(pgdat))
-               return NULL;
-
        folio = pfn_folio(pfn);
        if (folio_nid(folio) != pgdat->node_id)
                return NULL;
                total++;
                walk->mm_stats[MM_LEAF_TOTAL]++;
 
-               pfn = get_pte_pfn(ptent, args->vma, addr);
+               pfn = get_pte_pfn(ptent, args->vma, addr, pgdat);
                if (pfn == -1)
                        continue;
 
-               if (!pte_young(ptent)) {
-                       continue;
-               }
-
                folio = get_pfn_folio(pfn, memcg, pgdat, walk->can_swap);
                if (!folio)
                        continue;
 
-               if (!ptep_test_and_clear_young(args->vma, addr, pte + i))
-                       VM_WARN_ON_ONCE(true);
+               if (!ptep_clear_young_notify(args->vma, addr, pte + i))
+                       continue;
 
                young++;
                walk->mm_stats[MM_LEAF_YOUNG]++;
                /* don't round down the first address */
                addr = i ? (*first & PMD_MASK) + i * PMD_SIZE : *first;
 
-               pfn = get_pmd_pfn(pmd[i], vma, addr);
-               if (pfn == -1)
+               if (!pmd_present(pmd[i]))
                        goto next;
 
                if (!pmd_trans_huge(pmd[i])) {
-                       if (!walk->force_scan && should_clear_pmd_young())
+                       if (!walk->force_scan && should_clear_pmd_young() &&
+                           !mm_has_notifiers(args->mm))
                                pmdp_test_and_clear_young(vma, addr, pmd + i);
                        goto next;
                }
 
+               pfn = get_pmd_pfn(pmd[i], vma, addr, pgdat);
+               if (pfn == -1)
+                       goto next;
+
                folio = get_pfn_folio(pfn, memcg, pgdat, walk->can_swap);
                if (!folio)
                        goto next;
 
-               if (!pmdp_test_and_clear_young(vma, addr, pmd + i))
+               if (!pmdp_clear_young_notify(vma, addr, pmd + i))
                        goto next;
 
                walk->mm_stats[MM_LEAF_YOUNG]++;
                }
 
                if (pmd_trans_huge(val)) {
-                       unsigned long pfn = pmd_pfn(val);
                        struct pglist_data *pgdat = lruvec_pgdat(walk->lruvec);
+                       unsigned long pfn = get_pmd_pfn(val, vma, addr, pgdat);
 
                        walk->mm_stats[MM_LEAF_TOTAL]++;
 
-                       if (!pmd_young(val)) {
-                               continue;
-                       }
-
-                       /* try to avoid unnecessary memory loads */
-                       if (pfn < pgdat->node_start_pfn || pfn >= pgdat_end_pfn(pgdat))
-                               continue;
-
-                       walk_pmd_range_locked(pud, addr, vma, args, bitmap, &first);
+                       if (pfn != -1)
+                               walk_pmd_range_locked(pud, addr, vma, args, bitmap, &first);
                        continue;
                }
 
-               if (!walk->force_scan && should_clear_pmd_young()) {
+               if (!walk->force_scan && should_clear_pmd_young() &&
+                   !mm_has_notifiers(args->mm)) {
                        if (!pmd_young(val))
                                continue;
 
  * the PTE table to the Bloom filter. This forms a feedback loop between the
  * eviction and the aging.
  */
-void lru_gen_look_around(struct page_vma_mapped_walk *pvmw)
+bool lru_gen_look_around(struct page_vma_mapped_walk *pvmw)
 {
        int i;
        unsigned long start;
        unsigned long end;
        struct lru_gen_mm_walk *walk;
-       int young = 0;
+       int young = 1;
        pte_t *pte = pvmw->pte;
        unsigned long addr = pvmw->address;
        struct vm_area_struct *vma = pvmw->vma;
        lockdep_assert_held(pvmw->ptl);
        VM_WARN_ON_ONCE_FOLIO(folio_test_lru(folio), folio);
 
+       if (!ptep_clear_young_notify(vma, addr, pte))
+               return false;
+
        if (spin_is_contended(pvmw->ptl))
-               return;
+               return true;
 
        /* exclude special VMAs containing anon pages from COW */
        if (vma->vm_flags & VM_SPECIAL)
-               return;
+               return true;
 
        /* avoid taking the LRU lock under the PTL when possible */
        walk = current->reclaim_state ? current->reclaim_state->mm_walk : NULL;
        start = max(addr & PMD_MASK, vma->vm_start);
        end = min(addr | ~PMD_MASK, vma->vm_end - 1) + 1;
 
+       if (end - start == PAGE_SIZE)
+               return true;
+
        if (end - start > MIN_LRU_BATCH * PAGE_SIZE) {
                if (addr - start < MIN_LRU_BATCH * PAGE_SIZE / 2)
                        end = start + MIN_LRU_BATCH * PAGE_SIZE;
 
        /* folio_update_gen() requires stable folio_memcg() */
        if (!mem_cgroup_trylock_pages(memcg))
-               return;
+               return true;
 
        arch_enter_lazy_mmu_mode();
 
                unsigned long pfn;
                pte_t ptent = ptep_get(pte + i);
 
-               pfn = get_pte_pfn(ptent, vma, addr);
+               pfn = get_pte_pfn(ptent, vma, addr, pgdat);
                if (pfn == -1)
                        continue;
 
-               if (!pte_young(ptent))
-                       continue;
-
                folio = get_pfn_folio(pfn, memcg, pgdat, can_swap);
                if (!folio)
                        continue;
 
-               if (!ptep_test_and_clear_young(vma, addr, pte + i))
-                       VM_WARN_ON_ONCE(true);
+               if (!ptep_clear_young_notify(vma, addr, pte + i))
+                       continue;
 
                young++;
 
        /* feedback from rmap walkers to page table walkers */
        if (mm_state && suitable_to_scan(i, young))
                update_bloom_filter(mm_state, max_seq, pvmw->pmd);
+
+       return true;
 }
 
 /******************************************************************************