__aes_arm64_encrypt(ctx->aes_key.key_enc, tag, iv, nrounds);
                put_unaligned_be32(2, iv + GCM_IV_SIZE);
 
-               while (walk.nbytes >= AES_BLOCK_SIZE) {
+               while (walk.nbytes >= (2 * AES_BLOCK_SIZE)) {
                        int blocks = walk.nbytes / AES_BLOCK_SIZE;
                        u8 *dst = walk.dst.virt.addr;
                        u8 *src = walk.src.virt.addr;
                                        NULL);
 
                        err = skcipher_walk_done(&walk,
-                                                walk.nbytes % AES_BLOCK_SIZE);
+                                                walk.nbytes % (2 * AES_BLOCK_SIZE));
                }
-               if (walk.nbytes)
+               if (walk.nbytes) {
                        __aes_arm64_encrypt(ctx->aes_key.key_enc, ks, iv,
                                            nrounds);
+                       if (walk.nbytes > AES_BLOCK_SIZE) {
+                               crypto_inc(iv, AES_BLOCK_SIZE);
+                               __aes_arm64_encrypt(ctx->aes_key.key_enc,
+                                                   ks + AES_BLOCK_SIZE, iv,
+                                                   nrounds);
+                       }
+               }
        }
 
        /* handle the tail */
                __aes_arm64_encrypt(ctx->aes_key.key_enc, tag, iv, nrounds);
                put_unaligned_be32(2, iv + GCM_IV_SIZE);
 
-               while (walk.nbytes >= AES_BLOCK_SIZE) {
+               while (walk.nbytes >= (2 * AES_BLOCK_SIZE)) {
                        int blocks = walk.nbytes / AES_BLOCK_SIZE;
                        u8 *dst = walk.dst.virt.addr;
                        u8 *src = walk.src.virt.addr;
                        } while (--blocks > 0);
 
                        err = skcipher_walk_done(&walk,
-                                                walk.nbytes % AES_BLOCK_SIZE);
+                                                walk.nbytes % (2 * AES_BLOCK_SIZE));
                }
-               if (walk.nbytes)
+               if (walk.nbytes) {
+                       if (walk.nbytes > AES_BLOCK_SIZE) {
+                               u8 *iv2 = iv + AES_BLOCK_SIZE;
+
+                               memcpy(iv2, iv, AES_BLOCK_SIZE);
+                               crypto_inc(iv2, AES_BLOCK_SIZE);
+
+                               __aes_arm64_encrypt(ctx->aes_key.key_enc, iv2,
+                                                   iv2, nrounds);
+                       }
                        __aes_arm64_encrypt(ctx->aes_key.key_enc, iv, iv,
                                            nrounds);
+               }
        }
 
        /* handle the tail */