*/
 static void sock_shutdown(struct nbd_device *nbd)
 {
+       struct socket *sock;
+
        spin_lock_irq(&nbd->sock_lock);
 
        if (!nbd->sock) {
                return;
        }
 
+       sock = nbd->sock;
        dev_warn(disk_to_dev(nbd->disk), "shutting down socket\n");
-       kernel_sock_shutdown(nbd->sock, SHUT_RDWR);
-       sockfd_put(nbd->sock);
        nbd->sock = NULL;
        spin_unlock_irq(&nbd->sock_lock);
 
+       kernel_sock_shutdown(sock, SHUT_RDWR);
+       sockfd_put(sock);
+
        del_timer(&nbd->timeout_timer);
 }
 
 static void nbd_xmit_timeout(unsigned long arg)
 {
        struct nbd_device *nbd = (struct nbd_device *)arg;
+       struct socket *sock = NULL;
        unsigned long flags;
 
        if (!atomic_read(&nbd->outstanding_cmds))
 
        nbd->timedout = true;
 
-       if (nbd->sock)
-               kernel_sock_shutdown(nbd->sock, SHUT_RDWR);
+       if (nbd->sock) {
+               sock = nbd->sock;
+               get_file(sock->file);
+       }
 
        spin_unlock_irqrestore(&nbd->sock_lock, flags);
+       if (sock) {
+               kernel_sock_shutdown(sock, SHUT_RDWR);
+               sockfd_put(sock);
+       }
 
        dev_err(nbd_to_dev(nbd), "Connection timed out, shutting down connection\n");
 }