#include <crypto/ctr.h>
 #include <crypto/internal/simd.h>
 #include <crypto/internal/skcipher.h>
+#include <crypto/scatterwalk.h>
 #include <crypto/xts.h>
 #include <linux/module.h>
 
                                     int rounds, int blocks);
 asmlinkage void neon_aes_cbc_encrypt(u8 out[], u8 const in[], u32 const rk[],
                                     int rounds, int blocks, u8 iv[]);
+asmlinkage void neon_aes_xts_encrypt(u8 out[], u8 const in[],
+                                    u32 const rk1[], int rounds, int bytes,
+                                    u32 const rk2[], u8 iv[], int first);
+asmlinkage void neon_aes_xts_decrypt(u8 out[], u8 const in[],
+                                    u32 const rk1[], int rounds, int bytes,
+                                    u32 const rk2[], u8 iv[], int first);
 
 struct aesbs_ctx {
        u8      rk[13 * (8 * AES_BLOCK_SIZE) + 32];
 struct aesbs_xts_ctx {
        struct aesbs_ctx        key;
        u32                     twkey[AES_MAX_KEYLENGTH_U32];
+       struct crypto_aes_ctx   cts;
 };
 
 static int aesbs_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
                return err;
 
        key_len /= 2;
+       err = aes_expandkey(&ctx->cts, in_key, key_len);
+       if (err)
+               return err;
+
        err = aes_expandkey(&rk, in_key + key_len, key_len);
        if (err)
                return err;
        return ctr_encrypt(req);
 }
 
-static int __xts_crypt(struct skcipher_request *req,
+static int __xts_crypt(struct skcipher_request *req, bool encrypt,
                       void (*fn)(u8 out[], u8 const in[], u8 const rk[],
                                  int rounds, int blocks, u8 iv[]))
 {
        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
        struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
+       int tail = req->cryptlen % (8 * AES_BLOCK_SIZE);
+       struct scatterlist sg_src[2], sg_dst[2];
+       struct skcipher_request subreq;
+       struct scatterlist *src, *dst;
        struct skcipher_walk walk;
-       int err;
+       int nbytes, err;
+       int first = 1;
+       u8 *out, *in;
+
+       if (req->cryptlen < AES_BLOCK_SIZE)
+               return -EINVAL;
+
+       /* ensure that the cts tail is covered by a single step */
+       if (unlikely(tail > 0 && tail < AES_BLOCK_SIZE)) {
+               int xts_blocks = DIV_ROUND_UP(req->cryptlen,
+                                             AES_BLOCK_SIZE) - 2;
+
+               skcipher_request_set_tfm(&subreq, tfm);
+               skcipher_request_set_callback(&subreq,
+                                             skcipher_request_flags(req),
+                                             NULL, NULL);
+               skcipher_request_set_crypt(&subreq, req->src, req->dst,
+                                          xts_blocks * AES_BLOCK_SIZE,
+                                          req->iv);
+               req = &subreq;
+       } else {
+               tail = 0;
+       }
 
        err = skcipher_walk_virt(&walk, req, false);
        if (err)
                return err;
 
-       kernel_neon_begin();
-       neon_aes_ecb_encrypt(walk.iv, walk.iv, ctx->twkey, ctx->key.rounds, 1);
-       kernel_neon_end();
-
        while (walk.nbytes >= AES_BLOCK_SIZE) {
                unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
 
-               if (walk.nbytes < walk.total)
+               if (walk.nbytes < walk.total || walk.nbytes % AES_BLOCK_SIZE)
                        blocks = round_down(blocks,
                                            walk.stride / AES_BLOCK_SIZE);
 
+               out = walk.dst.virt.addr;
+               in = walk.src.virt.addr;
+               nbytes = walk.nbytes;
+
                kernel_neon_begin();
-               fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->key.rk,
-                  ctx->key.rounds, blocks, walk.iv);
+               if (likely(blocks > 6)) { /* plain NEON is faster otherwise */
+                       if (first)
+                               neon_aes_ecb_encrypt(walk.iv, walk.iv,
+                                                    ctx->twkey,
+                                                    ctx->key.rounds, 1);
+                       first = 0;
+
+                       fn(out, in, ctx->key.rk, ctx->key.rounds, blocks,
+                          walk.iv);
+
+                       out += blocks * AES_BLOCK_SIZE;
+                       in += blocks * AES_BLOCK_SIZE;
+                       nbytes -= blocks * AES_BLOCK_SIZE;
+               }
+
+               if (walk.nbytes == walk.total && nbytes > 0)
+                       goto xts_tail;
+
                kernel_neon_end();
-               err = skcipher_walk_done(&walk,
-                                        walk.nbytes - blocks * AES_BLOCK_SIZE);
+               skcipher_walk_done(&walk, nbytes);
        }
-       return err;
+
+       if (err || likely(!tail))
+               return err;
+
+       /* handle ciphertext stealing */
+       dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
+       if (req->dst != req->src)
+               dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
+
+       skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
+                                  req->iv);
+
+       err = skcipher_walk_virt(&walk, req, false);
+       if (err)
+               return err;
+
+       out = walk.dst.virt.addr;
+       in = walk.src.virt.addr;
+       nbytes = walk.nbytes;
+
+       kernel_neon_begin();
+xts_tail:
+       if (encrypt)
+               neon_aes_xts_encrypt(out, in, ctx->cts.key_enc, ctx->key.rounds,
+                                    nbytes, ctx->twkey, walk.iv, first ?: 2);
+       else
+               neon_aes_xts_decrypt(out, in, ctx->cts.key_dec, ctx->key.rounds,
+                                    nbytes, ctx->twkey, walk.iv, first ?: 2);
+       kernel_neon_end();
+
+       return skcipher_walk_done(&walk, 0);
 }
 
 static int xts_encrypt(struct skcipher_request *req)
 {
-       return __xts_crypt(req, aesbs_xts_encrypt);
+       return __xts_crypt(req, true, aesbs_xts_encrypt);
 }
 
 static int xts_decrypt(struct skcipher_request *req)
 {
-       return __xts_crypt(req, aesbs_xts_decrypt);
+       return __xts_crypt(req, false, aesbs_xts_decrypt);
 }
 
 static struct skcipher_alg aes_algs[] = { {