write_unlock(&kvm->mmu_lock);
 
+       /*
+        * Zap the invalidated TDP MMU roots, all SPTEs must be dropped before
+        * returning to the caller, e.g. if the zap is in response to a memslot
+        * deletion, mmu_notifier callbacks will be unable to reach the SPTEs
+        * associated with the deleted memslot once the update completes, and
+        * Deferring the zap until the final reference to the root is put would
+        * lead to use-after-free.
+        */
        if (is_tdp_mmu_enabled(kvm)) {
                read_lock(&kvm->mmu_lock);
                kvm_tdp_mmu_zap_invalidated_roots(kvm);
 
 }
 
 /*
- * Since kvm_tdp_mmu_zap_all_fast has acquired a reference to each
- * invalidated root, they will not be freed until this function drops the
- * reference. Before dropping that reference, tear down the paging
- * structure so that whichever thread does drop the last reference
- * only has to do a trivial amount of work. Since the roots are invalid,
- * no new SPTEs should be created under them.
+ * Zap all invalidated roots to ensure all SPTEs are dropped before the "fast
+ * zap" completes.  Since kvm_tdp_mmu_invalidate_all_roots() has acquired a
+ * reference to each invalidated root, roots will not be freed until after this
+ * function drops the gifted reference, e.g. so that vCPUs don't get stuck with
+ * tearing down paging structures.
  */
 void kvm_tdp_mmu_zap_invalidated_roots(struct kvm *kvm)
 {
 }
 
 /*
- * Mark each TDP MMU root as invalid so that other threads
- * will drop their references and allow the root count to
- * go to 0.
+ * Mark each TDP MMU root as invalid to prevent vCPUs from reusing a root that
+ * is about to be zapped, e.g. in response to a memslots update.  The caller is
+ * responsible for invoking kvm_tdp_mmu_zap_invalidated_roots() to do the actual
+ * zapping.
  *
- * Also take a reference on all roots so that this thread
- * can do the bulk of the work required to free the roots
- * once they are invalidated. Without this reference, a
- * vCPU thread might drop the last reference to a root and
- * get stuck with tearing down the entire paging structure.
+ * Take a reference on all roots to prevent the root from being freed before it
+ * is zapped by this thread.  Freeing a root is not a correctness issue, but if
+ * a vCPU drops the last reference to a root prior to the root being zapped, it
+ * will get stuck with tearing down the entire paging structure.
  *
- * Roots which have a zero refcount should be skipped as
- * they're already being torn down.
- * Already invalid roots should be referenced again so that
- * they aren't freed before kvm_tdp_mmu_zap_all_fast is
- * done with them.
+ * Get a reference even if the root is already invalid,
+ * kvm_tdp_mmu_zap_invalidated_roots() assumes it was gifted a reference to all
+ * invalid roots, e.g. there's no epoch to identify roots that were invalidated
+ * by a previous call.  Roots stay on the list until the last reference is
+ * dropped, so even though all invalid roots are zapped, a root may not go away
+ * for quite some time, e.g. if a vCPU blocks across multiple memslot updates.
+ *
+ * Because mmu_lock is held for write, it should be impossible to observe a
+ * root with zero refcount, i.e. the list of roots cannot be stale.
  *
  * This has essentially the same effect for the TDP MMU
  * as updating mmu_valid_gen does for the shadow MMU.
        struct kvm_mmu_page *root;
 
        lockdep_assert_held_write(&kvm->mmu_lock);
-       list_for_each_entry(root, &kvm->arch.tdp_mmu_roots, link)
-               if (refcount_inc_not_zero(&root->tdp_mmu_root_count))
+       list_for_each_entry(root, &kvm->arch.tdp_mmu_roots, link) {
+               if (!WARN_ON_ONCE(!kvm_tdp_mmu_get_root(root)))
                        root->role.invalid = true;
+       }
 }
 
 /*