* @gfn_offset: start of the BITS_PER_LONG pages we care about
  * @mask: indicates which pages we should protect
  *
- * Used when we do not need to care about huge page mappings: e.g. during dirty
- * logging we do not have any such mappings.
+ * Used when we do not need to care about huge page mappings.
  */
 static void kvm_mmu_write_protect_pt_masked(struct kvm *kvm,
                                     struct kvm_memory_slot *slot,
  * It calls kvm_mmu_write_protect_pt_masked to write protect selected pages to
  * enable dirty logging for them.
  *
- * Used when we do not need to care about huge page mappings: e.g. during dirty
- * logging we do not have any such mappings.
+ * We need to care about huge page mappings: e.g. during dirty logging we may
+ * have such mappings.
  */
 void kvm_arch_mmu_enable_log_dirty_pt_masked(struct kvm *kvm,
                                struct kvm_memory_slot *slot,
                                gfn_t gfn_offset, unsigned long mask)
 {
+       /*
+        * Huge pages are NOT write protected when we start dirty logging in
+        * initially-all-set mode; must write protect them here so that they
+        * are split to 4K on the first write.
+        *
+        * The gfn_offset is guaranteed to be aligned to 64, but the base_gfn
+        * of memslot has no such restriction, so the range can cross two large
+        * pages.
+        */
+       if (kvm_dirty_log_manual_protect_and_init_set(kvm)) {
+               gfn_t start = slot->base_gfn + gfn_offset + __ffs(mask);
+               gfn_t end = slot->base_gfn + gfn_offset + __fls(mask);
+
+               kvm_mmu_slot_gfn_write_protect(kvm, slot, start, PG_LEVEL_2M);
+
+               /* Cross two large pages? */
+               if (ALIGN(start << PAGE_SHIFT, PMD_SIZE) !=
+                   ALIGN(end << PAGE_SHIFT, PMD_SIZE))
+                       kvm_mmu_slot_gfn_write_protect(kvm, slot, end,
+                                                      PG_LEVEL_2M);
+       }
+
+       /* Now handle 4K PTEs.  */
        if (kvm_x86_ops.cpu_dirty_log_size)
                kvm_mmu_clear_dirty_pt_masked(kvm, slot, gfn_offset, mask);
        else
 
                 */
                kvm_mmu_zap_collapsible_sptes(kvm, new);
        } else {
-               /* By default, write-protect everything to log writes. */
-               int level = PG_LEVEL_4K;
+               /*
+                * Initially-all-set does not require write protecting any page,
+                * because they're all assumed to be dirty.
+                */
+               if (kvm_dirty_log_manual_protect_and_init_set(kvm))
+                       return;
 
                if (kvm_x86_ops.cpu_dirty_log_size) {
-                       /*
-                        * Clear all dirty bits, unless pages are treated as
-                        * dirty from the get-go.
-                        */
-                       if (!kvm_dirty_log_manual_protect_and_init_set(kvm))
-                               kvm_mmu_slot_leaf_clear_dirty(kvm, new);
-
-                       /*
-                        * Write-protect large pages on write so that dirty
-                        * logging happens at 4k granularity.  No need to
-                        * write-protect small SPTEs since write accesses are
-                        * logged by the CPU via dirty bits.
-                        */
-                       level = PG_LEVEL_2M;
-               } else if (kvm_dirty_log_manual_protect_and_init_set(kvm)) {
-                       /*
-                        * If we're with initial-all-set, we don't need
-                        * to write protect any small page because
-                        * they're reported as dirty already.  However
-                        * we still need to write-protect huge pages
-                        * so that the page split can happen lazily on
-                        * the first write to the huge page.
-                        */
-                       level = PG_LEVEL_2M;
+                       kvm_mmu_slot_leaf_clear_dirty(kvm, new);
+                       kvm_mmu_slot_remove_write_access(kvm, new, PG_LEVEL_2M);
+               } else {
+                       kvm_mmu_slot_remove_write_access(kvm, new, PG_LEVEL_4K);
                }
-               kvm_mmu_slot_remove_write_access(kvm, new, level);
        }
 }