#include <linux/export.h>
 #include <linux/vmalloc.h>
 #include <linux/hugetlb.h>
-#include <linux/interval_tree_generic.h>
+#include <linux/interval_tree.h>
 #include <linux/pagemap.h>
 
 #include <rdma/ib_verbs.h>
 #include <rdma/ib_umem.h>
 #include <rdma/ib_umem_odp.h>
 
-/*
- * The ib_umem list keeps track of memory regions for which the HW
- * device request to receive notification when the related memory
- * mapping is changed.
- *
- * ib_umem_lock protects the list.
- */
-
-static u64 node_start(struct umem_odp_node *n)
-{
-       struct ib_umem_odp *umem_odp =
-                       container_of(n, struct ib_umem_odp, interval_tree);
-
-       return ib_umem_start(umem_odp);
-}
-
-/* Note that the representation of the intervals in the interval tree
- * considers the ending point as contained in the interval, while the
- * function ib_umem_end returns the first address which is not contained
- * in the umem.
- */
-static u64 node_last(struct umem_odp_node *n)
-{
-       struct ib_umem_odp *umem_odp =
-                       container_of(n, struct ib_umem_odp, interval_tree);
-
-       return ib_umem_end(umem_odp) - 1;
-}
-
-INTERVAL_TREE_DEFINE(struct umem_odp_node, rb, u64, __subtree_last,
-                    node_start, node_last, static, rbt_ib_umem)
-
 static void ib_umem_notifier_start_account(struct ib_umem_odp *umem_odp)
 {
        mutex_lock(&umem_odp->umem_mutex);
        struct ib_ucontext_per_mm *per_mm = umem_odp->per_mm;
 
        down_write(&per_mm->umem_rwsem);
-       if (likely(ib_umem_start(umem_odp) != ib_umem_end(umem_odp)))
-               rbt_ib_umem_insert(&umem_odp->interval_tree,
-                                  &per_mm->umem_tree);
+       if (likely(ib_umem_start(umem_odp) != ib_umem_end(umem_odp))) {
+               /*
+                * Note that the representation of the intervals in the
+                * interval tree considers the ending point as contained in
+                * the interval, while the function ib_umem_end returns the
+                * first address which is not contained in the umem.
+                */
+               umem_odp->interval_tree.start = ib_umem_start(umem_odp);
+               umem_odp->interval_tree.last = ib_umem_end(umem_odp) - 1;
+               interval_tree_insert(&umem_odp->interval_tree,
+                                    &per_mm->umem_tree);
+       }
        up_write(&per_mm->umem_rwsem);
 }
 
 
        down_write(&per_mm->umem_rwsem);
        if (likely(ib_umem_start(umem_odp) != ib_umem_end(umem_odp)))
-               rbt_ib_umem_remove(&umem_odp->interval_tree,
-                                  &per_mm->umem_tree);
+               interval_tree_remove(&umem_odp->interval_tree,
+                                    &per_mm->umem_tree);
        complete_all(&umem_odp->notifier_completion);
 
        up_write(&per_mm->umem_rwsem);
                                  void *cookie)
 {
        int ret_val = 0;
-       struct umem_odp_node *node, *next;
+       struct interval_tree_node *node, *next;
        struct ib_umem_odp *umem;
 
        if (unlikely(start == last))
                return ret_val;
 
-       for (node = rbt_ib_umem_iter_first(root, start, last - 1);
+       for (node = interval_tree_iter_first(root, start, last - 1);
                        node; node = next) {
                /* TODO move the blockable decision up to the callback */
                if (!blockable)
                        return -EAGAIN;
-               next = rbt_ib_umem_iter_next(node, start, last - 1);
+               next = interval_tree_iter_next(node, start, last - 1);
                umem = container_of(node, struct ib_umem_odp, interval_tree);
                ret_val = cb(umem, start, last, cookie) || ret_val;
        }
        return ret_val;
 }
 EXPORT_SYMBOL(rbt_ib_umem_for_each_in_range);
-
-struct ib_umem_odp *rbt_ib_umem_lookup(struct rb_root_cached *root,
-                                      u64 addr, u64 length)
-{
-       struct umem_odp_node *node;
-
-       node = rbt_ib_umem_iter_first(root, addr, addr + length - 1);
-       if (node)
-               return container_of(node, struct ib_umem_odp, interval_tree);
-       return NULL;
-
-}
-EXPORT_SYMBOL(rbt_ib_umem_lookup);
 
 #include <rdma/ib_verbs.h>
 #include <linux/interval_tree.h>
 
-struct umem_odp_node {
-       u64 __subtree_last;
-       struct rb_node rb;
-};
-
 struct ib_umem_odp {
        struct ib_umem umem;
        struct ib_ucontext_per_mm *per_mm;
        int npages;
 
        /* Tree tracking */
-       struct umem_odp_node    interval_tree;
+       struct interval_tree_node interval_tree;
 
        struct completion       notifier_completion;
        int                     dying;
  * Find first region intersecting with address range.
  * Return NULL if not found
  */
-struct ib_umem_odp *rbt_ib_umem_lookup(struct rb_root_cached *root,
-                                      u64 addr, u64 length);
+static inline struct ib_umem_odp *
+rbt_ib_umem_lookup(struct rb_root_cached *root, u64 addr, u64 length)
+{
+       struct interval_tree_node *node;
+
+       node = interval_tree_iter_first(root, addr, addr + length - 1);
+       if (!node)
+               return NULL;
+       return container_of(node, struct ib_umem_odp, interval_tree);
+
+}
 
 static inline int ib_umem_mmu_notifier_retry(struct ib_umem_odp *umem_odp,
                                             unsigned long mmu_seq)