_mmu->shadow_root_level, _start, _end)
 
 /*
- * Flush the TLB if the process should drop kvm->mmu_lock.
- * Return whether the caller still needs to flush the tlb.
+ * Flush the TLB and yield if the MMU lock is contended or this thread needs to
+ * return control to the scheduler.
+ *
+ * If this function yields, it will also reset the tdp_iter's walk over the
+ * paging structure and the calling function should allow the iterator to
+ * continue its traversal from the paging structure root.
+ *
+ * Return true if this function yielded, the TLBs were flushed, and the
+ * iterator's traversal was reset. Return false if a yield was not needed.
  */
 static bool tdp_mmu_iter_flush_cond_resched(struct kvm *kvm, struct tdp_iter *iter)
 {
                kvm_flush_remote_tlbs(kvm);
                cond_resched_lock(&kvm->mmu_lock);
                tdp_iter_refresh_walk(iter);
-               return false;
-       } else {
                return true;
        }
+
+       return false;
 }
 
-static void tdp_mmu_iter_cond_resched(struct kvm *kvm, struct tdp_iter *iter)
+/*
+ * Yield if the MMU lock is contended or this thread needs to return control
+ * to the scheduler.
+ *
+ * If this function yields, it will also reset the tdp_iter's walk over the
+ * paging structure and the calling function should allow the iterator to
+ * continue its traversal from the paging structure root.
+ *
+ * Return true if this function yielded and the iterator's traversal was reset.
+ * Return false if a yield was not needed.
+ */
+static bool tdp_mmu_iter_cond_resched(struct kvm *kvm, struct tdp_iter *iter)
 {
        if (need_resched() || spin_needbreak(&kvm->mmu_lock)) {
                cond_resched_lock(&kvm->mmu_lock);
                tdp_iter_refresh_walk(iter);
+               return true;
        }
+
+       return false;
 }
 
 /*
 
                tdp_mmu_set_spte(kvm, &iter, 0);
 
-               if (can_yield)
-                       flush_needed = tdp_mmu_iter_flush_cond_resched(kvm, &iter);
-               else
-                       flush_needed = true;
+               flush_needed = !can_yield ||
+                              !tdp_mmu_iter_flush_cond_resched(kvm, &iter);
        }
        return flush_needed;
 }
 
                tdp_mmu_set_spte(kvm, &iter, 0);
 
-               spte_set = tdp_mmu_iter_flush_cond_resched(kvm, &iter);
+               spte_set = !tdp_mmu_iter_flush_cond_resched(kvm, &iter);
        }
 
        if (spte_set)