*/
 
 #include <linux/init.h>
+#include <linux/interval_tree_generic.h>
 #include <linux/kmemleak.h>
 #include <linux/kvm_host.h>
 #include <asm/kvm_mmu.h>
 }
 device_initcall_sync(finalize_pkvm);
 
-static int cmp_mappings(struct rb_node *node, const struct rb_node *parent)
+static u64 __pkvm_mapping_start(struct pkvm_mapping *m)
 {
-       struct pkvm_mapping *a = rb_entry(node, struct pkvm_mapping, node);
-       struct pkvm_mapping *b = rb_entry(parent, struct pkvm_mapping, node);
-
-       if (a->gfn < b->gfn)
-               return -1;
-       if (a->gfn > b->gfn)
-               return 1;
-       return 0;
+       return m->gfn * PAGE_SIZE;
 }
 
-static struct rb_node *find_first_mapping_node(struct rb_root *root, u64 gfn)
+static u64 __pkvm_mapping_end(struct pkvm_mapping *m)
 {
-       struct rb_node *node = root->rb_node, *prev = NULL;
-       struct pkvm_mapping *mapping;
-
-       while (node) {
-               mapping = rb_entry(node, struct pkvm_mapping, node);
-               if (mapping->gfn == gfn)
-                       return node;
-               prev = node;
-               node = (gfn < mapping->gfn) ? node->rb_left : node->rb_right;
-       }
-
-       return prev;
+       return (m->gfn + 1) * PAGE_SIZE - 1;
 }
 
+INTERVAL_TREE_DEFINE(struct pkvm_mapping, node, u64, __subtree_last,
+                    __pkvm_mapping_start, __pkvm_mapping_end, static,
+                    pkvm_mapping);
+
 /*
- * __tmp is updated to rb_next(__tmp) *before* entering the body of the loop to allow freeing
- * of __map inline.
+ * __tmp is updated to iter_first(pkvm_mappings) *before* entering the body of the loop to allow
+ * freeing of __map inline.
  */
 #define for_each_mapping_in_range_safe(__pgt, __start, __end, __map)                           \
-       for (struct rb_node *__tmp = find_first_mapping_node(&(__pgt)->pkvm_mappings,           \
-                                                            ((__start) >> PAGE_SHIFT));        \
+       for (struct pkvm_mapping *__tmp = pkvm_mapping_iter_first(&(__pgt)->pkvm_mappings,      \
+                                                                 __start, __end - 1);          \
             __tmp && ({                                                                        \
-                               __map = rb_entry(__tmp, struct pkvm_mapping, node);             \
-                               __tmp = rb_next(__tmp);                                         \
+                               __map = __tmp;                                                  \
+                               __tmp = pkvm_mapping_iter_next(__map, __start, __end - 1);      \
                                true;                                                           \
                       });                                                                      \
-           )                                                                                   \
-               if (__map->gfn < ((__start) >> PAGE_SHIFT))                                     \
-                       continue;                                                               \
-               else if (__map->gfn >= ((__end) >> PAGE_SHIFT))                                 \
-                       break;                                                                  \
-               else
+           )
 
 int pkvm_pgtable_stage2_init(struct kvm_pgtable *pgt, struct kvm_s2_mmu *mmu,
                             struct kvm_pgtable_mm_ops *mm_ops)
 {
-       pgt->pkvm_mappings      = RB_ROOT;
+       pgt->pkvm_mappings      = RB_ROOT_CACHED;
        pgt->mmu                = mmu;
 
        return 0;
 }
 
-void pkvm_pgtable_stage2_destroy(struct kvm_pgtable *pgt)
+static int __pkvm_pgtable_stage2_unmap(struct kvm_pgtable *pgt, u64 start, u64 end)
 {
        struct kvm *kvm = kvm_s2_mmu_to_kvm(pgt->mmu);
        pkvm_handle_t handle = kvm->arch.pkvm.handle;
        struct pkvm_mapping *mapping;
-       struct rb_node *node;
+       int ret;
 
        if (!handle)
-               return;
+               return 0;
 
-       node = rb_first(&pgt->pkvm_mappings);
-       while (node) {
-               mapping = rb_entry(node, struct pkvm_mapping, node);
-               kvm_call_hyp_nvhe(__pkvm_host_unshare_guest, handle, mapping->gfn);
-               node = rb_next(node);
-               rb_erase(&mapping->node, &pgt->pkvm_mappings);
+       for_each_mapping_in_range_safe(pgt, start, end, mapping) {
+               ret = kvm_call_hyp_nvhe(__pkvm_host_unshare_guest, handle, mapping->gfn, 1);
+               if (WARN_ON(ret))
+                       return ret;
+               pkvm_mapping_remove(mapping, &pgt->pkvm_mappings);
                kfree(mapping);
        }
+
+       return 0;
+}
+
+void pkvm_pgtable_stage2_destroy(struct kvm_pgtable *pgt)
+{
+       __pkvm_pgtable_stage2_unmap(pgt, 0, ~(0ULL));
 }
 
 int pkvm_pgtable_stage2_map(struct kvm_pgtable *pgt, u64 addr, u64 size,
        swap(mapping, cache->mapping);
        mapping->gfn = gfn;
        mapping->pfn = pfn;
-       WARN_ON(rb_find_add(&mapping->node, &pgt->pkvm_mappings, cmp_mappings));
+       pkvm_mapping_insert(mapping, &pgt->pkvm_mappings);
 
        return ret;
 }
 
 int pkvm_pgtable_stage2_unmap(struct kvm_pgtable *pgt, u64 addr, u64 size)
 {
-       struct kvm *kvm = kvm_s2_mmu_to_kvm(pgt->mmu);
-       pkvm_handle_t handle = kvm->arch.pkvm.handle;
-       struct pkvm_mapping *mapping;
-       int ret = 0;
-
-       lockdep_assert_held_write(&kvm->mmu_lock);
-       for_each_mapping_in_range_safe(pgt, addr, addr + size, mapping) {
-               ret = kvm_call_hyp_nvhe(__pkvm_host_unshare_guest, handle, mapping->gfn, 1);
-               if (WARN_ON(ret))
-                       break;
-               rb_erase(&mapping->node, &pgt->pkvm_mappings);
-               kfree(mapping);
-       }
+       lockdep_assert_held_write(&kvm_s2_mmu_to_kvm(pgt->mmu)->mmu_lock);
 
-       return ret;
+       return __pkvm_pgtable_stage2_unmap(pgt, addr, addr + size);
 }
 
 int pkvm_pgtable_stage2_wrprotect(struct kvm_pgtable *pgt, u64 addr, u64 size)