static int crypto_blake2s_update_arm(struct shash_desc *desc,
                                     const u8 *in, unsigned int inlen)
 {
-       return crypto_blake2s_update(desc, in, inlen, blake2s_compress);
+       return crypto_blake2s_update(desc, in, inlen, false);
 }
 
 static int crypto_blake2s_final_arm(struct shash_desc *desc, u8 *out)
 {
-       return crypto_blake2s_final(desc, out, blake2s_compress);
+       return crypto_blake2s_final(desc, out, false);
 }
 
 #define BLAKE2S_ALG(name, driver_name, digest_size)                    \
 
 static int crypto_blake2s_update_x86(struct shash_desc *desc,
                                     const u8 *in, unsigned int inlen)
 {
-       return crypto_blake2s_update(desc, in, inlen, blake2s_compress);
+       return crypto_blake2s_update(desc, in, inlen, false);
 }
 
 static int crypto_blake2s_final_x86(struct shash_desc *desc, u8 *out)
 {
-       return crypto_blake2s_final(desc, out, blake2s_compress);
+       return crypto_blake2s_final(desc, out, false);
 }
 
 #define BLAKE2S_ALG(name, driver_name, digest_size)                    \
 
 static int crypto_blake2s_update_generic(struct shash_desc *desc,
                                         const u8 *in, unsigned int inlen)
 {
-       return crypto_blake2s_update(desc, in, inlen, blake2s_compress_generic);
+       return crypto_blake2s_update(desc, in, inlen, true);
 }
 
 static int crypto_blake2s_final_generic(struct shash_desc *desc, u8 *out)
 {
-       return crypto_blake2s_final(desc, out, blake2s_compress_generic);
+       return crypto_blake2s_final(desc, out, true);
 }
 
 #define BLAKE2S_ALG(name, driver_name, digest_size)                    \
 
        state->f[0] = -1;
 }
 
-typedef void (*blake2s_compress_t)(struct blake2s_state *state,
-                                  const u8 *block, size_t nblocks, u32 inc);
-
 /* Helper functions for BLAKE2s shared by the library and shash APIs */
 
-static inline void __blake2s_update(struct blake2s_state *state,
-                                   const u8 *in, size_t inlen,
-                                   blake2s_compress_t compress)
+static __always_inline void
+__blake2s_update(struct blake2s_state *state, const u8 *in, size_t inlen,
+                bool force_generic)
 {
        const size_t fill = BLAKE2S_BLOCK_SIZE - state->buflen;
 
                return;
        if (inlen > fill) {
                memcpy(state->buf + state->buflen, in, fill);
-               (*compress)(state, state->buf, 1, BLAKE2S_BLOCK_SIZE);
+               if (force_generic)
+                       blake2s_compress_generic(state, state->buf, 1,
+                                                BLAKE2S_BLOCK_SIZE);
+               else
+                       blake2s_compress(state, state->buf, 1,
+                                        BLAKE2S_BLOCK_SIZE);
                state->buflen = 0;
                in += fill;
                inlen -= fill;
        if (inlen > BLAKE2S_BLOCK_SIZE) {
                const size_t nblocks = DIV_ROUND_UP(inlen, BLAKE2S_BLOCK_SIZE);
                /* Hash one less (full) block than strictly possible */
-               (*compress)(state, in, nblocks - 1, BLAKE2S_BLOCK_SIZE);
+               if (force_generic)
+                       blake2s_compress_generic(state, in, nblocks - 1,
+                                                BLAKE2S_BLOCK_SIZE);
+               else
+                       blake2s_compress(state, in, nblocks - 1,
+                                        BLAKE2S_BLOCK_SIZE);
                in += BLAKE2S_BLOCK_SIZE * (nblocks - 1);
                inlen -= BLAKE2S_BLOCK_SIZE * (nblocks - 1);
        }
        state->buflen += inlen;
 }
 
-static inline void __blake2s_final(struct blake2s_state *state, u8 *out,
-                                  blake2s_compress_t compress)
+static __always_inline void
+__blake2s_final(struct blake2s_state *state, u8 *out, bool force_generic)
 {
        blake2s_set_lastblock(state);
        memset(state->buf + state->buflen, 0,
               BLAKE2S_BLOCK_SIZE - state->buflen); /* Padding */
-       (*compress)(state, state->buf, 1, state->buflen);
+       if (force_generic)
+               blake2s_compress_generic(state, state->buf, 1, state->buflen);
+       else
+               blake2s_compress(state, state->buf, 1, state->buflen);
        cpu_to_le32_array(state->h, ARRAY_SIZE(state->h));
        memcpy(out, state->h, state->outlen);
 }
 
 static inline int crypto_blake2s_update(struct shash_desc *desc,
                                        const u8 *in, unsigned int inlen,
-                                       blake2s_compress_t compress)
+                                       bool force_generic)
 {
        struct blake2s_state *state = shash_desc_ctx(desc);
 
-       __blake2s_update(state, in, inlen, compress);
+       __blake2s_update(state, in, inlen, force_generic);
        return 0;
 }
 
 static inline int crypto_blake2s_final(struct shash_desc *desc, u8 *out,
-                                      blake2s_compress_t compress)
+                                      bool force_generic)
 {
        struct blake2s_state *state = shash_desc_ctx(desc);
 
-       __blake2s_final(state, out, compress);
+       __blake2s_final(state, out, force_generic);
        return 0;
 }
 
 
 
 void blake2s_update(struct blake2s_state *state, const u8 *in, size_t inlen)
 {
-       __blake2s_update(state, in, inlen, blake2s_compress);
+       __blake2s_update(state, in, inlen, false);
 }
 EXPORT_SYMBOL(blake2s_update);
 
 void blake2s_final(struct blake2s_state *state, u8 *out)
 {
        WARN_ON(IS_ENABLED(DEBUG) && !out);
-       __blake2s_final(state, out, blake2s_compress);
+       __blake2s_final(state, out, false);
        memzero_explicit(state, sizeof(*state));
 }
 EXPORT_SYMBOL(blake2s_final);