void __iomem *metric_tbl_addr;
        struct semaphore hsmp_sem;
        char name[HSMP_ATTR_GRP_NAME_SIZE];
+       struct pci_dev *root;
        u16 sock_ind;
 };
 
 
 static struct hsmp_plat_device plat_dev;
 
-static int amd_hsmp_rdwr(struct pci_dev *root, u32 address,
+static int amd_hsmp_rdwr(struct hsmp_socket *sock, u32 address,
                         u32 *value, bool write)
 {
        int ret;
 
-       ret = pci_write_config_dword(root, HSMP_INDEX_REG, address);
+       if (!sock->root)
+               return -ENODEV;
+
+       ret = pci_write_config_dword(sock->root, HSMP_INDEX_REG, address);
        if (ret)
                return ret;
 
-       ret = (write ? pci_write_config_dword(root, HSMP_DATA_REG, *value)
-                    : pci_read_config_dword(root, HSMP_DATA_REG, value));
+       ret = (write ? pci_write_config_dword(sock->root, HSMP_DATA_REG, *value)
+                    : pci_read_config_dword(sock->root, HSMP_DATA_REG, value));
 
        return ret;
 }
  * Returns 0 for success and populates the requested number of arguments.
  * Returns a negative error code for failure.
  */
-static int __hsmp_send_message(struct pci_dev *root, struct hsmp_message *msg)
+static int __hsmp_send_message(struct hsmp_socket *sock, struct hsmp_message *msg)
 {
        unsigned long timeout, short_sleep;
        u32 mbox_status;
 
        /* Clear the status register */
        mbox_status = HSMP_STATUS_NOT_READY;
-       ret = amd_hsmp_rdwr(root, SMN_HSMP_MSG_RESP, &mbox_status, HSMP_WR);
+       ret = amd_hsmp_rdwr(sock, SMN_HSMP_MSG_RESP, &mbox_status, HSMP_WR);
        if (ret) {
                pr_err("Error %d clearing mailbox status register\n", ret);
                return ret;
        index = 0;
        /* Write any message arguments */
        while (index < msg->num_args) {
-               ret = amd_hsmp_rdwr(root, SMN_HSMP_MSG_DATA + (index << 2),
+               ret = amd_hsmp_rdwr(sock, SMN_HSMP_MSG_DATA + (index << 2),
                                    &msg->args[index], HSMP_WR);
                if (ret) {
                        pr_err("Error %d writing message argument %d\n", ret, index);
        }
 
        /* Write the message ID which starts the operation */
-       ret = amd_hsmp_rdwr(root, SMN_HSMP_MSG_ID, &msg->msg_id, HSMP_WR);
+       ret = amd_hsmp_rdwr(sock, SMN_HSMP_MSG_ID, &msg->msg_id, HSMP_WR);
        if (ret) {
                pr_err("Error %d writing message ID %u\n", ret, msg->msg_id);
                return ret;
        timeout = jiffies + msecs_to_jiffies(HSMP_MSG_TIMEOUT);
 
        while (time_before(jiffies, timeout)) {
-               ret = amd_hsmp_rdwr(root, SMN_HSMP_MSG_RESP, &mbox_status, HSMP_RD);
+               ret = amd_hsmp_rdwr(sock, SMN_HSMP_MSG_RESP, &mbox_status, HSMP_RD);
                if (ret) {
                        pr_err("Error %d reading mailbox status\n", ret);
                        return ret;
         */
        index = 0;
        while (index < msg->response_sz) {
-               ret = amd_hsmp_rdwr(root, SMN_HSMP_MSG_DATA + (index << 2),
+               ret = amd_hsmp_rdwr(sock, SMN_HSMP_MSG_DATA + (index << 2),
                                    &msg->args[index], HSMP_RD);
                if (ret) {
                        pr_err("Error %d reading response %u for message ID:%u\n",
 
 int hsmp_send_message(struct hsmp_message *msg)
 {
-       struct hsmp_socket *sock = &plat_dev.sock[msg->sock_ind];
-       struct amd_northbridge *nb;
+       struct hsmp_socket *sock;
        int ret;
 
        if (!msg)
                return -EINVAL;
-
-       nb = node_to_amd_nb(msg->sock_ind);
-       if (!nb || !nb->root)
-               return -ENODEV;
-
        ret = validate_message(msg);
        if (ret)
                return ret;
 
+       if (!plat_dev.sock || msg->sock_ind >= plat_dev.num_sockets)
+               return -ENODEV;
+       sock = &plat_dev.sock[msg->sock_ind];
+
        /*
         * The time taken by smu operation to complete is between
         * 10us to 1ms. Sometime it may take more time.
        if (ret < 0)
                return ret;
 
-       ret = __hsmp_send_message(nb->root, msg);
+       ret = __hsmp_send_message(sock, msg);
 
        up(&sock->hsmp_sem);
 
                sema_init(&plat_dev.sock[i].hsmp_sem, 1);
                plat_dev.sock[i].sock_ind = i;
 
+               if (!node_to_amd_nb(i))
+                       return -ENODEV;
+               plat_dev.sock[i].root = node_to_amd_nb(i)->root;
+
                /* Test the hsmp interface on each socket */
                ret = hsmp_test(i, 0xDEADBEEF);
                if (ret) {