tdp_mmu_free_sp(root);
 }
 
-static inline bool tdp_mmu_next_root_valid(struct kvm *kvm,
-                                          struct kvm_mmu_page *root)
+/*
+ * Finds the next valid root after root (or the first valid root if root
+ * is NULL), takes a reference on it, and returns that next root. If root
+ * is not NULL, this thread should have already taken a reference on it, and
+ * that reference will be dropped. If no valid root is found, this
+ * function will return NULL.
+ */
+static struct kvm_mmu_page *tdp_mmu_next_root(struct kvm *kvm,
+                                             struct kvm_mmu_page *prev_root)
 {
-       lockdep_assert_held_write(&kvm->mmu_lock);
+       struct kvm_mmu_page *next_root;
 
-       if (list_entry_is_head(root, &kvm->arch.tdp_mmu_roots, link))
-               return false;
+       lockdep_assert_held_write(&kvm->mmu_lock);
 
-       kvm_tdp_mmu_get_root(kvm, root);
-       return true;
+       if (prev_root)
+               next_root = list_next_entry(prev_root, link);
+       else
+               next_root = list_first_entry(&kvm->arch.tdp_mmu_roots,
+                                            typeof(*next_root), link);
 
-}
+       if (list_entry_is_head(next_root, &kvm->arch.tdp_mmu_roots, link))
+               next_root = NULL;
+       else
+               kvm_tdp_mmu_get_root(kvm, next_root);
 
-static inline struct kvm_mmu_page *tdp_mmu_next_root(struct kvm *kvm,
-                                                    struct kvm_mmu_page *root)
-{
-       struct kvm_mmu_page *next_root;
+       if (prev_root)
+               kvm_tdp_mmu_put_root(kvm, prev_root);
 
-       next_root = list_next_entry(root, link);
-       kvm_tdp_mmu_put_root(kvm, root);
        return next_root;
 }
 
  * recent root. (Unless keeping a live reference is desirable.)
  */
 #define for_each_tdp_mmu_root_yield_safe(_kvm, _root, _as_id)          \
-       for (_root = list_first_entry(&_kvm->arch.tdp_mmu_roots,        \
-                                     typeof(*_root), link);            \
-            tdp_mmu_next_root_valid(_kvm, _root);                      \
-            _root = tdp_mmu_next_root(_kvm, _root))                    \
-               if (kvm_mmu_page_as_id(_root) != _as_id) {              \
+       for (_root = tdp_mmu_next_root(_kvm, NULL);             \
+            _root;                                             \
+            _root = tdp_mmu_next_root(_kvm, _root))            \
+               if (kvm_mmu_page_as_id(_root) != _as_id) {      \
                } else
 
 #define for_each_tdp_mmu_root(_kvm, _root, _as_id)                     \