.endm
 
        /*
-        * aes_ecb_encrypt(u8 out[], u8 const in[], u8 const rk[], int rounds,
+        * aes_ecb_encrypt(u8 out[], u8 const in[], u32 const rk[], int rounds,
         *                 int blocks)
-        * aes_ecb_decrypt(u8 out[], u8 const in[], u8 const rk[], int rounds,
+        * aes_ecb_decrypt(u8 out[], u8 const in[], u32 const rk[], int rounds,
         *                 int blocks)
         */
 ENTRY(ce_aes_ecb_encrypt)
 ENDPROC(ce_aes_ecb_decrypt)
 
        /*
-        * aes_cbc_encrypt(u8 out[], u8 const in[], u8 const rk[], int rounds,
+        * aes_cbc_encrypt(u8 out[], u8 const in[], u32 const rk[], int rounds,
         *                 int blocks, u8 iv[])
-        * aes_cbc_decrypt(u8 out[], u8 const in[], u8 const rk[], int rounds,
+        * aes_cbc_decrypt(u8 out[], u8 const in[], u32 const rk[], int rounds,
         *                 int blocks, u8 iv[])
         */
 ENTRY(ce_aes_cbc_encrypt)
 ENDPROC(ce_aes_cbc_decrypt)
 
        /*
-        * aes_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[], int rounds,
+        * aes_ctr_encrypt(u8 out[], u8 const in[], u32 const rk[], int rounds,
         *                 int blocks, u8 ctr[])
         */
 ENTRY(ce_aes_ctr_encrypt)
 ENDPROC(ce_aes_ctr_encrypt)
 
        /*
-        * aes_xts_encrypt(u8 out[], u8 const in[], u8 const rk1[], int rounds,
-        *                 int blocks, u8 iv[], u8 const rk2[], int first)
-        * aes_xts_decrypt(u8 out[], u8 const in[], u8 const rk1[], int rounds,
-        *                 int blocks, u8 iv[], u8 const rk2[], int first)
+        * aes_xts_encrypt(u8 out[], u8 const in[], u32 const rk1[], int rounds,
+        *                 int blocks, u8 iv[], u32 const rk2[], int first)
+        * aes_xts_decrypt(u8 out[], u8 const in[], u32 const rk1[], int rounds,
+        *                 int blocks, u8 iv[], u32 const rk2[], int first)
         */
 
        .macro          next_tweak, out, in, const, tmp
 
 asmlinkage u32 ce_aes_sub(u32 input);
 asmlinkage void ce_aes_invert(void *dst, void *src);
 
-asmlinkage void ce_aes_ecb_encrypt(u8 out[], u8 const in[], u8 const rk[],
+asmlinkage void ce_aes_ecb_encrypt(u8 out[], u8 const in[], u32 const rk[],
                                   int rounds, int blocks);
-asmlinkage void ce_aes_ecb_decrypt(u8 out[], u8 const in[], u8 const rk[],
+asmlinkage void ce_aes_ecb_decrypt(u8 out[], u8 const in[], u32 const rk[],
                                   int rounds, int blocks);
 
-asmlinkage void ce_aes_cbc_encrypt(u8 out[], u8 const in[], u8 const rk[],
+asmlinkage void ce_aes_cbc_encrypt(u8 out[], u8 const in[], u32 const rk[],
                                   int rounds, int blocks, u8 iv[]);
-asmlinkage void ce_aes_cbc_decrypt(u8 out[], u8 const in[], u8 const rk[],
+asmlinkage void ce_aes_cbc_decrypt(u8 out[], u8 const in[], u32 const rk[],
                                   int rounds, int blocks, u8 iv[]);
 
-asmlinkage void ce_aes_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[],
+asmlinkage void ce_aes_ctr_encrypt(u8 out[], u8 const in[], u32 const rk[],
                                   int rounds, int blocks, u8 ctr[]);
 
-asmlinkage void ce_aes_xts_encrypt(u8 out[], u8 const in[], u8 const rk1[],
+asmlinkage void ce_aes_xts_encrypt(u8 out[], u8 const in[], u32 const rk1[],
                                   int rounds, int blocks, u8 iv[],
-                                  u8 const rk2[], int first);
-asmlinkage void ce_aes_xts_decrypt(u8 out[], u8 const in[], u8 const rk1[],
+                                  u32 const rk2[], int first);
+asmlinkage void ce_aes_xts_decrypt(u8 out[], u8 const in[], u32 const rk1[],
                                   int rounds, int blocks, u8 iv[],
-                                  u8 const rk2[], int first);
+                                  u32 const rk2[], int first);
 
 struct aes_block {
        u8 b[AES_BLOCK_SIZE];
        kernel_neon_begin();
        while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
                ce_aes_ecb_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
-                                  (u8 *)ctx->key_enc, num_rounds(ctx), blocks);
+                                  ctx->key_enc, num_rounds(ctx), blocks);
                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
        }
        kernel_neon_end();
        kernel_neon_begin();
        while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
                ce_aes_ecb_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
-                                  (u8 *)ctx->key_dec, num_rounds(ctx), blocks);
+                                  ctx->key_dec, num_rounds(ctx), blocks);
                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
        }
        kernel_neon_end();
        kernel_neon_begin();
        while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
                ce_aes_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
-                                  (u8 *)ctx->key_enc, num_rounds(ctx), blocks,
+                                  ctx->key_enc, num_rounds(ctx), blocks,
                                   walk.iv);
                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
        }
        kernel_neon_begin();
        while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
                ce_aes_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
-                                  (u8 *)ctx->key_dec, num_rounds(ctx), blocks,
+                                  ctx->key_dec, num_rounds(ctx), blocks,
                                   walk.iv);
                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
        }
        kernel_neon_begin();
        while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
                ce_aes_ctr_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
-                                  (u8 *)ctx->key_enc, num_rounds(ctx), blocks,
+                                  ctx->key_enc, num_rounds(ctx), blocks,
                                   walk.iv);
                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
        }
                 */
                blocks = -1;
 
-               ce_aes_ctr_encrypt(tail, NULL, (u8 *)ctx->key_enc,
-                                  num_rounds(ctx), blocks, walk.iv);
+               ce_aes_ctr_encrypt(tail, NULL, ctx->key_enc, num_rounds(ctx),
+                                  blocks, walk.iv);
                crypto_xor_cpy(tdst, tsrc, tail, nbytes);
                err = skcipher_walk_done(&walk, 0);
        }
        kernel_neon_begin();
        for (first = 1; (blocks = (walk.nbytes / AES_BLOCK_SIZE)); first = 0) {
                ce_aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
-                                  (u8 *)ctx->key1.key_enc, rounds, blocks,
-                                  walk.iv, (u8 *)ctx->key2.key_enc, first);
+                                  ctx->key1.key_enc, rounds, blocks, walk.iv,
+                                  ctx->key2.key_enc, first);
                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
        }
        kernel_neon_end();
        kernel_neon_begin();
        for (first = 1; (blocks = (walk.nbytes / AES_BLOCK_SIZE)); first = 0) {
                ce_aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
-                                  (u8 *)ctx->key1.key_dec, rounds, blocks,
-                                  walk.iv, (u8 *)ctx->key2.key_enc, first);
+                                  ctx->key1.key_dec, rounds, blocks, walk.iv,
+                                  ctx->key2.key_enc, first);
                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
        }
        kernel_neon_end();