* Copyright (c) 2016-2017, Lance Chao <lancerchao@fb.com>. All rights reserved.
  * Copyright (c) 2016, Fridolin Pokorny <fridolin.pokorny@gmail.com>. All rights reserved.
  * Copyright (c) 2016, Nikos Mavrogiannopoulos <nmav@gnutls.org>. All rights reserved.
+ * Copyright (c) 2018, Covalent IO, Inc. http://covalent.io
  *
  * This software is available to you under a choice of one of two
  * licenses.  You may choose to be licensed under the terms of the GNU
        return sk_msg_clone(sk, msg_pl, msg_en, skip, len);
 }
 
-static void tls_free_open_rec(struct sock *sk)
+static struct tls_rec *tls_get_rec(struct sock *sk)
 {
        struct tls_context *tls_ctx = tls_get_ctx(sk);
        struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
-       struct tls_rec *rec = ctx->open_rec;
+       struct sk_msg *msg_pl, *msg_en;
+       struct tls_rec *rec;
+       int mem_size;
 
-       /* Return if there is no open record */
+       mem_size = sizeof(struct tls_rec) + crypto_aead_reqsize(ctx->aead_send);
+
+       rec = kzalloc(mem_size, sk->sk_allocation);
        if (!rec)
-               return;
+               return NULL;
 
+       msg_pl = &rec->msg_plaintext;
+       msg_en = &rec->msg_encrypted;
+
+       sk_msg_init(msg_pl);
+       sk_msg_init(msg_en);
+
+       sg_init_table(rec->sg_aead_in, 2);
+       sg_set_buf(&rec->sg_aead_in[0], rec->aad_space,
+                  sizeof(rec->aad_space));
+       sg_unmark_end(&rec->sg_aead_in[1]);
+
+       sg_init_table(rec->sg_aead_out, 2);
+       sg_set_buf(&rec->sg_aead_out[0], rec->aad_space,
+                  sizeof(rec->aad_space));
+       sg_unmark_end(&rec->sg_aead_out[1]);
+
+       return rec;
+}
+
+static void tls_free_rec(struct sock *sk, struct tls_rec *rec)
+{
        sk_msg_free(sk, &rec->msg_encrypted);
        sk_msg_free(sk, &rec->msg_plaintext);
        kfree(rec);
 }
 
+static void tls_free_open_rec(struct sock *sk)
+{
+       struct tls_context *tls_ctx = tls_get_ctx(sk);
+       struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
+       struct tls_rec *rec = ctx->open_rec;
+
+       if (rec) {
+               tls_free_rec(sk, rec);
+               ctx->open_rec = NULL;
+       }
+}
+
 int tls_tx_records(struct sock *sk, int flags)
 {
        struct tls_context *tls_ctx = tls_get_ctx(sk);
        return rc;
 }
 
+static int tls_split_open_record(struct sock *sk, struct tls_rec *from,
+                                struct tls_rec **to, struct sk_msg *msg_opl,
+                                struct sk_msg *msg_oen, u32 split_point,
+                                u32 tx_overhead_size, u32 *orig_end)
+{
+       u32 i, j, bytes = 0, apply = msg_opl->apply_bytes;
+       struct scatterlist *sge, *osge, *nsge;
+       u32 orig_size = msg_opl->sg.size;
+       struct scatterlist tmp = { };
+       struct sk_msg *msg_npl;
+       struct tls_rec *new;
+       int ret;
+
+       new = tls_get_rec(sk);
+       if (!new)
+               return -ENOMEM;
+       ret = sk_msg_alloc(sk, &new->msg_encrypted, msg_opl->sg.size +
+                          tx_overhead_size, 0);
+       if (ret < 0) {
+               tls_free_rec(sk, new);
+               return ret;
+       }
+
+       *orig_end = msg_opl->sg.end;
+       i = msg_opl->sg.start;
+       sge = sk_msg_elem(msg_opl, i);
+       while (apply && sge->length) {
+               if (sge->length > apply) {
+                       u32 len = sge->length - apply;
+
+                       get_page(sg_page(sge));
+                       sg_set_page(&tmp, sg_page(sge), len,
+                                   sge->offset + apply);
+                       sge->length = apply;
+                       bytes += apply;
+                       apply = 0;
+               } else {
+                       apply -= sge->length;
+                       bytes += sge->length;
+               }
+
+               sk_msg_iter_var_next(i);
+               if (i == msg_opl->sg.end)
+                       break;
+               sge = sk_msg_elem(msg_opl, i);
+       }
+
+       msg_opl->sg.end = i;
+       msg_opl->sg.curr = i;
+       msg_opl->sg.copybreak = 0;
+       msg_opl->apply_bytes = 0;
+       msg_opl->sg.size = bytes;
+
+       msg_npl = &new->msg_plaintext;
+       msg_npl->apply_bytes = apply;
+       msg_npl->sg.size = orig_size - bytes;
+
+       j = msg_npl->sg.start;
+       nsge = sk_msg_elem(msg_npl, j);
+       if (tmp.length) {
+               memcpy(nsge, &tmp, sizeof(*nsge));
+               sk_msg_iter_var_next(j);
+               nsge = sk_msg_elem(msg_npl, j);
+       }
+
+       osge = sk_msg_elem(msg_opl, i);
+       while (osge->length) {
+               memcpy(nsge, osge, sizeof(*nsge));
+               sg_unmark_end(nsge);
+               sk_msg_iter_var_next(i);
+               sk_msg_iter_var_next(j);
+               if (i == *orig_end)
+                       break;
+               osge = sk_msg_elem(msg_opl, i);
+               nsge = sk_msg_elem(msg_npl, j);
+       }
+
+       msg_npl->sg.end = j;
+       msg_npl->sg.curr = j;
+       msg_npl->sg.copybreak = 0;
+
+       *to = new;
+       return 0;
+}
+
+static void tls_merge_open_record(struct sock *sk, struct tls_rec *to,
+                                 struct tls_rec *from, u32 orig_end)
+{
+       struct sk_msg *msg_npl = &from->msg_plaintext;
+       struct sk_msg *msg_opl = &to->msg_plaintext;
+       struct scatterlist *osge, *nsge;
+       u32 i, j;
+
+       i = msg_opl->sg.end;
+       sk_msg_iter_var_prev(i);
+       j = msg_npl->sg.start;
+
+       osge = sk_msg_elem(msg_opl, i);
+       nsge = sk_msg_elem(msg_npl, j);
+
+       if (sg_page(osge) == sg_page(nsge) &&
+           osge->offset + osge->length == nsge->offset) {
+               osge->length += nsge->length;
+               put_page(sg_page(nsge));
+       }
+
+       msg_opl->sg.end = orig_end;
+       msg_opl->sg.curr = orig_end;
+       msg_opl->sg.copybreak = 0;
+       msg_opl->apply_bytes = msg_opl->sg.size + msg_npl->sg.size;
+       msg_opl->sg.size += msg_npl->sg.size;
+
+       sk_msg_free(sk, &to->msg_encrypted);
+       sk_msg_xfer_full(&to->msg_encrypted, &from->msg_encrypted);
+
+       kfree(from);
+}
+
 static int tls_push_record(struct sock *sk, int flags,
                           unsigned char record_type)
 {
        struct tls_context *tls_ctx = tls_get_ctx(sk);
        struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
-       struct tls_rec *rec = ctx->open_rec;
+       struct tls_rec *rec = ctx->open_rec, *tmp = NULL;
+       u32 i, split_point, uninitialized_var(orig_end);
        struct sk_msg *msg_pl, *msg_en;
        struct aead_request *req;
+       bool split;
        int rc;
-       u32 i;
 
        if (!rec)
                return 0;
        msg_pl = &rec->msg_plaintext;
        msg_en = &rec->msg_encrypted;
 
+       split_point = msg_pl->apply_bytes;
+       split = split_point && split_point < msg_pl->sg.size;
+       if (split) {
+               rc = tls_split_open_record(sk, rec, &tmp, msg_pl, msg_en,
+                                          split_point, tls_ctx->tx.overhead_size,
+                                          &orig_end);
+               if (rc < 0)
+                       return rc;
+               sk_msg_trim(sk, msg_en, msg_pl->sg.size +
+                           tls_ctx->tx.overhead_size);
+       }
+
        rec->tx_flags = flags;
        req = &rec->aead_req;
 
 
        rc = tls_do_encryption(sk, tls_ctx, ctx, req, msg_pl->sg.size, i);
        if (rc < 0) {
-               if (rc != -EINPROGRESS)
+               if (rc != -EINPROGRESS) {
                        tls_err_abort(sk, EBADMSG);
+                       if (split) {
+                               tls_ctx->pending_open_record_frags = true;
+                               tls_merge_open_record(sk, rec, tmp, orig_end);
+                       }
+               }
                return rc;
+       } else if (split) {
+               msg_pl = &tmp->msg_plaintext;
+               msg_en = &tmp->msg_encrypted;
+               sk_msg_trim(sk, msg_en, msg_pl->sg.size +
+                           tls_ctx->tx.overhead_size);
+               tls_ctx->pending_open_record_frags = true;
+               ctx->open_rec = tmp;
        }
 
        return tls_tx_records(sk, flags);
 }
 
-static int tls_sw_push_pending_record(struct sock *sk, int flags)
-{
-       return tls_push_record(sk, flags, TLS_RECORD_TYPE_DATA);
-}
-
-static struct tls_rec *get_rec(struct sock *sk)
+static int bpf_exec_tx_verdict(struct sk_msg *msg, struct sock *sk,
+                              bool full_record, u8 record_type,
+                              size_t *copied, int flags)
 {
        struct tls_context *tls_ctx = tls_get_ctx(sk);
        struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
-       struct sk_msg *msg_pl, *msg_en;
+       struct sk_msg msg_redir = { };
+       struct sk_psock *psock;
+       struct sock *sk_redir;
        struct tls_rec *rec;
-       int mem_size;
+       int err = 0, send;
+       bool enospc;
+
+       psock = sk_psock_get(sk);
+       if (!psock)
+               return tls_push_record(sk, flags, record_type);
+more_data:
+       enospc = sk_msg_full(msg);
+       if (psock->eval == __SK_NONE)
+               psock->eval = sk_psock_msg_verdict(sk, psock, msg);
+       if (msg->cork_bytes && msg->cork_bytes > msg->sg.size &&
+           !enospc && !full_record) {
+               err = -ENOSPC;
+               goto out_err;
+       }
+       msg->cork_bytes = 0;
+       send = msg->sg.size;
+       if (msg->apply_bytes && msg->apply_bytes < send)
+               send = msg->apply_bytes;
+
+       switch (psock->eval) {
+       case __SK_PASS:
+               err = tls_push_record(sk, flags, record_type);
+               if (err < 0) {
+                       *copied -= sk_msg_free(sk, msg);
+                       tls_free_open_rec(sk);
+                       goto out_err;
+               }
+               break;
+       case __SK_REDIRECT:
+               sk_redir = psock->sk_redir;
+               memcpy(&msg_redir, msg, sizeof(*msg));
+               if (msg->apply_bytes < send)
+                       msg->apply_bytes = 0;
+               else
+                       msg->apply_bytes -= send;
+               sk_msg_return_zero(sk, msg, send);
+               msg->sg.size -= send;
+               release_sock(sk);
+               err = tcp_bpf_sendmsg_redir(sk_redir, &msg_redir, send, flags);
+               lock_sock(sk);
+               if (err < 0) {
+                       *copied -= sk_msg_free_nocharge(sk, &msg_redir);
+                       msg->sg.size = 0;
+               }
+               if (msg->sg.size == 0)
+                       tls_free_open_rec(sk);
+               break;
+       case __SK_DROP:
+       default:
+               sk_msg_free_partial(sk, msg, send);
+               if (msg->apply_bytes < send)
+                       msg->apply_bytes = 0;
+               else
+                       msg->apply_bytes -= send;
+               if (msg->sg.size == 0)
+                       tls_free_open_rec(sk);
+               *copied -= send;
+               err = -EACCES;
+       }
 
-       /* Return if we already have an open record */
-       if (ctx->open_rec)
-               return ctx->open_rec;
+       if (likely(!err)) {
+               bool reset_eval = !ctx->open_rec;
 
-       mem_size = sizeof(struct tls_rec) + crypto_aead_reqsize(ctx->aead_send);
+               rec = ctx->open_rec;
+               if (rec) {
+                       msg = &rec->msg_plaintext;
+                       if (!msg->apply_bytes)
+                               reset_eval = true;
+               }
+               if (reset_eval) {
+                       psock->eval = __SK_NONE;
+                       if (psock->sk_redir) {
+                               sock_put(psock->sk_redir);
+                               psock->sk_redir = NULL;
+                       }
+               }
+               if (rec)
+                       goto more_data;
+       }
+ out_err:
+       sk_psock_put(sk, psock);
+       return err;
+}
+
+static int tls_sw_push_pending_record(struct sock *sk, int flags)
+{
+       struct tls_context *tls_ctx = tls_get_ctx(sk);
+       struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
+       struct tls_rec *rec = ctx->open_rec;
+       struct sk_msg *msg_pl;
+       size_t copied;
 
-       rec = kzalloc(mem_size, sk->sk_allocation);
        if (!rec)
-               return NULL;
+               return 0;
 
        msg_pl = &rec->msg_plaintext;
-       msg_en = &rec->msg_encrypted;
-
-       sk_msg_init(msg_pl);
-       sk_msg_init(msg_en);
-
-       sg_init_table(rec->sg_aead_in, 2);
-       sg_set_buf(&rec->sg_aead_in[0], rec->aad_space,
-                  sizeof(rec->aad_space));
-       sg_unmark_end(&rec->sg_aead_in[1]);
-
-       sg_init_table(rec->sg_aead_out, 2);
-       sg_set_buf(&rec->sg_aead_out[0], rec->aad_space,
-                  sizeof(rec->aad_space));
-       sg_unmark_end(&rec->sg_aead_out[1]);
-
-       ctx->open_rec = rec;
-       rec->inplace_crypto = 1;
+       copied = msg_pl->sg.size;
+       if (!copied)
+               return 0;
 
-       return rec;
+       return bpf_exec_tx_verdict(msg_pl, sk, true, TLS_RECORD_TYPE_DATA,
+                                  &copied, flags);
 }
 
 int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
                        goto send_end;
                }
 
-               rec = get_rec(sk);
+               if (ctx->open_rec)
+                       rec = ctx->open_rec;
+               else
+                       rec = ctx->open_rec = tls_get_rec(sk);
                if (!rec) {
                        ret = -ENOMEM;
                        goto send_end;
                }
 
                if (!is_kvec && (full_record || eor) && !async_capable) {
+                       u32 first = msg_pl->sg.end;
+
                        ret = sk_msg_zerocopy_from_iter(sk, &msg->msg_iter,
                                                        msg_pl, try_to_copy);
                        if (ret)
 
                        num_zc++;
                        copied += try_to_copy;
-                       ret = tls_push_record(sk, msg->msg_flags, record_type);
+
+                       sk_msg_sg_copy_set(msg_pl, first);
+                       ret = bpf_exec_tx_verdict(msg_pl, sk, full_record,
+                                                 record_type, &copied,
+                                                 msg->msg_flags);
                        if (ret) {
                                if (ret == -EINPROGRESS)
                                        num_async++;
+                               else if (ret == -ENOMEM)
+                                       goto wait_for_memory;
+                               else if (ret == -ENOSPC)
+                                       goto rollback_iter;
                                else if (ret != -EAGAIN)
                                        goto send_end;
                        }
                        continue;
-
+rollback_iter:
+                       copied -= try_to_copy;
+                       sk_msg_sg_copy_clear(msg_pl, first);
+                       iov_iter_revert(&msg->msg_iter,
+                                       msg_pl->sg.size - orig_size);
 fallback_to_reg_send:
                        sk_msg_trim(sk, msg_pl, orig_size);
                }
                tls_ctx->pending_open_record_frags = true;
                copied += try_to_copy;
                if (full_record || eor) {
-                       ret = tls_push_record(sk, msg->msg_flags, record_type);
+                       ret = bpf_exec_tx_verdict(msg_pl, sk, full_record,
+                                                 record_type, &copied,
+                                                 msg->msg_flags);
                        if (ret) {
                                if (ret == -EINPROGRESS)
                                        num_async++;
-                               else if (ret != -EAGAIN)
+                               else if (ret == -ENOMEM)
+                                       goto wait_for_memory;
+                               else if (ret != -EAGAIN) {
+                                       if (ret == -ENOSPC)
+                                               ret = 0;
                                        goto send_end;
+                               }
                        }
                }
 
        struct tls_context *tls_ctx = tls_get_ctx(sk);
        struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
        unsigned char record_type = TLS_RECORD_TYPE_DATA;
-       size_t orig_size = size;
        struct sk_msg *msg_pl;
        struct tls_rec *rec;
        int num_async = 0;
+       size_t copied = 0;
        bool full_record;
        int record_room;
        int ret = 0;
                        goto sendpage_end;
                }
 
-               rec = get_rec(sk);
+               if (ctx->open_rec)
+                       rec = ctx->open_rec;
+               else
+                       rec = ctx->open_rec = tls_get_rec(sk);
                if (!rec) {
                        ret = -ENOMEM;
                        goto sendpage_end;
 
                full_record = false;
                record_room = TLS_MAX_PAYLOAD_SIZE - msg_pl->sg.size;
+               copied = 0;
                copy = size;
                if (copy >= record_room) {
                        copy = record_room;
 
                offset += copy;
                size -= copy;
+               copied += copy;
 
                tls_ctx->pending_open_record_frags = true;
                if (full_record || eor || sk_msg_full(msg_pl)) {
                        rec->inplace_crypto = 0;
-                       ret = tls_push_record(sk, flags, record_type);
+                       ret = bpf_exec_tx_verdict(msg_pl, sk, full_record,
+                                                 record_type, &copied, flags);
                        if (ret) {
                                if (ret == -EINPROGRESS)
                                        num_async++;
-                               else if (ret != -EAGAIN)
+                               else if (ret == -ENOMEM)
+                                       goto wait_for_memory;
+                               else if (ret != -EAGAIN) {
+                                       if (ret == -ENOSPC)
+                                               ret = 0;
                                        goto sendpage_end;
+                               }
                        }
                }
                continue;
                }
        }
 sendpage_end:
-       if (orig_size > size)
-               ret = orig_size - size;
-       else
-               ret = sk_stream_error(sk, flags, ret);
-
+       ret = sk_stream_error(sk, flags, ret);
        release_sock(sk);
-       return ret;
+       return copied ? copied : ret;
 }
 
-static struct sk_buff *tls_wait_data(struct sock *sk, int flags,
-                                    long timeo, int *err)
+static struct sk_buff *tls_wait_data(struct sock *sk, struct sk_psock *psock,
+                                    int flags, long timeo, int *err)
 {
        struct tls_context *tls_ctx = tls_get_ctx(sk);
        struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
        struct sk_buff *skb;
        DEFINE_WAIT_FUNC(wait, woken_wake_function);
 
-       while (!(skb = ctx->recv_pkt)) {
+       while (!(skb = ctx->recv_pkt) && sk_psock_queue_empty(psock)) {
                if (sk->sk_err) {
                        *err = sock_error(sk);
                        return NULL;
 
                add_wait_queue(sk_sleep(sk), &wait);
                sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
-               sk_wait_event(sk, &timeo, ctx->recv_pkt != skb, &wait);
+               sk_wait_event(sk, &timeo,
+                             ctx->recv_pkt != skb ||
+                             !sk_psock_queue_empty(psock),
+                             &wait);
                sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
                remove_wait_queue(sk_sleep(sk), &wait);
 
 {
        struct tls_context *tls_ctx = tls_get_ctx(sk);
        struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
+       struct sk_psock *psock;
        unsigned char control;
        struct strp_msg *rxm;
        struct sk_buff *skb;
        if (unlikely(flags & MSG_ERRQUEUE))
                return sock_recv_errqueue(sk, msg, len, SOL_IP, IP_RECVERR);
 
+       psock = sk_psock_get(sk);
        lock_sock(sk);
 
        target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);
                bool async = false;
                int chunk = 0;
 
-               skb = tls_wait_data(sk, flags, timeo, &err);
-               if (!skb)
+               skb = tls_wait_data(sk, psock, flags, timeo, &err);
+               if (!skb) {
+                       if (psock) {
+                               int ret = __tcp_bpf_recvmsg(sk, psock, msg, len);
+
+                               if (ret > 0) {
+                                       copied += ret;
+                                       len -= ret;
+                                       continue;
+                               }
+                       }
                        goto recv_end;
+               }
 
                rxm = strp_msg(skb);
 
        }
 
        release_sock(sk);
+       if (psock)
+               sk_psock_put(sk, psock);
        return copied ? : err;
 }
 
 
        timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
 
-       skb = tls_wait_data(sk, flags, timeo, &err);
+       skb = tls_wait_data(sk, NULL, flags, timeo, &err);
        if (!skb)
                goto splice_read_end;
 
 {
        struct tls_context *tls_ctx = tls_get_ctx(sk);
        struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
+       bool ingress_empty = true;
+       struct sk_psock *psock;
 
-       if (ctx->recv_pkt)
-               return true;
+       rcu_read_lock();
+       psock = sk_psock(sk);
+       if (psock)
+               ingress_empty = list_empty(&psock->ingress_msg);
+       rcu_read_unlock();
 
-       return false;
+       return !ingress_empty || ctx->recv_pkt;
 }
 
 static int tls_read_size(struct strparser *strp, struct sk_buff *skb)
 {
        struct tls_context *tls_ctx = tls_get_ctx(sk);
        struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
+       struct sk_psock *psock;
 
        strp_data_ready(&ctx->strp);
+
+       psock = sk_psock_get(sk);
+       if (psock && !list_empty(&psock->ingress_msg)) {
+               ctx->saved_data_ready(sk);
+               sk_psock_put(sk, psock);
+       }
 }
 
 void tls_sw_free_resources_tx(struct sock *sk)