return false;
 }
 
-static u64 expand_ack(u64 old_ack, u64 cur_ack, bool use_64bit)
+u64 __mptcp_expand_seq(u64 old_seq, u64 cur_seq)
 {
-       u32 old_ack32, cur_ack32;
-
-       if (use_64bit)
-               return cur_ack;
-
-       old_ack32 = (u32)old_ack;
-       cur_ack32 = (u32)cur_ack;
-       cur_ack = (old_ack & GENMASK_ULL(63, 32)) + cur_ack32;
-       if (unlikely(before(cur_ack32, old_ack32)))
-               return cur_ack + (1LL << 32);
-       return cur_ack;
+       u32 old_seq32, cur_seq32;
+
+       old_seq32 = (u32)old_seq;
+       cur_seq32 = (u32)cur_seq;
+       cur_seq = (old_seq & GENMASK_ULL(63, 32)) + cur_seq32;
+       if (unlikely(cur_seq32 < old_seq32 && before(old_seq32, cur_seq32)))
+               return cur_seq + (1LL << 32);
+
+       /* reverse wrap could happen, too */
+       if (unlikely(cur_seq32 > old_seq32 && after(old_seq32, cur_seq32)))
+               return cur_seq - (1LL << 32);
+       return cur_seq;
 }
 
 static void ack_update_msk(struct mptcp_sock *msk,
         * more dangerous than missing an ack
         */
        old_snd_una = msk->snd_una;
-       new_snd_una = expand_ack(old_snd_una, mp_opt->data_ack, mp_opt->ack64);
+       new_snd_una = mptcp_expand_seq(old_snd_una, mp_opt->data_ack, mp_opt->ack64);
 
        /* ACK for data not even sent yet? Ignore. */
        if (after64(new_snd_una, snd_nxt))
                return false;
 
        WRITE_ONCE(msk->rcv_data_fin_seq,
-                  expand_ack(READ_ONCE(msk->ack_seq), data_fin_seq, use_64bit));
+                  mptcp_expand_seq(READ_ONCE(msk->ack_seq), data_fin_seq, use_64bit));
        WRITE_ONCE(msk->rcv_data_fin, 1);
 
        return true;
 
 int mptcp_getsockopt(struct sock *sk, int level, int optname,
                     char __user *optval, int __user *option);
 
+u64 __mptcp_expand_seq(u64 old_seq, u64 cur_seq);
+static inline u64 mptcp_expand_seq(u64 old_seq, u64 cur_seq, bool use_64bit)
+{
+       if (use_64bit)
+               return cur_seq;
+
+       return __mptcp_expand_seq(old_seq, cur_seq);
+}
 void __mptcp_check_push(struct sock *sk, struct sock *ssk);
 void __mptcp_data_acked(struct sock *sk);
 void __mptcp_error_report(struct sock *sk);