T1              .req    v2
        T2              .req    v3
        MASK            .req    v4
-       XL              .req    v5
-       XM              .req    v6
+       XM              .req    v5
+       XL              .req    v6
        XH              .req    v7
        IN1             .req    v7
 
        __pmull_ghash   p8
 ENDPROC(pmull_ghash_update_p8)
 
-       KS0             .req    v12
-       KS1             .req    v13
-       INP0            .req    v14
-       INP1            .req    v15
-
-       .macro          load_round_keys, rounds, rk
-       cmp             \rounds, #12
-       blo             2222f           /* 128 bits */
-       beq             1111f           /* 192 bits */
-       ld1             {v17.4s-v18.4s}, [\rk], #32
-1111:  ld1             {v19.4s-v20.4s}, [\rk], #32
-2222:  ld1             {v21.4s-v24.4s}, [\rk], #64
-       ld1             {v25.4s-v28.4s}, [\rk], #64
-       ld1             {v29.4s-v31.4s}, [\rk]
+       KS0             .req    v8
+       KS1             .req    v9
+       KS2             .req    v10
+       KS3             .req    v11
+
+       INP0            .req    v21
+       INP1            .req    v22
+       INP2            .req    v23
+       INP3            .req    v24
+
+       K0              .req    v25
+       K1              .req    v26
+       K2              .req    v27
+       K3              .req    v28
+       K4              .req    v12
+       K5              .req    v13
+       K6              .req    v4
+       K7              .req    v5
+       K8              .req    v14
+       K9              .req    v15
+       KK              .req    v29
+       KL              .req    v30
+       KM              .req    v31
+
+       .macro          load_round_keys, rounds, rk, tmp
+       add             \tmp, \rk, #64
+       ld1             {K0.4s-K3.4s}, [\rk]
+       ld1             {K4.4s-K5.4s}, [\tmp]
+       add             \tmp, \rk, \rounds, lsl #4
+       sub             \tmp, \tmp, #32
+       ld1             {KK.4s-KM.4s}, [\tmp]
        .endm
 
        .macro          enc_round, state, key
        aesmc           \state\().16b, \state\().16b
        .endm
 
-       .macro          enc_block, state, rounds
-       cmp             \rounds, #12
-       b.lo            2222f           /* 128 bits */
-       b.eq            1111f           /* 192 bits */
-       enc_round       \state, v17
-       enc_round       \state, v18
-1111:  enc_round       \state, v19
-       enc_round       \state, v20
-2222:  .irp            key, v21, v22, v23, v24, v25, v26, v27, v28, v29
+       .macro          enc_qround, s0, s1, s2, s3, key
+       enc_round       \s0, \key
+       enc_round       \s1, \key
+       enc_round       \s2, \key
+       enc_round       \s3, \key
+       .endm
+
+       .macro          enc_block, state, rounds, rk, tmp
+       add             \tmp, \rk, #96
+       ld1             {K6.4s-K7.4s}, [\tmp], #32
+       .irp            key, K0, K1, K2, K3, K4 K5
        enc_round       \state, \key
        .endr
-       aese            \state\().16b, v30.16b
-       eor             \state\().16b, \state\().16b, v31.16b
+
+       tbnz            \rounds, #2, .Lnot128_\@
+.Lout256_\@:
+       enc_round       \state, K6
+       enc_round       \state, K7
+
+.Lout192_\@:
+       enc_round       \state, KK
+       aese            \state\().16b, KL.16b
+       eor             \state\().16b, \state\().16b, KM.16b
+
+       .subsection     1
+.Lnot128_\@:
+       ld1             {K8.4s-K9.4s}, [\tmp], #32
+       enc_round       \state, K6
+       enc_round       \state, K7
+       ld1             {K6.4s-K7.4s}, [\tmp]
+       enc_round       \state, K8
+       enc_round       \state, K9
+       tbz             \rounds, #1, .Lout192_\@
+       b               .Lout256_\@
+       .previous
        .endm
 
+       .align          6
        .macro          pmull_gcm_do_crypt, enc
-       ld1             {SHASH.2d}, [x4], #16
-       ld1             {HH.2d}, [x4]
-       ld1             {XL.2d}, [x1]
-       ldr             x8, [x5, #8]                    // load lower counter
+       stp             x29, x30, [sp, #-32]!
+       mov             x29, sp
+       str             x19, [sp, #24]
+
+       load_round_keys x7, x6, x8
+
+       ld1             {SHASH.2d}, [x3], #16
+       ld1             {HH.2d-HH4.2d}, [x3]
 
-       movi            MASK.16b, #0xe1
        trn1            SHASH2.2d, SHASH.2d, HH.2d
        trn2            T1.2d, SHASH.2d, HH.2d
-CPU_LE(        rev             x8, x8          )
-       shl             MASK.2d, MASK.2d, #57
        eor             SHASH2.16b, SHASH2.16b, T1.16b
 
-       .if             \enc == 1
-       ldr             x10, [sp]
-       ld1             {KS0.16b-KS1.16b}, [x10]
-       .endif
+       trn1            HH34.2d, HH3.2d, HH4.2d
+       trn2            T1.2d, HH3.2d, HH4.2d
+       eor             HH34.16b, HH34.16b, T1.16b
 
-       cbnz            x6, 4f
+       ld1             {XL.2d}, [x4]
 
-0:     ld1             {INP0.16b-INP1.16b}, [x3], #32
+       cbz             x0, 3f                          // tag only?
 
-       rev             x9, x8
-       add             x11, x8, #1
-       add             x8, x8, #2
+       ldr             w8, [x5, #12]                   // load lower counter
+CPU_LE(        rev             w8, w8          )
 
-       .if             \enc == 1
-       eor             INP0.16b, INP0.16b, KS0.16b     // encrypt input
-       eor             INP1.16b, INP1.16b, KS1.16b
+0:     mov             w9, #4                          // max blocks per round
+       add             x10, x0, #0xf
+       lsr             x10, x10, #4                    // remaining blocks
+
+       subs            x0, x0, #64
+       csel            w9, w10, w9, mi
+       add             w8, w8, w9
+
+       bmi             1f
+       ld1             {INP0.16b-INP3.16b}, [x2], #64
+       .subsection     1
+       /*
+        * Populate the four input registers right to left with up to 63 bytes
+        * of data, using overlapping loads to avoid branches.
+        *
+        *                INP0     INP1     INP2     INP3
+        *  1 byte     |        |        |        |x       |
+        * 16 bytes    |        |        |        |xxxxxxxx|
+        * 17 bytes    |        |        |xxxxxxxx|x       |
+        * 47 bytes    |        |xxxxxxxx|xxxxxxxx|xxxxxxx |
+        * etc etc
+        *
+        * Note that this code may read up to 15 bytes before the start of
+        * the input. It is up to the calling code to ensure this is safe if
+        * this happens in the first iteration of the loop (i.e., when the
+        * input size is < 16 bytes)
+        */
+1:     mov             x15, #16
+       ands            x19, x0, #0xf
+       csel            x19, x19, x15, ne
+       adr_l           x17, .Lpermute_table + 16
+
+       sub             x11, x15, x19
+       add             x12, x17, x11
+       sub             x17, x17, x11
+       ld1             {T1.16b}, [x12]
+       sub             x10, x1, x11
+       sub             x11, x2, x11
+
+       cmp             x0, #-16
+       csel            x14, x15, xzr, gt
+       cmp             x0, #-32
+       csel            x15, x15, xzr, gt
+       cmp             x0, #-48
+       csel            x16, x19, xzr, gt
+       csel            x1, x1, x10, gt
+       csel            x2, x2, x11, gt
+
+       ld1             {INP0.16b}, [x2], x14
+       ld1             {INP1.16b}, [x2], x15
+       ld1             {INP2.16b}, [x2], x16
+       ld1             {INP3.16b}, [x2]
+       tbl             INP3.16b, {INP3.16b}, T1.16b
+       b               2f
+       .previous
+
+2:     .if             \enc == 0
+       bl              pmull_gcm_ghash_4x
        .endif
 
-       ld1             {KS0.8b}, [x5]                  // load upper counter
-       rev             x11, x11
-       sub             w0, w0, #2
-       mov             KS1.8b, KS0.8b
-       ins             KS0.d[1], x9                    // set lower counter
-       ins             KS1.d[1], x11
+       bl              pmull_gcm_enc_4x
 
-       rev64           T1.16b, INP1.16b
+       tbnz            x0, #63, 6f
+       st1             {INP0.16b-INP3.16b}, [x1], #64
+       .if             \enc == 1
+       bl              pmull_gcm_ghash_4x
+       .endif
+       bne             0b
 
-       cmp             w7, #12
-       b.ge            2f                              // AES-192/256?
+3:     ldp             x19, x10, [sp, #24]
+       cbz             x10, 5f                         // output tag?
 
-1:     enc_round       KS0, v21
-       ext             IN1.16b, T1.16b, T1.16b, #8
+       ld1             {INP3.16b}, [x10]               // load lengths[]
+       mov             w9, #1
+       bl              pmull_gcm_ghash_4x
 
-       enc_round       KS1, v21
-       pmull2          XH2.1q, SHASH.2d, IN1.2d        // a1 * b1
+       mov             w11, #(0x1 << 24)               // BE '1U'
+       ld1             {KS0.16b}, [x5]
+       mov             KS0.s[3], w11
 
-       enc_round       KS0, v22
-       eor             T1.16b, T1.16b, IN1.16b
+       enc_block       KS0, x7, x6, x12
 
-       enc_round       KS1, v22
-       pmull           XL2.1q, SHASH.1d, IN1.1d        // a0 * b0
+       ext             XL.16b, XL.16b, XL.16b, #8
+       rev64           XL.16b, XL.16b
+       eor             XL.16b, XL.16b, KS0.16b
+       st1             {XL.16b}, [x10]                 // store tag
 
-       enc_round       KS0, v23
-       pmull           XM2.1q, SHASH2.1d, T1.1d        // (a1 + a0)(b1 + b0)
+4:     ldp             x29, x30, [sp], #32
+       ret
 
-       enc_round       KS1, v23
-       rev64           T1.16b, INP0.16b
-       ext             T2.16b, XL.16b, XL.16b, #8
+5:
+CPU_LE(        rev             w8, w8          )
+       str             w8, [x5, #12]                   // store lower counter
+       st1             {XL.2d}, [x4]
+       b               4b
+
+6:     ld1             {T1.16b-T2.16b}, [x17], #32     // permute vectors
+       sub             x17, x17, x19, lsl #1
+
+       cmp             w9, #1
+       beq             7f
+       .subsection     1
+7:     ld1             {INP2.16b}, [x1]
+       tbx             INP2.16b, {INP3.16b}, T1.16b
+       mov             INP3.16b, INP2.16b
+       b               8f
+       .previous
+
+       st1             {INP0.16b}, [x1], x14
+       st1             {INP1.16b}, [x1], x15
+       st1             {INP2.16b}, [x1], x16
+       tbl             INP3.16b, {INP3.16b}, T1.16b
+       tbx             INP3.16b, {INP2.16b}, T2.16b
+8:     st1             {INP3.16b}, [x1]
 
-       enc_round       KS0, v24
-       ext             IN1.16b, T1.16b, T1.16b, #8
-       eor             T1.16b, T1.16b, T2.16b
+       .if             \enc == 1
+       ld1             {T1.16b}, [x17]
+       tbl             INP3.16b, {INP3.16b}, T1.16b    // clear non-data bits
+       bl              pmull_gcm_ghash_4x
+       .endif
+       b               3b
+       .endm
 
-       enc_round       KS1, v24
-       eor             XL.16b, XL.16b, IN1.16b
+       /*
+        * void pmull_gcm_encrypt(int blocks, u8 dst[], const u8 src[],
+        *                        struct ghash_key const *k, u64 dg[], u8 ctr[],
+        *                        int rounds, u8 tag)
+        */
+ENTRY(pmull_gcm_encrypt)
+       pmull_gcm_do_crypt      1
+ENDPROC(pmull_gcm_encrypt)
 
-       enc_round       KS0, v25
-       eor             T1.16b, T1.16b, XL.16b
+       /*
+        * void pmull_gcm_decrypt(int blocks, u8 dst[], const u8 src[],
+        *                        struct ghash_key const *k, u64 dg[], u8 ctr[],
+        *                        int rounds, u8 tag)
+        */
+ENTRY(pmull_gcm_decrypt)
+       pmull_gcm_do_crypt      0
+ENDPROC(pmull_gcm_decrypt)
 
-       enc_round       KS1, v25
-       pmull2          XH.1q, HH.2d, XL.2d             // a1 * b1
+pmull_gcm_ghash_4x:
+       movi            MASK.16b, #0xe1
+       shl             MASK.2d, MASK.2d, #57
 
-       enc_round       KS0, v26
-       pmull           XL.1q, HH.1d, XL.1d             // a0 * b0
+       rev64           T1.16b, INP0.16b
+       rev64           T2.16b, INP1.16b
+       rev64           TT3.16b, INP2.16b
+       rev64           TT4.16b, INP3.16b
 
-       enc_round       KS1, v26
-       pmull2          XM.1q, SHASH2.2d, T1.2d         // (a1 + a0)(b1 + b0)
+       ext             XL.16b, XL.16b, XL.16b, #8
 
-       enc_round       KS0, v27
-       eor             XL.16b, XL.16b, XL2.16b
-       eor             XH.16b, XH.16b, XH2.16b
+       tbz             w9, #2, 0f                      // <4 blocks?
+       .subsection     1
+0:     movi            XH2.16b, #0
+       movi            XM2.16b, #0
+       movi            XL2.16b, #0
 
-       enc_round       KS1, v27
-       eor             XM.16b, XM.16b, XM2.16b
-       ext             T1.16b, XL.16b, XH.16b, #8
+       tbz             w9, #0, 1f                      // 2 blocks?
+       tbz             w9, #1, 2f                      // 1 block?
 
-       enc_round       KS0, v28
-       eor             T2.16b, XL.16b, XH.16b
-       eor             XM.16b, XM.16b, T1.16b
+       eor             T2.16b, T2.16b, XL.16b
+       ext             T1.16b, T2.16b, T2.16b, #8
+       b               .Lgh3
 
-       enc_round       KS1, v28
-       eor             XM.16b, XM.16b, T2.16b
+1:     eor             TT3.16b, TT3.16b, XL.16b
+       ext             T2.16b, TT3.16b, TT3.16b, #8
+       b               .Lgh2
 
-       enc_round       KS0, v29
-       pmull           T2.1q, XL.1d, MASK.1d
+2:     eor             TT4.16b, TT4.16b, XL.16b
+       ext             IN1.16b, TT4.16b, TT4.16b, #8
+       b               .Lgh1
+       .previous
 
-       enc_round       KS1, v29
-       mov             XH.d[0], XM.d[1]
-       mov             XM.d[1], XL.d[0]
+       eor             T1.16b, T1.16b, XL.16b
+       ext             IN1.16b, T1.16b, T1.16b, #8
 
-       aese            KS0.16b, v30.16b
-       eor             XL.16b, XM.16b, T2.16b
+       pmull2          XH2.1q, HH4.2d, IN1.2d          // a1 * b1
+       eor             T1.16b, T1.16b, IN1.16b
+       pmull           XL2.1q, HH4.1d, IN1.1d          // a0 * b0
+       pmull2          XM2.1q, HH34.2d, T1.2d          // (a1 + a0)(b1 + b0)
 
-       aese            KS1.16b, v30.16b
-       ext             T2.16b, XL.16b, XL.16b, #8
+       ext             T1.16b, T2.16b, T2.16b, #8
+.Lgh3: eor             T2.16b, T2.16b, T1.16b
+       pmull2          XH.1q, HH3.2d, T1.2d            // a1 * b1
+       pmull           XL.1q, HH3.1d, T1.1d            // a0 * b0
+       pmull           XM.1q, HH34.1d, T2.1d           // (a1 + a0)(b1 + b0)
 
-       eor             KS0.16b, KS0.16b, v31.16b
-       pmull           XL.1q, XL.1d, MASK.1d
-       eor             T2.16b, T2.16b, XH.16b
+       eor             XH2.16b, XH2.16b, XH.16b
+       eor             XL2.16b, XL2.16b, XL.16b
+       eor             XM2.16b, XM2.16b, XM.16b
 
-       eor             KS1.16b, KS1.16b, v31.16b
-       eor             XL.16b, XL.16b, T2.16b
+       ext             T2.16b, TT3.16b, TT3.16b, #8
+.Lgh2: eor             TT3.16b, TT3.16b, T2.16b
+       pmull2          XH.1q, HH.2d, T2.2d             // a1 * b1
+       pmull           XL.1q, HH.1d, T2.1d             // a0 * b0
+       pmull2          XM.1q, SHASH2.2d, TT3.2d        // (a1 + a0)(b1 + b0)
 
-       .if             \enc == 0
-       eor             INP0.16b, INP0.16b, KS0.16b
-       eor             INP1.16b, INP1.16b, KS1.16b
-       .endif
+       eor             XH2.16b, XH2.16b, XH.16b
+       eor             XL2.16b, XL2.16b, XL.16b
+       eor             XM2.16b, XM2.16b, XM.16b
 
-       st1             {INP0.16b-INP1.16b}, [x2], #32
+       ext             IN1.16b, TT4.16b, TT4.16b, #8
+.Lgh1: eor             TT4.16b, TT4.16b, IN1.16b
+       pmull           XL.1q, SHASH.1d, IN1.1d         // a0 * b0
+       pmull2          XH.1q, SHASH.2d, IN1.2d         // a1 * b1
+       pmull           XM.1q, SHASH2.1d, TT4.1d        // (a1 + a0)(b1 + b0)
 
-       cbnz            w0, 0b
+       eor             XH.16b, XH.16b, XH2.16b
+       eor             XL.16b, XL.16b, XL2.16b
+       eor             XM.16b, XM.16b, XM2.16b
 
-CPU_LE(        rev             x8, x8          )
-       st1             {XL.2d}, [x1]
-       str             x8, [x5, #8]                    // store lower counter
+       eor             T2.16b, XL.16b, XH.16b
+       ext             T1.16b, XL.16b, XH.16b, #8
+       eor             XM.16b, XM.16b, T2.16b
 
-       .if             \enc == 1
-       st1             {KS0.16b-KS1.16b}, [x10]
-       .endif
+       __pmull_reduce_p64
+
+       eor             T2.16b, T2.16b, XH.16b
+       eor             XL.16b, XL.16b, T2.16b
 
        ret
+ENDPROC(pmull_gcm_ghash_4x)
+
+pmull_gcm_enc_4x:
+       ld1             {KS0.16b}, [x5]                 // load upper counter
+       sub             w10, w8, #4
+       sub             w11, w8, #3
+       sub             w12, w8, #2
+       sub             w13, w8, #1
+       rev             w10, w10
+       rev             w11, w11
+       rev             w12, w12
+       rev             w13, w13
+       mov             KS1.16b, KS0.16b
+       mov             KS2.16b, KS0.16b
+       mov             KS3.16b, KS0.16b
+       ins             KS0.s[3], w10                   // set lower counter
+       ins             KS1.s[3], w11
+       ins             KS2.s[3], w12
+       ins             KS3.s[3], w13
+
+       add             x10, x6, #96                    // round key pointer
+       ld1             {K6.4s-K7.4s}, [x10], #32
+       .irp            key, K0, K1, K2, K3, K4, K5
+       enc_qround      KS0, KS1, KS2, KS3, \key
+       .endr
 
-2:     b.eq            3f                              // AES-192?
-       enc_round       KS0, v17
-       enc_round       KS1, v17
-       enc_round       KS0, v18
-       enc_round       KS1, v18
-3:     enc_round       KS0, v19
-       enc_round       KS1, v19
-       enc_round       KS0, v20
-       enc_round       KS1, v20
-       b               1b
+       tbnz            x7, #2, .Lnot128
+       .subsection     1
+.Lnot128:
+       ld1             {K8.4s-K9.4s}, [x10], #32
+       .irp            key, K6, K7
+       enc_qround      KS0, KS1, KS2, KS3, \key
+       .endr
+       ld1             {K6.4s-K7.4s}, [x10]
+       .irp            key, K8, K9
+       enc_qround      KS0, KS1, KS2, KS3, \key
+       .endr
+       tbz             x7, #1, .Lout192
+       b               .Lout256
+       .previous
 
-4:     load_round_keys w7, x6
-       b               0b
-       .endm
+.Lout256:
+       .irp            key, K6, K7
+       enc_qround      KS0, KS1, KS2, KS3, \key
+       .endr
 
-       /*
-        * void pmull_gcm_encrypt(int blocks, u64 dg[], u8 dst[], const u8 src[],
-        *                        struct ghash_key const *k, u8 ctr[],
-        *                        int rounds, u8 ks[])
-        */
-ENTRY(pmull_gcm_encrypt)
-       pmull_gcm_do_crypt      1
-ENDPROC(pmull_gcm_encrypt)
+.Lout192:
+       enc_qround      KS0, KS1, KS2, KS3, KK
 
-       /*
-        * void pmull_gcm_decrypt(int blocks, u64 dg[], u8 dst[], const u8 src[],
-        *                        struct ghash_key const *k, u8 ctr[],
-        *                        int rounds)
-        */
-ENTRY(pmull_gcm_decrypt)
-       pmull_gcm_do_crypt      0
-ENDPROC(pmull_gcm_decrypt)
+       aese            KS0.16b, KL.16b
+       aese            KS1.16b, KL.16b
+       aese            KS2.16b, KL.16b
+       aese            KS3.16b, KL.16b
+
+       eor             KS0.16b, KS0.16b, KM.16b
+       eor             KS1.16b, KS1.16b, KM.16b
+       eor             KS2.16b, KS2.16b, KM.16b
+       eor             KS3.16b, KS3.16b, KM.16b
+
+       eor             INP0.16b, INP0.16b, KS0.16b
+       eor             INP1.16b, INP1.16b, KS1.16b
+       eor             INP2.16b, INP2.16b, KS2.16b
+       eor             INP3.16b, INP3.16b, KS3.16b
 
-       /*
-        * void pmull_gcm_encrypt_block(u8 dst[], u8 src[], u8 rk[], int rounds)
-        */
-ENTRY(pmull_gcm_encrypt_block)
-       cbz             x2, 0f
-       load_round_keys w3, x2
-0:     ld1             {v0.16b}, [x1]
-       enc_block       v0, w3
-       st1             {v0.16b}, [x0]
        ret
-ENDPROC(pmull_gcm_encrypt_block)
+ENDPROC(pmull_gcm_enc_4x)
+
+       .section        ".rodata", "a"
+       .align          6
+.Lpermute_table:
+       .byte           0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
+       .byte           0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
+       .byte            0x0,  0x1,  0x2,  0x3,  0x4,  0x5,  0x6,  0x7
+       .byte            0x8,  0x9,  0xa,  0xb,  0xc,  0xd,  0xe,  0xf
+       .byte           0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
+       .byte           0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
+       .byte            0x0,  0x1,  0x2,  0x3,  0x4,  0x5,  0x6,  0x7
+       .byte            0x8,  0x9,  0xa,  0xb,  0xc,  0xd,  0xe,  0xf
+       .previous
 
                                      struct ghash_key const *k,
                                      const char *head);
 
-asmlinkage void pmull_gcm_encrypt(int blocks, u64 dg[], u8 dst[],
-                                 const u8 src[], struct ghash_key const *k,
+asmlinkage void pmull_gcm_encrypt(int bytes, u8 dst[], const u8 src[],
+                                 struct ghash_key const *k, u64 dg[],
                                  u8 ctr[], u32 const rk[], int rounds,
-                                 u8 ks[]);
+                                 u8 tag[]);
 
-asmlinkage void pmull_gcm_decrypt(int blocks, u64 dg[], u8 dst[],
-                                 const u8 src[], struct ghash_key const *k,
-                                 u8 ctr[], u32 const rk[], int rounds);
-
-asmlinkage void pmull_gcm_encrypt_block(u8 dst[], u8 const src[],
-                                       u32 const rk[], int rounds);
+asmlinkage void pmull_gcm_decrypt(int bytes, u8 dst[], const u8 src[],
+                                 struct ghash_key const *k, u64 dg[],
+                                 u8 ctr[], u32 const rk[], int rounds,
+                                 u8 tag[]);
 
 static int ghash_init(struct shash_desc *desc)
 {
                                                struct ghash_key const *k,
                                                const char *head))
 {
-       if (likely(crypto_simd_usable())) {
+       if (likely(crypto_simd_usable() && simd_update)) {
                kernel_neon_begin();
                simd_update(blocks, dg, src, key, head);
                kernel_neon_end();
        }
 }
 
-static void gcm_final(struct aead_request *req, struct gcm_aes_ctx *ctx,
-                     u64 dg[], u8 tag[], int cryptlen)
-{
-       u8 mac[AES_BLOCK_SIZE];
-       u128 lengths;
-
-       lengths.a = cpu_to_be64(req->assoclen * 8);
-       lengths.b = cpu_to_be64(cryptlen * 8);
-
-       ghash_do_update(1, dg, (void *)&lengths, &ctx->ghash_key, NULL,
-                       pmull_ghash_update_p64);
-
-       put_unaligned_be64(dg[1], mac);
-       put_unaligned_be64(dg[0], mac + 8);
-
-       crypto_xor(tag, mac, AES_BLOCK_SIZE);
-}
-
 static int gcm_encrypt(struct aead_request *req)
 {
        struct crypto_aead *aead = crypto_aead_reqtfm(req);
        struct gcm_aes_ctx *ctx = crypto_aead_ctx(aead);
+       int nrounds = num_rounds(&ctx->aes_key);
        struct skcipher_walk walk;
+       u8 buf[AES_BLOCK_SIZE];
        u8 iv[AES_BLOCK_SIZE];
-       u8 ks[2 * AES_BLOCK_SIZE];
-       u8 tag[AES_BLOCK_SIZE];
        u64 dg[2] = {};
-       int nrounds = num_rounds(&ctx->aes_key);
+       u128 lengths;
+       u8 *tag;
        int err;
 
+       lengths.a = cpu_to_be64(req->assoclen * 8);
+       lengths.b = cpu_to_be64(req->cryptlen * 8);
+
        if (req->assoclen)
                gcm_calculate_auth_mac(req, dg);
 
        memcpy(iv, req->iv, GCM_IV_SIZE);
-       put_unaligned_be32(1, iv + GCM_IV_SIZE);
+       put_unaligned_be32(2, iv + GCM_IV_SIZE);
 
        err = skcipher_walk_aead_encrypt(&walk, req, false);
 
-       if (likely(crypto_simd_usable() && walk.total >= 2 * AES_BLOCK_SIZE)) {
-               u32 const *rk = NULL;
-
-               kernel_neon_begin();
-               pmull_gcm_encrypt_block(tag, iv, ctx->aes_key.key_enc, nrounds);
-               put_unaligned_be32(2, iv + GCM_IV_SIZE);
-               pmull_gcm_encrypt_block(ks, iv, NULL, nrounds);
-               put_unaligned_be32(3, iv + GCM_IV_SIZE);
-               pmull_gcm_encrypt_block(ks + AES_BLOCK_SIZE, iv, NULL, nrounds);
-               put_unaligned_be32(4, iv + GCM_IV_SIZE);
-
+       if (likely(crypto_simd_usable())) {
                do {
-                       int blocks = walk.nbytes / (2 * AES_BLOCK_SIZE) * 2;
+                       const u8 *src = walk.src.virt.addr;
+                       u8 *dst = walk.dst.virt.addr;
+                       int nbytes = walk.nbytes;
+
+                       tag = (u8 *)&lengths;
 
-                       if (rk)
-                               kernel_neon_begin();
+                       if (unlikely(nbytes > 0 && nbytes < AES_BLOCK_SIZE)) {
+                               src = dst = memcpy(buf + sizeof(buf) - nbytes,
+                                                  src, nbytes);
+                       } else if (nbytes < walk.total) {
+                               nbytes &= ~(AES_BLOCK_SIZE - 1);
+                               tag = NULL;
+                       }
 
-                       pmull_gcm_encrypt(blocks, dg, walk.dst.virt.addr,
-                                         walk.src.virt.addr, &ctx->ghash_key,
-                                         iv, rk, nrounds, ks);
+                       kernel_neon_begin();
+                       pmull_gcm_encrypt(nbytes, dst, src, &ctx->ghash_key, dg,
+                                         iv, ctx->aes_key.key_enc, nrounds,
+                                         tag);
                        kernel_neon_end();
 
-                       err = skcipher_walk_done(&walk,
-                                       walk.nbytes % (2 * AES_BLOCK_SIZE));
+                       if (unlikely(!nbytes))
+                               break;
 
-                       rk = ctx->aes_key.key_enc;
-               } while (walk.nbytes >= 2 * AES_BLOCK_SIZE);
-       } else {
-               aes_encrypt(&ctx->aes_key, tag, iv);
-               put_unaligned_be32(2, iv + GCM_IV_SIZE);
+                       if (unlikely(nbytes > 0 && nbytes < AES_BLOCK_SIZE))
+                               memcpy(walk.dst.virt.addr,
+                                      buf + sizeof(buf) - nbytes, nbytes);
 
-               while (walk.nbytes >= (2 * AES_BLOCK_SIZE)) {
-                       const int blocks =
-                               walk.nbytes / (2 * AES_BLOCK_SIZE) * 2;
+                       err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
+               } while (walk.nbytes);
+       } else {
+               while (walk.nbytes >= AES_BLOCK_SIZE) {
+                       int blocks = walk.nbytes / AES_BLOCK_SIZE;
+                       const u8 *src = walk.src.virt.addr;
                        u8 *dst = walk.dst.virt.addr;
-                       u8 *src = walk.src.virt.addr;
                        int remaining = blocks;
 
                        do {
-                               aes_encrypt(&ctx->aes_key, ks, iv);
-                               crypto_xor_cpy(dst, src, ks, AES_BLOCK_SIZE);
+                               aes_encrypt(&ctx->aes_key, buf, iv);
+                               crypto_xor_cpy(dst, src, buf, AES_BLOCK_SIZE);
                                crypto_inc(iv, AES_BLOCK_SIZE);
 
                                dst += AES_BLOCK_SIZE;
                                src += AES_BLOCK_SIZE;
                        } while (--remaining > 0);
 
-                       ghash_do_update(blocks, dg,
-                                       walk.dst.virt.addr, &ctx->ghash_key,
-                                       NULL, pmull_ghash_update_p64);
+                       ghash_do_update(blocks, dg, walk.dst.virt.addr,
+                                       &ctx->ghash_key, NULL, NULL);
 
                        err = skcipher_walk_done(&walk,
-                                                walk.nbytes % (2 * AES_BLOCK_SIZE));
-               }
-               if (walk.nbytes) {
-                       aes_encrypt(&ctx->aes_key, ks, iv);
-                       if (walk.nbytes > AES_BLOCK_SIZE) {
-                               crypto_inc(iv, AES_BLOCK_SIZE);
-                               aes_encrypt(&ctx->aes_key, ks + AES_BLOCK_SIZE, iv);
-                       }
+                                                walk.nbytes % AES_BLOCK_SIZE);
                }
-       }
 
-       /* handle the tail */
-       if (walk.nbytes) {
-               u8 buf[GHASH_BLOCK_SIZE];
-               unsigned int nbytes = walk.nbytes;
-               u8 *dst = walk.dst.virt.addr;
-               u8 *head = NULL;
+               /* handle the tail */
+               if (walk.nbytes) {
+                       aes_encrypt(&ctx->aes_key, buf, iv);
 
-               crypto_xor_cpy(walk.dst.virt.addr, walk.src.virt.addr, ks,
-                              walk.nbytes);
+                       crypto_xor_cpy(walk.dst.virt.addr, walk.src.virt.addr,
+                                      buf, walk.nbytes);
 
-               if (walk.nbytes > GHASH_BLOCK_SIZE) {
-                       head = dst;
-                       dst += GHASH_BLOCK_SIZE;
-                       nbytes %= GHASH_BLOCK_SIZE;
+                       memcpy(buf, walk.dst.virt.addr, walk.nbytes);
+                       memset(buf + walk.nbytes, 0, sizeof(buf) - walk.nbytes);
                }
 
-               memcpy(buf, dst, nbytes);
-               memset(buf + nbytes, 0, GHASH_BLOCK_SIZE - nbytes);
-               ghash_do_update(!!nbytes, dg, buf, &ctx->ghash_key, head,
-                               pmull_ghash_update_p64);
+               tag = (u8 *)&lengths;
+               ghash_do_update(1, dg, tag, &ctx->ghash_key,
+                               walk.nbytes ? buf : NULL, NULL);
 
-               err = skcipher_walk_done(&walk, 0);
+               if (walk.nbytes)
+                       err = skcipher_walk_done(&walk, 0);
+
+               put_unaligned_be64(dg[1], tag);
+               put_unaligned_be64(dg[0], tag + 8);
+               put_unaligned_be32(1, iv + GCM_IV_SIZE);
+               aes_encrypt(&ctx->aes_key, iv, iv);
+               crypto_xor(tag, iv, AES_BLOCK_SIZE);
        }
 
        if (err)
                return err;
 
-       gcm_final(req, ctx, dg, tag, req->cryptlen);
-
        /* copy authtag to end of dst */
        scatterwalk_map_and_copy(tag, req->dst, req->assoclen + req->cryptlen,
                                 crypto_aead_authsize(aead), 1);
        struct crypto_aead *aead = crypto_aead_reqtfm(req);
        struct gcm_aes_ctx *ctx = crypto_aead_ctx(aead);
        unsigned int authsize = crypto_aead_authsize(aead);
+       int nrounds = num_rounds(&ctx->aes_key);
        struct skcipher_walk walk;
-       u8 iv[2 * AES_BLOCK_SIZE];
-       u8 tag[AES_BLOCK_SIZE];
-       u8 buf[2 * GHASH_BLOCK_SIZE];
+       u8 buf[AES_BLOCK_SIZE];
+       u8 iv[AES_BLOCK_SIZE];
        u64 dg[2] = {};
-       int nrounds = num_rounds(&ctx->aes_key);
+       u128 lengths;
+       u8 *tag;
        int err;
 
+       lengths.a = cpu_to_be64(req->assoclen * 8);
+       lengths.b = cpu_to_be64((req->cryptlen - authsize) * 8);
+
        if (req->assoclen)
                gcm_calculate_auth_mac(req, dg);
 
        memcpy(iv, req->iv, GCM_IV_SIZE);
-       put_unaligned_be32(1, iv + GCM_IV_SIZE);
+       put_unaligned_be32(2, iv + GCM_IV_SIZE);
 
        err = skcipher_walk_aead_decrypt(&walk, req, false);
 
-       if (likely(crypto_simd_usable() && walk.total >= 2 * AES_BLOCK_SIZE)) {
-               u32 const *rk = NULL;
-
-               kernel_neon_begin();
-               pmull_gcm_encrypt_block(tag, iv, ctx->aes_key.key_enc, nrounds);
-               put_unaligned_be32(2, iv + GCM_IV_SIZE);
-
+       if (likely(crypto_simd_usable())) {
                do {
-                       int blocks = walk.nbytes / (2 * AES_BLOCK_SIZE) * 2;
-                       int rem = walk.total - blocks * AES_BLOCK_SIZE;
-
-                       if (rk)
-                               kernel_neon_begin();
-
-                       pmull_gcm_decrypt(blocks, dg, walk.dst.virt.addr,
-                                         walk.src.virt.addr, &ctx->ghash_key,
-                                         iv, rk, nrounds);
-
-                       /* check if this is the final iteration of the loop */
-                       if (rem < (2 * AES_BLOCK_SIZE)) {
-                               u8 *iv2 = iv + AES_BLOCK_SIZE;
-
-                               if (rem > AES_BLOCK_SIZE) {
-                                       memcpy(iv2, iv, AES_BLOCK_SIZE);
-                                       crypto_inc(iv2, AES_BLOCK_SIZE);
-                               }
+                       const u8 *src = walk.src.virt.addr;
+                       u8 *dst = walk.dst.virt.addr;
+                       int nbytes = walk.nbytes;
 
-                               pmull_gcm_encrypt_block(iv, iv, NULL, nrounds);
+                       tag = (u8 *)&lengths;
 
-                               if (rem > AES_BLOCK_SIZE)
-                                       pmull_gcm_encrypt_block(iv2, iv2, NULL,
-                                                               nrounds);
+                       if (unlikely(nbytes > 0 && nbytes < AES_BLOCK_SIZE)) {
+                               src = dst = memcpy(buf + sizeof(buf) - nbytes,
+                                                  src, nbytes);
+                       } else if (nbytes < walk.total) {
+                               nbytes &= ~(AES_BLOCK_SIZE - 1);
+                               tag = NULL;
                        }
 
+                       kernel_neon_begin();
+                       pmull_gcm_decrypt(nbytes, dst, src, &ctx->ghash_key, dg,
+                                         iv, ctx->aes_key.key_enc, nrounds,
+                                         tag);
                        kernel_neon_end();
 
-                       err = skcipher_walk_done(&walk,
-                                       walk.nbytes % (2 * AES_BLOCK_SIZE));
+                       if (unlikely(!nbytes))
+                               break;
 
-                       rk = ctx->aes_key.key_enc;
-               } while (walk.nbytes >= 2 * AES_BLOCK_SIZE);
-       } else {
-               aes_encrypt(&ctx->aes_key, tag, iv);
-               put_unaligned_be32(2, iv + GCM_IV_SIZE);
+                       if (unlikely(nbytes > 0 && nbytes < AES_BLOCK_SIZE))
+                               memcpy(walk.dst.virt.addr,
+                                      buf + sizeof(buf) - nbytes, nbytes);
 
-               while (walk.nbytes >= (2 * AES_BLOCK_SIZE)) {
-                       int blocks = walk.nbytes / (2 * AES_BLOCK_SIZE) * 2;
+                       err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
+               } while (walk.nbytes);
+       } else {
+               while (walk.nbytes >= AES_BLOCK_SIZE) {
+                       int blocks = walk.nbytes / AES_BLOCK_SIZE;
+                       const u8 *src = walk.src.virt.addr;
                        u8 *dst = walk.dst.virt.addr;
-                       u8 *src = walk.src.virt.addr;
 
                        ghash_do_update(blocks, dg, walk.src.virt.addr,
-                                       &ctx->ghash_key, NULL,
-                                       pmull_ghash_update_p64);
+                                       &ctx->ghash_key, NULL, NULL);
 
                        do {
                                aes_encrypt(&ctx->aes_key, buf, iv);
                        } while (--blocks > 0);
 
                        err = skcipher_walk_done(&walk,
-                                                walk.nbytes % (2 * AES_BLOCK_SIZE));
+                                                walk.nbytes % AES_BLOCK_SIZE);
                }
-               if (walk.nbytes) {
-                       if (walk.nbytes > AES_BLOCK_SIZE) {
-                               u8 *iv2 = iv + AES_BLOCK_SIZE;
-
-                               memcpy(iv2, iv, AES_BLOCK_SIZE);
-                               crypto_inc(iv2, AES_BLOCK_SIZE);
 
-                               aes_encrypt(&ctx->aes_key, iv2, iv2);
-                       }
-                       aes_encrypt(&ctx->aes_key, iv, iv);
+               /* handle the tail */
+               if (walk.nbytes) {
+                       memcpy(buf, walk.src.virt.addr, walk.nbytes);
+                       memset(buf + walk.nbytes, 0, sizeof(buf) - walk.nbytes);
                }
-       }
 
-       /* handle the tail */
-       if (walk.nbytes) {
-               const u8 *src = walk.src.virt.addr;
-               const u8 *head = NULL;
-               unsigned int nbytes = walk.nbytes;
+               tag = (u8 *)&lengths;
+               ghash_do_update(1, dg, tag, &ctx->ghash_key,
+                               walk.nbytes ? buf : NULL, NULL);
 
-               if (walk.nbytes > GHASH_BLOCK_SIZE) {
-                       head = src;
-                       src += GHASH_BLOCK_SIZE;
-                       nbytes %= GHASH_BLOCK_SIZE;
-               }
+               if (walk.nbytes) {
+                       aes_encrypt(&ctx->aes_key, buf, iv);
 
-               memcpy(buf, src, nbytes);
-               memset(buf + nbytes, 0, GHASH_BLOCK_SIZE - nbytes);
-               ghash_do_update(!!nbytes, dg, buf, &ctx->ghash_key, head,
-                               pmull_ghash_update_p64);
+                       crypto_xor_cpy(walk.dst.virt.addr, walk.src.virt.addr,
+                                      buf, walk.nbytes);
 
-               crypto_xor_cpy(walk.dst.virt.addr, walk.src.virt.addr, iv,
-                              walk.nbytes);
+                       err = skcipher_walk_done(&walk, 0);
+               }
 
-               err = skcipher_walk_done(&walk, 0);
+               put_unaligned_be64(dg[1], tag);
+               put_unaligned_be64(dg[0], tag + 8);
+               put_unaligned_be32(1, iv + GCM_IV_SIZE);
+               aes_encrypt(&ctx->aes_key, iv, iv);
+               crypto_xor(tag, iv, AES_BLOCK_SIZE);
        }
 
        if (err)
                return err;
 
-       gcm_final(req, ctx, dg, tag, req->cryptlen - authsize);
-
        /* compare calculated auth tag with the stored one */
        scatterwalk_map_and_copy(buf, req->src,
                                 req->assoclen + req->cryptlen - authsize,
 
 static struct aead_alg gcm_aes_alg = {
        .ivsize                 = GCM_IV_SIZE,
-       .chunksize              = 2 * AES_BLOCK_SIZE,
+       .chunksize              = AES_BLOCK_SIZE,
        .maxauthsize            = AES_BLOCK_SIZE,
        .setkey                 = gcm_setkey,
        .setauthsize            = gcm_setauthsize,