vhost_net_disable_vq(n, vq);
                rcu_assign_pointer(vq->private_data, sock);
                vhost_net_enable_vq(n, vq);
+
+               r = vhost_init_used(vq);
+               if (r)
+                       goto err_vq;
        }
 
        mutex_unlock(&vq->mutex);
 
                                                    lockdep_is_held(&vq->mutex));
                rcu_assign_pointer(vq->private_data, priv);
 
+               r = vhost_init_used(&n->vqs[index]);
+
                mutex_unlock(&vq->mutex);
 
+               if (r)
+                       goto err;
+
                if (oldpriv) {
                        vhost_test_flush_vq(n, index);
                }
 
        return 0;
 }
 
-static int init_used(struct vhost_virtqueue *vq,
-                    struct vring_used __user *used)
+int vhost_init_used(struct vhost_virtqueue *vq)
 {
-       int r = put_user(vq->used_flags, &used->flags);
+       int r;
+       if (!vq->private_data)
+               return 0;
 
+       r = put_user(vq->used_flags, &vq->used->flags);
        if (r)
                return r;
        vq->signalled_used_valid = false;
-       return get_user(vq->last_used_idx, &used->idx);
+       return get_user(vq->last_used_idx, &vq->used->idx);
 }
 
 static long vhost_set_vring(struct vhost_dev *d, int ioctl, void __user *argp)
                        }
                }
 
-               r = init_used(vq, (struct vring_used __user *)(unsigned long)
-                             a.used_user_addr);
-               if (r)
-                       break;
                vq->log_used = !!(a.flags & (0x1 << VHOST_VRING_F_LOG));
                vq->desc = (void __user *)(unsigned long)a.desc_user_addr;
                vq->avail = (void __user *)(unsigned long)a.avail_user_addr;
 
                      struct vhost_log *log, unsigned int *log_num);
 void vhost_discard_vq_desc(struct vhost_virtqueue *, int n);
 
+int vhost_init_used(struct vhost_virtqueue *);
 int vhost_add_used(struct vhost_virtqueue *, unsigned int head, int len);
 int vhost_add_used_n(struct vhost_virtqueue *, struct vring_used_elem *heads,
                     unsigned count);