}
 
 /*
- * Finds the next valid root after root (or the first valid root if root
- * is NULL), takes a reference on it, and returns that next root. If root
- * is not NULL, this thread should have already taken a reference on it, and
- * that reference will be dropped. If no valid root is found, this
- * function will return NULL.
+ * Returns the next root after @prev_root (or the first root if @prev_root is
+ * NULL).  A reference to the returned root is acquired, and the reference to
+ * @prev_root is released (the caller obviously must hold a reference to
+ * @prev_root if it's non-NULL).
+ *
+ * If @only_valid is true, invalid roots are skipped.
+ *
+ * Returns NULL if the end of tdp_mmu_roots was reached.
  */
 static struct kvm_mmu_page *tdp_mmu_next_root(struct kvm *kvm,
                                              struct kvm_mmu_page *prev_root,
-                                             bool shared)
+                                             bool shared, bool only_valid)
 {
        struct kvm_mmu_page *next_root;
 
                                                   typeof(*next_root), link);
 
        while (next_root) {
-               if (!next_root->role.invalid &&
+               if ((!only_valid || !next_root->role.invalid) &&
                    kvm_tdp_mmu_get_root(kvm, next_root))
                        break;
 
  * mode. In the unlikely event that this thread must free a root, the lock
  * will be temporarily dropped and reacquired in write mode.
  */
-#define for_each_tdp_mmu_root_yield_safe(_kvm, _root, _as_id, _shared) \
-       for (_root = tdp_mmu_next_root(_kvm, NULL, _shared);            \
-            _root;                                                     \
-            _root = tdp_mmu_next_root(_kvm, _root, _shared))           \
-               if (kvm_mmu_page_as_id(_root) != _as_id) {              \
+#define __for_each_tdp_mmu_root_yield_safe(_kvm, _root, _as_id, _shared, _only_valid)\
+       for (_root = tdp_mmu_next_root(_kvm, NULL, _shared, _only_valid);       \
+            _root;                                                             \
+            _root = tdp_mmu_next_root(_kvm, _root, _shared, _only_valid))      \
+               if (kvm_mmu_page_as_id(_root) != _as_id) {                      \
                } else
 
+#define for_each_valid_tdp_mmu_root_yield_safe(_kvm, _root, _as_id, _shared)   \
+       __for_each_tdp_mmu_root_yield_safe(_kvm, _root, _as_id, _shared, true)
+
+#define for_each_tdp_mmu_root_yield_safe(_kvm, _root, _as_id, _shared)         \
+       __for_each_tdp_mmu_root_yield_safe(_kvm, _root, _as_id, _shared, false)
+
 #define for_each_tdp_mmu_root(_kvm, _root, _as_id)                             \
        list_for_each_entry_rcu(_root, &_kvm->arch.tdp_mmu_roots, link,         \
                                lockdep_is_held_type(&kvm->mmu_lock, 0) ||      \
 
        lockdep_assert_held_read(&kvm->mmu_lock);
 
-       for_each_tdp_mmu_root_yield_safe(kvm, root, slot->as_id, true)
+       for_each_valid_tdp_mmu_root_yield_safe(kvm, root, slot->as_id, true)
                spte_set |= wrprot_gfn_range(kvm, root, slot->base_gfn,
                             slot->base_gfn + slot->npages, min_level);
 
 
        lockdep_assert_held_read(&kvm->mmu_lock);
 
-       for_each_tdp_mmu_root_yield_safe(kvm, root, slot->as_id, true)
+       for_each_valid_tdp_mmu_root_yield_safe(kvm, root, slot->as_id, true)
                spte_set |= clear_dirty_gfn_range(kvm, root, slot->base_gfn,
                                slot->base_gfn + slot->npages);
 
 
        lockdep_assert_held_read(&kvm->mmu_lock);
 
-       for_each_tdp_mmu_root_yield_safe(kvm, root, slot->as_id, true)
+       for_each_valid_tdp_mmu_root_yield_safe(kvm, root, slot->as_id, true)
                zap_collapsible_spte_range(kvm, root, slot);
 }