/*
         * aes_xts_encrypt(u8 out[], u8 const in[], u32 const rk1[], int rounds,
-        *                 int blocks, u8 iv[], u32 const rk2[], int first)
+        *                 int bytes, u8 iv[], u32 const rk2[], int first)
         * aes_xts_decrypt(u8 out[], u8 const in[], u32 const rk1[], int rounds,
-        *                 int blocks, u8 iv[], u32 const rk2[], int first)
+        *                 int bytes, u8 iv[], u32 const rk2[], int first)
         */
 
        .macro          next_tweak, out, in, const, tmp
 .Lxtsencloop4x:
        next_tweak      q4, q4, q15, q10
 .Lxtsenc4x:
-       subs            r4, r4, #4
+       subs            r4, r4, #64
        bmi             .Lxtsenc1x
        vld1.8          {q0-q1}, [r1]!          @ get 4 pt blocks
        vld1.8          {q2-q3}, [r1]!
        vst1.8          {q2-q3}, [r0]!
        vmov            q4, q7
        teq             r4, #0
-       beq             .Lxtsencout
+       beq             .Lxtsencret
        b               .Lxtsencloop4x
 .Lxtsenc1x:
-       adds            r4, r4, #4
+       adds            r4, r4, #64
        beq             .Lxtsencout
+       subs            r4, r4, #16
+       bmi             .LxtsencctsNx
 .Lxtsencloop:
        vld1.8          {q0}, [r1]!
+.Lxtsencctsout:
        veor            q0, q0, q4
        bl              aes_encrypt
        veor            q0, q0, q4
-       vst1.8          {q0}, [r0]!
-       subs            r4, r4, #1
+       teq             r4, #0
        beq             .Lxtsencout
+       subs            r4, r4, #16
        next_tweak      q4, q4, q15, q6
+       bmi             .Lxtsenccts
+       vst1.8          {q0}, [r0]!
        b               .Lxtsencloop
 .Lxtsencout:
+       vst1.8          {q0}, [r0]
+.Lxtsencret:
        vst1.8          {q4}, [r5]
        pop             {r4-r6, pc}
+
+.LxtsencctsNx:
+       vmov            q0, q3
+       sub             r0, r0, #16
+.Lxtsenccts:
+       movw            ip, :lower16:.Lcts_permute_table
+       movt            ip, :upper16:.Lcts_permute_table
+
+       add             r1, r1, r4              @ rewind input pointer
+       add             r4, r4, #16             @ # bytes in final block
+       add             lr, ip, #32
+       add             ip, ip, r4
+       sub             lr, lr, r4
+       add             r4, r0, r4              @ output address of final block
+
+       vld1.8          {q1}, [r1]              @ load final partial block
+       vld1.8          {q2}, [ip]
+       vld1.8          {q3}, [lr]
+
+       vtbl.8          d4, {d0-d1}, d4
+       vtbl.8          d5, {d0-d1}, d5
+       vtbx.8          d0, {d2-d3}, d6
+       vtbx.8          d1, {d2-d3}, d7
+
+       vst1.8          {q2}, [r4]              @ overlapping stores
+       mov             r4, #0
+       b               .Lxtsencctsout
 ENDPROC(ce_aes_xts_encrypt)
 
 
        prepare_key     r2, r3
        vmov            q4, q0
 
+       /* subtract 16 bytes if we are doing CTS */
+       tst             r4, #0xf
+       subne           r4, r4, #0x10
+
        teq             r6, #0                  @ start of a block?
        bne             .Lxtsdec4x
 
 .Lxtsdecloop4x:
        next_tweak      q4, q4, q15, q10
 .Lxtsdec4x:
-       subs            r4, r4, #4
+       subs            r4, r4, #64
        bmi             .Lxtsdec1x
        vld1.8          {q0-q1}, [r1]!          @ get 4 ct blocks
        vld1.8          {q2-q3}, [r1]!
        beq             .Lxtsdecout
        b               .Lxtsdecloop4x
 .Lxtsdec1x:
-       adds            r4, r4, #4
+       adds            r4, r4, #64
        beq             .Lxtsdecout
+       subs            r4, r4, #16
 .Lxtsdecloop:
        vld1.8          {q0}, [r1]!
+       bmi             .Lxtsdeccts
+.Lxtsdecctsout:
        veor            q0, q0, q4
-       add             ip, r2, #32             @ 3rd round key
        bl              aes_decrypt
        veor            q0, q0, q4
        vst1.8          {q0}, [r0]!
-       subs            r4, r4, #1
+       teq             r4, #0
        beq             .Lxtsdecout
+       subs            r4, r4, #16
        next_tweak      q4, q4, q15, q6
        b               .Lxtsdecloop
 .Lxtsdecout:
        vst1.8          {q4}, [r5]
        pop             {r4-r6, pc}
+
+.Lxtsdeccts:
+       movw            ip, :lower16:.Lcts_permute_table
+       movt            ip, :upper16:.Lcts_permute_table
+
+       add             r1, r1, r4              @ rewind input pointer
+       add             r4, r4, #16             @ # bytes in final block
+       add             lr, ip, #32
+       add             ip, ip, r4
+       sub             lr, lr, r4
+       add             r4, r0, r4              @ output address of final block
+
+       next_tweak      q5, q4, q15, q6
+
+       vld1.8          {q1}, [r1]              @ load final partial block
+       vld1.8          {q2}, [ip]
+       vld1.8          {q3}, [lr]
+
+       veor            q0, q0, q5
+       bl              aes_decrypt
+       veor            q0, q0, q5
+
+       vtbl.8          d4, {d0-d1}, d4
+       vtbl.8          d5, {d0-d1}, d5
+       vtbx.8          d0, {d2-d3}, d6
+       vtbx.8          d1, {d2-d3}, d7
+
+       vst1.8          {q2}, [r4]              @ overlapping stores
+       mov             r4, #0
+       b               .Lxtsdecctsout
 ENDPROC(ce_aes_xts_decrypt)
 
        /*
        vst1.32         {q0}, [r0]
        bx              lr
 ENDPROC(ce_aes_invert)
+
+       .section        ".rodata", "a"
+       .align          6
+.Lcts_permute_table:
+       .byte           0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
+       .byte           0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
+       .byte            0x0,  0x1,  0x2,  0x3,  0x4,  0x5,  0x6,  0x7
+       .byte            0x8,  0x9,  0xa,  0xb,  0xc,  0xd,  0xe,  0xf
+       .byte           0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
+       .byte           0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
 
 #include <crypto/ctr.h>
 #include <crypto/internal/simd.h>
 #include <crypto/internal/skcipher.h>
+#include <crypto/scatterwalk.h>
 #include <linux/cpufeature.h>
 #include <linux/module.h>
 #include <crypto/xts.h>
                                   int rounds, int blocks, u8 ctr[]);
 
 asmlinkage void ce_aes_xts_encrypt(u8 out[], u8 const in[], u32 const rk1[],
-                                  int rounds, int blocks, u8 iv[],
+                                  int rounds, int bytes, u8 iv[],
                                   u32 const rk2[], int first);
 asmlinkage void ce_aes_xts_decrypt(u8 out[], u8 const in[], u32 const rk1[],
-                                  int rounds, int blocks, u8 iv[],
+                                  int rounds, int bytes, u8 iv[],
                                   u32 const rk2[], int first);
 
 struct aes_block {
        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
        struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
        int err, first, rounds = num_rounds(&ctx->key1);
+       int tail = req->cryptlen % AES_BLOCK_SIZE;
+       struct scatterlist sg_src[2], sg_dst[2];
+       struct skcipher_request subreq;
+       struct scatterlist *src, *dst;
        struct skcipher_walk walk;
-       unsigned int blocks;
+
+       if (req->cryptlen < AES_BLOCK_SIZE)
+               return -EINVAL;
 
        err = skcipher_walk_virt(&walk, req, false);
 
-       for (first = 1; (blocks = (walk.nbytes / AES_BLOCK_SIZE)); first = 0) {
+       if (unlikely(tail > 0 && walk.nbytes < walk.total)) {
+               int xts_blocks = DIV_ROUND_UP(req->cryptlen,
+                                             AES_BLOCK_SIZE) - 2;
+
+               skcipher_walk_abort(&walk);
+
+               skcipher_request_set_tfm(&subreq, tfm);
+               skcipher_request_set_callback(&subreq,
+                                             skcipher_request_flags(req),
+                                             NULL, NULL);
+               skcipher_request_set_crypt(&subreq, req->src, req->dst,
+                                          xts_blocks * AES_BLOCK_SIZE,
+                                          req->iv);
+               req = &subreq;
+               err = skcipher_walk_virt(&walk, req, false);
+       } else {
+               tail = 0;
+       }
+
+       for (first = 1; walk.nbytes >= AES_BLOCK_SIZE; first = 0) {
+               int nbytes = walk.nbytes;
+
+               if (walk.nbytes < walk.total)
+                       nbytes &= ~(AES_BLOCK_SIZE - 1);
+
                kernel_neon_begin();
                ce_aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
-                                  ctx->key1.key_enc, rounds, blocks, walk.iv,
+                                  ctx->key1.key_enc, rounds, nbytes, walk.iv,
                                   ctx->key2.key_enc, first);
                kernel_neon_end();
-               err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
+               err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
        }
-       return err;
+
+       if (err || likely(!tail))
+               return err;
+
+       dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
+       if (req->dst != req->src)
+               dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
+
+       skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
+                                  req->iv);
+
+       err = skcipher_walk_virt(&walk, req, false);
+       if (err)
+               return err;
+
+       kernel_neon_begin();
+       ce_aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
+                          ctx->key1.key_enc, rounds, walk.nbytes, walk.iv,
+                          ctx->key2.key_enc, first);
+       kernel_neon_end();
+
+       return skcipher_walk_done(&walk, 0);
 }
 
 static int xts_decrypt(struct skcipher_request *req)
        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
        struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
        int err, first, rounds = num_rounds(&ctx->key1);
+       int tail = req->cryptlen % AES_BLOCK_SIZE;
+       struct scatterlist sg_src[2], sg_dst[2];
+       struct skcipher_request subreq;
+       struct scatterlist *src, *dst;
        struct skcipher_walk walk;
-       unsigned int blocks;
+
+       if (req->cryptlen < AES_BLOCK_SIZE)
+               return -EINVAL;
 
        err = skcipher_walk_virt(&walk, req, false);
 
-       for (first = 1; (blocks = (walk.nbytes / AES_BLOCK_SIZE)); first = 0) {
+       if (unlikely(tail > 0 && walk.nbytes < walk.total)) {
+               int xts_blocks = DIV_ROUND_UP(req->cryptlen,
+                                             AES_BLOCK_SIZE) - 2;
+
+               skcipher_walk_abort(&walk);
+
+               skcipher_request_set_tfm(&subreq, tfm);
+               skcipher_request_set_callback(&subreq,
+                                             skcipher_request_flags(req),
+                                             NULL, NULL);
+               skcipher_request_set_crypt(&subreq, req->src, req->dst,
+                                          xts_blocks * AES_BLOCK_SIZE,
+                                          req->iv);
+               req = &subreq;
+               err = skcipher_walk_virt(&walk, req, false);
+       } else {
+               tail = 0;
+       }
+
+       for (first = 1; walk.nbytes >= AES_BLOCK_SIZE; first = 0) {
+               int nbytes = walk.nbytes;
+
+               if (walk.nbytes < walk.total)
+                       nbytes &= ~(AES_BLOCK_SIZE - 1);
+
                kernel_neon_begin();
                ce_aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
-                                  ctx->key1.key_dec, rounds, blocks, walk.iv,
+                                  ctx->key1.key_dec, rounds, nbytes, walk.iv,
                                   ctx->key2.key_enc, first);
                kernel_neon_end();
-               err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
+               err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
        }
-       return err;
+
+       if (err || likely(!tail))
+               return err;
+
+       dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
+       if (req->dst != req->src)
+               dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
+
+       skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
+                                  req->iv);
+
+       err = skcipher_walk_virt(&walk, req, false);
+       if (err)
+               return err;
+
+       kernel_neon_begin();
+       ce_aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
+                          ctx->key1.key_dec, rounds, walk.nbytes, walk.iv,
+                          ctx->key2.key_enc, first);
+       kernel_neon_end();
+
+       return skcipher_walk_done(&walk, 0);
 }
 
 static struct skcipher_alg aes_algs[] = { {
        .min_keysize            = 2 * AES_MIN_KEY_SIZE,
        .max_keysize            = 2 * AES_MAX_KEY_SIZE,
        .ivsize                 = AES_BLOCK_SIZE,
+       .walksize               = 2 * AES_BLOCK_SIZE,
        .setkey                 = xts_set_key,
        .encrypt                = xts_encrypt,
        .decrypt                = xts_decrypt,