static u64 __read_mostly shadow_present_mask;
 
 /*
- * The mask/value to distinguish a PTE that has been marked not-present for
- * access tracking purposes.
- * The mask would be either 0 if access tracking is disabled, or
- * SPTE_SPECIAL_MASK|VMX_EPT_RWX_MASK if access tracking is enabled.
+ * SPTEs used by MMUs without A/D bits are marked with shadow_acc_track_value.
+ * Non-present SPTEs with shadow_acc_track_value set are in place for access
+ * tracking.
  */
 static u64 __read_mostly shadow_acc_track_mask;
 static const u64 shadow_acc_track_value = SPTE_SPECIAL_MASK;
 }
 EXPORT_SYMBOL_GPL(kvm_mmu_set_mmio_spte_mask);
 
+static inline bool sp_ad_disabled(struct kvm_mmu_page *sp)
+{
+       return sp->role.ad_disabled;
+}
+
+static inline bool spte_ad_enabled(u64 spte)
+{
+       MMU_WARN_ON((spte & shadow_mmio_mask) == shadow_mmio_value);
+       return !(spte & shadow_acc_track_value);
+}
+
+static inline u64 spte_shadow_accessed_mask(u64 spte)
+{
+       MMU_WARN_ON((spte & shadow_mmio_mask) == shadow_mmio_value);
+       return spte_ad_enabled(spte) ? shadow_accessed_mask : 0;
+}
+
+static inline u64 spte_shadow_dirty_mask(u64 spte)
+{
+       MMU_WARN_ON((spte & shadow_mmio_mask) == shadow_mmio_value);
+       return spte_ad_enabled(spte) ? shadow_dirty_mask : 0;
+}
+
 static inline bool is_access_track_spte(u64 spte)
 {
-       /* Always false if shadow_acc_track_mask is zero.  */
-       return (spte & shadow_acc_track_mask) == shadow_acc_track_value;
+       return !spte_ad_enabled(spte) && (spte & shadow_acc_track_mask) == 0;
 }
 
 /*
                u64 dirty_mask, u64 nx_mask, u64 x_mask, u64 p_mask,
                u64 acc_track_mask)
 {
-       if (acc_track_mask != 0)
-               acc_track_mask |= SPTE_SPECIAL_MASK;
        BUG_ON(!dirty_mask != !accessed_mask);
        BUG_ON(!accessed_mask && !acc_track_mask);
+       BUG_ON(acc_track_mask & shadow_acc_track_value);
 
        shadow_user_mask = user_mask;
        shadow_accessed_mask = accessed_mask;
        shadow_x_mask = x_mask;
        shadow_present_mask = p_mask;
        shadow_acc_track_mask = acc_track_mask;
-       WARN_ON(shadow_accessed_mask != 0 && shadow_acc_track_mask != 0);
 }
 EXPORT_SYMBOL_GPL(kvm_mmu_set_mask_ptes);
 
            is_access_track_spte(spte))
                return true;
 
-       if (shadow_accessed_mask) {
+       if (spte_ad_enabled(spte)) {
                if ((spte & shadow_accessed_mask) == 0 ||
                    (is_writable_pte(spte) && (spte & shadow_dirty_mask) == 0))
                        return true;
 
 static bool is_accessed_spte(u64 spte)
 {
-       return shadow_accessed_mask ? spte & shadow_accessed_mask
-                                   : !is_access_track_spte(spte);
+       u64 accessed_mask = spte_shadow_accessed_mask(spte);
+
+       return accessed_mask ? spte & accessed_mask
+                            : !is_access_track_spte(spte);
 }
 
 static bool is_dirty_spte(u64 spte)
 {
-       return shadow_dirty_mask ? spte & shadow_dirty_mask
-                                : spte & PT_WRITABLE_MASK;
+       u64 dirty_mask = spte_shadow_dirty_mask(spte);
+
+       return dirty_mask ? spte & dirty_mask : spte & PT_WRITABLE_MASK;
 }
 
 /* Rules for using mmu_spte_set:
 
 static u64 mark_spte_for_access_track(u64 spte)
 {
-       if (shadow_accessed_mask != 0)
+       if (spte_ad_enabled(spte))
                return spte & ~shadow_accessed_mask;
 
-       if (shadow_acc_track_mask == 0 || is_access_track_spte(spte))
+       if (is_access_track_spte(spte))
                return spte;
 
        /*
        spte |= (spte & shadow_acc_track_saved_bits_mask) <<
                shadow_acc_track_saved_bits_shift;
        spte &= ~shadow_acc_track_mask;
-       spte |= shadow_acc_track_value;
 
        return spte;
 }
        u64 saved_bits = (spte >> shadow_acc_track_saved_bits_shift)
                         & shadow_acc_track_saved_bits_mask;
 
+       WARN_ON_ONCE(spte_ad_enabled(spte));
        WARN_ON_ONCE(!is_access_track_spte(spte));
 
        new_spte &= ~shadow_acc_track_mask;
        if (!is_accessed_spte(spte))
                return false;
 
-       if (shadow_accessed_mask) {
+       if (spte_ad_enabled(spte)) {
                clear_bit((ffs(shadow_accessed_mask) - 1),
                          (unsigned long *)sptep);
        } else {
        return mmu_spte_update(sptep, spte);
 }
 
+static bool wrprot_ad_disabled_spte(u64 *sptep)
+{
+       bool was_writable = test_and_clear_bit(PT_WRITABLE_SHIFT,
+                                              (unsigned long *)sptep);
+       if (was_writable)
+               kvm_set_pfn_dirty(spte_to_pfn(*sptep));
+
+       return was_writable;
+}
+
+/*
+ * Gets the GFN ready for another round of dirty logging by clearing the
+ *     - D bit on ad-enabled SPTEs, and
+ *     - W bit on ad-disabled SPTEs.
+ * Returns true iff any D or W bits were cleared.
+ */
 static bool __rmap_clear_dirty(struct kvm *kvm, struct kvm_rmap_head *rmap_head)
 {
        u64 *sptep;
        bool flush = false;
 
        for_each_rmap_spte(rmap_head, &iter, sptep)
-               flush |= spte_clear_dirty(sptep);
+               if (spte_ad_enabled(*sptep))
+                       flush |= spte_clear_dirty(sptep);
+               else
+                       flush |= wrprot_ad_disabled_spte(sptep);
 
        return flush;
 }
        bool flush = false;
 
        for_each_rmap_spte(rmap_head, &iter, sptep)
-               flush |= spte_set_dirty(sptep);
+               if (spte_ad_enabled(*sptep))
+                       flush |= spte_set_dirty(sptep);
 
        return flush;
 }
 }
 
 /**
- * kvm_mmu_clear_dirty_pt_masked - clear MMU D-bit for PT level pages
+ * kvm_mmu_clear_dirty_pt_masked - clear MMU D-bit for PT level pages, or write
+ * protect the page if the D-bit isn't supported.
  * @kvm: kvm instance
  * @slot: slot to clear D-bit
  * @gfn_offset: start of the BITS_PER_LONG pages we care about
        BUILD_BUG_ON(VMX_EPT_WRITABLE_MASK != PT_WRITABLE_MASK);
 
        spte = __pa(sp->spt) | shadow_present_mask | PT_WRITABLE_MASK |
-              shadow_user_mask | shadow_x_mask | shadow_accessed_mask;
+              shadow_user_mask | shadow_x_mask;
+
+       if (sp_ad_disabled(sp))
+               spte |= shadow_acc_track_value;
+       else
+               spte |= shadow_accessed_mask;
 
        mmu_spte_set(sptep, spte);
 
 {
        u64 spte = 0;
        int ret = 0;
+       struct kvm_mmu_page *sp;
 
        if (set_mmio_spte(vcpu, sptep, gfn, pfn, pte_access))
                return 0;
 
+       sp = page_header(__pa(sptep));
+       if (sp_ad_disabled(sp))
+               spte |= shadow_acc_track_value;
+
        /*
         * For the EPT case, shadow_present_mask is 0 if hardware
         * supports exec-only page table entries.  In that case,
         */
        spte |= shadow_present_mask;
        if (!speculative)
-               spte |= shadow_accessed_mask;
+               spte |= spte_shadow_accessed_mask(spte);
 
        if (pte_access & ACC_EXEC_MASK)
                spte |= shadow_x_mask;
 
        if (pte_access & ACC_WRITE_MASK) {
                kvm_vcpu_mark_page_dirty(vcpu, gfn);
-               spte |= shadow_dirty_mask;
+               spte |= spte_shadow_dirty_mask(spte);
        }
 
        if (speculative)
 {
        struct kvm_mmu_page *sp;
 
+       sp = page_header(__pa(sptep));
+
        /*
-        * Since it's no accessed bit on EPT, it's no way to
-        * distinguish between actually accessed translations
-        * and prefetched, so disable pte prefetch if EPT is
-        * enabled.
+        * Without accessed bits, there's no way to distinguish between
+        * actually accessed translations and prefetched, so disable pte
+        * prefetch if accessed bits aren't available.
         */
-       if (!shadow_accessed_mask)
+       if (sp_ad_disabled(sp))
                return;
 
-       sp = page_header(__pa(sptep));
        if (sp->role.level > PT_PAGE_TABLE_LEVEL)
                return;
 
 
        context->base_role.word = 0;
        context->base_role.smm = is_smm(vcpu);
+       context->base_role.ad_disabled = (shadow_accessed_mask == 0);
        context->page_fault = tdp_page_fault;
        context->sync_page = nonpaging_sync_page;
        context->invlpg = nonpaging_invlpg;
        mask.smep_andnot_wp = 1;
        mask.smap_andnot_wp = 1;
        mask.smm = 1;
+       mask.ad_disabled = 1;
 
        /*
         * If we don't have indirect shadow pages, it means no page is