/* Per driver */
 struct blkvsc_driver_context {
-       /* !! These must be the first 2 fields !! */
-       /* FIXME this is a bug! */
-       struct driver_context drv_ctx;
        struct storvsc_driver_object drv_obj;
 };
 
 static int blkvsc_drv_init(int (*drv_init)(struct hv_driver *drv))
 {
        struct storvsc_driver_object *storvsc_drv_obj = &g_blkvsc_drv.drv_obj;
-       struct driver_context *drv_ctx = &g_blkvsc_drv.drv_ctx;
+       struct hv_driver *drv = &g_blkvsc_drv.drv_obj.base;
        int ret;
 
        storvsc_drv_obj->ring_buffer_size = blkvsc_ringbuffer_size;
 
+       drv->priv = storvsc_drv_obj;
+
        /* Callback to client driver to complete the initialization */
        drv_init(&storvsc_drv_obj->base);
 
-       drv_ctx->driver.name = storvsc_drv_obj->base.name;
-       memcpy(&drv_ctx->class_id, &storvsc_drv_obj->base.dev_type,
-              sizeof(struct hv_guid));
+       drv->driver.name = storvsc_drv_obj->base.name;
 
-       drv_ctx->driver.probe = blkvsc_probe;
-       drv_ctx->driver.remove = blkvsc_remove;
-       drv_ctx->driver.shutdown = blkvsc_shutdown;
+       drv->driver.probe = blkvsc_probe;
+       drv->driver.remove = blkvsc_remove;
+       drv->driver.shutdown = blkvsc_shutdown;
 
        /* The driver belongs to vmbus */
-       ret = vmbus_child_driver_register(&drv_ctx->driver);
+       ret = vmbus_child_driver_register(&drv->driver);
 
        return ret;
 }
 static void blkvsc_drv_exit(void)
 {
        struct storvsc_driver_object *storvsc_drv_obj = &g_blkvsc_drv.drv_obj;
-       struct driver_context *drv_ctx = &g_blkvsc_drv.drv_ctx;
+       struct hv_driver *drv = &g_blkvsc_drv.drv_obj.base;
        struct device *current_dev;
        int ret;
 
                current_dev = NULL;
 
                /* Get the device */
-               ret = driver_for_each_device(&drv_ctx->driver, NULL,
+               ret = driver_for_each_device(&drv->driver, NULL,
                                             (void *) ¤t_dev,
                                             blkvsc_drv_exit_cb);
 
        if (storvsc_drv_obj->base.cleanup)
                storvsc_drv_obj->base.cleanup(&storvsc_drv_obj->base);
 
-       vmbus_child_driver_unregister(&drv_ctx->driver);
+       vmbus_child_driver_unregister(&drv->driver);
 
        return;
 }
  */
 static int blkvsc_probe(struct device *device)
 {
-       struct driver_context *driver_ctx =
-                               driver_to_driver_context(device->driver);
+       struct hv_driver *drv =
+                               drv_to_hv_drv(device->driver);
        struct blkvsc_driver_context *blkvsc_drv_ctx =
-                               (struct blkvsc_driver_context *)driver_ctx;
+                               (struct blkvsc_driver_context *)drv->priv;
        struct storvsc_driver_object *storvsc_drv_obj =
                                &blkvsc_drv_ctx->drv_obj;
        struct vm_device *device_ctx = device_to_vm_device(device);
  */
 static int blkvsc_remove(struct device *device)
 {
-       struct driver_context *driver_ctx =
-                               driver_to_driver_context(device->driver);
+       struct hv_driver *drv =
+                               drv_to_hv_drv(device->driver);
        struct blkvsc_driver_context *blkvsc_drv_ctx =
-                               (struct blkvsc_driver_context *)driver_ctx;
+                               (struct blkvsc_driver_context *)drv->priv;
        struct storvsc_driver_object *storvsc_drv_obj =
                                &blkvsc_drv_ctx->drv_obj;
        struct vm_device *device_ctx = device_to_vm_device(device);
 {
        struct block_device_context *blkdev = blkvsc_req->dev;
        struct vm_device *device_ctx = blkdev->device_ctx;
-       struct driver_context *driver_ctx =
-                       driver_to_driver_context(device_ctx->device.driver);
+       struct hv_driver *drv =
+                       drv_to_hv_drv(device_ctx->device.driver);
        struct blkvsc_driver_context *blkvsc_drv_ctx =
-                       (struct blkvsc_driver_context *)driver_ctx;
+                       (struct blkvsc_driver_context *)drv->priv;
        struct storvsc_driver_object *storvsc_drv_obj =
                        &blkvsc_drv_ctx->drv_obj;
        struct hv_storvsc_request *storvsc_req;
 
 };
 
 struct mousevsc_driver_context {
-       struct driver_context   drv_ctx;
        struct mousevsc_drv_obj drv_obj;
 };
 
 {
        int ret = 0;
 
-       struct driver_context *driver_ctx =
-               driver_to_driver_context(device->driver);
+       struct hv_driver *drv =
+               drv_to_hv_drv(device->driver);
        struct mousevsc_driver_context *mousevsc_drv_ctx =
-               (struct mousevsc_driver_context *)driver_ctx;
+               (struct mousevsc_driver_context *)drv->priv;
        struct mousevsc_drv_obj *mousevsc_drv_obj = &mousevsc_drv_ctx->drv_obj;
 
        struct vm_device *device_ctx = device_to_vm_device(device);
 {
        int ret = 0;
 
-       struct driver_context *driver_ctx =
-               driver_to_driver_context(device->driver);
+       struct hv_driver *drv =
+               drv_to_hv_drv(device->driver);
        struct mousevsc_driver_context *mousevsc_drv_ctx =
-               (struct mousevsc_driver_context *)driver_ctx;
+               (struct mousevsc_driver_context *)drv->priv;
        struct mousevsc_drv_obj *mousevsc_drv_obj = &mousevsc_drv_ctx->drv_obj;
 
        struct vm_device *device_ctx = device_to_vm_device(device);
 static void mousevsc_drv_exit(void)
 {
        struct mousevsc_drv_obj *mousevsc_drv_obj = &g_mousevsc_drv.drv_obj;
-       struct driver_context *drv_ctx = &g_mousevsc_drv.drv_ctx;
+       struct hv_driver *drv = &g_mousevsc_drv.drv_obj.Base;
        int ret;
 
        struct device *current_dev = NULL;
                current_dev = NULL;
 
                /* Get the device */
-               ret = driver_for_each_device(&drv_ctx->driver, NULL,
+               ret = driver_for_each_device(&drv->driver, NULL,
                                             (void *)¤t_dev,
                                             mousevsc_drv_exit_cb);
                if (ret)
        if (mousevsc_drv_obj->Base.cleanup)
                mousevsc_drv_obj->Base.cleanup(&mousevsc_drv_obj->Base);
 
-       vmbus_child_driver_unregister(&drv_ctx->driver);
+       vmbus_child_driver_unregister(&drv->driver);
 
        return;
 }
 static int __init mousevsc_init(void)
 {
        struct mousevsc_drv_obj *input_drv_obj = &g_mousevsc_drv.drv_obj;
-       struct driver_context *drv_ctx = &g_mousevsc_drv.drv_ctx;
+       struct hv_driver *drv = &g_mousevsc_drv.drv_obj.Base;
 
        DPRINT_INFO(INPUTVSC_DRV, "Hyper-V Mouse driver initializing.");
 
        /* Callback to client driver to complete the initialization */
        mouse_vsc_initialize(&input_drv_obj->Base);
 
-       drv_ctx->driver.name = input_drv_obj->Base.name;
-       memcpy(&drv_ctx->class_id, &input_drv_obj->Base.dev_type,
-              sizeof(struct hv_guid));
+       drv->driver.name = input_drv_obj->Base.name;
+       drv->priv = input_drv_obj;
 
-       drv_ctx->driver.probe = mousevsc_probe;
-       drv_ctx->driver.remove = mousevsc_remove;
+       drv->driver.probe = mousevsc_probe;
+       drv->driver.remove = mousevsc_remove;
 
        /* The driver belongs to vmbus */
-       vmbus_child_driver_register(&drv_ctx->driver);
+       vmbus_child_driver_register(&drv->driver);
 
        return 0;
 }
 
 };
 
 struct netvsc_driver_context {
-       /* !! These must be the first 2 fields !! */
-       /* Which is a bug FIXME! */
-       struct driver_context drv_ctx;
        struct netvsc_driver drv_obj;
 };
 
 static int netvsc_start_xmit(struct sk_buff *skb, struct net_device *net)
 {
        struct net_device_context *net_device_ctx = netdev_priv(net);
-       struct driver_context *driver_ctx =
-           driver_to_driver_context(net_device_ctx->device_ctx->device.driver);
+       struct hv_driver *drv =
+           drv_to_hv_drv(net_device_ctx->device_ctx->device.driver);
        struct netvsc_driver_context *net_drv_ctx =
-               (struct netvsc_driver_context *)driver_ctx;
+               (struct netvsc_driver_context *)drv->priv;
        struct netvsc_driver *net_drv_obj = &net_drv_ctx->drv_obj;
        struct hv_netvsc_packet *packet;
        int ret;
 
 static int netvsc_probe(struct device *device)
 {
-       struct driver_context *driver_ctx =
-               driver_to_driver_context(device->driver);
+       struct hv_driver *drv =
+               drv_to_hv_drv(device->driver);
        struct netvsc_driver_context *net_drv_ctx =
-               (struct netvsc_driver_context *)driver_ctx;
+               (struct netvsc_driver_context *)drv->priv;
        struct netvsc_driver *net_drv_obj = &net_drv_ctx->drv_obj;
        struct vm_device *device_ctx = device_to_vm_device(device);
        struct hv_device *device_obj = &device_ctx->device_obj;
 
 static int netvsc_remove(struct device *device)
 {
-       struct driver_context *driver_ctx =
-               driver_to_driver_context(device->driver);
+       struct hv_driver *drv =
+               drv_to_hv_drv(device->driver);
        struct netvsc_driver_context *net_drv_ctx =
-               (struct netvsc_driver_context *)driver_ctx;
+               (struct netvsc_driver_context *)drv->priv;
        struct netvsc_driver *net_drv_obj = &net_drv_ctx->drv_obj;
        struct vm_device *device_ctx = device_to_vm_device(device);
        struct net_device *net = dev_get_drvdata(&device_ctx->device);
 static void netvsc_drv_exit(void)
 {
        struct netvsc_driver *netvsc_drv_obj = &g_netvsc_drv.drv_obj;
-       struct driver_context *drv_ctx = &g_netvsc_drv.drv_ctx;
+       struct hv_driver *drv = &g_netvsc_drv.drv_obj.base;
        struct device *current_dev;
        int ret;
 
                current_dev = NULL;
 
                /* Get the device */
-               ret = driver_for_each_device(&drv_ctx->driver, NULL,
+               ret = driver_for_each_device(&drv->driver, NULL,
                                             ¤t_dev, netvsc_drv_exit_cb);
                if (ret)
                        DPRINT_WARN(NETVSC_DRV,
        if (netvsc_drv_obj->base.cleanup)
                netvsc_drv_obj->base.cleanup(&netvsc_drv_obj->base);
 
-       vmbus_child_driver_unregister(&drv_ctx->driver);
+       vmbus_child_driver_unregister(&drv->driver);
 
        return;
 }
 static int netvsc_drv_init(int (*drv_init)(struct hv_driver *drv))
 {
        struct netvsc_driver *net_drv_obj = &g_netvsc_drv.drv_obj;
-       struct driver_context *drv_ctx = &g_netvsc_drv.drv_ctx;
+       struct hv_driver *drv = &g_netvsc_drv.drv_obj.base;
        int ret;
 
        net_drv_obj->ring_buf_size = ring_size * PAGE_SIZE;
        net_drv_obj->recv_cb = netvsc_recv_callback;
        net_drv_obj->link_status_change = netvsc_linkstatus_callback;
+       drv->priv = net_drv_obj;
 
        /* Callback to client driver to complete the initialization */
        drv_init(&net_drv_obj->base);
 
-       drv_ctx->driver.name = net_drv_obj->base.name;
-       memcpy(&drv_ctx->class_id, &net_drv_obj->base.dev_type,
-              sizeof(struct hv_guid));
+       drv->driver.name = net_drv_obj->base.name;
 
-       drv_ctx->driver.probe = netvsc_probe;
-       drv_ctx->driver.remove = netvsc_remove;
+       drv->driver.probe = netvsc_probe;
+       drv->driver.remove = netvsc_remove;
 
        /* The driver belongs to vmbus */
-       ret = vmbus_child_driver_register(&drv_ctx->driver);
+       ret = vmbus_child_driver_register(&drv->driver);
 
        return ret;
 }
 
 };
 
 struct storvsc_driver_context {
-       /* !! These must be the first 2 fields !! */
-       /* FIXME this is a bug... */
-       struct driver_context drv_ctx;
        struct storvsc_driver_object drv_obj;
 };
 
 {
        int ret;
        struct storvsc_driver_object *storvsc_drv_obj = &g_storvsc_drv.drv_obj;
-       struct driver_context *drv_ctx = &g_storvsc_drv.drv_ctx;
+       struct hv_driver *drv = &g_storvsc_drv.drv_obj.base;
 
        storvsc_drv_obj->ring_buffer_size = storvsc_ringbuffer_size;
 
        /* Callback to client driver to complete the initialization */
        drv_init(&storvsc_drv_obj->base);
 
+       drv->priv = storvsc_drv_obj;
+
        DPRINT_INFO(STORVSC_DRV,
                    "request extension size %u, max outstanding reqs %u",
                    storvsc_drv_obj->request_ext_size,
                return -1;
        }
 
-       drv_ctx->driver.name = storvsc_drv_obj->base.name;
-       memcpy(&drv_ctx->class_id, &storvsc_drv_obj->base.dev_type,
-              sizeof(struct hv_guid));
+       drv->driver.name = storvsc_drv_obj->base.name;
 
-       drv_ctx->driver.probe = storvsc_probe;
-       drv_ctx->driver.remove = storvsc_remove;
+       drv->driver.probe = storvsc_probe;
+       drv->driver.remove = storvsc_remove;
 
        /* The driver belongs to vmbus */
-       ret = vmbus_child_driver_register(&drv_ctx->driver);
+       ret = vmbus_child_driver_register(&drv->driver);
 
        return ret;
 }
 static void storvsc_drv_exit(void)
 {
        struct storvsc_driver_object *storvsc_drv_obj = &g_storvsc_drv.drv_obj;
-       struct driver_context *drv_ctx = &g_storvsc_drv.drv_ctx;
+       struct hv_driver *drv = &g_storvsc_drv.drv_obj.base;
        struct device *current_dev = NULL;
        int ret;
 
                current_dev = NULL;
 
                /* Get the device */
-               ret = driver_for_each_device(&drv_ctx->driver, NULL,
+               ret = driver_for_each_device(&drv->driver, NULL,
                                             (void *) ¤t_dev,
                                             storvsc_drv_exit_cb);
 
        if (storvsc_drv_obj->base.cleanup)
                storvsc_drv_obj->base.cleanup(&storvsc_drv_obj->base);
 
-       vmbus_child_driver_unregister(&drv_ctx->driver);
+       vmbus_child_driver_unregister(&drv->driver);
        return;
 }
 
 static int storvsc_probe(struct device *device)
 {
        int ret;
-       struct driver_context *driver_ctx =
-                               driver_to_driver_context(device->driver);
+       struct hv_driver *drv =
+                               drv_to_hv_drv(device->driver);
        struct storvsc_driver_context *storvsc_drv_ctx =
-                               (struct storvsc_driver_context *)driver_ctx;
+                               (struct storvsc_driver_context *)drv->priv;
        struct storvsc_driver_object *storvsc_drv_obj =
                                &storvsc_drv_ctx->drv_obj;
        struct vm_device *device_ctx = device_to_vm_device(device);
 static int storvsc_remove(struct device *device)
 {
        int ret;
-       struct driver_context *driver_ctx =
-                       driver_to_driver_context(device->driver);
+       struct hv_driver *drv =
+                       drv_to_hv_drv(device->driver);
        struct storvsc_driver_context *storvsc_drv_ctx =
-                       (struct storvsc_driver_context *)driver_ctx;
+                       (struct storvsc_driver_context *)drv->priv;
        struct storvsc_driver_object *storvsc_drv_obj =
                        &storvsc_drv_ctx->drv_obj;
        struct vm_device *device_ctx = device_to_vm_device(device);
        struct host_device_context *host_device_ctx =
                (struct host_device_context *)scmnd->device->host->hostdata;
        struct vm_device *device_ctx = host_device_ctx->device_ctx;
-       struct driver_context *driver_ctx =
-               driver_to_driver_context(device_ctx->device.driver);
+       struct hv_driver *drv =
+               drv_to_hv_drv(device_ctx->device.driver);
        struct storvsc_driver_context *storvsc_drv_ctx =
-               (struct storvsc_driver_context *)driver_ctx;
+               (struct storvsc_driver_context *)drv->priv;
        struct storvsc_driver_object *storvsc_drv_obj =
                &storvsc_drv_ctx->drv_obj;
        struct hv_storvsc_request *request;
 
 #include <linux/device.h>
 #include "vmbus_api.h"
 
-struct driver_context {
-       struct hv_guid class_id;
-
-       struct device_driver driver;
-
-};
 
 struct vm_device {
        struct work_struct probe_failed_work_item;
        return container_of(d, struct vm_device, device);
 }
 
-static inline struct driver_context *driver_to_driver_context(struct device_driver *d)
+static inline struct hv_driver *drv_to_hv_drv(struct device_driver *d)
 {
-       return container_of(d, struct driver_context, driver);
+       return container_of(d, struct hv_driver, driver);
 }
 
 
 
 #ifndef _VMBUS_API_H_
 #define _VMBUS_API_H_
 
+#include <linux/device.h>
+
 #define MAX_PAGE_BUFFER_COUNT                          16
 #define MAX_MULTIPAGE_BUFFER_COUNT                     32 /* 128K */
 
        /* the device type supported by this driver */
        struct hv_guid dev_type;
 
+       /*
+        * Device type specific drivers (net, blk etc.)
+        * need a mechanism to get a pointer to
+        * device type specific driver structure given
+        * a pointer to the base hyperv driver structure.
+        * The current code solves this problem using
+        * a hack. Support this need explicitly
+        */
+       void *priv;
+
+       struct device_driver driver;
+
        int (*dev_add)(struct hv_device *device, void *data);
        int (*dev_rm)(struct hv_device *device);
        void (*cleanup)(struct hv_driver *driver);
 
 
 /* Main vmbus driver data structure */
 struct vmbus_driver_context {
-       /* !! These must be the first 2 fields !! */
-       /* FIXME, this is a bug */
-       /* The driver field is not used in here. Instead, the bus field is */
-       /* used to represent the driver */
-       struct driver_context drv_ctx;
        struct hv_driver drv_obj;
 
        struct bus_type bus;
 static int vmbus_match(struct device *device, struct device_driver *driver)
 {
        int match = 0;
-       struct driver_context *driver_ctx = driver_to_driver_context(driver);
+       struct hv_driver *drv = drv_to_hv_drv(driver);
        struct vm_device *device_ctx = device_to_vm_device(device);
 
        /* We found our driver ? */
-       if (memcmp(&device_ctx->class_id, &driver_ctx->class_id,
+       if (memcmp(&device_ctx->class_id, &drv->dev_type,
                   sizeof(struct hv_guid)) == 0) {
-               /*
-                * !! NOTE: The driver_ctx is not a vmbus_drv_ctx. We typecast
-                * it here to access the struct hv_driver field
-                */
-               struct vmbus_driver_context *vmbus_drv_ctx =
-                       (struct vmbus_driver_context *)driver_ctx;
 
-               device_ctx->device_obj.drv = &vmbus_drv_ctx->drv_obj;
+               device_ctx->device_obj.drv = drv->priv;
                DPRINT_INFO(VMBUS_DRV,
                            "device object (%p) set to driver object (%p)",
                            &device_ctx->device_obj,
 static int vmbus_probe(struct device *child_device)
 {
        int ret = 0;
-       struct driver_context *driver_ctx =
-                       driver_to_driver_context(child_device->driver);
+       struct hv_driver *drv =
+                       drv_to_hv_drv(child_device->driver);
        struct vm_device *device_ctx =
                        device_to_vm_device(child_device);
 
        /* Let the specific open-source driver handles the probe if it can */
-       if (driver_ctx->driver.probe) {
+       if (drv->driver.probe) {
                ret = device_ctx->probe_error =
-               driver_ctx->driver.probe(child_device);
+               drv->driver.probe(child_device);
                if (ret != 0) {
                        DPRINT_ERR(VMBUS_DRV, "probe() failed for device %s "
                                   "(%p) on driver %s (%d)...",
 static int vmbus_remove(struct device *child_device)
 {
        int ret;
-       struct driver_context *driver_ctx;
+       struct hv_driver *drv;
 
        /* Special case root bus device */
        if (child_device->parent == NULL) {
        }
 
        if (child_device->driver) {
-               driver_ctx = driver_to_driver_context(child_device->driver);
+               drv = drv_to_hv_drv(child_device->driver);
 
                /*
                 * Let the specific open-source driver handles the removal if
                 * it can
                 */
-               if (driver_ctx->driver.remove) {
-                       ret = driver_ctx->driver.remove(child_device);
+               if (drv->driver.remove) {
+                       ret = drv->driver.remove(child_device);
                } else {
                        DPRINT_ERR(VMBUS_DRV,
                                   "remove() method not set for driver - %s",
  */
 static void vmbus_shutdown(struct device *child_device)
 {
-       struct driver_context *driver_ctx;
+       struct hv_driver *drv;
 
        /* Special case root bus device */
        if (child_device->parent == NULL) {
        if (!child_device->driver)
                return;
 
-       driver_ctx = driver_to_driver_context(child_device->driver);
+       drv = drv_to_hv_drv(child_device->driver);
 
        /* Let the specific open-source driver handles the removal if it can */
-       if (driver_ctx->driver.shutdown)
-               driver_ctx->driver.shutdown(child_device);
+       if (drv->driver.shutdown)
+               drv->driver.shutdown(child_device);
 
        return;
 }