unsigned long flags;
 
        /* cache cold stuff */
+       struct proto *sk_proto;
+
        void (*sk_destruct)(struct sock *sk);
        void (*sk_proto_close)(struct sock *sk, long timeout);
 
 
        struct list_head list;
        refcount_t refcount;
+
+       struct work_struct gc;
 };
 
 enum tls_offload_ctx_dir {
 int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size);
 int tls_sw_sendpage(struct sock *sk, struct page *page,
                    int offset, size_t size, int flags);
-void tls_sw_close(struct sock *sk, long timeout);
 void tls_sw_cancel_work_tx(struct tls_context *tls_ctx);
 void tls_sw_release_resources_tx(struct sock *sk);
 void tls_sw_free_ctx_tx(struct tls_context *tls_ctx);
 
        kfree(ctx);
 }
 
+static void tls_ctx_free_deferred(struct work_struct *gc)
+{
+       struct tls_context *ctx = container_of(gc, struct tls_context, gc);
+
+       /* Ensure any remaining work items are completed. The sk will
+        * already have lost its tls_ctx reference by the time we get
+        * here so no xmit operation will actually be performed.
+        */
+       if (ctx->tx_conf == TLS_SW) {
+               tls_sw_cancel_work_tx(ctx);
+               tls_sw_free_ctx_tx(ctx);
+       }
+
+       if (ctx->rx_conf == TLS_SW) {
+               tls_sw_strparser_done(ctx);
+               tls_sw_free_ctx_rx(ctx);
+       }
+
+       tls_ctx_free(ctx);
+}
+
+static void tls_ctx_free_wq(struct tls_context *ctx)
+{
+       INIT_WORK(&ctx->gc, tls_ctx_free_deferred);
+       schedule_work(&ctx->gc);
+}
+
 static void tls_sk_proto_cleanup(struct sock *sk,
                                 struct tls_context *ctx, long timeo)
 {
 #endif
 }
 
+static void tls_sk_proto_unhash(struct sock *sk)
+{
+       struct inet_connection_sock *icsk = inet_csk(sk);
+       long timeo = sock_sndtimeo(sk, 0);
+       struct tls_context *ctx;
+
+       if (unlikely(!icsk->icsk_ulp_data)) {
+               if (sk->sk_prot->unhash)
+                       sk->sk_prot->unhash(sk);
+       }
+
+       ctx = tls_get_ctx(sk);
+       tls_sk_proto_cleanup(sk, ctx, timeo);
+       icsk->icsk_ulp_data = NULL;
+
+       if (ctx->sk_proto->unhash)
+               ctx->sk_proto->unhash(sk);
+       tls_ctx_free_wq(ctx);
+}
+
 static void tls_sk_proto_close(struct sock *sk, long timeout)
 {
        void (*sk_proto_close)(struct sock *sk, long timeout);
        if (ctx->tx_conf != TLS_BASE || ctx->rx_conf != TLS_BASE)
                tls_sk_proto_cleanup(sk, ctx, timeo);
 
+       sk->sk_prot = ctx->sk_proto;
        release_sock(sk);
        if (ctx->tx_conf == TLS_SW)
                tls_sw_free_ctx_tx(ctx);
        ctx->setsockopt = sk->sk_prot->setsockopt;
        ctx->getsockopt = sk->sk_prot->getsockopt;
        ctx->sk_proto_close = sk->sk_prot->close;
+       ctx->unhash = sk->sk_prot->unhash;
        return ctx;
 }
 
        prot[TLS_BASE][TLS_BASE].setsockopt     = tls_setsockopt;
        prot[TLS_BASE][TLS_BASE].getsockopt     = tls_getsockopt;
        prot[TLS_BASE][TLS_BASE].close          = tls_sk_proto_close;
+       prot[TLS_BASE][TLS_BASE].unhash         = tls_sk_proto_unhash;
 
        prot[TLS_SW][TLS_BASE] = prot[TLS_BASE][TLS_BASE];
        prot[TLS_SW][TLS_BASE].sendmsg          = tls_sw_sendmsg;
 
 #ifdef CONFIG_TLS_DEVICE
        prot[TLS_HW][TLS_BASE] = prot[TLS_BASE][TLS_BASE];
+       prot[TLS_HW][TLS_BASE].unhash           = base->unhash;
        prot[TLS_HW][TLS_BASE].sendmsg          = tls_device_sendmsg;
        prot[TLS_HW][TLS_BASE].sendpage         = tls_device_sendpage;
 
        prot[TLS_HW][TLS_SW] = prot[TLS_BASE][TLS_SW];
+       prot[TLS_HW][TLS_SW].unhash             = base->unhash;
        prot[TLS_HW][TLS_SW].sendmsg            = tls_device_sendmsg;
        prot[TLS_HW][TLS_SW].sendpage           = tls_device_sendpage;
 
        prot[TLS_BASE][TLS_HW] = prot[TLS_BASE][TLS_SW];
+       prot[TLS_BASE][TLS_HW].unhash           = base->unhash;
 
        prot[TLS_SW][TLS_HW] = prot[TLS_SW][TLS_SW];
+       prot[TLS_SW][TLS_HW].unhash             = base->unhash;
 
        prot[TLS_HW][TLS_HW] = prot[TLS_HW][TLS_SW];
 #endif
        tls_build_proto(sk);
        ctx->tx_conf = TLS_BASE;
        ctx->rx_conf = TLS_BASE;
+       ctx->sk_proto = sk->sk_prot;
        update_sk_prot(sk, ctx);
 out:
        return rc;