}
 EXPORT_SYMBOL_GPL(xprt_force_disconnect);
 
+static unsigned int
+xprt_connect_cookie(struct rpc_xprt *xprt)
+{
+       return READ_ONCE(xprt->connect_cookie);
+}
+
+static bool
+xprt_request_retransmit_after_disconnect(struct rpc_task *task)
+{
+       struct rpc_rqst *req = task->tk_rqstp;
+       struct rpc_xprt *xprt = req->rq_xprt;
+
+       return req->rq_connect_cookie != xprt_connect_cookie(xprt) ||
+               !xprt_connected(xprt);
+}
+
 /**
  * xprt_conditional_disconnect - force a transport to disconnect
  * @xprt: transport to disconnect
                task->tk_status = 0;
 }
 
+/**
+ * xprt_request_wait_receive - wait for the reply to an RPC request
+ * @task: RPC task about to send a request
+ *
+ */
+void xprt_request_wait_receive(struct rpc_task *task)
+{
+       struct rpc_rqst *req = task->tk_rqstp;
+       struct rpc_xprt *xprt = req->rq_xprt;
+
+       if (!test_bit(RPC_TASK_NEED_RECV, &task->tk_runstate))
+               return;
+       /*
+        * Sleep on the pending queue if we're expecting a reply.
+        * The spinlock ensures atomicity between the test of
+        * req->rq_reply_bytes_recvd, and the call to rpc_sleep_on().
+        */
+       spin_lock(&xprt->queue_lock);
+       if (test_bit(RPC_TASK_NEED_RECV, &task->tk_runstate)) {
+               xprt->ops->set_retrans_timeout(task);
+               rpc_sleep_on(&xprt->pending, task, xprt_timer);
+               /*
+                * Send an extra queue wakeup call if the
+                * connection was dropped in case the call to
+                * rpc_sleep_on() raced.
+                */
+               if (xprt_request_retransmit_after_disconnect(task))
+                       rpc_wake_up_queued_task_set_status(&xprt->pending,
+                                       task, -ENOTCONN);
+       }
+       spin_unlock(&xprt->queue_lock);
+}
+
 /**
  * xprt_prepare_transmit - reserve the transport before sending a request
  * @task: RPC task about to send a request
                        task->tk_status = req->rq_reply_bytes_recvd;
                        goto out_unlock;
                }
-               if ((task->tk_flags & RPC_TASK_NO_RETRANS_TIMEOUT)
-                   && xprt_connected(xprt)
-                   && req->rq_connect_cookie == xprt->connect_cookie) {
+               if ((task->tk_flags & RPC_TASK_NO_RETRANS_TIMEOUT) &&
+                   !xprt_request_retransmit_after_disconnect(task)) {
                        xprt->ops->set_retrans_timeout(task);
                        rpc_sleep_on(&xprt->pending, task, xprt_timer);
                        goto out_unlock;
        task->tk_flags |= RPC_TASK_SENT;
        spin_lock_bh(&xprt->transport_lock);
 
-       xprt->ops->set_retrans_timeout(task);
-
        xprt->stat.sends++;
        xprt->stat.req_u += xprt->stat.sends - xprt->stat.recvs;
        xprt->stat.bklog_u += xprt->backlog.qlen;
        spin_unlock_bh(&xprt->transport_lock);
 
        req->rq_connect_cookie = connect_cookie;
-       if (test_bit(RPC_TASK_NEED_RECV, &task->tk_runstate)) {
-               /*
-                * Sleep on the pending queue if we're expecting a reply.
-                * The spinlock ensures atomicity between the test of
-                * req->rq_reply_bytes_recvd, and the call to rpc_sleep_on().
-                */
-               spin_lock(&xprt->queue_lock);
-               if (test_bit(RPC_TASK_NEED_RECV, &task->tk_runstate)) {
-                       rpc_sleep_on(&xprt->pending, task, xprt_timer);
-                       /* Wake up immediately if the connection was dropped */
-                       if (!xprt_connected(xprt))
-                               rpc_wake_up_queued_task_set_status(&xprt->pending,
-                                               task, -ENOTCONN);
-               }
-               spin_unlock(&xprt->queue_lock);
-       }
 }
 
 static void xprt_add_backlog(struct rpc_xprt *xprt, struct rpc_task *task)
        req->rq_xprt    = xprt;
        req->rq_buffer  = NULL;
        req->rq_xid     = xprt_alloc_xid(xprt);
-       req->rq_connect_cookie = xprt->connect_cookie - 1;
+       req->rq_connect_cookie = xprt_connect_cookie(xprt) - 1;
        req->rq_bytes_sent = 0;
        req->rq_snd_buf.len = 0;
        req->rq_snd_buf.buflen = 0;