const struct tls_crypto_info *crypto_info,
                   const struct tls_cipher_desc *cipher_desc,
                   int mode);
-int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx);
+int tls_set_sw_offload(struct sock *sk, int tx);
 void tls_update_rx_zc_capable(struct tls_context *tls_ctx);
 void tls_sw_strparser_arm(struct sock *sk, struct tls_context *ctx);
 void tls_sw_strparser_done(struct tls_context *tls_ctx);
 
        context->resync_nh_reset = 1;
 
        ctx->priv_ctx_rx = context;
-       rc = tls_set_sw_offload(sk, ctx, 0);
+       rc = tls_set_sw_offload(sk, 0);
        if (rc)
                goto release_ctx;
 
 
                        TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXDEVICE);
                        TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXDEVICE);
                } else {
-                       rc = tls_set_sw_offload(sk, ctx, 1);
+                       rc = tls_set_sw_offload(sk, 1);
                        if (rc)
                                goto err_crypto_info;
                        TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXSW);
                        TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXDEVICE);
                        TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXDEVICE);
                } else {
-                       rc = tls_set_sw_offload(sk, ctx, 0);
+                       rc = tls_set_sw_offload(sk, 0);
                        if (rc)
                                goto err_crypto_info;
                        TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXSW);
 
        return 0;
 }
 
-int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
+int tls_set_sw_offload(struct sock *sk, int tx)
 {
-       struct tls_context *tls_ctx = tls_get_ctx(sk);
-       struct tls_prot_info *prot = &tls_ctx->prot_info;
-       struct tls_crypto_info *crypto_info;
        struct tls_sw_context_tx *sw_ctx_tx = NULL;
        struct tls_sw_context_rx *sw_ctx_rx = NULL;
+       const struct tls_cipher_desc *cipher_desc;
+       struct tls_crypto_info *crypto_info;
+       char *iv, *rec_seq, *key, *salt;
        struct cipher_context *cctx;
+       struct tls_prot_info *prot;
        struct crypto_aead **aead;
+       struct tls_context *ctx;
        struct crypto_tfm *tfm;
-       char *iv, *rec_seq, *key, *salt;
-       const struct tls_cipher_desc *cipher_desc;
        int rc = 0;
 
-       if (!ctx) {
-               rc = -EINVAL;
-               goto out;
-       }
+       ctx = tls_get_ctx(sk);
+       prot = &ctx->prot_info;
 
        if (tx) {
                ctx->priv_ctx_tx = init_ctx_tx(ctx, sk);