!kvm_is_zone_device_pfn(pfn) && level == PT_PAGE_TABLE_LEVEL &&
            PageTransCompoundMap(pfn_to_page(pfn))) {
                unsigned long mask;
+
                /*
-                * mmu_notifier_retry was successful and we hold the
-                * mmu_lock here, so the pmd can't become splitting
-                * from under us, and in turn
-                * __split_huge_page_refcount() can't run from under
-                * us and we can safely transfer the refcount from
-                * PG_tail to PG_head as we switch the pfn to tail to
-                * head.
+                * mmu_notifier_retry() was successful and mmu_lock is held, so
+                * the pmd can't be split from under us.
                 */
                *levelp = level = PT_DIRECTORY_LEVEL;
                mask = KVM_PAGES_PER_HPAGE(level) - 1;
                VM_BUG_ON((gfn & mask) != (pfn & mask));
-               if (pfn & mask) {
-                       kvm_release_pfn_clean(pfn);
-                       pfn &= ~mask;
-                       kvm_get_pfn(pfn);
-                       *pfnp = pfn;
-               }
+               *pfnp = pfn & ~mask;
        }
 }
 
 }
 
 static int __direct_map(struct kvm_vcpu *vcpu, gpa_t gpa, int write,
-                       int map_writable, int level, kvm_pfn_t pfn,
-                       bool prefault, bool account_disallowed_nx_lpage)
+                       int map_writable, int level, int max_level,
+                       kvm_pfn_t pfn, bool prefault,
+                       bool account_disallowed_nx_lpage)
 {
        struct kvm_shadow_walk_iterator it;
        struct kvm_mmu_page *sp;
        if (!VALID_PAGE(vcpu->arch.mmu->root_hpa))
                return RET_PF_RETRY;
 
+       if (likely(max_level > PT_PAGE_TABLE_LEVEL))
+               transparent_hugepage_adjust(vcpu, gfn, &pfn, &level);
+
        trace_kvm_mmu_spte_requested(gpa, level, pfn);
        for_each_shadow_entry(vcpu, gpa, it) {
                /*
                goto out_unlock;
        if (make_mmu_pages_available(vcpu) < 0)
                goto out_unlock;
-       if (likely(max_level > PT_PAGE_TABLE_LEVEL))
-               transparent_hugepage_adjust(vcpu, gfn, &pfn, &level);
-       r = __direct_map(vcpu, gpa, write, map_writable, level, pfn, prefault,
-                        is_tdp && lpage_disallowed);
+       r = __direct_map(vcpu, gpa, write, map_writable, level, max_level, pfn,
+                        prefault, is_tdp && lpage_disallowed);
 
 out_unlock:
        spin_unlock(&vcpu->kvm->mmu_lock);
 
  */
 static int FNAME(fetch)(struct kvm_vcpu *vcpu, gpa_t addr,
                         struct guest_walker *gw,
-                        int write_fault, int hlevel,
+                        int write_fault, int hlevel, int max_level,
                         kvm_pfn_t pfn, bool map_writable, bool prefault,
                         bool lpage_disallowed)
 {
        gfn = gw->gfn | ((addr & PT_LVL_OFFSET_MASK(gw->level)) >> PAGE_SHIFT);
        base_gfn = gfn;
 
+       if (max_level > PT_PAGE_TABLE_LEVEL)
+               transparent_hugepage_adjust(vcpu, gw->gfn, &pfn, &hlevel);
+
        trace_kvm_mmu_spte_requested(addr, gw->level, pfn);
 
        for (; shadow_walk_okay(&it); shadow_walk_next(&it)) {
        kvm_mmu_audit(vcpu, AUDIT_PRE_PAGE_FAULT);
        if (make_mmu_pages_available(vcpu) < 0)
                goto out_unlock;
-       if (max_level > PT_PAGE_TABLE_LEVEL)
-               transparent_hugepage_adjust(vcpu, walker.gfn, &pfn, &level);
-       r = FNAME(fetch)(vcpu, addr, &walker, write_fault,
-                        level, pfn, map_writable, prefault, lpage_disallowed);
+       r = FNAME(fetch)(vcpu, addr, &walker, write_fault, level, max_level,
+                        pfn, map_writable, prefault, lpage_disallowed);
        kvm_mmu_audit(vcpu, AUDIT_POST_PAGE_FAULT);
 
 out_unlock: