return stage2_ptep_test_and_clear_young((pte_t *)pmd);
 }
 
+static int stage2_pudp_test_and_clear_young(pud_t *pud)
+{
+       return stage2_ptep_test_and_clear_young((pte_t *)pud);
+}
+
 /**
  * kvm_phys_addr_ioremap - map a device range to guest IPA
  *
 
 static int kvm_age_hva_handler(struct kvm *kvm, gpa_t gpa, u64 size, void *data)
 {
+       pud_t *pud;
        pmd_t *pmd;
        pte_t *pte;
 
-       WARN_ON(size != PAGE_SIZE && size != PMD_SIZE);
-       pmd = stage2_get_pmd(kvm, NULL, gpa);
-       if (!pmd || pmd_none(*pmd))     /* Nothing there */
+       WARN_ON(size != PAGE_SIZE && size != PMD_SIZE && size != PUD_SIZE);
+       if (!stage2_get_leaf_entry(kvm, gpa, &pud, &pmd, &pte))
                return 0;
 
-       if (pmd_thp_or_huge(*pmd))      /* THP, HugeTLB */
+       if (pud)
+               return stage2_pudp_test_and_clear_young(pud);
+       else if (pmd)
                return stage2_pmdp_test_and_clear_young(pmd);
-
-       pte = pte_offset_kernel(pmd, gpa);
-       if (pte_none(*pte))
-               return 0;
-
-       return stage2_ptep_test_and_clear_young(pte);
+       else
+               return stage2_ptep_test_and_clear_young(pte);
 }
 
 static int kvm_test_age_hva_handler(struct kvm *kvm, gpa_t gpa, u64 size, void *data)
 {
+       pud_t *pud;
        pmd_t *pmd;
        pte_t *pte;
 
-       WARN_ON(size != PAGE_SIZE && size != PMD_SIZE);
-       pmd = stage2_get_pmd(kvm, NULL, gpa);
-       if (!pmd || pmd_none(*pmd))     /* Nothing there */
+       WARN_ON(size != PAGE_SIZE && size != PMD_SIZE && size != PUD_SIZE);
+       if (!stage2_get_leaf_entry(kvm, gpa, &pud, &pmd, &pte))
                return 0;
 
-       if (pmd_thp_or_huge(*pmd))              /* THP, HugeTLB */
+       if (pud)
+               return kvm_s2pud_young(*pud);
+       else if (pmd)
                return pmd_young(*pmd);
-
-       pte = pte_offset_kernel(pmd, gpa);
-       if (!pte_none(*pte))            /* Just a page... */
+       else
                return pte_young(*pte);
-
-       return 0;
 }
 
 int kvm_age_hva(struct kvm *kvm, unsigned long start, unsigned long end)