*/
 struct guest_walker {
        int level;
+       unsigned max_level;
        gfn_t table_gfn[PT_MAX_FULL_LEVELS];
        pt_element_t ptes[PT_MAX_FULL_LEVELS];
        pt_element_t prefetch_ptes[PTE_PREFETCH_NUM];
        gpa_t pte_gpa[PT_MAX_FULL_LEVELS];
+       pt_element_t __user *ptep_user[PT_MAX_FULL_LEVELS];
        unsigned pt_access;
        unsigned pte_access;
        gfn_t gfn;
        return false;
 }
 
+static int FNAME(update_accessed_dirty_bits)(struct kvm_vcpu *vcpu,
+                                            struct kvm_mmu *mmu,
+                                            struct guest_walker *walker,
+                                            int write_fault)
+{
+       unsigned level, index;
+       pt_element_t pte, orig_pte;
+       pt_element_t __user *ptep_user;
+       gfn_t table_gfn;
+       int ret;
+
+       for (level = walker->max_level; level >= walker->level; --level) {
+               pte = orig_pte = walker->ptes[level - 1];
+               table_gfn = walker->table_gfn[level - 1];
+               ptep_user = walker->ptep_user[level - 1];
+               index = offset_in_page(ptep_user) / sizeof(pt_element_t);
+               if (!(pte & PT_ACCESSED_MASK)) {
+                       trace_kvm_mmu_set_accessed_bit(table_gfn, index, sizeof(pte));
+                       pte |= PT_ACCESSED_MASK;
+               }
+               if (level == walker->level && write_fault && !is_dirty_gpte(pte)) {
+                       trace_kvm_mmu_set_dirty_bit(table_gfn, index, sizeof(pte));
+                       pte |= PT_DIRTY_MASK;
+               }
+               if (pte == orig_pte)
+                       continue;
+
+               ret = FNAME(cmpxchg_gpte)(vcpu, mmu, ptep_user, index, orig_pte, pte);
+               if (ret)
+                       return ret;
+
+               mark_page_dirty(vcpu->kvm, table_gfn);
+               walker->ptes[level] = pte;
+       }
+       return 0;
+}
+
 /*
  * Fetch a guest pte for a guest virtual address
  */
                                    struct kvm_vcpu *vcpu, struct kvm_mmu *mmu,
                                    gva_t addr, u32 access)
 {
+       int ret;
        pt_element_t pte;
        pt_element_t __user *uninitialized_var(ptep_user);
        gfn_t table_gfn;
                --walker->level;
        }
 #endif
+       walker->max_level = walker->level;
        ASSERT((!is_long_mode(vcpu) && is_pae(vcpu)) ||
               (mmu->get_cr3(vcpu) & CR3_NONPAE_RESERVED_BITS) == 0);
 
                ptep_user = (pt_element_t __user *)((void *)host_addr + offset);
                if (unlikely(__copy_from_user(&pte, ptep_user, sizeof(pte))))
                        goto error;
+               walker->ptep_user[walker->level - 1] = ptep_user;
 
                trace_kvm_mmu_paging_element(pte, walker->level);
 
                                        eperm = true;
                }
 
-               if (!eperm && unlikely(!(pte & PT_ACCESSED_MASK))) {
-                       int ret;
-                       trace_kvm_mmu_set_accessed_bit(table_gfn, index,
-                                                      sizeof(pte));
-                       ret = FNAME(cmpxchg_gpte)(vcpu, mmu, ptep_user, index,
-                                                 pte, pte|PT_ACCESSED_MASK);
-                       if (unlikely(ret < 0))
-                               goto error;
-                       else if (ret)
-                               goto retry_walk;
-
-                       mark_page_dirty(vcpu->kvm, table_gfn);
-                       pte |= PT_ACCESSED_MASK;
-               }
-
                walker->ptes[walker->level - 1] = pte;
 
                if (last_gpte) {
 
        if (!write_fault)
                protect_clean_gpte(&pte_access, pte);
-       else if (unlikely(!is_dirty_gpte(pte))) {
-               int ret;
 
-               trace_kvm_mmu_set_dirty_bit(table_gfn, index, sizeof(pte));
-               ret = FNAME(cmpxchg_gpte)(vcpu, mmu, ptep_user, index,
-                                         pte, pte|PT_DIRTY_MASK);
-               if (unlikely(ret < 0))
-                       goto error;
-               else if (ret)
-                       goto retry_walk;
-
-               mark_page_dirty(vcpu->kvm, table_gfn);
-               pte |= PT_DIRTY_MASK;
-               walker->ptes[walker->level - 1] = pte;
-       }
+       ret = FNAME(update_accessed_dirty_bits)(vcpu, mmu, walker, write_fault);
+       if (unlikely(ret < 0))
+               goto error;
+       else if (ret)
+               goto retry_walk;
 
        walker->pt_access = pt_access;
        walker->pte_access = pte_access;