/*
         * aesbs_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[],
-        *                   int rounds, int blocks, u8 iv[], bool final)
+        *                   int rounds, int blocks, u8 iv[], u8 final[])
         */
 ENTRY(aesbs_ctr_encrypt)
        stp             x29, x30, [sp, #-16]!
        mov             x29, sp
 
-       add             x4, x4, x6              // do one extra block if final
+       cmp             x6, #0
+       cset            x10, ne
+       add             x4, x4, x10             // do one extra block if final
 
        ldp             x7, x8, [x5]
        ld1             {v0.16b}, [x5]
        csel            x4, x4, xzr, pl
        csel            x9, x9, xzr, le
 
+       tbnz            x9, #1, 0f
        next_ctr        v1
+       tbnz            x9, #2, 0f
        next_ctr        v2
+       tbnz            x9, #3, 0f
        next_ctr        v3
+       tbnz            x9, #4, 0f
        next_ctr        v4
+       tbnz            x9, #5, 0f
        next_ctr        v5
+       tbnz            x9, #6, 0f
        next_ctr        v6
+       tbnz            x9, #7, 0f
        next_ctr        v7
 
 0:     mov             bskey, x2
        mov             rounds, x3
        bl              aesbs_encrypt8
 
-       lsr             x9, x9, x6              // disregard the extra block
+       lsr             x9, x9, x10             // disregard the extra block
        tbnz            x9, #0, 0f
 
        ld1             {v8.16b}, [x1], #16
        eor             v5.16b, v5.16b, v15.16b
        st1             {v5.16b}, [x0], #16
 
-       next_ctr        v0
+8:     next_ctr        v0
        cbnz            x4, 99b
 
 0:     st1             {v0.16b}, [x5]
-8:     ldp             x29, x30, [sp], #16
+       ldp             x29, x30, [sp], #16
        ret
 
        /*
-        * If we are handling the tail of the input (x6 == 1), return the
-        * final keystream block back to the caller via the IV buffer.
+        * If we are handling the tail of the input (x6 != NULL), return the
+        * final keystream block back to the caller.
         */
 1:     cbz             x6, 8b
-       st1             {v1.16b}, [x5]
+       st1             {v1.16b}, [x6]
        b               8b
 2:     cbz             x6, 8b
-       st1             {v4.16b}, [x5]
+       st1             {v4.16b}, [x6]
        b               8b
 3:     cbz             x6, 8b
-       st1             {v6.16b}, [x5]
+       st1             {v6.16b}, [x6]
        b               8b
 4:     cbz             x6, 8b
-       st1             {v3.16b}, [x5]
+       st1             {v3.16b}, [x6]
        b               8b
 5:     cbz             x6, 8b
-       st1             {v7.16b}, [x5]
+       st1             {v7.16b}, [x6]
        b               8b
 6:     cbz             x6, 8b
-       st1             {v2.16b}, [x5]
+       st1             {v2.16b}, [x6]
        b               8b
 7:     cbz             x6, 8b
-       st1             {v5.16b}, [x5]
+       st1             {v5.16b}, [x6]
        b               8b
 ENDPROC(aesbs_ctr_encrypt)
 
                                  int rounds, int blocks, u8 iv[]);
 
 asmlinkage void aesbs_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[],
-                                 int rounds, int blocks, u8 iv[], bool final);
+                                 int rounds, int blocks, u8 iv[], u8 final[]);
 
 asmlinkage void aesbs_xts_encrypt(u8 out[], u8 const in[], u8 const rk[],
                                  int rounds, int blocks, u8 iv[]);
        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
        struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
        struct skcipher_walk walk;
+       u8 buf[AES_BLOCK_SIZE];
        int err;
 
        err = skcipher_walk_virt(&walk, req, true);
        kernel_neon_begin();
        while (walk.nbytes > 0) {
                unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
-               bool final = (walk.total % AES_BLOCK_SIZE) != 0;
+               u8 *final = (walk.total % AES_BLOCK_SIZE) ? buf : NULL;
 
                if (walk.nbytes < walk.total) {
                        blocks = round_down(blocks,
                                            walk.stride / AES_BLOCK_SIZE);
-                       final = false;
+                       final = NULL;
                }
 
                aesbs_ctr_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
 
                        if (dst != src)
                                memcpy(dst, src, walk.total % AES_BLOCK_SIZE);
-                       crypto_xor(dst, walk.iv, walk.total % AES_BLOCK_SIZE);
+                       crypto_xor(dst, final, walk.total % AES_BLOCK_SIZE);
 
                        err = skcipher_walk_done(&walk, 0);
                        break;