/*
  * Fetch a guest pte for a guest virtual address
  */
-static int FNAME(walk_addr)(struct guest_walker *walker,
-                           struct kvm_vcpu *vcpu, gva_t addr,
-                           int write_fault, int user_fault, int fetch_fault)
+static int FNAME(walk_addr_generic)(struct guest_walker *walker,
+                                   struct kvm_vcpu *vcpu, struct kvm_mmu *mmu,
+                                   gva_t addr, int write_fault,
+                                   int user_fault, int fetch_fault)
 {
        pt_element_t pte;
        gfn_t table_gfn;
 walk:
        present = true;
        eperm = rsvd_fault = false;
-       walker->level = vcpu->arch.mmu.root_level;
-       pte = vcpu->arch.mmu.get_cr3(vcpu);
+       walker->level = mmu->root_level;
+       pte           = mmu->get_cr3(vcpu);
+
 #if PTTYPE == 64
-       if (vcpu->arch.mmu.root_level == PT32E_ROOT_LEVEL) {
+       if (walker->level == PT32E_ROOT_LEVEL) {
                pte = kvm_pdptr_read(vcpu, (addr >> 30) & 3);
                trace_kvm_mmu_paging_element(pte, walker->level);
                if (!is_present_gpte(pte)) {
        }
 #endif
        ASSERT((!is_long_mode(vcpu) && is_pae(vcpu)) ||
-              (vcpu->arch.mmu.get_cr3(vcpu) & CR3_NONPAE_RESERVED_BITS) == 0);
+              (mmu->get_cr3(vcpu) & CR3_NONPAE_RESERVED_BITS) == 0);
 
        pt_access = ACC_ALL;
 
                                (PTTYPE == 64 || is_pse(vcpu))) ||
                    ((walker->level == PT_PDPE_LEVEL) &&
                                is_large_pte(pte) &&
-                               vcpu->arch.mmu.root_level == PT64_ROOT_LEVEL)) {
+                               mmu->root_level == PT64_ROOT_LEVEL)) {
                        int lvl = walker->level;
 
                        walker->gfn = gpte_to_gfn_lvl(pte, lvl);
        return 0;
 }
 
+static int FNAME(walk_addr)(struct guest_walker *walker,
+                           struct kvm_vcpu *vcpu, gva_t addr,
+                           int write_fault, int user_fault, int fetch_fault)
+{
+       return FNAME(walk_addr_generic)(walker, vcpu, &vcpu->arch.mmu, addr,
+                                       write_fault, user_fault, fetch_fault);
+}
+
 static void FNAME(update_pte)(struct kvm_vcpu *vcpu, struct kvm_mmu_page *sp,
                              u64 *spte, const void *pte)
 {