#include <crypto/aes.h>
 #include <crypto/skcipher.h>
+#include <crypto/internal/skcipher.h>
 
 #include "safexcel.h"
 
        unsigned int key_len;
 };
 
+struct safexcel_cipher_req {
+       bool needs_inv;
+};
+
 static void safexcel_cipher_token(struct safexcel_cipher_ctx *ctx,
                                  struct crypto_async_request *async,
                                  struct safexcel_command_desc *cdesc,
        return 0;
 }
 
-static int safexcel_handle_result(struct safexcel_crypto_priv *priv, int ring,
-                                 struct crypto_async_request *async,
-                                 bool *should_complete, int *ret)
+static int safexcel_handle_req_result(struct safexcel_crypto_priv *priv, int ring,
+                                     struct crypto_async_request *async,
+                                     bool *should_complete, int *ret)
 {
        struct skcipher_request *req = skcipher_request_cast(async);
        struct safexcel_result_desc *rdesc;
        spin_unlock_bh(&priv->ring[ring].egress_lock);
 
        request->req = &req->base;
-       ctx->base.handle_result = safexcel_handle_result;
 
        *commands = n_cdesc;
        *results = n_rdesc;
 
        ring = safexcel_select_ring(priv);
        ctx->base.ring = ring;
-       ctx->base.needs_inv = false;
-       ctx->base.send = safexcel_aes_send;
 
        spin_lock_bh(&priv->ring[ring].queue_lock);
        enq_ret = crypto_enqueue_request(&priv->ring[ring].queue, async);
        return ndesc;
 }
 
+static int safexcel_handle_result(struct safexcel_crypto_priv *priv, int ring,
+                                 struct crypto_async_request *async,
+                                 bool *should_complete, int *ret)
+{
+       struct skcipher_request *req = skcipher_request_cast(async);
+       struct safexcel_cipher_req *sreq = skcipher_request_ctx(req);
+       int err;
+
+       if (sreq->needs_inv) {
+               sreq->needs_inv = false;
+               err = safexcel_handle_inv_result(priv, ring, async,
+                                                should_complete, ret);
+       } else {
+               err = safexcel_handle_req_result(priv, ring, async,
+                                                should_complete, ret);
+       }
+
+       return err;
+}
+
 static int safexcel_cipher_send_inv(struct crypto_async_request *async,
                                    int ring, struct safexcel_request *request,
                                    int *commands, int *results)
        struct safexcel_crypto_priv *priv = ctx->priv;
        int ret;
 
-       ctx->base.handle_result = safexcel_handle_inv_result;
-
        ret = safexcel_invalidate_cache(async, &ctx->base, priv,
                                        ctx->base.ctxr_dma, ring, request);
        if (unlikely(ret))
        return 0;
 }
 
+static int safexcel_send(struct crypto_async_request *async,
+                        int ring, struct safexcel_request *request,
+                        int *commands, int *results)
+{
+       struct skcipher_request *req = skcipher_request_cast(async);
+       struct safexcel_cipher_req *sreq = skcipher_request_ctx(req);
+       int ret;
+
+       if (sreq->needs_inv)
+               ret = safexcel_cipher_send_inv(async, ring, request,
+                                              commands, results);
+       else
+               ret = safexcel_aes_send(async, ring, request,
+                                       commands, results);
+       return ret;
+}
+
 static int safexcel_cipher_exit_inv(struct crypto_tfm *tfm)
 {
        struct safexcel_cipher_ctx *ctx = crypto_tfm_ctx(tfm);
        struct safexcel_crypto_priv *priv = ctx->priv;
        struct skcipher_request req;
+       struct safexcel_cipher_req *sreq = skcipher_request_ctx(&req);
        struct safexcel_inv_result result = {};
        int ring = ctx->base.ring;
 
        skcipher_request_set_tfm(&req, __crypto_skcipher_cast(tfm));
        ctx = crypto_tfm_ctx(req.base.tfm);
        ctx->base.exit_inv = true;
-       ctx->base.send = safexcel_cipher_send_inv;
+       sreq->needs_inv = true;
 
        spin_lock_bh(&priv->ring[ring].queue_lock);
        crypto_enqueue_request(&priv->ring[ring].queue, &req.base);
                        enum safexcel_cipher_direction dir, u32 mode)
 {
        struct safexcel_cipher_ctx *ctx = crypto_tfm_ctx(req->base.tfm);
+       struct safexcel_cipher_req *sreq = skcipher_request_ctx(req);
        struct safexcel_crypto_priv *priv = ctx->priv;
        int ret, ring;
 
+       sreq->needs_inv = false;
        ctx->direction = dir;
        ctx->mode = mode;
 
        if (ctx->base.ctxr) {
-               if (ctx->base.needs_inv)
-                       ctx->base.send = safexcel_cipher_send_inv;
+               if (ctx->base.needs_inv) {
+                       sreq->needs_inv = true;
+                       ctx->base.needs_inv = false;
+               }
        } else {
                ctx->base.ring = safexcel_select_ring(priv);
-               ctx->base.send = safexcel_aes_send;
-
                ctx->base.ctxr = dma_pool_zalloc(priv->context_pool,
                                                 EIP197_GFP_FLAGS(req->base),
                                                 &ctx->base.ctxr_dma);
                             alg.skcipher.base);
 
        ctx->priv = tmpl->priv;
+       ctx->base.send = safexcel_send;
+       ctx->base.handle_result = safexcel_handle_result;
+
+       crypto_skcipher_set_reqsize(__crypto_skcipher_cast(tfm),
+                                   sizeof(struct safexcel_cipher_req));
 
        return 0;
 }
 
        bool last_req;
        bool finish;
        bool hmac;
+       bool needs_inv;
 
        u8 state_sz;    /* expected sate size, only set once */
        u32 state[SHA256_DIGEST_SIZE / sizeof(u32)];
        }
 }
 
-static int safexcel_handle_result(struct safexcel_crypto_priv *priv, int ring,
-                                 struct crypto_async_request *async,
-                                 bool *should_complete, int *ret)
+static int safexcel_handle_req_result(struct safexcel_crypto_priv *priv, int ring,
+                                     struct crypto_async_request *async,
+                                     bool *should_complete, int *ret)
 {
        struct safexcel_result_desc *rdesc;
        struct ahash_request *areq = ahash_request_cast(async);
        return 1;
 }
 
-static int safexcel_ahash_send(struct crypto_async_request *async, int ring,
-                              struct safexcel_request *request, int *commands,
-                              int *results)
+static int safexcel_ahash_send_req(struct crypto_async_request *async, int ring,
+                                  struct safexcel_request *request,
+                                  int *commands, int *results)
 {
        struct ahash_request *areq = ahash_request_cast(async);
        struct crypto_ahash *ahash = crypto_ahash_reqtfm(areq);
 
        req->processed += len;
        request->req = &areq->base;
-       ctx->base.handle_result = safexcel_handle_result;
 
        *commands = n_cdesc;
        *results = 1;
 
        ring = safexcel_select_ring(priv);
        ctx->base.ring = ring;
-       ctx->base.needs_inv = false;
-       ctx->base.send = safexcel_ahash_send;
 
        spin_lock_bh(&priv->ring[ring].queue_lock);
        enq_ret = crypto_enqueue_request(&priv->ring[ring].queue, async);
        return 1;
 }
 
+static int safexcel_handle_result(struct safexcel_crypto_priv *priv, int ring,
+                                 struct crypto_async_request *async,
+                                 bool *should_complete, int *ret)
+{
+       struct ahash_request *areq = ahash_request_cast(async);
+       struct safexcel_ahash_req *req = ahash_request_ctx(areq);
+       int err;
+
+       if (req->needs_inv) {
+               req->needs_inv = false;
+               err = safexcel_handle_inv_result(priv, ring, async,
+                                                should_complete, ret);
+       } else {
+               err = safexcel_handle_req_result(priv, ring, async,
+                                                should_complete, ret);
+       }
+
+       return err;
+}
+
 static int safexcel_ahash_send_inv(struct crypto_async_request *async,
                                   int ring, struct safexcel_request *request,
                                   int *commands, int *results)
        struct safexcel_ahash_ctx *ctx = crypto_ahash_ctx(crypto_ahash_reqtfm(areq));
        int ret;
 
-       ctx->base.handle_result = safexcel_handle_inv_result;
        ret = safexcel_invalidate_cache(async, &ctx->base, ctx->priv,
                                        ctx->base.ctxr_dma, ring, request);
        if (unlikely(ret))
        return 0;
 }
 
+static int safexcel_ahash_send(struct crypto_async_request *async,
+                              int ring, struct safexcel_request *request,
+                              int *commands, int *results)
+{
+       struct ahash_request *areq = ahash_request_cast(async);
+       struct safexcel_ahash_req *req = ahash_request_ctx(areq);
+       int ret;
+
+       if (req->needs_inv)
+               ret = safexcel_ahash_send_inv(async, ring, request,
+                                             commands, results);
+       else
+               ret = safexcel_ahash_send_req(async, ring, request,
+                                             commands, results);
+       return ret;
+}
+
 static int safexcel_ahash_exit_inv(struct crypto_tfm *tfm)
 {
        struct safexcel_ahash_ctx *ctx = crypto_tfm_ctx(tfm);
        struct safexcel_crypto_priv *priv = ctx->priv;
        struct ahash_request req;
+       struct safexcel_ahash_req *rctx = ahash_request_ctx(&req);
        struct safexcel_inv_result result = {};
        int ring = ctx->base.ring;
 
        ahash_request_set_tfm(&req, __crypto_ahash_cast(tfm));
        ctx = crypto_tfm_ctx(req.base.tfm);
        ctx->base.exit_inv = true;
-       ctx->base.send = safexcel_ahash_send_inv;
+       rctx->needs_inv = true;
 
        spin_lock_bh(&priv->ring[ring].queue_lock);
        crypto_enqueue_request(&priv->ring[ring].queue, &req.base);
        struct safexcel_crypto_priv *priv = ctx->priv;
        int ret, ring;
 
-       ctx->base.send = safexcel_ahash_send;
+       req->needs_inv = false;
 
        if (req->processed && ctx->digest == CONTEXT_CONTROL_DIGEST_PRECOMPUTED)
                ctx->base.needs_inv = safexcel_ahash_needs_inv_get(areq);
 
        if (ctx->base.ctxr) {
-               if (ctx->base.needs_inv)
-                       ctx->base.send = safexcel_ahash_send_inv;
+               if (ctx->base.needs_inv) {
+                       ctx->base.needs_inv = false;
+                       req->needs_inv = true;
+               }
        } else {
                ctx->base.ring = safexcel_select_ring(priv);
                ctx->base.ctxr = dma_pool_zalloc(priv->context_pool,
                             struct safexcel_alg_template, alg.ahash);
 
        ctx->priv = tmpl->priv;
+       ctx->base.send = safexcel_ahash_send;
+       ctx->base.handle_result = safexcel_handle_result;
 
        crypto_ahash_set_reqsize(__crypto_ahash_cast(tfm),
                                 sizeof(struct safexcel_ahash_req));