/* defined in aes-modes.S */
 asmlinkage void aes_ecb_encrypt(u8 out[], u8 const in[], u8 const rk[],
-                               int rounds, int blocks, int first);
+                               int rounds, int blocks);
 asmlinkage void aes_ecb_decrypt(u8 out[], u8 const in[], u8 const rk[],
-                               int rounds, int blocks, int first);
+                               int rounds, int blocks);
 
 asmlinkage void aes_cbc_encrypt(u8 out[], u8 const in[], u8 const rk[],
-                               int rounds, int blocks, u8 iv[], int first);
+                               int rounds, int blocks, u8 iv[]);
 asmlinkage void aes_cbc_decrypt(u8 out[], u8 const in[], u8 const rk[],
-                               int rounds, int blocks, u8 iv[], int first);
+                               int rounds, int blocks, u8 iv[]);
 
 asmlinkage void aes_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[],
-                               int rounds, int blocks, u8 ctr[], int first);
+                               int rounds, int blocks, u8 ctr[]);
 
 asmlinkage void aes_xts_encrypt(u8 out[], u8 const in[], u8 const rk1[],
                                int rounds, int blocks, u8 const rk2[], u8 iv[],
 {
        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
-       int err, first, rounds = 6 + ctx->key_length / 4;
+       int err, rounds = 6 + ctx->key_length / 4;
        struct skcipher_walk walk;
        unsigned int blocks;
 
-       err = skcipher_walk_virt(&walk, req, true);
+       err = skcipher_walk_virt(&walk, req, false);
 
-       kernel_neon_begin();
-       for (first = 1; (blocks = (walk.nbytes / AES_BLOCK_SIZE)); first = 0) {
+       while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
+               kernel_neon_begin();
                aes_ecb_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
-                               (u8 *)ctx->key_enc, rounds, blocks, first);
+                               (u8 *)ctx->key_enc, rounds, blocks);
+               kernel_neon_end();
                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
        }
-       kernel_neon_end();
        return err;
 }
 
 {
        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
-       int err, first, rounds = 6 + ctx->key_length / 4;
+       int err, rounds = 6 + ctx->key_length / 4;
        struct skcipher_walk walk;
        unsigned int blocks;
 
-       err = skcipher_walk_virt(&walk, req, true);
+       err = skcipher_walk_virt(&walk, req, false);
 
-       kernel_neon_begin();
-       for (first = 1; (blocks = (walk.nbytes / AES_BLOCK_SIZE)); first = 0) {
+       while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
+               kernel_neon_begin();
                aes_ecb_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
-                               (u8 *)ctx->key_dec, rounds, blocks, first);
+                               (u8 *)ctx->key_dec, rounds, blocks);
+               kernel_neon_end();
                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
        }
-       kernel_neon_end();
        return err;
 }
 
 {
        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
-       int err, first, rounds = 6 + ctx->key_length / 4;
+       int err, rounds = 6 + ctx->key_length / 4;
        struct skcipher_walk walk;
        unsigned int blocks;
 
-       err = skcipher_walk_virt(&walk, req, true);
+       err = skcipher_walk_virt(&walk, req, false);
 
-       kernel_neon_begin();
-       for (first = 1; (blocks = (walk.nbytes / AES_BLOCK_SIZE)); first = 0) {
+       while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
+               kernel_neon_begin();
                aes_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
-                               (u8 *)ctx->key_enc, rounds, blocks, walk.iv,
-                               first);
+                               (u8 *)ctx->key_enc, rounds, blocks, walk.iv);
+               kernel_neon_end();
                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
        }
-       kernel_neon_end();
        return err;
 }
 
 {
        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
-       int err, first, rounds = 6 + ctx->key_length / 4;
+       int err, rounds = 6 + ctx->key_length / 4;
        struct skcipher_walk walk;
        unsigned int blocks;
 
-       err = skcipher_walk_virt(&walk, req, true);
+       err = skcipher_walk_virt(&walk, req, false);
 
-       kernel_neon_begin();
-       for (first = 1; (blocks = (walk.nbytes / AES_BLOCK_SIZE)); first = 0) {
+       while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
+               kernel_neon_begin();
                aes_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
-                               (u8 *)ctx->key_dec, rounds, blocks, walk.iv,
-                               first);
+                               (u8 *)ctx->key_dec, rounds, blocks, walk.iv);
+               kernel_neon_end();
                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
        }
-       kernel_neon_end();
        return err;
 }
 
 {
        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
-       int err, first, rounds = 6 + ctx->key_length / 4;
+       int err, rounds = 6 + ctx->key_length / 4;
        struct skcipher_walk walk;
        int blocks;
 
-       err = skcipher_walk_virt(&walk, req, true);
+       err = skcipher_walk_virt(&walk, req, false);
 
-       first = 1;
-       kernel_neon_begin();
        while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
+               kernel_neon_begin();
                aes_ctr_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
-                               (u8 *)ctx->key_enc, rounds, blocks, walk.iv,
-                               first);
+                               (u8 *)ctx->key_enc, rounds, blocks, walk.iv);
                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
-               first = 0;
+               kernel_neon_end();
        }
        if (walk.nbytes) {
                u8 __aligned(8) tail[AES_BLOCK_SIZE];
                 */
                blocks = -1;
 
+               kernel_neon_begin();
                aes_ctr_encrypt(tail, NULL, (u8 *)ctx->key_enc, rounds,
-                               blocks, walk.iv, first);
+                               blocks, walk.iv);
+               kernel_neon_end();
                crypto_xor_cpy(tdst, tsrc, tail, nbytes);
                err = skcipher_walk_done(&walk, 0);
        }
-       kernel_neon_end();
 
        return err;
 }
        struct skcipher_walk walk;
        unsigned int blocks;
 
-       err = skcipher_walk_virt(&walk, req, true);
+       err = skcipher_walk_virt(&walk, req, false);
 
-       kernel_neon_begin();
        for (first = 1; (blocks = (walk.nbytes / AES_BLOCK_SIZE)); first = 0) {
+               kernel_neon_begin();
                aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
                                (u8 *)ctx->key1.key_enc, rounds, blocks,
                                (u8 *)ctx->key2.key_enc, walk.iv, first);
+               kernel_neon_end();
                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
        }
-       kernel_neon_end();
 
        return err;
 }
        struct skcipher_walk walk;
        unsigned int blocks;
 
-       err = skcipher_walk_virt(&walk, req, true);
+       err = skcipher_walk_virt(&walk, req, false);
 
-       kernel_neon_begin();
        for (first = 1; (blocks = (walk.nbytes / AES_BLOCK_SIZE)); first = 0) {
+               kernel_neon_begin();
                aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
                                (u8 *)ctx->key1.key_dec, rounds, blocks,
                                (u8 *)ctx->key2.key_enc, walk.iv, first);
+               kernel_neon_end();
                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
        }
-       kernel_neon_end();
 
        return err;
 }
 
        /* encrypt the zero vector */
        kernel_neon_begin();
-       aes_ecb_encrypt(ctx->consts, (u8[AES_BLOCK_SIZE]){}, rk, rounds, 1, 1);
+       aes_ecb_encrypt(ctx->consts, (u8[AES_BLOCK_SIZE]){}, rk, rounds, 1);
        kernel_neon_end();
 
        cmac_gf128_mul_by_x(consts, consts);
                return err;
 
        kernel_neon_begin();
-       aes_ecb_encrypt(key, ks[0], rk, rounds, 1, 1);
-       aes_ecb_encrypt(ctx->consts, ks[1], rk, rounds, 2, 0);
+       aes_ecb_encrypt(key, ks[0], rk, rounds, 1);
+       aes_ecb_encrypt(ctx->consts, ks[1], rk, rounds, 2);
        kernel_neon_end();
 
        return cbcmac_setkey(tfm, key, sizeof(key));
 
 #if INTERLEAVE == 2
 
 aes_encrypt_block2x:
-       encrypt_block2x v0, v1, w3, x2, x6, w7
+       encrypt_block2x v0, v1, w3, x2, x8, w7
        ret
 ENDPROC(aes_encrypt_block2x)
 
 aes_decrypt_block2x:
-       decrypt_block2x v0, v1, w3, x2, x6, w7
+       decrypt_block2x v0, v1, w3, x2, x8, w7
        ret
 ENDPROC(aes_decrypt_block2x)
 
 #elif INTERLEAVE == 4
 
 aes_encrypt_block4x:
-       encrypt_block4x v0, v1, v2, v3, w3, x2, x6, w7
+       encrypt_block4x v0, v1, v2, v3, w3, x2, x8, w7
        ret
 ENDPROC(aes_encrypt_block4x)
 
 aes_decrypt_block4x:
-       decrypt_block4x v0, v1, v2, v3, w3, x2, x6, w7
+       decrypt_block4x v0, v1, v2, v3, w3, x2, x8, w7
        ret
 ENDPROC(aes_decrypt_block4x)
 
 #define FRAME_POP
 
        .macro          do_encrypt_block2x
-       encrypt_block2x v0, v1, w3, x2, x6, w7
+       encrypt_block2x v0, v1, w3, x2, x8, w7
        .endm
 
        .macro          do_decrypt_block2x
-       decrypt_block2x v0, v1, w3, x2, x6, w7
+       decrypt_block2x v0, v1, w3, x2, x8, w7
        .endm
 
        .macro          do_encrypt_block4x
-       encrypt_block4x v0, v1, v2, v3, w3, x2, x6, w7
+       encrypt_block4x v0, v1, v2, v3, w3, x2, x8, w7
        .endm
 
        .macro          do_decrypt_block4x
-       decrypt_block4x v0, v1, v2, v3, w3, x2, x6, w7
+       decrypt_block4x v0, v1, v2, v3, w3, x2, x8, w7
        .endm
 
 #endif
 
        /*
         * aes_ecb_encrypt(u8 out[], u8 const in[], u8 const rk[], int rounds,
-        *                 int blocks, int first)
+        *                 int blocks)
         * aes_ecb_decrypt(u8 out[], u8 const in[], u8 const rk[], int rounds,
-        *                 int blocks, int first)
+        *                 int blocks)
         */
 
 AES_ENTRY(aes_ecb_encrypt)
        FRAME_PUSH
-       cbz             w5, .LecbencloopNx
 
        enc_prepare     w3, x2, x5
 
 
 AES_ENTRY(aes_ecb_decrypt)
        FRAME_PUSH
-       cbz             w5, .LecbdecloopNx
 
        dec_prepare     w3, x2, x5
 
 
        /*
         * aes_cbc_encrypt(u8 out[], u8 const in[], u8 const rk[], int rounds,
-        *                 int blocks, u8 iv[], int first)
+        *                 int blocks, u8 iv[])
         * aes_cbc_decrypt(u8 out[], u8 const in[], u8 const rk[], int rounds,
-        *                 int blocks, u8 iv[], int first)
+        *                 int blocks, u8 iv[])
         */
 
 AES_ENTRY(aes_cbc_encrypt)
-       cbz             w6, .Lcbcencloop
-
        ld1             {v0.16b}, [x5]                  /* get iv */
        enc_prepare     w3, x2, x6
 
 
 AES_ENTRY(aes_cbc_decrypt)
        FRAME_PUSH
-       cbz             w6, .LcbcdecloopNx
 
        ld1             {v7.16b}, [x5]                  /* get iv */
        dec_prepare     w3, x2, x6
 
        /*
         * aes_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[], int rounds,
-        *                 int blocks, u8 ctr[], int first)
+        *                 int blocks, u8 ctr[])
         */
 
 AES_ENTRY(aes_ctr_encrypt)
        FRAME_PUSH
-       cbz             w6, .Lctrnotfirst       /* 1st time around? */
+
        enc_prepare     w3, x2, x6
        ld1             {v4.16b}, [x5]
 
-.Lctrnotfirst:
-       umov            x8, v4.d[1]             /* keep swabbed ctr in reg */
-       rev             x8, x8
+       umov            x6, v4.d[1]             /* keep swabbed ctr in reg */
+       rev             x6, x6
 #if INTERLEAVE >= 2
-       cmn             w8, w4                  /* 32 bit overflow? */
+       cmn             w6, w4                  /* 32 bit overflow? */
        bcs             .Lctrloop
 .LctrloopNx:
        subs            w4, w4, #INTERLEAVE
 #if INTERLEAVE == 2
        mov             v0.8b, v4.8b
        mov             v1.8b, v4.8b
-       rev             x7, x8
-       add             x8, x8, #1
+       rev             x7, x6
+       add             x6, x6, #1
        ins             v0.d[1], x7
-       rev             x7, x8
-       add             x8, x8, #1
+       rev             x7, x6
+       add             x6, x6, #1
        ins             v1.d[1], x7
        ld1             {v2.16b-v3.16b}, [x1], #32      /* get 2 input blocks */
        do_encrypt_block2x
        st1             {v0.16b-v1.16b}, [x0], #32
 #else
        ldr             q8, =0x30000000200000001        /* addends 1,2,3[,0] */
-       dup             v7.4s, w8
+       dup             v7.4s, w6
        mov             v0.16b, v4.16b
        add             v7.4s, v7.4s, v8.4s
        mov             v1.16b, v4.16b
        eor             v2.16b, v7.16b, v2.16b
        eor             v3.16b, v5.16b, v3.16b
        st1             {v0.16b-v3.16b}, [x0], #64
-       add             x8, x8, #INTERLEAVE
+       add             x6, x6, #INTERLEAVE
 #endif
-       rev             x7, x8
+       rev             x7, x6
        ins             v4.d[1], x7
        cbz             w4, .Lctrout
        b               .LctrloopNx
 #endif
 .Lctrloop:
        mov             v0.16b, v4.16b
-       encrypt_block   v0, w3, x2, x6, w7
+       encrypt_block   v0, w3, x2, x8, w7
 
-       adds            x8, x8, #1              /* increment BE ctr */
-       rev             x7, x8
+       adds            x6, x6, #1              /* increment BE ctr */
+       rev             x7, x6
        ins             v4.d[1], x7
        bcs             .Lctrcarry              /* overflow? */
 
 
 AES_ENTRY(aes_xts_encrypt)
        FRAME_PUSH
-       cbz             w7, .LxtsencloopNx
-
        ld1             {v4.16b}, [x6]
-       enc_prepare     w3, x5, x6
-       encrypt_block   v4, w3, x5, x6, w7              /* first tweak */
-       enc_switch_key  w3, x2, x6
+       cbz             w7, .Lxtsencnotfirst
+
+       enc_prepare     w3, x5, x8
+       encrypt_block   v4, w3, x5, x8, w7              /* first tweak */
+       enc_switch_key  w3, x2, x8
        ldr             q7, .Lxts_mul_x
        b               .LxtsencNx
 
+.Lxtsencnotfirst:
+       enc_prepare     w3, x2, x8
 .LxtsencloopNx:
        ldr             q7, .Lxts_mul_x
        next_tweak      v4, v4, v7, v8
 .Lxtsencloop:
        ld1             {v1.16b}, [x1], #16
        eor             v0.16b, v1.16b, v4.16b
-       encrypt_block   v0, w3, x2, x6, w7
+       encrypt_block   v0, w3, x2, x8, w7
        eor             v0.16b, v0.16b, v4.16b
        st1             {v0.16b}, [x0], #16
        subs            w4, w4, #1
        next_tweak      v4, v4, v7, v8
        b               .Lxtsencloop
 .Lxtsencout:
+       st1             {v4.16b}, [x6]
        FRAME_POP
        ret
 AES_ENDPROC(aes_xts_encrypt)
 
 AES_ENTRY(aes_xts_decrypt)
        FRAME_PUSH
-       cbz             w7, .LxtsdecloopNx
-
        ld1             {v4.16b}, [x6]
-       enc_prepare     w3, x5, x6
-       encrypt_block   v4, w3, x5, x6, w7              /* first tweak */
-       dec_prepare     w3, x2, x6
+       cbz             w7, .Lxtsdecnotfirst
+
+       enc_prepare     w3, x5, x8
+       encrypt_block   v4, w3, x5, x8, w7              /* first tweak */
+       dec_prepare     w3, x2, x8
        ldr             q7, .Lxts_mul_x
        b               .LxtsdecNx
 
+.Lxtsdecnotfirst:
+       dec_prepare     w3, x2, x8
 .LxtsdecloopNx:
        ldr             q7, .Lxts_mul_x
        next_tweak      v4, v4, v7, v8
 .Lxtsdecloop:
        ld1             {v1.16b}, [x1], #16
        eor             v0.16b, v1.16b, v4.16b
-       decrypt_block   v0, w3, x2, x6, w7
+       decrypt_block   v0, w3, x2, x8, w7
        eor             v0.16b, v0.16b, v4.16b
        st1             {v0.16b}, [x0], #16
        subs            w4, w4, #1
        next_tweak      v4, v4, v7, v8
        b               .Lxtsdecloop
 .Lxtsdecout:
+       st1             {v4.16b}, [x6]
        FRAME_POP
        ret
 AES_ENDPROC(aes_xts_decrypt)
 
 
 /* borrowed from aes-neon-blk.ko */
 asmlinkage void neon_aes_ecb_encrypt(u8 out[], u8 const in[], u32 const rk[],
-                                    int rounds, int blocks, int first);
+                                    int rounds, int blocks);
 asmlinkage void neon_aes_cbc_encrypt(u8 out[], u8 const in[], u32 const rk[],
-                                    int rounds, int blocks, u8 iv[],
-                                    int first);
+                                    int rounds, int blocks, u8 iv[]);
 
 struct aesbs_ctx {
        u8      rk[13 * (8 * AES_BLOCK_SIZE) + 32];
        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
        struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
        struct skcipher_walk walk;
-       int err, first = 1;
+       int err;
 
        err = skcipher_walk_virt(&walk, req, true);
 
 
                /* fall back to the non-bitsliced NEON implementation */
                neon_aes_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
-                                    ctx->enc, ctx->key.rounds, blocks, walk.iv,
-                                    first);
+                                    ctx->enc, ctx->key.rounds, blocks,
+                                    walk.iv);
                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
-               first = 0;
        }
        kernel_neon_end();
        return err;
        kernel_neon_begin();
 
        neon_aes_ecb_encrypt(walk.iv, walk.iv, ctx->twkey,
-                            ctx->key.rounds, 1, 1);
+                            ctx->key.rounds, 1);
 
        while (walk.nbytes >= AES_BLOCK_SIZE) {
                unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;