struct mutex            sk_lock;        /* Protects .sk */
        struct sock __rcu       *sk;            /* Pointer to the session PPPoX socket */
        struct sock             *__sk;          /* Copy of .sk, for cleanup */
-       struct rcu_head         rcu;            /* For asynchronous release */
 };
 
 static int pppol2tp_xmit(struct ppp_channel *chan, struct sk_buff *skb);
        if (!sk)
                return NULL;
 
-       sock_hold(sk);
-       session = (struct l2tp_session *)(sk->sk_user_data);
-       if (!session) {
-               sock_put(sk);
-               goto out;
-       }
-       if (WARN_ON(session->magic != L2TP_SESSION_MAGIC)) {
-               session = NULL;
-               sock_put(sk);
-               goto out;
+       rcu_read_lock();
+       session = rcu_dereference_sk_user_data(sk);
+       if (session && refcount_inc_not_zero(&session->ref_count)) {
+               rcu_read_unlock();
+               WARN_ON_ONCE(session->magic != L2TP_SESSION_MAGIC);
+               return session;
        }
+       rcu_read_unlock();
 
-out:
-       return session;
+       return NULL;
 }
 
 /*****************************************************************************
        l2tp_xmit_skb(session, skb);
        local_bh_enable();
 
-       sock_put(sk);
+       l2tp_session_dec_refcount(session);
 
        return total_len;
 
 error_put_sess:
-       sock_put(sk);
+       l2tp_session_dec_refcount(session);
 error:
        return error;
 }
        l2tp_xmit_skb(session, skb);
        local_bh_enable();
 
-       sock_put(sk);
+       l2tp_session_dec_refcount(session);
 
        return 1;
 
 abort_put_sess:
-       sock_put(sk);
+       l2tp_session_dec_refcount(session);
 abort:
        /* Free the original skb */
        kfree_skb(skb);
  * Session (and tunnel control) socket create/destroy.
  *****************************************************************************/
 
-static void pppol2tp_put_sk(struct rcu_head *head)
-{
-       struct pppol2tp_session *ps;
-
-       ps = container_of(head, typeof(*ps), rcu);
-       sock_put(ps->__sk);
-}
-
 /* Really kill the session socket. (Called from sock_put() if
  * refcnt == 0.)
  */
 static void pppol2tp_session_destruct(struct sock *sk)
 {
-       struct l2tp_session *session = sk->sk_user_data;
-
        skb_queue_purge(&sk->sk_receive_queue);
        skb_queue_purge(&sk->sk_write_queue);
+}
 
-       if (session) {
-               sk->sk_user_data = NULL;
-               if (WARN_ON(session->magic != L2TP_SESSION_MAGIC))
-                       return;
+static void pppol2tp_session_close(struct l2tp_session *session)
+{
+       struct pppol2tp_session *ps;
+
+       ps = l2tp_session_priv(session);
+       mutex_lock(&ps->sk_lock);
+       ps->__sk = rcu_dereference_protected(ps->sk,
+                                            lockdep_is_held(&ps->sk_lock));
+       RCU_INIT_POINTER(ps->sk, NULL);
+       mutex_unlock(&ps->sk_lock);
+       if (ps->__sk) {
+               /* detach socket */
+               rcu_assign_sk_user_data(ps->__sk, NULL);
+               sock_put(ps->__sk);
+
+               /* drop ref taken when we referenced socket via sk_user_data */
                l2tp_session_dec_refcount(session);
        }
 }
 
        session = pppol2tp_sock_to_session(sk);
        if (session) {
-               struct pppol2tp_session *ps;
-
                l2tp_session_delete(session);
-
-               ps = l2tp_session_priv(session);
-               mutex_lock(&ps->sk_lock);
-               ps->__sk = rcu_dereference_protected(ps->sk,
-                                                    lockdep_is_held(&ps->sk_lock));
-               RCU_INIT_POINTER(ps->sk, NULL);
-               mutex_unlock(&ps->sk_lock);
-               call_rcu(&ps->rcu, pppol2tp_put_sk);
-
-               /* Rely on the sock_put() call at the end of the function for
-                * dropping the reference held by pppol2tp_sock_to_session().
-                * The last reference will be dropped by pppol2tp_put_sk().
-                */
+               /* drop ref taken by pppol2tp_sock_to_session */
+               l2tp_session_dec_refcount(session);
        }
 
        release_sock(sk);
 
-       /* This will delete the session context via
-        * pppol2tp_session_destruct() if the socket's refcnt drops to
-        * zero.
-        */
        sock_put(sk);
 
        return 0;
                goto out;
 
        sock_init_data(sock, sk);
+       sock_set_flag(sk, SOCK_RCU_FREE);
 
        sock->state  = SS_UNCONNECTED;
        sock->ops    = &pppol2tp_ops;
        struct pppol2tp_session *ps;
 
        session->recv_skb = pppol2tp_recv;
+       session->session_close = pppol2tp_session_close;
        if (IS_ENABLED(CONFIG_L2TP_DEBUGFS))
                session->show = pppol2tp_show;
 
 
 out_no_ppp:
        /* This is how we get the session context from the socket. */
-       sk->sk_user_data = session;
+       sock_hold(sk);
+       rcu_assign_sk_user_data(sk, session);
        rcu_assign_pointer(ps->sk, sk);
        mutex_unlock(&ps->sk_lock);
 
        /* Keep the reference we've grabbed on the session: sk doesn't expect
-        * the session to disappear. pppol2tp_session_destruct() is responsible
+        * the session to disappear. pppol2tp_session_close() is responsible
         * for dropping it.
         */
        drop_refcnt = false;
 
        error = len;
 
-       sock_put(sk);
+       l2tp_session_dec_refcount(session);
 end:
        return error;
 }
                err = pppol2tp_session_setsockopt(sk, session, optname, val);
        }
 
-       sock_put(sk);
+       l2tp_session_dec_refcount(session);
 end:
        return err;
 }
        err = 0;
 
 end_put_sess:
-       sock_put(sk);
+       l2tp_session_dec_refcount(session);
 end:
        return err;
 }