return direct_page_fault(vcpu, fault);
 }
 
+static int kvm_tdp_map_page(struct kvm_vcpu *vcpu, gpa_t gpa, u64 error_code,
+                           u8 *level)
+{
+       int r;
+
+       /*
+        * Restrict to TDP page fault, since that's the only case where the MMU
+        * is indexed by GPA.
+        */
+       if (vcpu->arch.mmu->page_fault != kvm_tdp_page_fault)
+               return -EOPNOTSUPP;
+
+       do {
+               if (signal_pending(current))
+                       return -EINTR;
+               cond_resched();
+               r = kvm_mmu_do_page_fault(vcpu, gpa, error_code, true, NULL, level);
+       } while (r == RET_PF_RETRY);
+
+       if (r < 0)
+               return r;
+
+       switch (r) {
+       case RET_PF_FIXED:
+       case RET_PF_SPURIOUS:
+               return 0;
+
+       case RET_PF_EMULATE:
+               return -ENOENT;
+
+       case RET_PF_RETRY:
+       case RET_PF_CONTINUE:
+       case RET_PF_INVALID:
+       default:
+               WARN_ONCE(1, "could not fix page fault during prefault");
+               return -EIO;
+       }
+}
+
+long kvm_arch_vcpu_pre_fault_memory(struct kvm_vcpu *vcpu,
+                                   struct kvm_pre_fault_memory *range)
+{
+       u64 error_code = PFERR_GUEST_FINAL_MASK;
+       u8 level = PG_LEVEL_4K;
+       u64 end;
+       int r;
+
+       /*
+        * reload is efficient when called repeatedly, so we can do it on
+        * every iteration.
+        */
+       kvm_mmu_reload(vcpu);
+
+       if (kvm_arch_has_private_mem(vcpu->kvm) &&
+           kvm_mem_is_private(vcpu->kvm, gpa_to_gfn(range->gpa)))
+               error_code |= PFERR_PRIVATE_ACCESS;
+
+       /*
+        * Shadow paging uses GVA for kvm page fault, so restrict to
+        * two-dimensional paging.
+        */
+       r = kvm_tdp_map_page(vcpu, range->gpa, error_code, &level);
+       if (r < 0)
+               return r;
+
+       /*
+        * If the mapping that covers range->gpa can use a huge page, it
+        * may start below it or end after range->gpa + range->size.
+        */
+       end = (range->gpa & KVM_HPAGE_MASK(level)) + KVM_HPAGE_SIZE(level);
+       return min(range->size, end - range->gpa);
+}
+
 static void nonpaging_init_context(struct kvm_mmu *context)
 {
        context->page_fault = nonpaging_page_fault;