#include <linux/module.h>
 #include <linux/mm.h>
 #include <linux/pci.h>         /* pci_bus_type */
+#include <linux/rbtree.h>
 #include <linux/sched.h>
 #include <linux/slab.h>
 #include <linux/uaccess.h>
 struct vfio_iommu {
        struct iommu_domain     *domain;
        struct mutex            lock;
-       struct list_head        dma_list;
+       struct rb_root          dma_list;
        struct list_head        group_list;
        bool                    cache;
 };
 
 struct vfio_dma {
-       struct list_head        next;
+       struct rb_node          node;
        dma_addr_t              iova;           /* Device address */
        unsigned long           vaddr;          /* Process virtual addr */
        long                    npage;          /* Number of pages */
 
 #define NPAGE_TO_SIZE(npage)   ((size_t)(npage) << PAGE_SHIFT)
 
+static struct vfio_dma *vfio_find_dma(struct vfio_iommu *iommu,
+                                     dma_addr_t start, size_t size)
+{
+       struct rb_node *node = iommu->dma_list.rb_node;
+
+       while (node) {
+               struct vfio_dma *dma = rb_entry(node, struct vfio_dma, node);
+
+               if (start + size <= dma->iova)
+                       node = node->rb_left;
+               else if (start >= dma->iova + NPAGE_TO_SIZE(dma->npage))
+                       node = node->rb_right;
+               else
+                       return dma;
+       }
+
+       return NULL;
+}
+
+static void vfio_insert_dma(struct vfio_iommu *iommu, struct vfio_dma *new)
+{
+       struct rb_node **link = &iommu->dma_list.rb_node, *parent = NULL;
+       struct vfio_dma *dma;
+
+       while (*link) {
+               parent = *link;
+               dma = rb_entry(parent, struct vfio_dma, node);
+
+               if (new->iova + NPAGE_TO_SIZE(new->npage) <= dma->iova)
+                       link = &(*link)->rb_left;
+               else
+                       link = &(*link)->rb_right;
+       }
+
+       rb_link_node(&new->node, parent, link);
+       rb_insert_color(&new->node, &iommu->dma_list);
+}
+
+static void vfio_remove_dma(struct vfio_iommu *iommu, struct vfio_dma *old)
+{
+       rb_erase(&old->node, &iommu->dma_list);
+}
+
 struct vwork {
        struct mm_struct        *mm;
        long                    npage;
        return 0;
 }
 
-static inline bool ranges_overlap(dma_addr_t start1, size_t size1,
-                                 dma_addr_t start2, size_t size2)
-{
-       if (start1 < start2)
-               return (start2 - start1 < size1);
-       else if (start2 < start1)
-               return (start1 - start2 < size2);
-       return (size1 > 0 && size2 > 0);
-}
-
-static struct vfio_dma *vfio_find_dma(struct vfio_iommu *iommu,
-                                               dma_addr_t start, size_t size)
-{
-       struct vfio_dma *dma;
-
-       list_for_each_entry(dma, &iommu->dma_list, next) {
-               if (ranges_overlap(dma->iova, NPAGE_TO_SIZE(dma->npage),
-                                  start, size))
-                       return dma;
-       }
-       return NULL;
-}
-
-static long vfio_remove_dma_overlap(struct vfio_iommu *iommu, dma_addr_t start,
-                                   size_t size, struct vfio_dma *dma)
+static int vfio_remove_dma_overlap(struct vfio_iommu *iommu, dma_addr_t start,
+                                  size_t size, struct vfio_dma *dma)
 {
        struct vfio_dma *split;
        long npage_lo, npage_hi;
        if (start <= dma->iova &&
            start + size >= dma->iova + NPAGE_TO_SIZE(dma->npage)) {
                vfio_dma_unmap(iommu, dma->iova, dma->npage, dma->prot);
-               list_del(&dma->next);
-               npage_lo = dma->npage;
+               vfio_remove_dma(iommu, dma);
                kfree(dma);
-               return npage_lo;
+               return 0;
        }
 
        /* Overlap low address of existing range */
                dma->iova += overlap;
                dma->vaddr += overlap;
                dma->npage -= npage_lo;
-               return npage_lo;
+               return 0;
        }
 
        /* Overlap high address of existing range */
 
                vfio_dma_unmap(iommu, start, npage_hi, dma->prot);
                dma->npage -= npage_hi;
-               return npage_hi;
+               return 0;
        }
 
        /* Split existing */
        split->iova = start + size;
        split->vaddr = dma->vaddr + NPAGE_TO_SIZE(npage_lo) + size;
        split->prot = dma->prot;
-       list_add(&split->next, &iommu->dma_list);
-       return size >> PAGE_SHIFT;
+       vfio_insert_dma(iommu, split);
+       return 0;
 }
 
 static int vfio_dma_do_unmap(struct vfio_iommu *iommu,
                             struct vfio_iommu_type1_dma_unmap *unmap)
 {
-       long ret = 0, npage = unmap->size >> PAGE_SHIFT;
-       struct vfio_dma *dma, *tmp;
        uint64_t mask;
+       struct vfio_dma *dma;
+       int ret = 0;
 
        mask = ((uint64_t)1 << __ffs(iommu->domain->ops->pgsize_bitmap)) - 1;
 
 
        mutex_lock(&iommu->lock);
 
-       list_for_each_entry_safe(dma, tmp, &iommu->dma_list, next) {
-               if (ranges_overlap(dma->iova, NPAGE_TO_SIZE(dma->npage),
-                                  unmap->iova, unmap->size)) {
-                       ret = vfio_remove_dma_overlap(iommu, unmap->iova,
-                                                     unmap->size, dma);
-                       if (ret > 0)
-                               npage -= ret;
-                       if (ret < 0 || npage == 0)
-                               break;
-               }
-       }
+       while (!ret && (dma = vfio_find_dma(iommu,
+                                           unmap->iova, unmap->size)))
+               ret = vfio_remove_dma_overlap(iommu, unmap->iova,
+                                             unmap->size, dma);
+
        mutex_unlock(&iommu->lock);
-       return ret > 0 ? 0 : (int)ret;
+       return ret;
 }
 
 static int vfio_dma_do_map(struct vfio_iommu *iommu,
                           struct vfio_iommu_type1_dma_map *map)
 {
-       struct vfio_dma *dma, *pdma = NULL;
+       struct vfio_dma *dma;
        dma_addr_t iova = map->iova;
        unsigned long locked, lock_limit, vaddr = map->vaddr;
        size_t size = map->size;
        if (!npage)
                return -EINVAL;
 
+       dma = kzalloc(sizeof *dma, GFP_KERNEL);
+       if (!dma)
+               return -ENOMEM;
+
        mutex_lock(&iommu->lock);
 
        if (vfio_find_dma(iommu, iova, size)) {
        if (ret)
                goto out_lock;
 
+       dma->npage = npage;
+       dma->iova = iova;
+       dma->vaddr = vaddr;
+       dma->prot = prot;
+
        /* Check if we abut a region below - nothing below 0 */
        if (iova) {
-               dma = vfio_find_dma(iommu, iova - 1, 1);
-               if (dma && dma->prot == prot &&
-                   dma->vaddr + NPAGE_TO_SIZE(dma->npage) == vaddr) {
-
-                       dma->npage += npage;
-                       iova = dma->iova;
-                       vaddr = dma->vaddr;
+               struct vfio_dma *tmp = vfio_find_dma(iommu, iova - 1, 1);
+               if (tmp && tmp->prot == prot &&
+                   tmp->vaddr + NPAGE_TO_SIZE(tmp->npage) == vaddr) {
+                       vfio_remove_dma(iommu, tmp);
+                       dma->npage += tmp->npage;
+                       dma->iova = iova = tmp->iova;
+                       dma->vaddr = vaddr = tmp->vaddr;
+                       kfree(tmp);
                        npage = dma->npage;
                        size = NPAGE_TO_SIZE(npage);
-
-                       pdma = dma;
                }
        }
 
        /* Check if we abut a region above - nothing above ~0 + 1 */
        if (iova + size) {
-               dma = vfio_find_dma(iommu, iova + size, 1);
-               if (dma && dma->prot == prot &&
-                   dma->vaddr == vaddr + size) {
-
-                       dma->npage += npage;
-                       dma->iova = iova;
-                       dma->vaddr = vaddr;
-
-                       /*
-                        * If merged above and below, remove previously
-                        * merged entry.  New entry covers it.
-                        */
-                       if (pdma) {
-                               list_del(&pdma->next);
-                               kfree(pdma);
-                       }
-                       pdma = dma;
+               struct vfio_dma *tmp = vfio_find_dma(iommu, iova + size, 1);
+               if (tmp && tmp->prot == prot &&
+                   tmp->vaddr == vaddr + size) {
+                       vfio_remove_dma(iommu, tmp);
+                       dma->npage += tmp->npage;
+                       kfree(tmp);
+                       npage = dma->npage;
+                       size = NPAGE_TO_SIZE(npage);
                }
        }
 
-       /* Isolated, new region */
-       if (!pdma) {
-               dma = kzalloc(sizeof *dma, GFP_KERNEL);
-               if (!dma) {
-                       ret = -ENOMEM;
-                       vfio_dma_unmap(iommu, iova, npage, prot);
-                       goto out_lock;
-               }
-
-               dma->npage = npage;
-               dma->iova = iova;
-               dma->vaddr = vaddr;
-               dma->prot = prot;
-               list_add(&dma->next, &iommu->dma_list);
-       }
+       vfio_insert_dma(iommu, dma);
 
 out_lock:
        mutex_unlock(&iommu->lock);
+       if (ret)
+               kfree(dma);
        return ret;
 }
 
                return ERR_PTR(-ENOMEM);
 
        INIT_LIST_HEAD(&iommu->group_list);
-       INIT_LIST_HEAD(&iommu->dma_list);
+       iommu->dma_list = RB_ROOT;
        mutex_init(&iommu->lock);
 
        /*
 {
        struct vfio_iommu *iommu = iommu_data;
        struct vfio_group *group, *group_tmp;
-       struct vfio_dma *dma, *dma_tmp;
+       struct rb_node *node;
 
        list_for_each_entry_safe(group, group_tmp, &iommu->group_list, next) {
                iommu_detach_group(iommu->domain, group->iommu_group);
                kfree(group);
        }
 
-       list_for_each_entry_safe(dma, dma_tmp, &iommu->dma_list, next) {
+       while ((node = rb_first(&iommu->dma_list))) {
+               struct vfio_dma *dma = rb_entry(node, struct vfio_dma, node);
                vfio_dma_unmap(iommu, dma->iova, dma->npage, dma->prot);
-               list_del(&dma->next);
+               vfio_remove_dma(iommu, dma);
                kfree(dma);
        }