asmlinkage void chacha_4block_xor_ssse3(u32 *state, u8 *dst, const u8 *src,
                                        unsigned int len, int nrounds);
 asmlinkage void hchacha_block_ssse3(const u32 *state, u32 *out, int nrounds);
-#ifdef CONFIG_AS_AVX2
+
 asmlinkage void chacha_2block_xor_avx2(u32 *state, u8 *dst, const u8 *src,
                                       unsigned int len, int nrounds);
 asmlinkage void chacha_4block_xor_avx2(u32 *state, u8 *dst, const u8 *src,
                                       unsigned int len, int nrounds);
 asmlinkage void chacha_8block_xor_avx2(u32 *state, u8 *dst, const u8 *src,
                                       unsigned int len, int nrounds);
-static bool chacha_use_avx2;
-#ifdef CONFIG_AS_AVX512
+
 asmlinkage void chacha_2block_xor_avx512vl(u32 *state, u8 *dst, const u8 *src,
                                           unsigned int len, int nrounds);
 asmlinkage void chacha_4block_xor_avx512vl(u32 *state, u8 *dst, const u8 *src,
                                           unsigned int len, int nrounds);
 asmlinkage void chacha_8block_xor_avx512vl(u32 *state, u8 *dst, const u8 *src,
                                           unsigned int len, int nrounds);
-static bool chacha_use_avx512vl;
-#endif
-#endif
+
+static __ro_after_init DEFINE_STATIC_KEY_FALSE(chacha_use_simd);
+static __ro_after_init DEFINE_STATIC_KEY_FALSE(chacha_use_avx2);
+static __ro_after_init DEFINE_STATIC_KEY_FALSE(chacha_use_avx512vl);
 
 static unsigned int chacha_advance(unsigned int len, unsigned int maxblocks)
 {
 static void chacha_dosimd(u32 *state, u8 *dst, const u8 *src,
                          unsigned int bytes, int nrounds)
 {
-#ifdef CONFIG_AS_AVX2
-#ifdef CONFIG_AS_AVX512
-       if (chacha_use_avx512vl) {
+       if (IS_ENABLED(CONFIG_AS_AVX512) &&
+           static_branch_likely(&chacha_use_avx512vl)) {
                while (bytes >= CHACHA_BLOCK_SIZE * 8) {
                        chacha_8block_xor_avx512vl(state, dst, src, bytes,
                                                   nrounds);
                        return;
                }
        }
-#endif
-       if (chacha_use_avx2) {
+
+       if (IS_ENABLED(CONFIG_AS_AVX2) &&
+           static_branch_likely(&chacha_use_avx2)) {
                while (bytes >= CHACHA_BLOCK_SIZE * 8) {
                        chacha_8block_xor_avx2(state, dst, src, bytes, nrounds);
                        bytes -= CHACHA_BLOCK_SIZE * 8;
                        return;
                }
        }
-#endif
+
        while (bytes >= CHACHA_BLOCK_SIZE * 4) {
                chacha_4block_xor_ssse3(state, dst, src, bytes, nrounds);
                bytes -= CHACHA_BLOCK_SIZE * 4;
        }
 }
 
+void hchacha_block_arch(const u32 *state, u32 *stream, int nrounds)
+{
+       state = PTR_ALIGN(state, CHACHA_STATE_ALIGN);
+
+       if (!static_branch_likely(&chacha_use_simd) || !crypto_simd_usable()) {
+               hchacha_block_generic(state, stream, nrounds);
+       } else {
+               kernel_fpu_begin();
+               hchacha_block_ssse3(state, stream, nrounds);
+               kernel_fpu_end();
+       }
+}
+EXPORT_SYMBOL(hchacha_block_arch);
+
+void chacha_init_arch(u32 *state, const u32 *key, const u8 *iv)
+{
+       state = PTR_ALIGN(state, CHACHA_STATE_ALIGN);
+
+       chacha_init_generic(state, key, iv);
+}
+EXPORT_SYMBOL(chacha_init_arch);
+
+void chacha_crypt_arch(u32 *state, u8 *dst, const u8 *src, unsigned int bytes,
+                      int nrounds)
+{
+       state = PTR_ALIGN(state, CHACHA_STATE_ALIGN);
+
+       if (!static_branch_likely(&chacha_use_simd) || !crypto_simd_usable() ||
+           bytes <= CHACHA_BLOCK_SIZE)
+               return chacha_crypt_generic(state, dst, src, bytes, nrounds);
+
+       kernel_fpu_begin();
+       chacha_dosimd(state, dst, src, bytes, nrounds);
+       kernel_fpu_end();
+}
+EXPORT_SYMBOL(chacha_crypt_arch);
+
 static int chacha_simd_stream_xor(struct skcipher_request *req,
                                  const struct chacha_ctx *ctx, const u8 *iv)
 {
                if (nbytes < walk.total)
                        nbytes = round_down(nbytes, walk.stride);
 
-               if (!crypto_simd_usable()) {
+               if (!static_branch_likely(&chacha_use_simd) ||
+                   !crypto_simd_usable()) {
                        chacha_crypt_generic(state, walk.dst.virt.addr,
                                             walk.src.virt.addr, nbytes,
                                             ctx->nrounds);
 static int __init chacha_simd_mod_init(void)
 {
        if (!boot_cpu_has(X86_FEATURE_SSSE3))
-               return -ENODEV;
-
-#ifdef CONFIG_AS_AVX2
-       chacha_use_avx2 = boot_cpu_has(X86_FEATURE_AVX) &&
-                         boot_cpu_has(X86_FEATURE_AVX2) &&
-                         cpu_has_xfeatures(XFEATURE_MASK_SSE | XFEATURE_MASK_YMM, NULL);
-#ifdef CONFIG_AS_AVX512
-       chacha_use_avx512vl = chacha_use_avx2 &&
-                             boot_cpu_has(X86_FEATURE_AVX512VL) &&
-                             boot_cpu_has(X86_FEATURE_AVX512BW); /* kmovq */
-#endif
-#endif
+               return 0;
+
+       static_branch_enable(&chacha_use_simd);
+
+       if (IS_ENABLED(CONFIG_AS_AVX2) &&
+           boot_cpu_has(X86_FEATURE_AVX) &&
+           boot_cpu_has(X86_FEATURE_AVX2) &&
+           cpu_has_xfeatures(XFEATURE_MASK_SSE | XFEATURE_MASK_YMM, NULL)) {
+               static_branch_enable(&chacha_use_avx2);
+
+               if (IS_ENABLED(CONFIG_AS_AVX512) &&
+                   boot_cpu_has(X86_FEATURE_AVX512VL) &&
+                   boot_cpu_has(X86_FEATURE_AVX512BW)) /* kmovq */
+                       static_branch_enable(&chacha_use_avx512vl);
+       }
        return crypto_register_skciphers(algs, ARRAY_SIZE(algs));
 }