int kvm_mmu_reset_context(struct kvm_vcpu *vcpu);
 void kvm_mmu_slot_remove_write_access(struct kvm *kvm, int slot);
-int kvm_mmu_rmap_write_protect(struct kvm *kvm, u64 gfn,
-                              struct kvm_memory_slot *slot);
+void kvm_mmu_write_protect_pt_masked(struct kvm *kvm,
+                                    struct kvm_memory_slot *slot,
+                                    gfn_t gfn_offset, unsigned long mask);
 void kvm_mmu_zap_all(struct kvm *kvm);
 unsigned int kvm_mmu_calculate_mmu_pages(struct kvm *kvm);
 void kvm_mmu_change_mmu_pages(struct kvm *kvm, unsigned int kvm_nr_mmu_pages);
 
        return write_protected;
 }
 
-int kvm_mmu_rmap_write_protect(struct kvm *kvm, u64 gfn,
-                              struct kvm_memory_slot *slot)
+/**
+ * kvm_mmu_write_protect_pt_masked - write protect selected PT level pages
+ * @kvm: kvm instance
+ * @slot: slot to protect
+ * @gfn_offset: start of the BITS_PER_LONG pages we care about
+ * @mask: indicates which pages we should protect
+ *
+ * Used when we do not need to care about huge page mappings: e.g. during dirty
+ * logging we do not have any such mappings.
+ */
+void kvm_mmu_write_protect_pt_masked(struct kvm *kvm,
+                                    struct kvm_memory_slot *slot,
+                                    gfn_t gfn_offset, unsigned long mask)
 {
        unsigned long *rmapp;
-       int i, write_protected = 0;
 
-       for (i = PT_PAGE_TABLE_LEVEL;
-            i < PT_PAGE_TABLE_LEVEL + KVM_NR_PAGE_SIZES; ++i) {
-               rmapp = __gfn_to_rmap(gfn, i, slot);
-               write_protected |= __rmap_write_protect(kvm, rmapp, i);
-       }
+       while (mask) {
+               rmapp = &slot->rmap[gfn_offset + __ffs(mask)];
+               __rmap_write_protect(kvm, rmapp, PT_PAGE_TABLE_LEVEL);
 
-       return write_protected;
+               /* clear the first set bit */
+               mask &= mask - 1;
+       }
 }
 
 static int rmap_write_protect(struct kvm *kvm, u64 gfn)
 {
        struct kvm_memory_slot *slot;
+       unsigned long *rmapp;
+       int i;
+       int write_protected = 0;
 
        slot = gfn_to_memslot(kvm, gfn);
-       return kvm_mmu_rmap_write_protect(kvm, gfn, slot);
+
+       for (i = PT_PAGE_TABLE_LEVEL;
+            i < PT_PAGE_TABLE_LEVEL + KVM_NR_PAGE_SIZES; ++i) {
+               rmapp = __gfn_to_rmap(gfn, i, slot);
+               write_protected |= __rmap_write_protect(kvm, rmapp, i);
+       }
+
+       return write_protected;
 }
 
 static int kvm_unmap_rmapp(struct kvm *kvm, unsigned long *rmapp,
 
 
        /* Not many dirty pages compared to # of shadow pages. */
        if (nr_dirty_pages < kvm->arch.n_used_mmu_pages) {
-               unsigned long gfn_offset;
+               gfn_t offset;
 
-               for_each_set_bit(gfn_offset, dirty_bitmap, memslot->npages) {
-                       unsigned long gfn = memslot->base_gfn + gfn_offset;
+               for_each_set_bit(offset, dirty_bitmap, memslot->npages)
+                       kvm_mmu_write_protect_pt_masked(kvm, memslot, offset, 1);
 
-                       kvm_mmu_rmap_write_protect(kvm, gfn, memslot);
-               }
                kvm_flush_remote_tlbs(kvm);
        } else
                kvm_mmu_slot_remove_write_access(kvm, memslot->id);