#include <crypto/scatterwalk.h>
 #include <linux/err.h>
 #include <linux/init.h>
+#include <linux/jump_label.h>
 #include <linux/kernel.h>
 #include <linux/module.h>
 #include <linux/scatterlist.h>
        union aegis_block key;
 };
 
-struct aegis128_ops {
-       int (*skcipher_walk_init)(struct skcipher_walk *walk,
-                                 struct aead_request *req, bool atomic);
-
-       void (*crypt_chunk)(struct aegis_state *state, u8 *dst,
-                           const u8 *src, unsigned int size);
-};
-
-static bool have_simd;
+static __ro_after_init DEFINE_STATIC_KEY_FALSE(have_simd);
 
 static const union aegis_block crypto_aegis_const[2] = {
        { .words64 = {
 static bool aegis128_do_simd(void)
 {
 #ifdef CONFIG_CRYPTO_AEGIS128_SIMD
-       if (have_simd)
+       if (static_branch_likely(&have_simd))
                return crypto_simd_usable();
 #endif
        return false;
        }
 }
 
-static void crypto_aegis128_process_crypt(struct aegis_state *state,
-                                         struct aead_request *req,
-                                         const struct aegis128_ops *ops)
+static __always_inline
+int crypto_aegis128_process_crypt(struct aegis_state *state,
+                                 struct aead_request *req,
+                                 struct skcipher_walk *walk,
+                                 void (*crypt)(struct aegis_state *state,
+                                               u8 *dst, const u8 *src,
+                                               unsigned int size))
 {
-       struct skcipher_walk walk;
+       int err = 0;
 
-       ops->skcipher_walk_init(&walk, req, false);
+       while (walk->nbytes) {
+               unsigned int nbytes = walk->nbytes;
 
-       while (walk.nbytes) {
-               unsigned int nbytes = walk.nbytes;
+               if (nbytes < walk->total)
+                       nbytes = round_down(nbytes, walk->stride);
 
-               if (nbytes < walk.total)
-                       nbytes = round_down(nbytes, walk.stride);
+               crypt(state, walk->dst.virt.addr, walk->src.virt.addr, nbytes);
 
-               ops->crypt_chunk(state, walk.dst.virt.addr, walk.src.virt.addr,
-                                nbytes);
-
-               skcipher_walk_done(&walk, walk.nbytes - nbytes);
+               err = skcipher_walk_done(walk, walk->nbytes - nbytes);
        }
+       return err;
 }
 
 static void crypto_aegis128_final(struct aegis_state *state,
        return 0;
 }
 
-static void crypto_aegis128_crypt(struct aead_request *req,
-                                 union aegis_block *tag_xor,
-                                 unsigned int cryptlen,
-                                 const struct aegis128_ops *ops)
+static int crypto_aegis128_encrypt(struct aead_request *req)
 {
        struct crypto_aead *tfm = crypto_aead_reqtfm(req);
+       union aegis_block tag = {};
+       unsigned int authsize = crypto_aead_authsize(tfm);
        struct aegis_ctx *ctx = crypto_aead_ctx(tfm);
+       unsigned int cryptlen = req->cryptlen;
+       struct skcipher_walk walk;
        struct aegis_state state;
 
        crypto_aegis128_init(&state, &ctx->key, req->iv);
        crypto_aegis128_process_ad(&state, req->src, req->assoclen);
-       crypto_aegis128_process_crypt(&state, req, ops);
-       crypto_aegis128_final(&state, tag_xor, req->assoclen, cryptlen);
-}
-
-static int crypto_aegis128_encrypt(struct aead_request *req)
-{
-       const struct aegis128_ops *ops = &(struct aegis128_ops){
-               .skcipher_walk_init = skcipher_walk_aead_encrypt,
-               .crypt_chunk = crypto_aegis128_encrypt_chunk,
-       };
-
-       struct crypto_aead *tfm = crypto_aead_reqtfm(req);
-       union aegis_block tag = {};
-       unsigned int authsize = crypto_aead_authsize(tfm);
-       unsigned int cryptlen = req->cryptlen;
 
+       skcipher_walk_aead_encrypt(&walk, req, false);
        if (aegis128_do_simd())
-               ops = &(struct aegis128_ops){
-                       .skcipher_walk_init = skcipher_walk_aead_encrypt,
-                       .crypt_chunk = crypto_aegis128_encrypt_chunk_simd };
-
-       crypto_aegis128_crypt(req, &tag, cryptlen, ops);
+               crypto_aegis128_process_crypt(&state, req, &walk,
+                                             crypto_aegis128_encrypt_chunk_simd);
+       else
+               crypto_aegis128_process_crypt(&state, req, &walk,
+                                             crypto_aegis128_encrypt_chunk);
+       crypto_aegis128_final(&state, &tag, req->assoclen, cryptlen);
 
        scatterwalk_map_and_copy(tag.bytes, req->dst, req->assoclen + cryptlen,
                                 authsize, 1);
 
 static int crypto_aegis128_decrypt(struct aead_request *req)
 {
-       const struct aegis128_ops *ops = &(struct aegis128_ops){
-               .skcipher_walk_init = skcipher_walk_aead_decrypt,
-               .crypt_chunk = crypto_aegis128_decrypt_chunk,
-       };
        static const u8 zeros[AEGIS128_MAX_AUTH_SIZE] = {};
-
        struct crypto_aead *tfm = crypto_aead_reqtfm(req);
        union aegis_block tag;
        unsigned int authsize = crypto_aead_authsize(tfm);
        unsigned int cryptlen = req->cryptlen - authsize;
+       struct aegis_ctx *ctx = crypto_aead_ctx(tfm);
+       struct skcipher_walk walk;
+       struct aegis_state state;
 
        scatterwalk_map_and_copy(tag.bytes, req->src, req->assoclen + cryptlen,
                                 authsize, 0);
 
-       if (aegis128_do_simd())
-               ops = &(struct aegis128_ops){
-                       .skcipher_walk_init = skcipher_walk_aead_decrypt,
-                       .crypt_chunk = crypto_aegis128_decrypt_chunk_simd };
+       crypto_aegis128_init(&state, &ctx->key, req->iv);
+       crypto_aegis128_process_ad(&state, req->src, req->assoclen);
 
-       crypto_aegis128_crypt(req, &tag, cryptlen, ops);
+       skcipher_walk_aead_decrypt(&walk, req, false);
+       if (aegis128_do_simd())
+               crypto_aegis128_process_crypt(&state, req, &walk,
+                                             crypto_aegis128_decrypt_chunk_simd);
+       else
+               crypto_aegis128_process_crypt(&state, req, &walk,
+                                             crypto_aegis128_decrypt_chunk);
+       crypto_aegis128_final(&state, &tag, req->assoclen, cryptlen);
 
        return crypto_memneq(tag.bytes, zeros, authsize) ? -EBADMSG : 0;
 }
 
 static int __init crypto_aegis128_module_init(void)
 {
-       if (IS_ENABLED(CONFIG_CRYPTO_AEGIS128_SIMD))
-               have_simd = crypto_aegis128_have_simd();
+       if (IS_ENABLED(CONFIG_CRYPTO_AEGIS128_SIMD) &&
+           crypto_aegis128_have_simd())
+               static_branch_enable(&have_simd);
 
        return crypto_register_aead(&crypto_aegis128_alg);
 }