#include <linux/uaccess.h>
 #include <linux/slab.h>
 #include <linux/export.h>
+#include <linux/bvec.h>
+#include <linux/highmem.h>
+#include <linux/vhost_iotlb.h>
 #include <uapi/linux/virtio_config.h>
 
 static __printf(1,2) __cold void vringh_bad(const char *fmt, ...)
 }
 
 /* Copy some bytes to/from the iovec.  Returns num copied. */
-static inline ssize_t vringh_iov_xfer(struct vringh_kiov *iov,
+static inline ssize_t vringh_iov_xfer(struct vringh *vrh,
+                                     struct vringh_kiov *iov,
                                      void *ptr, size_t len,
-                                     int (*xfer)(void *addr, void *ptr,
+                                     int (*xfer)(const struct vringh *vrh,
+                                                 void *addr, void *ptr,
                                                  size_t len))
 {
        int err, done = 0;
                size_t partlen;
 
                partlen = min(iov->iov[iov->i].iov_len, len);
-               err = xfer(iov->iov[iov->i].iov_base, ptr, partlen);
+               err = xfer(vrh, iov->iov[iov->i].iov_base, ptr, partlen);
                if (err)
                        return err;
                done += partlen;
                        /* Fix up old iov element then increment. */
                        iov->iov[iov->i].iov_len = iov->consumed;
                        iov->iov[iov->i].iov_base -= iov->consumed;
+
                        
                        iov->consumed = 0;
                        iov->i++;
                                      u64 addr,
                                      struct vringh_range *r),
                     struct vringh_range *range,
-                    int (*copy)(void *dst, const void *src, size_t len))
+                    int (*copy)(const struct vringh *vrh,
+                                void *dst, const void *src, size_t len))
 {
        size_t part, len = sizeof(struct vring_desc);
 
                if (!rcheck(vrh, addr, &part, range, getrange))
                        return -EINVAL;
 
-               err = copy(dst, src, part);
+               err = copy(vrh, dst, src, part);
                if (err)
                        return err;
 
                                             struct vringh_range *)),
             bool (*getrange)(struct vringh *, u64, struct vringh_range *),
             gfp_t gfp,
-            int (*copy)(void *dst, const void *src, size_t len))
+            int (*copy)(const struct vringh *vrh,
+                        void *dst, const void *src, size_t len))
 {
        int err, count = 0, up_next, desc_max;
        struct vring_desc desc, *descs;
                        err = slow_copy(vrh, &desc, &descs[i], rcheck, getrange,
                                        &slowrange, copy);
                else
-                       err = copy(&desc, &descs[i], sizeof(desc));
+                       err = copy(vrh, &desc, &descs[i], sizeof(desc));
                if (unlikely(err))
                        goto fail;
 
                                    unsigned int num_used,
                                    int (*putu16)(const struct vringh *vrh,
                                                  __virtio16 *p, u16 val),
-                                   int (*putused)(struct vring_used_elem *dst,
+                                   int (*putused)(const struct vringh *vrh,
+                                                  struct vring_used_elem *dst,
                                                   const struct vring_used_elem
                                                   *src, unsigned num))
 {
        /* Compiler knows num_used == 1 sometimes, hence extra check */
        if (num_used > 1 && unlikely(off + num_used >= vrh->vring.num)) {
                u16 part = vrh->vring.num - off;
-               err = putused(&used_ring->ring[off], used, part);
+               err = putused(vrh, &used_ring->ring[off], used, part);
                if (!err)
-                       err = putused(&used_ring->ring[0], used + part,
+                       err = putused(vrh, &used_ring->ring[0], used + part,
                                      num_used - part);
        } else
-               err = putused(&used_ring->ring[off], used, num_used);
+               err = putused(vrh, &used_ring->ring[off], used, num_used);
 
        if (err) {
                vringh_bad("Failed to write %u used entries %u at %p",
        return put_user(v, (__force __virtio16 __user *)p);
 }
 
-static inline int copydesc_user(void *dst, const void *src, size_t len)
+static inline int copydesc_user(const struct vringh *vrh,
+                               void *dst, const void *src, size_t len)
 {
        return copy_from_user(dst, (__force void __user *)src, len) ?
                -EFAULT : 0;
 }
 
-static inline int putused_user(struct vring_used_elem *dst,
+static inline int putused_user(const struct vringh *vrh,
+                              struct vring_used_elem *dst,
                               const struct vring_used_elem *src,
                               unsigned int num)
 {
                            sizeof(*dst) * num) ? -EFAULT : 0;
 }
 
-static inline int xfer_from_user(void *src, void *dst, size_t len)
+static inline int xfer_from_user(const struct vringh *vrh, void *src,
+                                void *dst, size_t len)
 {
        return copy_from_user(dst, (__force void __user *)src, len) ?
                -EFAULT : 0;
 }
 
-static inline int xfer_to_user(void *dst, void *src, size_t len)
+static inline int xfer_to_user(const struct vringh *vrh,
+                              void *dst, void *src, size_t len)
 {
        return copy_to_user((__force void __user *)dst, src, len) ?
                -EFAULT : 0;
  */
 ssize_t vringh_iov_pull_user(struct vringh_iov *riov, void *dst, size_t len)
 {
-       return vringh_iov_xfer((struct vringh_kiov *)riov,
+       return vringh_iov_xfer(NULL, (struct vringh_kiov *)riov,
                               dst, len, xfer_from_user);
 }
 EXPORT_SYMBOL(vringh_iov_pull_user);
 ssize_t vringh_iov_push_user(struct vringh_iov *wiov,
                             const void *src, size_t len)
 {
-       return vringh_iov_xfer((struct vringh_kiov *)wiov,
+       return vringh_iov_xfer(NULL, (struct vringh_kiov *)wiov,
                               (void *)src, len, xfer_to_user);
 }
 EXPORT_SYMBOL(vringh_iov_push_user);
        return 0;
 }
 
-static inline int copydesc_kern(void *dst, const void *src, size_t len)
+static inline int copydesc_kern(const struct vringh *vrh,
+                               void *dst, const void *src, size_t len)
 {
        memcpy(dst, src, len);
        return 0;
 }
 
-static inline int putused_kern(struct vring_used_elem *dst,
+static inline int putused_kern(const struct vringh *vrh,
+                              struct vring_used_elem *dst,
                               const struct vring_used_elem *src,
                               unsigned int num)
 {
        return 0;
 }
 
-static inline int xfer_kern(void *src, void *dst, size_t len)
+static inline int xfer_kern(const struct vringh *vrh, void *src,
+                           void *dst, size_t len)
 {
        memcpy(dst, src, len);
        return 0;
 }
 
-static inline int kern_xfer(void *dst, void *src, size_t len)
+static inline int kern_xfer(const struct vringh *vrh, void *dst,
+                           void *src, size_t len)
 {
        memcpy(dst, src, len);
        return 0;
  */
 ssize_t vringh_iov_pull_kern(struct vringh_kiov *riov, void *dst, size_t len)
 {
-       return vringh_iov_xfer(riov, dst, len, xfer_kern);
+       return vringh_iov_xfer(NULL, riov, dst, len, xfer_kern);
 }
 EXPORT_SYMBOL(vringh_iov_pull_kern);
 
 ssize_t vringh_iov_push_kern(struct vringh_kiov *wiov,
                             const void *src, size_t len)
 {
-       return vringh_iov_xfer(wiov, (void *)src, len, kern_xfer);
+       return vringh_iov_xfer(NULL, wiov, (void *)src, len, kern_xfer);
 }
 EXPORT_SYMBOL(vringh_iov_push_kern);
 
 }
 EXPORT_SYMBOL(vringh_need_notify_kern);
 
+static int iotlb_translate(const struct vringh *vrh,
+                          u64 addr, u64 len, struct bio_vec iov[],
+                          int iov_size, u32 perm)
+{
+       struct vhost_iotlb_map *map;
+       struct vhost_iotlb *iotlb = vrh->iotlb;
+       int ret = 0;
+       u64 s = 0;
+
+       while (len > s) {
+               u64 size, pa, pfn;
+
+               if (unlikely(ret >= iov_size)) {
+                       ret = -ENOBUFS;
+                       break;
+               }
+
+               map = vhost_iotlb_itree_first(iotlb, addr,
+                                             addr + len - 1);
+               if (!map || map->start > addr) {
+                       ret = -EINVAL;
+                       break;
+               } else if (!(map->perm & perm)) {
+                       ret = -EPERM;
+                       break;
+               }
+
+               size = map->size - addr + map->start;
+               pa = map->addr + addr - map->start;
+               pfn = pa >> PAGE_SHIFT;
+               iov[ret].bv_page = pfn_to_page(pfn);
+               iov[ret].bv_len = min(len - s, size);
+               iov[ret].bv_offset = pa & (PAGE_SIZE - 1);
+               s += size;
+               addr += size;
+               ++ret;
+       }
+
+       return ret;
+}
+
+static inline int copy_from_iotlb(const struct vringh *vrh, void *dst,
+                                 void *src, size_t len)
+{
+       struct iov_iter iter;
+       struct bio_vec iov[16];
+       int ret;
+
+       ret = iotlb_translate(vrh, (u64)(uintptr_t)src,
+                             len, iov, 16, VHOST_MAP_RO);
+       if (ret < 0)
+               return ret;
+
+       iov_iter_bvec(&iter, READ, iov, ret, len);
+
+       ret = copy_from_iter(dst, len, &iter);
+
+       return ret;
+}
+
+static inline int copy_to_iotlb(const struct vringh *vrh, void *dst,
+                               void *src, size_t len)
+{
+       struct iov_iter iter;
+       struct bio_vec iov[16];
+       int ret;
+
+       ret = iotlb_translate(vrh, (u64)(uintptr_t)dst,
+                             len, iov, 16, VHOST_MAP_WO);
+       if (ret < 0)
+               return ret;
+
+       iov_iter_bvec(&iter, WRITE, iov, ret, len);
+
+       return copy_to_iter(src, len, &iter);
+}
+
+static inline int getu16_iotlb(const struct vringh *vrh,
+                              u16 *val, const __virtio16 *p)
+{
+       struct bio_vec iov;
+       void *kaddr, *from;
+       int ret;
+
+       /* Atomic read is needed for getu16 */
+       ret = iotlb_translate(vrh, (u64)(uintptr_t)p, sizeof(*p),
+                             &iov, 1, VHOST_MAP_RO);
+       if (ret < 0)
+               return ret;
+
+       kaddr = kmap_atomic(iov.bv_page);
+       from = kaddr + iov.bv_offset;
+       *val = vringh16_to_cpu(vrh, READ_ONCE(*(__virtio16 *)from));
+       kunmap_atomic(kaddr);
+
+       return 0;
+}
+
+static inline int putu16_iotlb(const struct vringh *vrh,
+                              __virtio16 *p, u16 val)
+{
+       struct bio_vec iov;
+       void *kaddr, *to;
+       int ret;
+
+       /* Atomic write is needed for putu16 */
+       ret = iotlb_translate(vrh, (u64)(uintptr_t)p, sizeof(*p),
+                             &iov, 1, VHOST_MAP_WO);
+       if (ret < 0)
+               return ret;
+
+       kaddr = kmap_atomic(iov.bv_page);
+       to = kaddr + iov.bv_offset;
+       WRITE_ONCE(*(__virtio16 *)to, cpu_to_vringh16(vrh, val));
+       kunmap_atomic(kaddr);
+
+       return 0;
+}
+
+static inline int copydesc_iotlb(const struct vringh *vrh,
+                                void *dst, const void *src, size_t len)
+{
+       int ret;
+
+       ret = copy_from_iotlb(vrh, dst, (void *)src, len);
+       if (ret != len)
+               return -EFAULT;
+
+       return 0;
+}
+
+static inline int xfer_from_iotlb(const struct vringh *vrh, void *src,
+                                 void *dst, size_t len)
+{
+       int ret;
+
+       ret = copy_from_iotlb(vrh, dst, src, len);
+       if (ret != len)
+               return -EFAULT;
+
+       return 0;
+}
+
+static inline int xfer_to_iotlb(const struct vringh *vrh,
+                              void *dst, void *src, size_t len)
+{
+       int ret;
+
+       ret = copy_to_iotlb(vrh, dst, src, len);
+       if (ret != len)
+               return -EFAULT;
+
+       return 0;
+}
+
+static inline int putused_iotlb(const struct vringh *vrh,
+                               struct vring_used_elem *dst,
+                               const struct vring_used_elem *src,
+                               unsigned int num)
+{
+       int size = num * sizeof(*dst);
+       int ret;
+
+       ret = copy_to_iotlb(vrh, dst, (void *)src, num * sizeof(*dst));
+       if (ret != size)
+               return -EFAULT;
+
+       return 0;
+}
+
+/**
+ * vringh_init_iotlb - initialize a vringh for a ring with IOTLB.
+ * @vrh: the vringh to initialize.
+ * @features: the feature bits for this ring.
+ * @num: the number of elements.
+ * @weak_barriers: true if we only need memory barriers, not I/O.
+ * @desc: the userpace descriptor pointer.
+ * @avail: the userpace avail pointer.
+ * @used: the userpace used pointer.
+ *
+ * Returns an error if num is invalid.
+ */
+int vringh_init_iotlb(struct vringh *vrh, u64 features,
+                     unsigned int num, bool weak_barriers,
+                     struct vring_desc *desc,
+                     struct vring_avail *avail,
+                     struct vring_used *used)
+{
+       return vringh_init_kern(vrh, features, num, weak_barriers,
+                               desc, avail, used);
+}
+EXPORT_SYMBOL(vringh_init_iotlb);
+
+/**
+ * vringh_set_iotlb - initialize a vringh for a ring with IOTLB.
+ * @vrh: the vring
+ * @iotlb: iotlb associated with this vring
+ */
+void vringh_set_iotlb(struct vringh *vrh, struct vhost_iotlb *iotlb)
+{
+       vrh->iotlb = iotlb;
+}
+EXPORT_SYMBOL(vringh_set_iotlb);
+
+/**
+ * vringh_getdesc_iotlb - get next available descriptor from ring with
+ * IOTLB.
+ * @vrh: the kernelspace vring.
+ * @riov: where to put the readable descriptors (or NULL)
+ * @wiov: where to put the writable descriptors (or NULL)
+ * @head: head index we received, for passing to vringh_complete_iotlb().
+ * @gfp: flags for allocating larger riov/wiov.
+ *
+ * Returns 0 if there was no descriptor, 1 if there was, or -errno.
+ *
+ * Note that on error return, you can tell the difference between an
+ * invalid ring and a single invalid descriptor: in the former case,
+ * *head will be vrh->vring.num.  You may be able to ignore an invalid
+ * descriptor, but there's not much you can do with an invalid ring.
+ *
+ * Note that you may need to clean up riov and wiov, even on error!
+ */
+int vringh_getdesc_iotlb(struct vringh *vrh,
+                        struct vringh_kiov *riov,
+                        struct vringh_kiov *wiov,
+                        u16 *head,
+                        gfp_t gfp)
+{
+       int err;
+
+       err = __vringh_get_head(vrh, getu16_iotlb, &vrh->last_avail_idx);
+       if (err < 0)
+               return err;
+
+       /* Empty... */
+       if (err == vrh->vring.num)
+               return 0;
+
+       *head = err;
+       err = __vringh_iov(vrh, *head, riov, wiov, no_range_check, NULL,
+                          gfp, copydesc_iotlb);
+       if (err)
+               return err;
+
+       return 1;
+}
+EXPORT_SYMBOL(vringh_getdesc_iotlb);
+
+/**
+ * vringh_iov_pull_iotlb - copy bytes from vring_iov.
+ * @vrh: the vring.
+ * @riov: the riov as passed to vringh_getdesc_iotlb() (updated as we consume)
+ * @dst: the place to copy.
+ * @len: the maximum length to copy.
+ *
+ * Returns the bytes copied <= len or a negative errno.
+ */
+ssize_t vringh_iov_pull_iotlb(struct vringh *vrh,
+                             struct vringh_kiov *riov,
+                             void *dst, size_t len)
+{
+       return vringh_iov_xfer(vrh, riov, dst, len, xfer_from_iotlb);
+}
+EXPORT_SYMBOL(vringh_iov_pull_iotlb);
+
+/**
+ * vringh_iov_push_iotlb - copy bytes into vring_iov.
+ * @vrh: the vring.
+ * @wiov: the wiov as passed to vringh_getdesc_iotlb() (updated as we consume)
+ * @dst: the place to copy.
+ * @len: the maximum length to copy.
+ *
+ * Returns the bytes copied <= len or a negative errno.
+ */
+ssize_t vringh_iov_push_iotlb(struct vringh *vrh,
+                             struct vringh_kiov *wiov,
+                             const void *src, size_t len)
+{
+       return vringh_iov_xfer(vrh, wiov, (void *)src, len, xfer_to_iotlb);
+}
+EXPORT_SYMBOL(vringh_iov_push_iotlb);
+
+/**
+ * vringh_abandon_iotlb - we've decided not to handle the descriptor(s).
+ * @vrh: the vring.
+ * @num: the number of descriptors to put back (ie. num
+ *      vringh_get_iotlb() to undo).
+ *
+ * The next vringh_get_iotlb() will return the old descriptor(s) again.
+ */
+void vringh_abandon_iotlb(struct vringh *vrh, unsigned int num)
+{
+       /* We only update vring_avail_event(vr) when we want to be notified,
+        * so we haven't changed that yet.
+        */
+       vrh->last_avail_idx -= num;
+}
+EXPORT_SYMBOL(vringh_abandon_iotlb);
+
+/**
+ * vringh_complete_iotlb - we've finished with descriptor, publish it.
+ * @vrh: the vring.
+ * @head: the head as filled in by vringh_getdesc_iotlb.
+ * @len: the length of data we have written.
+ *
+ * You should check vringh_need_notify_iotlb() after one or more calls
+ * to this function.
+ */
+int vringh_complete_iotlb(struct vringh *vrh, u16 head, u32 len)
+{
+       struct vring_used_elem used;
+
+       used.id = cpu_to_vringh32(vrh, head);
+       used.len = cpu_to_vringh32(vrh, len);
+
+       return __vringh_complete(vrh, &used, 1, putu16_iotlb, putused_iotlb);
+}
+EXPORT_SYMBOL(vringh_complete_iotlb);
+
+/**
+ * vringh_notify_enable_iotlb - we want to know if something changes.
+ * @vrh: the vring.
+ *
+ * This always enables notifications, but returns false if there are
+ * now more buffers available in the vring.
+ */
+bool vringh_notify_enable_iotlb(struct vringh *vrh)
+{
+       return __vringh_notify_enable(vrh, getu16_iotlb, putu16_iotlb);
+}
+EXPORT_SYMBOL(vringh_notify_enable_iotlb);
+
+/**
+ * vringh_notify_disable_iotlb - don't tell us if something changes.
+ * @vrh: the vring.
+ *
+ * This is our normal running state: we disable and then only enable when
+ * we're going to sleep.
+ */
+void vringh_notify_disable_iotlb(struct vringh *vrh)
+{
+       __vringh_notify_disable(vrh, putu16_iotlb);
+}
+EXPORT_SYMBOL(vringh_notify_disable_iotlb);
+
+/**
+ * vringh_need_notify_iotlb - must we tell the other side about used buffers?
+ * @vrh: the vring we've called vringh_complete_iotlb() on.
+ *
+ * Returns -errno or 0 if we don't need to tell the other side, 1 if we do.
+ */
+int vringh_need_notify_iotlb(struct vringh *vrh)
+{
+       return __vringh_need_notify(vrh, getu16_iotlb);
+}
+EXPORT_SYMBOL(vringh_need_notify_iotlb);
+
+
 MODULE_LICENSE("GPL");