*/
        }
 
+       if (use_ptemod && map->notifier_init)
+               mmu_interval_notifier_remove(&map->notifier);
+
        if (map->notify.flags & UNMAP_NOTIFY_SEND_EVENT) {
                notify_remote_via_evtchn(map->notify.event);
                evtchn_put(map->notify.event);
 static int find_grant_ptes(pte_t *pte, unsigned long addr, void *data)
 {
        struct gntdev_grant_map *map = data;
-       unsigned int pgnr = (addr - map->vma->vm_start) >> PAGE_SHIFT;
+       unsigned int pgnr = (addr - map->pages_vm_start) >> PAGE_SHIFT;
        int flags = map->flags | GNTMAP_application_map | GNTMAP_contains_pte |
                    (1 << _GNTMAP_guest_avail0);
        u64 pte_maddr;
        struct gntdev_priv *priv = file->private_data;
 
        pr_debug("gntdev_vma_close %p\n", vma);
-       if (use_ptemod) {
-               WARN_ON(map->vma != vma);
-               mmu_interval_notifier_remove(&map->notifier);
-               map->vma = NULL;
-       }
+
        vma->vm_private_data = NULL;
        gntdev_put_map(priv, map);
 }
        struct gntdev_grant_map *map =
                container_of(mn, struct gntdev_grant_map, notifier);
        unsigned long mstart, mend;
+       unsigned long map_start, map_end;
 
        if (!mmu_notifier_range_blockable(range))
                return false;
 
+       map_start = map->pages_vm_start;
+       map_end = map->pages_vm_start + (map->count << PAGE_SHIFT);
+
        /*
         * If the VMA is split or otherwise changed the notifier is not
         * updated, but we don't want to process VA's outside the modified
         * VMA. FIXME: It would be much more understandable to just prevent
         * modifying the VMA in the first place.
         */
-       if (map->vma->vm_start >= range->end ||
-           map->vma->vm_end <= range->start)
+       if (map_start >= range->end || map_end <= range->start)
                return true;
 
-       mstart = max(range->start, map->vma->vm_start);
-       mend = min(range->end, map->vma->vm_end);
+       mstart = max(range->start, map_start);
+       mend = min(range->end, map_end);
        pr_debug("map %d+%d (%lx %lx), range %lx %lx, mrange %lx %lx\n",
-                       map->index, map->count,
-                       map->vma->vm_start, map->vma->vm_end,
-                       range->start, range->end, mstart, mend);
-       unmap_grant_pages(map,
-                               (mstart - map->vma->vm_start) >> PAGE_SHIFT,
-                               (mend - mstart) >> PAGE_SHIFT);
+                map->index, map->count, map_start, map_end,
+                range->start, range->end, mstart, mend);
+       unmap_grant_pages(map, (mstart - map_start) >> PAGE_SHIFT,
+                         (mend - mstart) >> PAGE_SHIFT);
 
        return true;
 }
                return -EINVAL;
 
        pr_debug("map %d+%d at %lx (pgoff %lx)\n",
-                       index, count, vma->vm_start, vma->vm_pgoff);
+                index, count, vma->vm_start, vma->vm_pgoff);
 
        mutex_lock(&priv->lock);
        map = gntdev_find_map_index(priv, index, count);
        if (!map)
                goto unlock_out;
-       if (use_ptemod && map->vma)
-               goto unlock_out;
-       if (atomic_read(&map->live_grants)) {
-               err = -EAGAIN;
+       if (!atomic_add_unless(&map->in_use, 1, 1))
                goto unlock_out;
-       }
+
        refcount_inc(&map->users);
 
        vma->vm_ops = &gntdev_vmops;
                        map->flags |= GNTMAP_readonly;
        }
 
+       map->pages_vm_start = vma->vm_start;
+
        if (use_ptemod) {
-               map->vma = vma;
                err = mmu_interval_notifier_insert_locked(
                        &map->notifier, vma->vm_mm, vma->vm_start,
                        vma->vm_end - vma->vm_start, &gntdev_mmu_ops);
-               if (err) {
-                       map->vma = NULL;
+               if (err)
                        goto out_unlock_put;
-               }
+
+               map->notifier_init = true;
        }
        mutex_unlock(&priv->lock);
 
                 */
                mmu_interval_read_begin(&map->notifier);
 
-               map->pages_vm_start = vma->vm_start;
                err = apply_to_page_range(vma->vm_mm, vma->vm_start,
                                          vma->vm_end - vma->vm_start,
                                          find_grant_ptes, map);
 out_unlock_put:
        mutex_unlock(&priv->lock);
 out_put_map:
-       if (use_ptemod) {
+       if (use_ptemod)
                unmap_grant_pages(map, 0, map->count);
-               if (map->vma) {
-                       mmu_interval_notifier_remove(&map->notifier);
-                       map->vma = NULL;
-               }
-       }
        gntdev_put_map(priv, map);
        return err;
 }