__vhost_vq_meta_reset(d->vqs[i]);
 }
 
+#if VHOST_ARCH_CAN_ACCEL_UACCESS
+static void vhost_map_unprefetch(struct vhost_map *map)
+{
+       kfree(map->pages);
+       map->pages = NULL;
+       map->npages = 0;
+       map->addr = NULL;
+}
+
+static void vhost_uninit_vq_maps(struct vhost_virtqueue *vq)
+{
+       struct vhost_map *map[VHOST_NUM_ADDRS];
+       int i;
+
+       spin_lock(&vq->mmu_lock);
+       for (i = 0; i < VHOST_NUM_ADDRS; i++) {
+               map[i] = rcu_dereference_protected(vq->maps[i],
+                                 lockdep_is_held(&vq->mmu_lock));
+               if (map[i])
+                       rcu_assign_pointer(vq->maps[i], NULL);
+       }
+       spin_unlock(&vq->mmu_lock);
+
+       synchronize_rcu();
+
+       for (i = 0; i < VHOST_NUM_ADDRS; i++)
+               if (map[i])
+                       vhost_map_unprefetch(map[i]);
+
+}
+
+static void vhost_reset_vq_maps(struct vhost_virtqueue *vq)
+{
+       int i;
+
+       vhost_uninit_vq_maps(vq);
+       for (i = 0; i < VHOST_NUM_ADDRS; i++)
+               vq->uaddrs[i].size = 0;
+}
+
+static bool vhost_map_range_overlap(struct vhost_uaddr *uaddr,
+                                    unsigned long start,
+                                    unsigned long end)
+{
+       if (unlikely(!uaddr->size))
+               return false;
+
+       return !(end < uaddr->uaddr || start > uaddr->uaddr - 1 + uaddr->size);
+}
+
+static void vhost_invalidate_vq_start(struct vhost_virtqueue *vq,
+                                     int index,
+                                     unsigned long start,
+                                     unsigned long end)
+{
+       struct vhost_uaddr *uaddr = &vq->uaddrs[index];
+       struct vhost_map *map;
+       int i;
+
+       if (!vhost_map_range_overlap(uaddr, start, end))
+               return;
+
+       spin_lock(&vq->mmu_lock);
+       ++vq->invalidate_count;
+
+       map = rcu_dereference_protected(vq->maps[index],
+                                       lockdep_is_held(&vq->mmu_lock));
+       if (map) {
+               if (uaddr->write) {
+                       for (i = 0; i < map->npages; i++)
+                               set_page_dirty(map->pages[i]);
+               }
+               rcu_assign_pointer(vq->maps[index], NULL);
+       }
+       spin_unlock(&vq->mmu_lock);
+
+       if (map) {
+               synchronize_rcu();
+               vhost_map_unprefetch(map);
+       }
+}
+
+static void vhost_invalidate_vq_end(struct vhost_virtqueue *vq,
+                                   int index,
+                                   unsigned long start,
+                                   unsigned long end)
+{
+       if (!vhost_map_range_overlap(&vq->uaddrs[index], start, end))
+               return;
+
+       spin_lock(&vq->mmu_lock);
+       --vq->invalidate_count;
+       spin_unlock(&vq->mmu_lock);
+}
+
+static int vhost_invalidate_range_start(struct mmu_notifier *mn,
+                                       const struct mmu_notifier_range *range)
+{
+       struct vhost_dev *dev = container_of(mn, struct vhost_dev,
+                                            mmu_notifier);
+       int i, j;
+
+       if (!mmu_notifier_range_blockable(range))
+               return -EAGAIN;
+
+       for (i = 0; i < dev->nvqs; i++) {
+               struct vhost_virtqueue *vq = dev->vqs[i];
+
+               for (j = 0; j < VHOST_NUM_ADDRS; j++)
+                       vhost_invalidate_vq_start(vq, j,
+                                                 range->start,
+                                                 range->end);
+       }
+
+       return 0;
+}
+
+static void vhost_invalidate_range_end(struct mmu_notifier *mn,
+                                      const struct mmu_notifier_range *range)
+{
+       struct vhost_dev *dev = container_of(mn, struct vhost_dev,
+                                            mmu_notifier);
+       int i, j;
+
+       for (i = 0; i < dev->nvqs; i++) {
+               struct vhost_virtqueue *vq = dev->vqs[i];
+
+               for (j = 0; j < VHOST_NUM_ADDRS; j++)
+                       vhost_invalidate_vq_end(vq, j,
+                                               range->start,
+                                               range->end);
+       }
+}
+
+static const struct mmu_notifier_ops vhost_mmu_notifier_ops = {
+       .invalidate_range_start = vhost_invalidate_range_start,
+       .invalidate_range_end = vhost_invalidate_range_end,
+};
+
+static void vhost_init_maps(struct vhost_dev *dev)
+{
+       struct vhost_virtqueue *vq;
+       int i, j;
+
+       dev->mmu_notifier.ops = &vhost_mmu_notifier_ops;
+
+       for (i = 0; i < dev->nvqs; ++i) {
+               vq = dev->vqs[i];
+               for (j = 0; j < VHOST_NUM_ADDRS; j++)
+                       RCU_INIT_POINTER(vq->maps[j], NULL);
+       }
+}
+#endif
+
 static void vhost_vq_reset(struct vhost_dev *dev,
                           struct vhost_virtqueue *vq)
 {
        vq->busyloop_timeout = 0;
        vq->umem = NULL;
        vq->iotlb = NULL;
+       vq->invalidate_count = 0;
        __vhost_vq_meta_reset(vq);
+#if VHOST_ARCH_CAN_ACCEL_UACCESS
+       vhost_reset_vq_maps(vq);
+#endif
 }
 
 static int vhost_worker(void *data)
        INIT_LIST_HEAD(&dev->read_list);
        INIT_LIST_HEAD(&dev->pending_list);
        spin_lock_init(&dev->iotlb_lock);
-
+#if VHOST_ARCH_CAN_ACCEL_UACCESS
+       vhost_init_maps(dev);
+#endif
 
        for (i = 0; i < dev->nvqs; ++i) {
                vq = dev->vqs[i];
                vq->heads = NULL;
                vq->dev = dev;
                mutex_init(&vq->mutex);
+               spin_lock_init(&vq->mmu_lock);
                vhost_vq_reset(dev, vq);
                if (vq->handle_kick)
                        vhost_poll_init(&vq->poll, vq->handle_kick,
        if (err)
                goto err_cgroup;
 
+#if VHOST_ARCH_CAN_ACCEL_UACCESS
+       err = mmu_notifier_register(&dev->mmu_notifier, dev->mm);
+       if (err)
+               goto err_mmu_notifier;
+#endif
+
        return 0;
+
+#if VHOST_ARCH_CAN_ACCEL_UACCESS
+err_mmu_notifier:
+       vhost_dev_free_iovecs(dev);
+#endif
 err_cgroup:
        kthread_stop(worker);
        dev->worker = NULL;
        spin_unlock(&dev->iotlb_lock);
 }
 
+#if VHOST_ARCH_CAN_ACCEL_UACCESS
+static void vhost_setup_uaddr(struct vhost_virtqueue *vq,
+                             int index, unsigned long uaddr,
+                             size_t size, bool write)
+{
+       struct vhost_uaddr *addr = &vq->uaddrs[index];
+
+       addr->uaddr = uaddr;
+       addr->size = size;
+       addr->write = write;
+}
+
+static void vhost_setup_vq_uaddr(struct vhost_virtqueue *vq)
+{
+       vhost_setup_uaddr(vq, VHOST_ADDR_DESC,
+                         (unsigned long)vq->desc,
+                         vhost_get_desc_size(vq, vq->num),
+                         false);
+       vhost_setup_uaddr(vq, VHOST_ADDR_AVAIL,
+                         (unsigned long)vq->avail,
+                         vhost_get_avail_size(vq, vq->num),
+                         false);
+       vhost_setup_uaddr(vq, VHOST_ADDR_USED,
+                         (unsigned long)vq->used,
+                         vhost_get_used_size(vq, vq->num),
+                         true);
+}
+
+static int vhost_map_prefetch(struct vhost_virtqueue *vq,
+                              int index)
+{
+       struct vhost_map *map;
+       struct vhost_uaddr *uaddr = &vq->uaddrs[index];
+       struct page **pages;
+       int npages = DIV_ROUND_UP(uaddr->size, PAGE_SIZE);
+       int npinned;
+       void *vaddr, *v;
+       int err;
+       int i;
+
+       spin_lock(&vq->mmu_lock);
+
+       err = -EFAULT;
+       if (vq->invalidate_count)
+               goto err;
+
+       err = -ENOMEM;
+       map = kmalloc(sizeof(*map), GFP_ATOMIC);
+       if (!map)
+               goto err;
+
+       pages = kmalloc_array(npages, sizeof(struct page *), GFP_ATOMIC);
+       if (!pages)
+               goto err_pages;
+
+       err = EFAULT;
+       npinned = __get_user_pages_fast(uaddr->uaddr, npages,
+                                       uaddr->write, pages);
+       if (npinned > 0)
+               release_pages(pages, npinned);
+       if (npinned != npages)
+               goto err_gup;
+
+       for (i = 0; i < npinned; i++)
+               if (PageHighMem(pages[i]))
+                       goto err_gup;
+
+       vaddr = v = page_address(pages[0]);
+
+       /* For simplicity, fallback to userspace address if VA is not
+        * contigious.
+        */
+       for (i = 1; i < npinned; i++) {
+               v += PAGE_SIZE;
+               if (v != page_address(pages[i]))
+                       goto err_gup;
+       }
+
+       map->addr = vaddr + (uaddr->uaddr & (PAGE_SIZE - 1));
+       map->npages = npages;
+       map->pages = pages;
+
+       rcu_assign_pointer(vq->maps[index], map);
+       /* No need for a synchronize_rcu(). This function should be
+        * called by dev->worker so we are serialized with all
+        * readers.
+        */
+       spin_unlock(&vq->mmu_lock);
+
+       return 0;
+
+err_gup:
+       kfree(pages);
+err_pages:
+       kfree(map);
+err:
+       spin_unlock(&vq->mmu_lock);
+       return err;
+}
+#endif
+
 void vhost_dev_cleanup(struct vhost_dev *dev)
 {
        int i;
                kthread_stop(dev->worker);
                dev->worker = NULL;
        }
-       if (dev->mm)
+       if (dev->mm) {
+#if VHOST_ARCH_CAN_ACCEL_UACCESS
+               mmu_notifier_unregister(&dev->mmu_notifier, dev->mm);
+#endif
                mmput(dev->mm);
+       }
+#if VHOST_ARCH_CAN_ACCEL_UACCESS
+       for (i = 0; i < dev->nvqs; i++)
+               vhost_uninit_vq_maps(dev->vqs[i]);
+#endif
        dev->mm = NULL;
 }
 EXPORT_SYMBOL_GPL(vhost_dev_cleanup);
 
 static inline int vhost_put_avail_event(struct vhost_virtqueue *vq)
 {
+#if VHOST_ARCH_CAN_ACCEL_UACCESS
+       struct vhost_map *map;
+       struct vring_used *used;
+
+       if (!vq->iotlb) {
+               rcu_read_lock();
+
+               map = rcu_dereference(vq->maps[VHOST_ADDR_USED]);
+               if (likely(map)) {
+                       used = map->addr;
+                       *((__virtio16 *)&used->ring[vq->num]) =
+                               cpu_to_vhost16(vq, vq->avail_idx);
+                       rcu_read_unlock();
+                       return 0;
+               }
+
+               rcu_read_unlock();
+       }
+#endif
+
        return vhost_put_user(vq, cpu_to_vhost16(vq, vq->avail_idx),
                              vhost_avail_event(vq));
 }
                                 struct vring_used_elem *head, int idx,
                                 int count)
 {
+#if VHOST_ARCH_CAN_ACCEL_UACCESS
+       struct vhost_map *map;
+       struct vring_used *used;
+       size_t size;
+
+       if (!vq->iotlb) {
+               rcu_read_lock();
+
+               map = rcu_dereference(vq->maps[VHOST_ADDR_USED]);
+               if (likely(map)) {
+                       used = map->addr;
+                       size = count * sizeof(*head);
+                       memcpy(used->ring + idx, head, size);
+                       rcu_read_unlock();
+                       return 0;
+               }
+
+               rcu_read_unlock();
+       }
+#endif
+
        return vhost_copy_to_user(vq, vq->used->ring + idx, head,
                                  count * sizeof(*head));
 }
 static inline int vhost_put_used_flags(struct vhost_virtqueue *vq)
 
 {
+#if VHOST_ARCH_CAN_ACCEL_UACCESS
+       struct vhost_map *map;
+       struct vring_used *used;
+
+       if (!vq->iotlb) {
+               rcu_read_lock();
+
+               map = rcu_dereference(vq->maps[VHOST_ADDR_USED]);
+               if (likely(map)) {
+                       used = map->addr;
+                       used->flags = cpu_to_vhost16(vq, vq->used_flags);
+                       rcu_read_unlock();
+                       return 0;
+               }
+
+               rcu_read_unlock();
+       }
+#endif
+
        return vhost_put_user(vq, cpu_to_vhost16(vq, vq->used_flags),
                              &vq->used->flags);
 }
 static inline int vhost_put_used_idx(struct vhost_virtqueue *vq)
 
 {
+#if VHOST_ARCH_CAN_ACCEL_UACCESS
+       struct vhost_map *map;
+       struct vring_used *used;
+
+       if (!vq->iotlb) {
+               rcu_read_lock();
+
+               map = rcu_dereference(vq->maps[VHOST_ADDR_USED]);
+               if (likely(map)) {
+                       used = map->addr;
+                       used->idx = cpu_to_vhost16(vq, vq->last_used_idx);
+                       rcu_read_unlock();
+                       return 0;
+               }
+
+               rcu_read_unlock();
+       }
+#endif
+
        return vhost_put_user(vq, cpu_to_vhost16(vq, vq->last_used_idx),
                              &vq->used->idx);
 }
 static inline int vhost_get_avail_idx(struct vhost_virtqueue *vq,
                                      __virtio16 *idx)
 {
+#if VHOST_ARCH_CAN_ACCEL_UACCESS
+       struct vhost_map *map;
+       struct vring_avail *avail;
+
+       if (!vq->iotlb) {
+               rcu_read_lock();
+
+               map = rcu_dereference(vq->maps[VHOST_ADDR_AVAIL]);
+               if (likely(map)) {
+                       avail = map->addr;
+                       *idx = avail->idx;
+                       rcu_read_unlock();
+                       return 0;
+               }
+
+               rcu_read_unlock();
+       }
+#endif
+
        return vhost_get_avail(vq, *idx, &vq->avail->idx);
 }
 
 static inline int vhost_get_avail_head(struct vhost_virtqueue *vq,
                                       __virtio16 *head, int idx)
 {
+#if VHOST_ARCH_CAN_ACCEL_UACCESS
+       struct vhost_map *map;
+       struct vring_avail *avail;
+
+       if (!vq->iotlb) {
+               rcu_read_lock();
+
+               map = rcu_dereference(vq->maps[VHOST_ADDR_AVAIL]);
+               if (likely(map)) {
+                       avail = map->addr;
+                       *head = avail->ring[idx & (vq->num - 1)];
+                       rcu_read_unlock();
+                       return 0;
+               }
+
+               rcu_read_unlock();
+       }
+#endif
+
        return vhost_get_avail(vq, *head,
                               &vq->avail->ring[idx & (vq->num - 1)]);
 }
 static inline int vhost_get_avail_flags(struct vhost_virtqueue *vq,
                                        __virtio16 *flags)
 {
+#if VHOST_ARCH_CAN_ACCEL_UACCESS
+       struct vhost_map *map;
+       struct vring_avail *avail;
+
+       if (!vq->iotlb) {
+               rcu_read_lock();
+
+               map = rcu_dereference(vq->maps[VHOST_ADDR_AVAIL]);
+               if (likely(map)) {
+                       avail = map->addr;
+                       *flags = avail->flags;
+                       rcu_read_unlock();
+                       return 0;
+               }
+
+               rcu_read_unlock();
+       }
+#endif
+
        return vhost_get_avail(vq, *flags, &vq->avail->flags);
 }
 
 static inline int vhost_get_used_event(struct vhost_virtqueue *vq,
                                       __virtio16 *event)
 {
+#if VHOST_ARCH_CAN_ACCEL_UACCESS
+       struct vhost_map *map;
+       struct vring_avail *avail;
+
+       if (!vq->iotlb) {
+               rcu_read_lock();
+               map = rcu_dereference(vq->maps[VHOST_ADDR_AVAIL]);
+               if (likely(map)) {
+                       avail = map->addr;
+                       *event = (__virtio16)avail->ring[vq->num];
+                       rcu_read_unlock();
+                       return 0;
+               }
+               rcu_read_unlock();
+       }
+#endif
+
        return vhost_get_avail(vq, *event, vhost_used_event(vq));
 }
 
 static inline int vhost_get_used_idx(struct vhost_virtqueue *vq,
                                     __virtio16 *idx)
 {
+#if VHOST_ARCH_CAN_ACCEL_UACCESS
+       struct vhost_map *map;
+       struct vring_used *used;
+
+       if (!vq->iotlb) {
+               rcu_read_lock();
+
+               map = rcu_dereference(vq->maps[VHOST_ADDR_USED]);
+               if (likely(map)) {
+                       used = map->addr;
+                       *idx = used->idx;
+                       rcu_read_unlock();
+                       return 0;
+               }
+
+               rcu_read_unlock();
+       }
+#endif
+
        return vhost_get_used(vq, *idx, &vq->used->idx);
 }
 
 static inline int vhost_get_desc(struct vhost_virtqueue *vq,
                                 struct vring_desc *desc, int idx)
 {
+#if VHOST_ARCH_CAN_ACCEL_UACCESS
+       struct vhost_map *map;
+       struct vring_desc *d;
+
+       if (!vq->iotlb) {
+               rcu_read_lock();
+
+               map = rcu_dereference(vq->maps[VHOST_ADDR_DESC]);
+               if (likely(map)) {
+                       d = map->addr;
+                       *desc = *(d + idx);
+                       rcu_read_unlock();
+                       return 0;
+               }
+
+               rcu_read_unlock();
+       }
+#endif
+
        return vhost_copy_from_user(vq, desc, vq->desc + idx, sizeof(*desc));
 }
 
        return true;
 }
 
+#if VHOST_ARCH_CAN_ACCEL_UACCESS
+static void vhost_vq_map_prefetch(struct vhost_virtqueue *vq)
+{
+       struct vhost_map __rcu *map;
+       int i;
+
+       for (i = 0; i < VHOST_NUM_ADDRS; i++) {
+               rcu_read_lock();
+               map = rcu_dereference(vq->maps[i]);
+               rcu_read_unlock();
+               if (unlikely(!map))
+                       vhost_map_prefetch(vq, i);
+       }
+}
+#endif
+
 int vq_meta_prefetch(struct vhost_virtqueue *vq)
 {
        unsigned int num = vq->num;
 
-       if (!vq->iotlb)
+       if (!vq->iotlb) {
+#if VHOST_ARCH_CAN_ACCEL_UACCESS
+               vhost_vq_map_prefetch(vq);
+#endif
                return 1;
+       }
 
        return iotlb_access_ok(vq, VHOST_ACCESS_RO, (u64)(uintptr_t)vq->desc,
                               vhost_get_desc_size(vq, num), VHOST_ADDR_DESC) &&
 
        mutex_lock(&vq->mutex);
 
+#if VHOST_ARCH_CAN_ACCEL_UACCESS
+       /* Unregister MMU notifer to allow invalidation callback
+        * can access vq->uaddrs[] without holding a lock.
+        */
+       if (d->mm)
+               mmu_notifier_unregister(&d->mmu_notifier, d->mm);
+
+       vhost_uninit_vq_maps(vq);
+#endif
+
        switch (ioctl) {
        case VHOST_SET_VRING_NUM:
                r = vhost_vring_set_num(d, vq, argp);
                BUG();
        }
 
+#if VHOST_ARCH_CAN_ACCEL_UACCESS
+       vhost_setup_vq_uaddr(vq);
+
+       if (d->mm)
+               mmu_notifier_register(&d->mmu_notifier, d->mm);
+#endif
+
        mutex_unlock(&vq->mutex);
 
        return r;