/*
  * Forcibly shutdown the socket causing all listeners to error
  */
-static void sock_shutdown(struct nbd_device *nbd, int lock)
+static void sock_shutdown(struct nbd_device *nbd)
 {
-       if (lock)
-               mutex_lock(&nbd->tx_lock);
        if (nbd->sock) {
                dev_warn(disk_to_dev(nbd->disk), "shutting down socket\n");
                kernel_sock_shutdown(nbd->sock, SHUT_RDWR);
                nbd->sock = NULL;
                del_timer_sync(&nbd->timeout_timer);
        }
-       if (lock)
-               mutex_unlock(&nbd->tx_lock);
 }
 
 static void nbd_xmit_timeout(unsigned long arg)
                ret = dequeue_signal_lock(current, ¤t->blocked, &info);
                dev_warn(nbd_to_dev(nbd), "pid %d, %s, got signal %d\n",
                         task_pid_nr(current), current->comm, ret);
-               sock_shutdown(nbd, 1);
+               mutex_lock(&nbd->tx_lock);
+               sock_shutdown(nbd);
+               mutex_unlock(&nbd->tx_lock);
                ret = -ETIMEDOUT;
        }
 
                                                  &info);
                        dev_warn(nbd_to_dev(nbd), "pid %d, %s, got signal %d\n",
                                 task_pid_nr(current), current->comm, ret);
-                       sock_shutdown(nbd, 1);
+                       mutex_lock(&nbd->tx_lock);
+                       sock_shutdown(nbd);
+                       mutex_unlock(&nbd->tx_lock);
                        break;
                }
 
                mutex_lock(&nbd->tx_lock);
                if (error)
                        return error;
-               sock_shutdown(nbd, 0);
+               sock_shutdown(nbd);
                sock = nbd->sock;
                nbd->sock = NULL;
                nbd_clear_que(nbd);