#include <asm/types.h>
 
 #include <linux/nbd.h>
+#include <linux/nbd-netlink.h>
+#include <net/genetlink.h>
 
 static DEFINE_IDR(nbd_index_idr);
 static DEFINE_MUTEX(nbd_index_mutex);
 #define NBD_DISCONNECT_REQUESTED       1
 #define NBD_DISCONNECTED               2
 #define NBD_HAS_PID_FILE               3
+#define NBD_HAS_CONFIG_REF             4
+#define NBD_BOUND                      5
 
 struct nbd_config {
        u32 flags;
 struct nbd_device {
        struct blk_mq_tag_set tag_set;
 
+       int index;
        refcount_t config_refs;
        struct nbd_config *config;
        struct mutex config_lock;
 static int nbd_dev_dbg_init(struct nbd_device *nbd);
 static void nbd_dev_dbg_close(struct nbd_device *nbd);
 static void nbd_config_put(struct nbd_device *nbd);
+static void nbd_connect_reply(struct genl_info *info, int index);
 
 static inline struct device *nbd_to_dev(struct nbd_device *nbd)
 {
        return ret;
 }
 
-static int nbd_add_socket(struct nbd_device *nbd, unsigned long arg)
+static int nbd_add_socket(struct nbd_device *nbd, unsigned long arg,
+                         bool netlink)
 {
        struct nbd_config *config = nbd->config;
        struct socket *sock;
        if (!sock)
                return err;
 
-       if (!nbd->task_setup)
+       if (!netlink && !nbd->task_setup &&
+           !test_bit(NBD_BOUND, &config->runtime_flags))
                nbd->task_setup = current;
-       if (nbd->task_setup != current) {
+
+       if (!netlink &&
+           (nbd->task_setup != current ||
+            test_bit(NBD_BOUND, &config->runtime_flags))) {
                dev_err(disk_to_dev(nbd->disk),
                        "Device being setup by another task");
                sockfd_put(sock);
-               return -EINVAL;
+               return -EBUSY;
        }
 
        socks = krealloc(config->socks, (config->num_connections + 1) *
        }
 }
 
-static int nbd_start_device(struct nbd_device *nbd, struct block_device *bdev)
+static int nbd_start_device(struct nbd_device *nbd)
 {
        struct nbd_config *config = nbd->config;
        int num_connections = config->num_connections;
                return -EINVAL;
        }
 
-       if (max_part)
-               bdev->bd_invalidated = 1;
        blk_mq_update_nr_hw_queues(&nbd->tag_set, config->num_connections);
        nbd->task_recv = current;
-       mutex_unlock(&nbd->config_lock);
 
        nbd_parse_flags(nbd);
 
                dev_err(disk_to_dev(nbd->disk), "device_create_file failed!\n");
                return error;
        }
-
        set_bit(NBD_HAS_PID_FILE, &config->runtime_flags);
-       if (max_part)
-               bdev->bd_invalidated = 1;
-       bd_set_size(bdev, config->bytesize);
 
        nbd_dev_dbg_init(nbd);
        for (i = 0; i < num_connections; i++) {
                args->index = i;
                queue_work(recv_workqueue, &args->work);
        }
-       error = wait_event_interruptible(config->recv_wq,
+       return error;
+}
+
+static int nbd_start_device_ioctl(struct nbd_device *nbd, struct block_device *bdev)
+{
+       struct nbd_config *config = nbd->config;
+       int ret;
+
+       ret = nbd_start_device(nbd);
+       if (ret)
+               return ret;
+
+       bd_set_size(bdev, config->bytesize);
+       if (max_part)
+               bdev->bd_invalidated = 1;
+       mutex_unlock(&nbd->config_lock);
+       ret = wait_event_interruptible(config->recv_wq,
                                         atomic_read(&config->recv_threads) == 0);
-       if (error)
+       if (ret)
                sock_shutdown(nbd);
        mutex_lock(&nbd->config_lock);
-
+       bd_set_size(bdev, 0);
        /* user requested, ignore socket errors */
        if (test_bit(NBD_DISCONNECT_REQUESTED, &config->runtime_flags))
-               error = 0;
+               ret = 0;
        if (test_bit(NBD_TIMEDOUT, &config->runtime_flags))
-               error = -ETIMEDOUT;
-       return error;
+               ret = -ETIMEDOUT;
+       return ret;
 }
 
 static void nbd_clear_sock_ioctl(struct nbd_device *nbd,
        nbd_clear_sock(nbd);
        kill_bdev(bdev);
        nbd_bdev_reset(bdev);
+       if (test_and_clear_bit(NBD_HAS_CONFIG_REF,
+                              &nbd->config->runtime_flags))
+               nbd_config_put(nbd);
 }
 
 /* Must be called with config_lock held */
                nbd_clear_sock_ioctl(nbd, bdev);
                return 0;
        case NBD_SET_SOCK:
-               return nbd_add_socket(nbd, arg);
+               return nbd_add_socket(nbd, arg, false);
        case NBD_SET_BLKSIZE:
                nbd_size_set(nbd, arg,
                             div_s64(config->bytesize, arg));
                config->flags = arg;
                return 0;
        case NBD_DO_IT:
-               return nbd_start_device(nbd, bdev);
+               return nbd_start_device_ioctl(nbd, bdev);
        case NBD_CLEAR_QUE:
                /*
                 * This is for compatibility only.  The queue is always cleared
                     unsigned int cmd, unsigned long arg)
 {
        struct nbd_device *nbd = bdev->bd_disk->private_data;
-       int error;
+       struct nbd_config *config = nbd->config;
+       int error = -EINVAL;
 
        if (!capable(CAP_SYS_ADMIN))
                return -EPERM;
 
        mutex_lock(&nbd->config_lock);
-       error = __nbd_ioctl(bdev, nbd, cmd, arg);
+
+       /* Don't allow ioctl operations on a nbd device that was created with
+        * netlink, unless it's DISCONNECT or CLEAR_SOCK, which are fine.
+        */
+       if (!test_bit(NBD_BOUND, &config->runtime_flags) ||
+           (cmd == NBD_DISCONNECT || cmd == NBD_CLEAR_SOCK))
+               error = __nbd_ioctl(bdev, nbd, cmd, arg);
+       else
+               dev_err(nbd_to_dev(nbd), "Cannot use ioctl interface on a netlink controlled device.\n");
        mutex_unlock(&nbd->config_lock);
        return error;
 }
        if (err < 0)
                goto out_free_disk;
 
+       nbd->index = index;
        nbd->disk = disk;
        nbd->tag_set.ops = &nbd_mq_ops;
        nbd->tag_set.nr_hw_queues = 1;
        return err;
 }
 
-/*
- * And here should be modules and kernel interface 
- *  (Just smiley confuses emacs :-)
- */
+static int find_free_cb(int id, void *ptr, void *data)
+{
+       struct nbd_device *nbd = ptr;
+       struct nbd_device **found = data;
+
+       if (!refcount_read(&nbd->config_refs)) {
+               *found = nbd;
+               return 1;
+       }
+       return 0;
+}
+
+/* Netlink interface. */
+static struct nla_policy nbd_attr_policy[NBD_ATTR_MAX + 1] = {
+       [NBD_ATTR_INDEX]                =       { .type = NLA_U32 },
+       [NBD_ATTR_SIZE_BYTES]           =       { .type = NLA_U64 },
+       [NBD_ATTR_BLOCK_SIZE_BYTES]     =       { .type = NLA_U64 },
+       [NBD_ATTR_TIMEOUT]              =       { .type = NLA_U64 },
+       [NBD_ATTR_SERVER_FLAGS]         =       { .type = NLA_U64 },
+       [NBD_ATTR_CLIENT_FLAGS]         =       { .type = NLA_U64 },
+       [NBD_ATTR_SOCKETS]              =       { .type = NLA_NESTED},
+};
+
+static struct nla_policy nbd_sock_policy[NBD_SOCK_MAX + 1] = {
+       [NBD_SOCK_FD]                   =       { .type = NLA_U32 },
+};
+
+static int nbd_genl_connect(struct sk_buff *skb, struct genl_info *info)
+{
+       struct nbd_device *nbd = NULL;
+       struct nbd_config *config;
+       int index = -1;
+       int ret;
+
+       if (!netlink_capable(skb, CAP_SYS_ADMIN))
+               return -EPERM;
+
+       if (info->attrs[NBD_ATTR_INDEX])
+               index = nla_get_u32(info->attrs[NBD_ATTR_INDEX]);
+       if (!info->attrs[NBD_ATTR_SOCKETS]) {
+               printk(KERN_ERR "nbd: must specify at least one socket\n");
+               return -EINVAL;
+       }
+       if (!info->attrs[NBD_ATTR_SIZE_BYTES]) {
+               printk(KERN_ERR "nbd: must specify a size in bytes for the device\n");
+               return -EINVAL;
+       }
+again:
+       mutex_lock(&nbd_index_mutex);
+       if (index == -1) {
+               ret = idr_for_each(&nbd_index_idr, &find_free_cb, &nbd);
+               if (ret == 0) {
+                       int new_index;
+                       new_index = nbd_dev_add(-1);
+                       if (new_index < 0) {
+                               mutex_unlock(&nbd_index_mutex);
+                               printk(KERN_ERR "nbd: failed to add new device\n");
+                               return ret;
+                       }
+                       nbd = idr_find(&nbd_index_idr, new_index);
+               }
+       } else {
+               nbd = idr_find(&nbd_index_idr, index);
+       }
+       mutex_unlock(&nbd_index_mutex);
+       if (!nbd) {
+               printk(KERN_ERR "nbd: couldn't find device at index %d\n",
+                      index);
+               return -EINVAL;
+       }
+
+       mutex_lock(&nbd->config_lock);
+       if (refcount_read(&nbd->config_refs)) {
+               mutex_unlock(&nbd->config_lock);
+               if (index == -1)
+                       goto again;
+               printk(KERN_ERR "nbd: nbd%d already in use\n", index);
+               return -EBUSY;
+       }
+       if (WARN_ON(nbd->config)) {
+               mutex_unlock(&nbd->config_lock);
+               return -EINVAL;
+       }
+       config = nbd->config = nbd_alloc_config();
+       if (!nbd->config) {
+               mutex_unlock(&nbd->config_lock);
+               printk(KERN_ERR "nbd: couldn't allocate config\n");
+               return -ENOMEM;
+       }
+       refcount_set(&nbd->config_refs, 1);
+       set_bit(NBD_BOUND, &config->runtime_flags);
+
+       if (info->attrs[NBD_ATTR_SIZE_BYTES]) {
+               u64 bytes = nla_get_u64(info->attrs[NBD_ATTR_SIZE_BYTES]);
+               nbd_size_set(nbd, config->blksize,
+                            div64_u64(bytes, config->blksize));
+       }
+       if (info->attrs[NBD_ATTR_BLOCK_SIZE_BYTES]) {
+               u64 bsize =
+                       nla_get_u64(info->attrs[NBD_ATTR_BLOCK_SIZE_BYTES]);
+               nbd_size_set(nbd, bsize, div64_u64(config->bytesize, bsize));
+       }
+       if (info->attrs[NBD_ATTR_TIMEOUT]) {
+               u64 timeout = nla_get_u64(info->attrs[NBD_ATTR_TIMEOUT]);
+               nbd->tag_set.timeout = timeout * HZ;
+               blk_queue_rq_timeout(nbd->disk->queue, timeout * HZ);
+       }
+       if (info->attrs[NBD_ATTR_SERVER_FLAGS])
+               config->flags =
+                       nla_get_u64(info->attrs[NBD_ATTR_SERVER_FLAGS]);
+       if (info->attrs[NBD_ATTR_SOCKETS]) {
+               struct nlattr *attr;
+               int rem, fd;
+
+               nla_for_each_nested(attr, info->attrs[NBD_ATTR_SOCKETS],
+                                   rem) {
+                       struct nlattr *socks[NBD_SOCK_MAX+1];
+
+                       if (nla_type(attr) != NBD_SOCK_ITEM) {
+                               printk(KERN_ERR "nbd: socks must be embedded in a SOCK_ITEM attr\n");
+                               ret = -EINVAL;
+                               goto out;
+                       }
+                       ret = nla_parse_nested(socks, NBD_SOCK_MAX, attr,
+                                              nbd_sock_policy);
+                       if (ret != 0) {
+                               printk(KERN_ERR "nbd: error processing sock list\n");
+                               ret = -EINVAL;
+                               goto out;
+                       }
+                       if (!socks[NBD_SOCK_FD])
+                               continue;
+                       fd = (int)nla_get_u32(socks[NBD_SOCK_FD]);
+                       ret = nbd_add_socket(nbd, fd, true);
+                       if (ret)
+                               goto out;
+               }
+       }
+       ret = nbd_start_device(nbd);
+out:
+       mutex_unlock(&nbd->config_lock);
+       if (!ret) {
+               set_bit(NBD_HAS_CONFIG_REF, &config->runtime_flags);
+               refcount_inc(&nbd->config_refs);
+               nbd_connect_reply(info, nbd->index);
+       }
+       nbd_config_put(nbd);
+       return ret;
+}
+
+static int nbd_genl_disconnect(struct sk_buff *skb, struct genl_info *info)
+{
+       struct nbd_device *nbd;
+       int index;
+
+       if (!netlink_capable(skb, CAP_SYS_ADMIN))
+               return -EPERM;
+
+       if (!info->attrs[NBD_ATTR_INDEX]) {
+               printk(KERN_ERR "nbd: must specify an index to disconnect\n");
+               return -EINVAL;
+       }
+       index = nla_get_u32(info->attrs[NBD_ATTR_INDEX]);
+       mutex_lock(&nbd_index_mutex);
+       nbd = idr_find(&nbd_index_idr, index);
+       mutex_unlock(&nbd_index_mutex);
+       if (!nbd) {
+               printk(KERN_ERR "nbd: couldn't find device at index %d\n",
+                      index);
+               return -EINVAL;
+       }
+       if (!refcount_inc_not_zero(&nbd->config_refs))
+               return 0;
+       mutex_lock(&nbd->config_lock);
+       nbd_disconnect(nbd);
+       mutex_unlock(&nbd->config_lock);
+       if (test_and_clear_bit(NBD_HAS_CONFIG_REF,
+                              &nbd->config->runtime_flags))
+               nbd_config_put(nbd);
+       nbd_config_put(nbd);
+       return 0;
+}
+
+static const struct genl_ops nbd_connect_genl_ops[] = {
+       {
+               .cmd    = NBD_CMD_CONNECT,
+               .policy = nbd_attr_policy,
+               .doit   = nbd_genl_connect,
+       },
+       {
+               .cmd    = NBD_CMD_DISCONNECT,
+               .policy = nbd_attr_policy,
+               .doit   = nbd_genl_disconnect,
+       },
+};
+
+static struct genl_family nbd_genl_family __ro_after_init = {
+       .hdrsize        = 0,
+       .name           = NBD_GENL_FAMILY_NAME,
+       .version        = NBD_GENL_VERSION,
+       .module         = THIS_MODULE,
+       .ops            = nbd_connect_genl_ops,
+       .n_ops          = ARRAY_SIZE(nbd_connect_genl_ops),
+       .maxattr        = NBD_ATTR_MAX,
+};
+
+static void nbd_connect_reply(struct genl_info *info, int index)
+{
+       struct sk_buff *skb;
+       void *msg_head;
+       int ret;
+
+       skb = genlmsg_new(nla_total_size(sizeof(u32)), GFP_KERNEL);
+       if (!skb)
+               return;
+       msg_head = genlmsg_put_reply(skb, info, &nbd_genl_family, 0,
+                                    NBD_CMD_CONNECT);
+       if (!msg_head) {
+               nlmsg_free(skb);
+               return;
+       }
+       ret = nla_put_u32(skb, NBD_ATTR_INDEX, index);
+       if (ret) {
+               nlmsg_free(skb);
+               return;
+       }
+       genlmsg_end(skb, msg_head);
+       genlmsg_reply(skb, info);
+}
 
 static int __init nbd_init(void)
 {
                return -EIO;
        }
 
+       if (genl_register_family(&nbd_genl_family)) {
+               unregister_blkdev(NBD_MAJOR, "nbd");
+               destroy_workqueue(recv_workqueue);
+               return -EINVAL;
+       }
        nbd_dbg_init();
 
        mutex_lock(&nbd_index_mutex);
 
        idr_for_each(&nbd_index_idr, &nbd_exit_cb, NULL);
        idr_destroy(&nbd_index_idr);
+       genl_unregister_family(&nbd_genl_family);
        destroy_workqueue(recv_workqueue);
        unregister_blkdev(NBD_MAJOR, "nbd");
 }