}
 
 /*
- * Zap leafs SPTEs for the range of gfns, [start, end). Returns true if SPTEs
- * have been cleared and a TLB flush is needed before releasing the MMU lock.
+ * Tears down the mappings for the range of gfns, [start, end), and frees the
+ * non-root pages mapping GFNs strictly within that range. Returns true if
+ * SPTEs have been cleared and a TLB flush is needed before releasing the
+ * MMU lock.
  *
  * If can_yield is true, will release the MMU lock and reschedule if the
  * scheduler needs the CPU or there is contention on the MMU lock. If this
  * the caller must ensure it does not supply too large a GFN range, or the
  * operation can cause a soft lockup.
  */
-static bool tdp_mmu_zap_leafs(struct kvm *kvm, struct kvm_mmu_page *root,
-                             gfn_t start, gfn_t end, bool can_yield, bool flush)
+static bool zap_gfn_range(struct kvm *kvm, struct kvm_mmu_page *root,
+                         gfn_t start, gfn_t end, bool can_yield, bool flush)
 {
+       bool zap_all = (start == 0 && end >= tdp_mmu_max_gfn_host());
        struct tdp_iter iter;
 
+       /*
+        * No need to try to step down in the iterator when zapping all SPTEs,
+        * zapping the top-level non-leaf SPTEs will recurse on their children.
+        */
+       int min_level = zap_all ? root->role.level : PG_LEVEL_4K;
+
        end = min(end, tdp_mmu_max_gfn_host());
 
        lockdep_assert_held_write(&kvm->mmu_lock);
 
        rcu_read_lock();
 
-       for_each_tdp_pte_min_level(iter, root, PG_LEVEL_4K, start, end) {
+       for_each_tdp_pte_min_level(iter, root, min_level, start, end) {
                if (can_yield &&
                    tdp_mmu_iter_cond_resched(kvm, &iter, flush, false)) {
                        flush = false;
                        continue;
                }
 
-               if (!is_shadow_present_pte(iter.old_spte) ||
+               if (!is_shadow_present_pte(iter.old_spte))
+                       continue;
+
+               /*
+                * If this is a non-last-level SPTE that covers a larger range
+                * than should be zapped, continue, and zap the mappings at a
+                * lower level, except when zapping all SPTEs.
+                */
+               if (!zap_all &&
+                   (iter.gfn < start ||
+                    iter.gfn + KVM_PAGES_PER_HPAGE(iter.level) > end) &&
                    !is_last_spte(iter.old_spte, iter.level))
                        continue;
 
  * SPTEs have been cleared and a TLB flush is needed before releasing the
  * MMU lock.
  */
-bool kvm_tdp_mmu_zap_leafs(struct kvm *kvm, int as_id, gfn_t start, gfn_t end,
-                          bool can_yield, bool flush)
+bool __kvm_tdp_mmu_zap_gfn_range(struct kvm *kvm, int as_id, gfn_t start,
+                                gfn_t end, bool can_yield, bool flush)
 {
        struct kvm_mmu_page *root;
 
        for_each_tdp_mmu_root_yield_safe(kvm, root, as_id)
-               flush = tdp_mmu_zap_leafs(kvm, root, start, end, can_yield, false);
+               flush = zap_gfn_range(kvm, root, start, end, can_yield, flush);
 
        return flush;
 }
 bool kvm_tdp_mmu_unmap_gfn_range(struct kvm *kvm, struct kvm_gfn_range *range,
                                 bool flush)
 {
-       return kvm_tdp_mmu_zap_leafs(kvm, range->slot->as_id, range->start,
-                                    range->end, range->may_block, flush);
+       return __kvm_tdp_mmu_zap_gfn_range(kvm, range->slot->as_id, range->start,
+                                          range->end, range->may_block, flush);
 }
 
 typedef bool (*tdp_handler_t)(struct kvm *kvm, struct tdp_iter *iter,
 
 void kvm_tdp_mmu_put_root(struct kvm *kvm, struct kvm_mmu_page *root,
                          bool shared);
 
-bool kvm_tdp_mmu_zap_leafs(struct kvm *kvm, int as_id, gfn_t start,
+bool __kvm_tdp_mmu_zap_gfn_range(struct kvm *kvm, int as_id, gfn_t start,
                                 gfn_t end, bool can_yield, bool flush);
+static inline bool kvm_tdp_mmu_zap_gfn_range(struct kvm *kvm, int as_id,
+                                            gfn_t start, gfn_t end, bool flush)
+{
+       return __kvm_tdp_mmu_zap_gfn_range(kvm, as_id, start, end, true, flush);
+}
+
 bool kvm_tdp_mmu_zap_sp(struct kvm *kvm, struct kvm_mmu_page *sp);
 void kvm_tdp_mmu_zap_all(struct kvm *kvm);
 void kvm_tdp_mmu_invalidate_all_roots(struct kvm *kvm);