asmlinkage void aesbs_xts_decrypt(u8 out[], u8 const in[], u8 const rk[],
                                  int rounds, int blocks, u8 iv[]);
 
-asmlinkage void __aes_arm_encrypt(const u32 rk[], int rounds, const u8 in[],
-                                 u8 out[]);
-
 struct aesbs_ctx {
        int     rounds;
        u8      rk[13 * (8 * AES_BLOCK_SIZE) + 32] __aligned(AES_BLOCK_SIZE);
 
 struct aesbs_cbc_ctx {
        struct aesbs_ctx        key;
-       u32                     enc[AES_MAX_KEYLENGTH_U32];
+       struct crypto_cipher    *enc_tfm;
 };
 
 struct aesbs_xts_ctx {
        struct aesbs_ctx        key;
-       u32                     twkey[AES_MAX_KEYLENGTH_U32];
+       struct crypto_cipher    *tweak_tfm;
 };
 
 static int aesbs_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
 
        ctx->key.rounds = 6 + key_len / 4;
 
-       memcpy(ctx->enc, rk.key_enc, sizeof(ctx->enc));
-
        kernel_neon_begin();
        aesbs_convert_key(ctx->key.rk, rk.key_enc, ctx->key.rounds);
        kernel_neon_end();
 
-       return 0;
+       return crypto_cipher_setkey(ctx->enc_tfm, in_key, key_len);
 }
 
 static void cbc_encrypt_one(struct crypto_skcipher *tfm, const u8 *src, u8 *dst)
 {
        struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
 
-       __aes_arm_encrypt(ctx->enc, ctx->key.rounds, src, dst);
+       crypto_cipher_encrypt_one(ctx->enc_tfm, dst, src);
 }
 
 static int cbc_encrypt(struct skcipher_request *req)
        return err;
 }
 
+static int cbc_init(struct crypto_tfm *tfm)
+{
+       struct aesbs_cbc_ctx *ctx = crypto_tfm_ctx(tfm);
+
+       ctx->enc_tfm = crypto_alloc_cipher("aes", 0, 0);
+       if (IS_ERR(ctx->enc_tfm))
+               return PTR_ERR(ctx->enc_tfm);
+       return 0;
+}
+
+static void cbc_exit(struct crypto_tfm *tfm)
+{
+       struct aesbs_cbc_ctx *ctx = crypto_tfm_ctx(tfm);
+
+       crypto_free_cipher(ctx->enc_tfm);
+}
+
 static int ctr_encrypt(struct skcipher_request *req)
 {
        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
                            unsigned int key_len)
 {
        struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
-       struct crypto_aes_ctx rk;
        int err;
 
        err = xts_verify_key(tfm, in_key, key_len);
                return err;
 
        key_len /= 2;
-       err = crypto_aes_expand_key(&rk, in_key + key_len, key_len);
+       err = crypto_cipher_setkey(ctx->tweak_tfm, in_key + key_len, key_len);
        if (err)
                return err;
 
-       memcpy(ctx->twkey, rk.key_enc, sizeof(ctx->twkey));
-
        return aesbs_setkey(tfm, in_key, key_len);
 }
 
+static int xts_init(struct crypto_tfm *tfm)
+{
+       struct aesbs_xts_ctx *ctx = crypto_tfm_ctx(tfm);
+
+       ctx->tweak_tfm = crypto_alloc_cipher("aes", 0, 0);
+       if (IS_ERR(ctx->tweak_tfm))
+               return PTR_ERR(ctx->tweak_tfm);
+       return 0;
+}
+
+static void xts_exit(struct crypto_tfm *tfm)
+{
+       struct aesbs_xts_ctx *ctx = crypto_tfm_ctx(tfm);
+
+       crypto_free_cipher(ctx->tweak_tfm);
+}
+
 static int __xts_crypt(struct skcipher_request *req,
                       void (*fn)(u8 out[], u8 const in[], u8 const rk[],
                                  int rounds, int blocks, u8 iv[]))
 
        err = skcipher_walk_virt(&walk, req, true);
 
-       __aes_arm_encrypt(ctx->twkey, ctx->key.rounds, walk.iv, walk.iv);
+       crypto_cipher_encrypt_one(ctx->tweak_tfm, walk.iv, walk.iv);
 
        kernel_neon_begin();
        while (walk.nbytes >= AES_BLOCK_SIZE) {
        .base.cra_ctxsize       = sizeof(struct aesbs_cbc_ctx),
        .base.cra_module        = THIS_MODULE,
        .base.cra_flags         = CRYPTO_ALG_INTERNAL,
+       .base.cra_init          = cbc_init,
+       .base.cra_exit          = cbc_exit,
 
        .min_keysize            = AES_MIN_KEY_SIZE,
        .max_keysize            = AES_MAX_KEY_SIZE,
        .base.cra_ctxsize       = sizeof(struct aesbs_xts_ctx),
        .base.cra_module        = THIS_MODULE,
        .base.cra_flags         = CRYPTO_ALG_INTERNAL,
+       .base.cra_init          = xts_init,
+       .base.cra_exit          = xts_exit,
 
        .min_keysize            = 2 * AES_MIN_KEY_SIZE,
        .max_keysize            = 2 * AES_MAX_KEY_SIZE,
        return err;
 }
 
-module_init(aes_init);
+late_initcall(aes_init);
 module_exit(aes_exit);