struct skcipher_walk walk;
        int err;
 
-       err = skcipher_walk_virt(&walk, req, true);
+       err = skcipher_walk_virt(&walk, req, false);
 
-       kernel_neon_begin();
        while (walk.nbytes >= AES_BLOCK_SIZE) {
                unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
 
                        blocks = round_down(blocks,
                                            walk.stride / AES_BLOCK_SIZE);
 
+               kernel_neon_begin();
                fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->rk,
                   ctx->rounds, blocks);
+               kernel_neon_end();
                err = skcipher_walk_done(&walk,
                                         walk.nbytes - blocks * AES_BLOCK_SIZE);
        }
-       kernel_neon_end();
 
        return err;
 }
        struct skcipher_walk walk;
        int err;
 
-       err = skcipher_walk_virt(&walk, req, true);
+       err = skcipher_walk_virt(&walk, req, false);
 
-       kernel_neon_begin();
        while (walk.nbytes >= AES_BLOCK_SIZE) {
                unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
 
                /* fall back to the non-bitsliced NEON implementation */
+               kernel_neon_begin();
                neon_aes_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
                                     ctx->enc, ctx->key.rounds, blocks,
                                     walk.iv);
+               kernel_neon_end();
                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
        }
-       kernel_neon_end();
        return err;
 }
 
        struct skcipher_walk walk;
        int err;
 
-       err = skcipher_walk_virt(&walk, req, true);
+       err = skcipher_walk_virt(&walk, req, false);
 
-       kernel_neon_begin();
        while (walk.nbytes >= AES_BLOCK_SIZE) {
                unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
 
                        blocks = round_down(blocks,
                                            walk.stride / AES_BLOCK_SIZE);
 
+               kernel_neon_begin();
                aesbs_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
                                  ctx->key.rk, ctx->key.rounds, blocks,
                                  walk.iv);
+               kernel_neon_end();
                err = skcipher_walk_done(&walk,
                                         walk.nbytes - blocks * AES_BLOCK_SIZE);
        }
-       kernel_neon_end();
 
        return err;
 }
        u8 buf[AES_BLOCK_SIZE];
        int err;
 
-       err = skcipher_walk_virt(&walk, req, true);
+       err = skcipher_walk_virt(&walk, req, false);
 
-       kernel_neon_begin();
        while (walk.nbytes > 0) {
                unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
                u8 *final = (walk.total % AES_BLOCK_SIZE) ? buf : NULL;
                        final = NULL;
                }
 
+               kernel_neon_begin();
                aesbs_ctr_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
                                  ctx->rk, ctx->rounds, blocks, walk.iv, final);
+               kernel_neon_end();
 
                if (final) {
                        u8 *dst = walk.dst.virt.addr + blocks * AES_BLOCK_SIZE;
                err = skcipher_walk_done(&walk,
                                         walk.nbytes - blocks * AES_BLOCK_SIZE);
        }
-       kernel_neon_end();
-
        return err;
 }
 
        struct skcipher_walk walk;
        int err;
 
-       err = skcipher_walk_virt(&walk, req, true);
+       err = skcipher_walk_virt(&walk, req, false);
 
        kernel_neon_begin();
-
-       neon_aes_ecb_encrypt(walk.iv, walk.iv, ctx->twkey,
-                            ctx->key.rounds, 1);
+       neon_aes_ecb_encrypt(walk.iv, walk.iv, ctx->twkey, ctx->key.rounds, 1);
+       kernel_neon_end();
 
        while (walk.nbytes >= AES_BLOCK_SIZE) {
                unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
                        blocks = round_down(blocks,
                                            walk.stride / AES_BLOCK_SIZE);
 
+               kernel_neon_begin();
                fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->key.rk,
                   ctx->key.rounds, blocks, walk.iv);
+               kernel_neon_end();
                err = skcipher_walk_done(&walk,
                                         walk.nbytes - blocks * AES_BLOCK_SIZE);
        }
-       kernel_neon_end();
-
        return err;
 }