struct vfio_dma *split;
        int ret;
 
+       if (!*size)
+               return 0;
+
        /*
         * Existing dma region is completely covered, unmap all.  This is
         * the likely case since userspace tends to map and unmap buffers
                        dma->vaddr += overlap;
                        dma->size -= overlap;
                        vfio_insert_dma(iommu, dma);
-               }
+               } else
+                       kfree(dma);
+
                *size = overlap;
                return 0;
        }
                if (ret)
                        return ret;
 
-               /*
-                * We may have unmapped the entire vfio_dma if the user is
-                * trying to unmap a sub-region of what was originally
-                * mapped.  If anything left, we can resize in place since
-                * iova is unchanged.
-                */
-               if (overlap < dma->size)
-                       dma->size -= overlap;
-               else
-                       vfio_remove_dma(iommu, dma);
-
+               dma->size -= overlap;
                *size = overlap;
                return 0;
        }
 
        /* Split existing */
+       split = kzalloc(sizeof(*split), GFP_KERNEL);
+       if (!split)
+               return -ENOMEM;
+
        offset = start - dma->iova;
 
        ret = vfio_unmap_unpin(iommu, dma, start, size);
        if (ret)
                return ret;
 
-       WARN_ON(!*size);
+       if (!*size) {
+               kfree(split);
+               return -EINVAL;
+       }
+
        tmp = dma->size;
 
-       /*
-        * Resize the lower vfio_dma in place, insert new for remaining
-        * upper segment.
-        */
+       /* Resize the lower vfio_dma in place, before the below insert */
        dma->size = offset;
 
-       if (offset + *size < tmp) {
-               split = kzalloc(sizeof(*split), GFP_KERNEL);
-               if (!split)
-                       return -ENOMEM;
-
+       /* Insert new for remainder, assuming it didn't all get unmapped */
+       if (likely(offset + *size < tmp)) {
                split->size = tmp - offset - *size;
                split->iova = dma->iova + offset + *size;
                split->vaddr = dma->vaddr + offset + *size;
                split->prot = dma->prot;
                vfio_insert_dma(iommu, split);
-       }
+       } else
+               kfree(split);
 
        return 0;
 }
 
        if (unmap->iova & mask)
                return -EINVAL;
-       if (unmap->size & mask)
+       if (!unmap->size || unmap->size & mask)
                return -EINVAL;
 
        WARN_ON(mask & PAGE_MASK);
        while ((dma = vfio_find_dma(iommu, unmap->iova, unmap->size))) {
                size = unmap->size;
                ret = vfio_remove_dma_overlap(iommu, unmap->iova, &size, dma);
-               if (ret)
+               if (ret || !size)
                        break;
                unmapped += size;
        }
                        if (tmp && tmp->prot == prot &&
                            tmp->vaddr + tmp->size == vaddr) {
                                tmp->size += size;
-
                                iova = tmp->iova;
                                size = tmp->size;
                                vaddr = tmp->vaddr;
                        }
                }
 
-               /* Check if we abut a region above - nothing above ~0 + 1 */
+               /*
+                * Check if we abut a region above - nothing above ~0 + 1.
+                * If we abut above and below, remove and free.  If only
+                * abut above, remove, modify, reinsert.
+                */
                if (likely(iova + size)) {
                        struct vfio_dma *tmp;
-
                        tmp = vfio_find_dma(iommu, iova + size, 1);
                        if (tmp && tmp->prot == prot &&
                            tmp->vaddr == vaddr + size) {
                                vfio_remove_dma(iommu, tmp);
-                               if (dma)
+                               if (dma) {
                                        dma->size += tmp->size;
-                               else
+                                       kfree(tmp);
+                               } else {
                                        size += tmp->size;
-                               kfree(tmp);
+                                       tmp->size = size;
+                                       tmp->iova = iova;
+                                       tmp->vaddr = vaddr;
+                                       vfio_insert_dma(iommu, tmp);
+                                       dma = tmp;
+                               }
                        }
                }
 
                iova = map->iova;
                size = map->size;
                while ((tmp = vfio_find_dma(iommu, iova, size))) {
-                       if (vfio_remove_dma_overlap(iommu, iova, &size, tmp)) {
-                               pr_warn("%s: Error rolling back failed map\n",
-                                       __func__);
+                       int r = vfio_remove_dma_overlap(iommu, iova,
+                                                       &size, tmp);
+                       if (WARN_ON(r || !size))
                                break;
-                       }
                }
        }
 
                struct vfio_dma *dma = rb_entry(node, struct vfio_dma, node);
                size_t size = dma->size;
                vfio_remove_dma_overlap(iommu, dma->iova, &size, dma);
+               if (WARN_ON(!size))
+                       break;
        }
 
        iommu_domain_free(iommu->domain);