#include <linux/skbuff.h>
 #include <linux/workqueue.h>
 #include <linux/list.h>
+#include <linux/mm.h>
 #include <net/strparser.h>
 #include <net/tcp.h>
 
 struct bpf_stab {
        struct bpf_map map;
        struct sock **sock_map;
+       struct bpf_prog *bpf_tx_msg;
        struct bpf_prog *bpf_parse;
        struct bpf_prog *bpf_verdict;
 };
        int save_off;
        struct sk_buff *save_skb;
 
+       /* datapath variables for tx_msg ULP */
+       struct sock *sk_redir;
+       int apply_bytes;
+       int cork_bytes;
+       int sg_size;
+       int eval;
+       struct sk_msg_buff *cork;
+
        struct strparser strp;
+       struct bpf_prog *bpf_tx_msg;
        struct bpf_prog *bpf_parse;
        struct bpf_prog *bpf_verdict;
        struct list_head maps;
        void (*save_write_space)(struct sock *sk);
 };
 
+static void smap_release_sock(struct smap_psock *psock, struct sock *sock);
+static int bpf_tcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size);
+static int bpf_tcp_sendpage(struct sock *sk, struct page *page,
+                           int offset, size_t size, int flags);
+
 static inline struct smap_psock *smap_psock_sk(const struct sock *sk)
 {
        return rcu_dereference_sk_user_data(sk);
 
        psock->save_close = sk->sk_prot->close;
        psock->sk_proto = sk->sk_prot;
+
+       if (psock->bpf_tx_msg) {
+               tcp_bpf_proto.sendmsg = bpf_tcp_sendmsg;
+               tcp_bpf_proto.sendpage = bpf_tcp_sendpage;
+       }
+
        sk->sk_prot = &tcp_bpf_proto;
        rcu_read_unlock();
        return 0;
 }
 
+static void smap_release_sock(struct smap_psock *psock, struct sock *sock);
+static int free_start_sg(struct sock *sk, struct sk_msg_buff *md);
+
 static void bpf_tcp_release(struct sock *sk)
 {
        struct smap_psock *psock;
 
        rcu_read_lock();
        psock = smap_psock_sk(sk);
+       if (unlikely(!psock))
+               goto out;
 
-       if (likely(psock)) {
-               sk->sk_prot = psock->sk_proto;
-               psock->sk_proto = NULL;
+       if (psock->cork) {
+               free_start_sg(psock->sock, psock->cork);
+               kfree(psock->cork);
+               psock->cork = NULL;
        }
+
+       sk->sk_prot = psock->sk_proto;
+       psock->sk_proto = NULL;
+out:
        rcu_read_unlock();
 }
 
-static void smap_release_sock(struct smap_psock *psock, struct sock *sock);
-
 static void bpf_tcp_close(struct sock *sk, long timeout)
 {
        void (*close_fun)(struct sock *sk, long timeout);
        __SK_DROP = 0,
        __SK_PASS,
        __SK_REDIRECT,
+       __SK_NONE,
 };
 
 static struct tcp_ulp_ops bpf_tcp_ulp_ops __read_mostly = {
        .release        = bpf_tcp_release,
 };
 
+static int memcopy_from_iter(struct sock *sk,
+                            struct sk_msg_buff *md,
+                            struct iov_iter *from, int bytes)
+{
+       struct scatterlist *sg = md->sg_data;
+       int i = md->sg_curr, rc = -ENOSPC;
+
+       do {
+               int copy;
+               char *to;
+
+               if (md->sg_copybreak >= sg[i].length) {
+                       md->sg_copybreak = 0;
+
+                       if (++i == MAX_SKB_FRAGS)
+                               i = 0;
+
+                       if (i == md->sg_end)
+                               break;
+               }
+
+               copy = sg[i].length - md->sg_copybreak;
+               to = sg_virt(&sg[i]) + md->sg_copybreak;
+               md->sg_copybreak += copy;
+
+               if (sk->sk_route_caps & NETIF_F_NOCACHE_COPY)
+                       rc = copy_from_iter_nocache(to, copy, from);
+               else
+                       rc = copy_from_iter(to, copy, from);
+
+               if (rc != copy) {
+                       rc = -EFAULT;
+                       goto out;
+               }
+
+               bytes -= copy;
+               if (!bytes)
+                       break;
+
+               md->sg_copybreak = 0;
+               if (++i == MAX_SKB_FRAGS)
+                       i = 0;
+       } while (i != md->sg_end);
+out:
+       md->sg_curr = i;
+       return rc;
+}
+
+static int bpf_tcp_push(struct sock *sk, int apply_bytes,
+                       struct sk_msg_buff *md,
+                       int flags, bool uncharge)
+{
+       bool apply = apply_bytes;
+       struct scatterlist *sg;
+       int offset, ret = 0;
+       struct page *p;
+       size_t size;
+
+       while (1) {
+               sg = md->sg_data + md->sg_start;
+               size = (apply && apply_bytes < sg->length) ?
+                       apply_bytes : sg->length;
+               offset = sg->offset;
+
+               tcp_rate_check_app_limited(sk);
+               p = sg_page(sg);
+retry:
+               ret = do_tcp_sendpages(sk, p, offset, size, flags);
+               if (ret != size) {
+                       if (ret > 0) {
+                               if (apply)
+                                       apply_bytes -= ret;
+                               size -= ret;
+                               offset += ret;
+                               if (uncharge)
+                                       sk_mem_uncharge(sk, ret);
+                               goto retry;
+                       }
+
+                       sg->length = size;
+                       sg->offset = offset;
+                       return ret;
+               }
+
+               if (apply)
+                       apply_bytes -= ret;
+               sg->offset += ret;
+               sg->length -= ret;
+               if (uncharge)
+                       sk_mem_uncharge(sk, ret);
+
+               if (!sg->length) {
+                       put_page(p);
+                       md->sg_start++;
+                       if (md->sg_start == MAX_SKB_FRAGS)
+                               md->sg_start = 0;
+                       memset(sg, 0, sizeof(*sg));
+
+                       if (md->sg_start == md->sg_end)
+                               break;
+               }
+
+               if (apply && !apply_bytes)
+                       break;
+       }
+       return 0;
+}
+
+static inline void bpf_compute_data_pointers_sg(struct sk_msg_buff *md)
+{
+       struct scatterlist *sg = md->sg_data + md->sg_start;
+
+       if (md->sg_copy[md->sg_start]) {
+               md->data = md->data_end = 0;
+       } else {
+               md->data = sg_virt(sg);
+               md->data_end = md->data + sg->length;
+       }
+}
+
+static void return_mem_sg(struct sock *sk, int bytes, struct sk_msg_buff *md)
+{
+       struct scatterlist *sg = md->sg_data;
+       int i = md->sg_start;
+
+       do {
+               int uncharge = (bytes < sg[i].length) ? bytes : sg[i].length;
+
+               sk_mem_uncharge(sk, uncharge);
+               bytes -= uncharge;
+               if (!bytes)
+                       break;
+               i++;
+               if (i == MAX_SKB_FRAGS)
+                       i = 0;
+       } while (i != md->sg_end);
+}
+
+static void free_bytes_sg(struct sock *sk, int bytes, struct sk_msg_buff *md)
+{
+       struct scatterlist *sg = md->sg_data;
+       int i = md->sg_start, free;
+
+       while (bytes && sg[i].length) {
+               free = sg[i].length;
+               if (bytes < free) {
+                       sg[i].length -= bytes;
+                       sg[i].offset += bytes;
+                       sk_mem_uncharge(sk, bytes);
+                       break;
+               }
+
+               sk_mem_uncharge(sk, sg[i].length);
+               put_page(sg_page(&sg[i]));
+               bytes -= sg[i].length;
+               sg[i].length = 0;
+               sg[i].page_link = 0;
+               sg[i].offset = 0;
+               i++;
+
+               if (i == MAX_SKB_FRAGS)
+                       i = 0;
+       }
+}
+
+static int free_sg(struct sock *sk, int start, struct sk_msg_buff *md)
+{
+       struct scatterlist *sg = md->sg_data;
+       int i = start, free = 0;
+
+       while (sg[i].length) {
+               free += sg[i].length;
+               sk_mem_uncharge(sk, sg[i].length);
+               put_page(sg_page(&sg[i]));
+               sg[i].length = 0;
+               sg[i].page_link = 0;
+               sg[i].offset = 0;
+               i++;
+
+               if (i == MAX_SKB_FRAGS)
+                       i = 0;
+       }
+
+       return free;
+}
+
+static int free_start_sg(struct sock *sk, struct sk_msg_buff *md)
+{
+       int free = free_sg(sk, md->sg_start, md);
+
+       md->sg_start = md->sg_end;
+       return free;
+}
+
+static int free_curr_sg(struct sock *sk, struct sk_msg_buff *md)
+{
+       return free_sg(sk, md->sg_curr, md);
+}
+
+static int bpf_map_msg_verdict(int _rc, struct sk_msg_buff *md)
+{
+       return ((_rc == SK_PASS) ?
+              (md->map ? __SK_REDIRECT : __SK_PASS) :
+              __SK_DROP);
+}
+
+static unsigned int smap_do_tx_msg(struct sock *sk,
+                                  struct smap_psock *psock,
+                                  struct sk_msg_buff *md)
+{
+       struct bpf_prog *prog;
+       unsigned int rc, _rc;
+
+       preempt_disable();
+       rcu_read_lock();
+
+       /* If the policy was removed mid-send then default to 'accept' */
+       prog = READ_ONCE(psock->bpf_tx_msg);
+       if (unlikely(!prog)) {
+               _rc = SK_PASS;
+               goto verdict;
+       }
+
+       bpf_compute_data_pointers_sg(md);
+       rc = (*prog->bpf_func)(md, prog->insnsi);
+       psock->apply_bytes = md->apply_bytes;
+
+       /* Moving return codes from UAPI namespace into internal namespace */
+       _rc = bpf_map_msg_verdict(rc, md);
+
+       /* The psock has a refcount on the sock but not on the map and because
+        * we need to drop rcu read lock here its possible the map could be
+        * removed between here and when we need it to execute the sock
+        * redirect. So do the map lookup now for future use.
+        */
+       if (_rc == __SK_REDIRECT) {
+               if (psock->sk_redir)
+                       sock_put(psock->sk_redir);
+               psock->sk_redir = do_msg_redirect_map(md);
+               if (!psock->sk_redir) {
+                       _rc = __SK_DROP;
+                       goto verdict;
+               }
+               sock_hold(psock->sk_redir);
+       }
+verdict:
+       rcu_read_unlock();
+       preempt_enable();
+
+       return _rc;
+}
+
+static int bpf_tcp_sendmsg_do_redirect(struct sock *sk, int send,
+                                      struct sk_msg_buff *md,
+                                      int flags)
+{
+       struct smap_psock *psock;
+       struct scatterlist *sg;
+       int i, err, free = 0;
+
+       sg = md->sg_data;
+
+       rcu_read_lock();
+       psock = smap_psock_sk(sk);
+       if (unlikely(!psock))
+               goto out_rcu;
+
+       if (!refcount_inc_not_zero(&psock->refcnt))
+               goto out_rcu;
+
+       rcu_read_unlock();
+       lock_sock(sk);
+       err = bpf_tcp_push(sk, send, md, flags, false);
+       release_sock(sk);
+       smap_release_sock(psock, sk);
+       if (unlikely(err))
+               goto out;
+       return 0;
+out_rcu:
+       rcu_read_unlock();
+out:
+       i = md->sg_start;
+       while (sg[i].length) {
+               free += sg[i].length;
+               put_page(sg_page(&sg[i]));
+               sg[i].length = 0;
+               i++;
+               if (i == MAX_SKB_FRAGS)
+                       i = 0;
+       }
+       return free;
+}
+
+static inline void bpf_md_init(struct smap_psock *psock)
+{
+       if (!psock->apply_bytes) {
+               psock->eval =  __SK_NONE;
+               if (psock->sk_redir) {
+                       sock_put(psock->sk_redir);
+                       psock->sk_redir = NULL;
+               }
+       }
+}
+
+static void apply_bytes_dec(struct smap_psock *psock, int i)
+{
+       if (psock->apply_bytes) {
+               if (psock->apply_bytes < i)
+                       psock->apply_bytes = 0;
+               else
+                       psock->apply_bytes -= i;
+       }
+}
+
+static int bpf_exec_tx_verdict(struct smap_psock *psock,
+                              struct sk_msg_buff *m,
+                              struct sock *sk,
+                              int *copied, int flags)
+{
+       bool cork = false, enospc = (m->sg_start == m->sg_end);
+       struct sock *redir;
+       int err = 0;
+       int send;
+
+more_data:
+       if (psock->eval == __SK_NONE)
+               psock->eval = smap_do_tx_msg(sk, psock, m);
+
+       if (m->cork_bytes &&
+           m->cork_bytes > psock->sg_size && !enospc) {
+               psock->cork_bytes = m->cork_bytes - psock->sg_size;
+               if (!psock->cork) {
+                       psock->cork = kcalloc(1,
+                                       sizeof(struct sk_msg_buff),
+                                       GFP_ATOMIC | __GFP_NOWARN);
+
+                       if (!psock->cork) {
+                               err = -ENOMEM;
+                               goto out_err;
+                       }
+               }
+               memcpy(psock->cork, m, sizeof(*m));
+               goto out_err;
+       }
+
+       send = psock->sg_size;
+       if (psock->apply_bytes && psock->apply_bytes < send)
+               send = psock->apply_bytes;
+
+       switch (psock->eval) {
+       case __SK_PASS:
+               err = bpf_tcp_push(sk, send, m, flags, true);
+               if (unlikely(err)) {
+                       *copied -= free_start_sg(sk, m);
+                       break;
+               }
+
+               apply_bytes_dec(psock, send);
+               psock->sg_size -= send;
+               break;
+       case __SK_REDIRECT:
+               redir = psock->sk_redir;
+               apply_bytes_dec(psock, send);
+
+               if (psock->cork) {
+                       cork = true;
+                       psock->cork = NULL;
+               }
+
+               return_mem_sg(sk, send, m);
+               release_sock(sk);
+
+               err = bpf_tcp_sendmsg_do_redirect(redir, send, m, flags);
+               lock_sock(sk);
+
+               if (cork) {
+                       free_start_sg(sk, m);
+                       kfree(m);
+                       m = NULL;
+               }
+               if (unlikely(err))
+                       *copied -= err;
+               else
+                       psock->sg_size -= send;
+               break;
+       case __SK_DROP:
+       default:
+               free_bytes_sg(sk, send, m);
+               apply_bytes_dec(psock, send);
+               *copied -= send;
+               psock->sg_size -= send;
+               err = -EACCES;
+               break;
+       }
+
+       if (likely(!err)) {
+               bpf_md_init(psock);
+               if (m &&
+                   m->sg_data[m->sg_start].page_link &&
+                   m->sg_data[m->sg_start].length)
+                       goto more_data;
+       }
+
+out_err:
+       return err;
+}
+
+static int bpf_tcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
+{
+       int flags = msg->msg_flags | MSG_NO_SHARED_FRAGS;
+       struct sk_msg_buff md = {0};
+       unsigned int sg_copy = 0;
+       struct smap_psock *psock;
+       int copied = 0, err = 0;
+       struct scatterlist *sg;
+       long timeo;
+
+       /* Its possible a sock event or user removed the psock _but_ the ops
+        * have not been reprogrammed yet so we get here. In this case fallback
+        * to tcp_sendmsg. Note this only works because we _only_ ever allow
+        * a single ULP there is no hierarchy here.
+        */
+       rcu_read_lock();
+       psock = smap_psock_sk(sk);
+       if (unlikely(!psock)) {
+               rcu_read_unlock();
+               return tcp_sendmsg(sk, msg, size);
+       }
+
+       /* Increment the psock refcnt to ensure its not released while sending a
+        * message. Required because sk lookup and bpf programs are used in
+        * separate rcu critical sections. Its OK if we lose the map entry
+        * but we can't lose the sock reference.
+        */
+       if (!refcount_inc_not_zero(&psock->refcnt)) {
+               rcu_read_unlock();
+               return tcp_sendmsg(sk, msg, size);
+       }
+
+       sg = md.sg_data;
+       sg_init_table(sg, MAX_SKB_FRAGS);
+       rcu_read_unlock();
+
+       lock_sock(sk);
+       timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
+
+       while (msg_data_left(msg)) {
+               struct sk_msg_buff *m;
+               bool enospc = false;
+               int copy;
+
+               if (sk->sk_err) {
+                       err = sk->sk_err;
+                       goto out_err;
+               }
+
+               copy = msg_data_left(msg);
+               if (!sk_stream_memory_free(sk))
+                       goto wait_for_sndbuf;
+
+               m = psock->cork_bytes ? psock->cork : &md;
+               m->sg_curr = m->sg_copybreak ? m->sg_curr : m->sg_end;
+               err = sk_alloc_sg(sk, copy, m->sg_data,
+                                 m->sg_start, &m->sg_end, &sg_copy,
+                                 m->sg_end - 1);
+               if (err) {
+                       if (err != -ENOSPC)
+                               goto wait_for_memory;
+                       enospc = true;
+                       copy = sg_copy;
+               }
+
+               err = memcopy_from_iter(sk, m, &msg->msg_iter, copy);
+               if (err < 0) {
+                       free_curr_sg(sk, m);
+                       goto out_err;
+               }
+
+               psock->sg_size += copy;
+               copied += copy;
+               sg_copy = 0;
+
+               /* When bytes are being corked skip running BPF program and
+                * applying verdict unless there is no more buffer space. In
+                * the ENOSPC case simply run BPF prorgram with currently
+                * accumulated data. We don't have much choice at this point
+                * we could try extending the page frags or chaining complex
+                * frags but even in these cases _eventually_ we will hit an
+                * OOM scenario. More complex recovery schemes may be
+                * implemented in the future, but BPF programs must handle
+                * the case where apply_cork requests are not honored. The
+                * canonical method to verify this is to check data length.
+                */
+               if (psock->cork_bytes) {
+                       if (copy > psock->cork_bytes)
+                               psock->cork_bytes = 0;
+                       else
+                               psock->cork_bytes -= copy;
+
+                       if (psock->cork_bytes && !enospc)
+                               goto out_cork;
+
+                       /* All cork bytes accounted for re-run filter */
+                       psock->eval = __SK_NONE;
+                       psock->cork_bytes = 0;
+               }
+
+               err = bpf_exec_tx_verdict(psock, m, sk, &copied, flags);
+               if (unlikely(err < 0))
+                       goto out_err;
+               continue;
+wait_for_sndbuf:
+               set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
+wait_for_memory:
+               err = sk_stream_wait_memory(sk, &timeo);
+               if (err)
+                       goto out_err;
+       }
+out_err:
+       if (err < 0)
+               err = sk_stream_error(sk, msg->msg_flags, err);
+out_cork:
+       release_sock(sk);
+       smap_release_sock(psock, sk);
+       return copied ? copied : err;
+}
+
+static int bpf_tcp_sendpage(struct sock *sk, struct page *page,
+                           int offset, size_t size, int flags)
+{
+       struct sk_msg_buff md = {0}, *m = NULL;
+       int err = 0, copied = 0;
+       struct smap_psock *psock;
+       struct scatterlist *sg;
+       bool enospc = false;
+
+       rcu_read_lock();
+       psock = smap_psock_sk(sk);
+       if (unlikely(!psock))
+               goto accept;
+
+       if (!refcount_inc_not_zero(&psock->refcnt))
+               goto accept;
+       rcu_read_unlock();
+
+       lock_sock(sk);
+
+       if (psock->cork_bytes)
+               m = psock->cork;
+       else
+               m = &md;
+
+       /* Catch case where ring is full and sendpage is stalled. */
+       if (unlikely(m->sg_end == m->sg_start &&
+           m->sg_data[m->sg_end].length))
+               goto out_err;
+
+       psock->sg_size += size;
+       sg = &m->sg_data[m->sg_end];
+       sg_set_page(sg, page, size, offset);
+       get_page(page);
+       m->sg_copy[m->sg_end] = true;
+       sk_mem_charge(sk, size);
+       m->sg_end++;
+       copied = size;
+
+       if (m->sg_end == MAX_SKB_FRAGS)
+               m->sg_end = 0;
+
+       if (m->sg_end == m->sg_start)
+               enospc = true;
+
+       if (psock->cork_bytes) {
+               if (size > psock->cork_bytes)
+                       psock->cork_bytes = 0;
+               else
+                       psock->cork_bytes -= size;
+
+               if (psock->cork_bytes && !enospc)
+                       goto out_err;
+
+               /* All cork bytes accounted for re-run filter */
+               psock->eval = __SK_NONE;
+               psock->cork_bytes = 0;
+       }
+
+       err = bpf_exec_tx_verdict(psock, m, sk, &copied, flags);
+out_err:
+       release_sock(sk);
+       smap_release_sock(psock, sk);
+       return copied ? copied : err;
+accept:
+       rcu_read_unlock();
+       return tcp_sendpage(sk, page, offset, size, flags);
+}
+
+static void bpf_tcp_msg_add(struct smap_psock *psock,
+                           struct sock *sk,
+                           struct bpf_prog *tx_msg)
+{
+       struct bpf_prog *orig_tx_msg;
+
+       orig_tx_msg = xchg(&psock->bpf_tx_msg, tx_msg);
+       if (orig_tx_msg)
+               bpf_prog_put(orig_tx_msg);
+}
+
 static int bpf_tcp_ulp_register(void)
 {
        tcp_bpf_proto = tcp_prot;
        tcp_bpf_proto.close = bpf_tcp_close;
+       /* Once BPF TX ULP is registered it is never unregistered. It
+        * will be in the ULP list for the lifetime of the system. Doing
+        * duplicate registers is not a problem.
+        */
        return tcp_register_ulp(&bpf_tcp_ulp_ops);
 }
 
        return rc;
 }
 
-
 static int smap_read_sock_done(struct strparser *strp, int err)
 {
        return err;
                bpf_prog_put(psock->bpf_parse);
        if (psock->bpf_verdict)
                bpf_prog_put(psock->bpf_verdict);
+       if (psock->bpf_tx_msg)
+               bpf_prog_put(psock->bpf_tx_msg);
+
+       if (psock->cork) {
+               free_start_sg(psock->sock, psock->cork);
+               kfree(psock->cork);
+       }
 
        list_for_each_entry_safe(e, tmp, &psock->maps, list) {
                list_del(&e->list);
                kfree(e);
        }
 
+       if (psock->sk_redir)
+               sock_put(psock->sk_redir);
+
        sock_put(psock->sock);
        kfree(psock);
 }
        if (!psock)
                return ERR_PTR(-ENOMEM);
 
+       psock->eval =  __SK_NONE;
        psock->sock = sock;
        skb_queue_head_init(&psock->rxqueue);
        INIT_WORK(&psock->tx_work, smap_tx_work);
 {
        struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
        struct smap_psock_map_entry *e = NULL;
-       struct bpf_prog *verdict, *parse;
+       struct bpf_prog *verdict, *parse, *tx_msg;
        struct sock *osock, *sock;
        struct smap_psock *psock;
        u32 i = *(u32 *)key;
+       bool new = false;
        int err;
 
        if (unlikely(flags > BPF_EXIST))
         */
        verdict = READ_ONCE(stab->bpf_verdict);
        parse = READ_ONCE(stab->bpf_parse);
+       tx_msg = READ_ONCE(stab->bpf_tx_msg);
 
        if (parse && verdict) {
                /* bpf prog refcnt may be zero if a concurrent attach operation
                }
        }
 
+       if (tx_msg) {
+               tx_msg = bpf_prog_inc_not_zero(stab->bpf_tx_msg);
+               if (IS_ERR(tx_msg)) {
+                       if (verdict)
+                               bpf_prog_put(verdict);
+                       if (parse)
+                               bpf_prog_put(parse);
+                       return PTR_ERR(tx_msg);
+               }
+       }
+
        write_lock_bh(&sock->sk_callback_lock);
        psock = smap_psock_sk(sock);
 
                        err = -EBUSY;
                        goto out_progs;
                }
-               refcount_inc(&psock->refcnt);
+               if (READ_ONCE(psock->bpf_tx_msg) && tx_msg) {
+                       err = -EBUSY;
+                       goto out_progs;
+               }
+               if (!refcount_inc_not_zero(&psock->refcnt)) {
+                       err = -EAGAIN;
+                       goto out_progs;
+               }
        } else {
                psock = smap_init_psock(sock, stab);
                if (IS_ERR(psock)) {
                        goto out_progs;
                }
 
-               err = tcp_set_ulp_id(sock, TCP_ULP_BPF);
-               if (err)
-                       goto out_progs;
-
                set_bit(SMAP_TX_RUNNING, &psock->state);
+               new = true;
        }
 
        e = kzalloc(sizeof(*e), GFP_ATOMIC | __GFP_NOWARN);
        /* 3. At this point we have a reference to a valid psock that is
         * running. Attach any BPF programs needed.
         */
+       if (tx_msg)
+               bpf_tcp_msg_add(psock, sock, tx_msg);
+       if (new) {
+               err = tcp_set_ulp_id(sock, TCP_ULP_BPF);
+               if (err)
+                       goto out_free;
+       }
+
        if (parse && verdict && !psock->strp_enabled) {
                err = smap_init_sock(psock, sock);
                if (err)
                struct smap_psock *opsock = smap_psock_sk(osock);
 
                write_lock_bh(&osock->sk_callback_lock);
-               if (osock != sock && parse)
-                       smap_stop_sock(opsock, osock);
                smap_list_remove(opsock, &stab->sock_map[i]);
                smap_release_sock(opsock, osock);
                write_unlock_bh(&osock->sk_callback_lock);
                bpf_prog_put(verdict);
        if (parse)
                bpf_prog_put(parse);
+       if (tx_msg)
+               bpf_prog_put(tx_msg);
        write_unlock_bh(&sock->sk_callback_lock);
        kfree(e);
        return err;
                return -EINVAL;
 
        switch (type) {
+       case BPF_SK_MSG_VERDICT:
+               orig = xchg(&stab->bpf_tx_msg, prog);
+               break;
        case BPF_SK_SKB_STREAM_PARSER:
                orig = xchg(&stab->bpf_parse, prog);
                break;
        orig = xchg(&stab->bpf_verdict, NULL);
        if (orig)
                bpf_prog_put(orig);
+
+       orig = xchg(&stab->bpf_tx_msg, NULL);
+       if (orig)
+               bpf_prog_put(orig);
 }
 
 const struct bpf_map_ops sock_map_ops = {
 
        .arg4_type      = ARG_ANYTHING,
 };
 
+BPF_CALL_4(bpf_msg_redirect_map, struct sk_msg_buff *, msg,
+          struct bpf_map *, map, u32, key, u64, flags)
+{
+       /* If user passes invalid input drop the packet. */
+       if (unlikely(flags))
+               return SK_DROP;
+
+       msg->key = key;
+       msg->flags = flags;
+       msg->map = map;
+
+       return SK_PASS;
+}
+
+struct sock *do_msg_redirect_map(struct sk_msg_buff *msg)
+{
+       struct sock *sk = NULL;
+
+       if (msg->map) {
+               sk = __sock_map_lookup_elem(msg->map, msg->key);
+
+               msg->key = 0;
+               msg->map = NULL;
+       }
+
+       return sk;
+}
+
+static const struct bpf_func_proto bpf_msg_redirect_map_proto = {
+       .func           = bpf_msg_redirect_map,
+       .gpl_only       = false,
+       .ret_type       = RET_INTEGER,
+       .arg1_type      = ARG_PTR_TO_CTX,
+       .arg2_type      = ARG_CONST_MAP_PTR,
+       .arg3_type      = ARG_ANYTHING,
+       .arg4_type      = ARG_ANYTHING,
+};
+
 BPF_CALL_1(bpf_get_cgroup_classid, const struct sk_buff *, skb)
 {
        return task_get_classid(skb);
        }
 }
 
+static const struct bpf_func_proto *sk_msg_func_proto(enum bpf_func_id func_id)
+{
+       switch (func_id) {
+       case BPF_FUNC_msg_redirect_map:
+               return &bpf_msg_redirect_map_proto;
+       default:
+               return bpf_base_func_proto(func_id);
+       }
+}
+
 static const struct bpf_func_proto *sk_skb_func_proto(enum bpf_func_id func_id)
 {
        switch (func_id) {
        return bpf_skb_is_valid_access(off, size, type, info);
 }
 
+static bool sk_msg_is_valid_access(int off, int size,
+                                  enum bpf_access_type type,
+                                  struct bpf_insn_access_aux *info)
+{
+       if (type == BPF_WRITE)
+               return false;
+
+       switch (off) {
+       case offsetof(struct sk_msg_md, data):
+               info->reg_type = PTR_TO_PACKET;
+               break;
+       case offsetof(struct sk_msg_md, data_end):
+               info->reg_type = PTR_TO_PACKET_END;
+               break;
+       }
+
+       if (off < 0 || off >= sizeof(struct sk_msg_md))
+               return false;
+       if (off % size != 0)
+               return false;
+       if (size != sizeof(__u64))
+               return false;
+
+       return true;
+}
+
 static u32 bpf_convert_ctx_access(enum bpf_access_type type,
                                  const struct bpf_insn *si,
                                  struct bpf_insn *insn_buf,
        return insn - insn_buf;
 }
 
+static u32 sk_msg_convert_ctx_access(enum bpf_access_type type,
+                                    const struct bpf_insn *si,
+                                    struct bpf_insn *insn_buf,
+                                    struct bpf_prog *prog, u32 *target_size)
+{
+       struct bpf_insn *insn = insn_buf;
+
+       switch (si->off) {
+       case offsetof(struct sk_msg_md, data):
+               *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_msg_buff, data),
+                                     si->dst_reg, si->src_reg,
+                                     offsetof(struct sk_msg_buff, data));
+               break;
+       case offsetof(struct sk_msg_md, data_end):
+               *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_msg_buff, data_end),
+                                     si->dst_reg, si->src_reg,
+                                     offsetof(struct sk_msg_buff, data_end));
+               break;
+       }
+
+       return insn - insn_buf;
+}
+
 const struct bpf_verifier_ops sk_filter_verifier_ops = {
        .get_func_proto         = sk_filter_func_proto,
        .is_valid_access        = sk_filter_is_valid_access,
 const struct bpf_prog_ops sk_skb_prog_ops = {
 };
 
+const struct bpf_verifier_ops sk_msg_verifier_ops = {
+       .get_func_proto         = sk_msg_func_proto,
+       .is_valid_access        = sk_msg_is_valid_access,
+       .convert_ctx_access     = sk_msg_convert_ctx_access,
+};
+
+const struct bpf_prog_ops sk_msg_prog_ops = {
+};
+
 int sk_detach_filter(struct sock *sk)
 {
        int ret = -ENOENT;