#include <net/netlink.h>
 #include <net/sock.h>
 
+static const unsigned int mctp_message_maxlen = 64 * 1024;
+
 /* route output callbacks */
 static int mctp_route_discard(struct mctp_route *route, struct sk_buff *skb)
 {
        return ret;
 }
 
+static struct mctp_sk_key *mctp_key_alloc(struct mctp_sock *msk,
+                                         mctp_eid_t local, mctp_eid_t peer,
+                                         u8 tag, gfp_t gfp)
+{
+       struct mctp_sk_key *key;
+
+       key = kzalloc(sizeof(*key), gfp);
+       if (!key)
+               return NULL;
+
+       key->peer_addr = peer;
+       key->local_addr = local;
+       key->tag = tag;
+       key->sk = &msk->sk;
+       spin_lock_init(&key->reasm_lock);
+
+       return key;
+}
+
+static int mctp_key_add(struct mctp_sk_key *key, struct mctp_sock *msk)
+{
+       struct net *net = sock_net(&msk->sk);
+       struct mctp_sk_key *tmp;
+       unsigned long flags;
+       int rc = 0;
+
+       spin_lock_irqsave(&net->mctp.keys_lock, flags);
+
+       hlist_for_each_entry(tmp, &net->mctp.keys, hlist) {
+               if (mctp_key_match(tmp, key->local_addr, key->peer_addr,
+                                  key->tag)) {
+                       rc = -EEXIST;
+                       break;
+               }
+       }
+
+       if (!rc) {
+               hlist_add_head(&key->hlist, &net->mctp.keys);
+               hlist_add_head(&key->sklist, &msk->keys);
+       }
+
+       spin_unlock_irqrestore(&net->mctp.keys_lock, flags);
+
+       return rc;
+}
+
+/* Must be called with key->reasm_lock, which it will release. Will schedule
+ * the key for an RCU free.
+ */
+static void __mctp_key_unlock_drop(struct mctp_sk_key *key, struct net *net,
+                                  unsigned long flags)
+       __releases(&key->reasm_lock)
+{
+       struct sk_buff *skb;
+
+       skb = key->reasm_head;
+       key->reasm_head = NULL;
+       key->reasm_dead = true;
+       spin_unlock_irqrestore(&key->reasm_lock, flags);
+
+       spin_lock_irqsave(&net->mctp.keys_lock, flags);
+       hlist_del_rcu(&key->hlist);
+       hlist_del_rcu(&key->sklist);
+       spin_unlock_irqrestore(&net->mctp.keys_lock, flags);
+       kfree_rcu(key, rcu);
+
+       if (skb)
+               kfree_skb(skb);
+}
+
+static int mctp_frag_queue(struct mctp_sk_key *key, struct sk_buff *skb)
+{
+       struct mctp_hdr *hdr = mctp_hdr(skb);
+       u8 exp_seq, this_seq;
+
+       this_seq = (hdr->flags_seq_tag >> MCTP_HDR_SEQ_SHIFT)
+               & MCTP_HDR_SEQ_MASK;
+
+       if (!key->reasm_head) {
+               key->reasm_head = skb;
+               key->reasm_tailp = &(skb_shinfo(skb)->frag_list);
+               key->last_seq = this_seq;
+               return 0;
+       }
+
+       exp_seq = (key->last_seq + 1) & MCTP_HDR_SEQ_MASK;
+
+       if (this_seq != exp_seq)
+               return -EINVAL;
+
+       if (key->reasm_head->len + skb->len > mctp_message_maxlen)
+               return -EINVAL;
+
+       skb->next = NULL;
+       skb->sk = NULL;
+       *key->reasm_tailp = skb;
+       key->reasm_tailp = &skb->next;
+
+       key->last_seq = this_seq;
+
+       key->reasm_head->data_len += skb->len;
+       key->reasm_head->len += skb->len;
+       key->reasm_head->truesize += skb->truesize;
+
+       return 0;
+}
+
 static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb)
 {
        struct net *net = dev_net(skb->dev);
        struct mctp_sk_key *key;
        struct mctp_sock *msk;
        struct mctp_hdr *mh;
+       unsigned long f;
+       u8 tag, flags;
+       int rc;
 
        msk = NULL;
+       rc = -EINVAL;
 
        /* we may be receiving a locally-routed packet; drop source sk
         * accounting
 
        /* ensure we have enough data for a header and a type */
        if (skb->len < sizeof(struct mctp_hdr) + 1)
-               goto drop;
+               goto out;
 
        /* grab header, advance data ptr */
        mh = mctp_hdr(skb);
        skb_pull(skb, sizeof(struct mctp_hdr));
 
        if (mh->ver != 1)
-               goto drop;
+               goto out;
 
-       /* TODO: reassembly */
-       if ((mh->flags_seq_tag & (MCTP_HDR_FLAG_SOM | MCTP_HDR_FLAG_EOM))
-                               != (MCTP_HDR_FLAG_SOM | MCTP_HDR_FLAG_EOM))
-               goto drop;
+       flags = mh->flags_seq_tag & (MCTP_HDR_FLAG_SOM | MCTP_HDR_FLAG_EOM);
+       tag = mh->flags_seq_tag & (MCTP_HDR_TAG_MASK | MCTP_HDR_FLAG_TO);
 
        rcu_read_lock();
-       /* 1. lookup socket matching (src,dest,tag) */
+
+       /* lookup socket / reasm context, exactly matching (src,dest,tag) */
        key = mctp_lookup_key(net, skb, mh->src);
 
-       /* 2. lookup socket macthing (BCAST,dest,tag) */
-       if (!key)
-               key = mctp_lookup_key(net, skb, MCTP_ADDR_ANY);
+       if (flags & MCTP_HDR_FLAG_SOM) {
+               if (key) {
+                       msk = container_of(key->sk, struct mctp_sock, sk);
+               } else {
+                       /* first response to a broadcast? do a more general
+                        * key lookup to find the socket, but don't use this
+                        * key for reassembly - we'll create a more specific
+                        * one for future packets if required (ie, !EOM).
+                        */
+                       key = mctp_lookup_key(net, skb, MCTP_ADDR_ANY);
+                       if (key) {
+                               msk = container_of(key->sk,
+                                                  struct mctp_sock, sk);
+                               key = NULL;
+                       }
+               }
 
-       /* 3. SOM? -> lookup bound socket, conditionally (!EOM) create
-        * mapping for future (1)/(2).
-        */
-       if (key)
-               msk = container_of(key->sk, struct mctp_sock, sk);
-       else if (!msk && (mh->flags_seq_tag & MCTP_HDR_FLAG_SOM))
-               msk = mctp_lookup_bind(net, skb);
+               if (!key && !msk && (tag & MCTP_HDR_FLAG_TO))
+                       msk = mctp_lookup_bind(net, skb);
 
-       if (!msk)
-               goto unlock_drop;
+               if (!msk) {
+                       rc = -ENOENT;
+                       goto out_unlock;
+               }
 
-       sock_queue_rcv_skb(&msk->sk, skb);
+               /* single-packet message? deliver to socket, clean up any
+                * pending key.
+                */
+               if (flags & MCTP_HDR_FLAG_EOM) {
+                       sock_queue_rcv_skb(&msk->sk, skb);
+                       if (key) {
+                               spin_lock_irqsave(&key->reasm_lock, f);
+                               /* we've hit a pending reassembly; not much we
+                                * can do but drop it
+                                */
+                               __mctp_key_unlock_drop(key, net, f);
+                       }
+                       rc = 0;
+                       goto out_unlock;
+               }
 
-       rcu_read_unlock();
+               /* broadcast response or a bind() - create a key for further
+                * packets for this message
+                */
+               if (!key) {
+                       key = mctp_key_alloc(msk, mh->dest, mh->src,
+                                            tag, GFP_ATOMIC);
+                       if (!key) {
+                               rc = -ENOMEM;
+                               goto out_unlock;
+                       }
 
-       return 0;
+                       /* we can queue without the reasm lock here, as the
+                        * key isn't observable yet
+                        */
+                       mctp_frag_queue(key, skb);
+
+                       /* if the key_add fails, we've raced with another
+                        * SOM packet with the same src, dest and tag. There's
+                        * no way to distinguish future packets, so all we
+                        * can do is drop; we'll free the skb on exit from
+                        * this function.
+                        */
+                       rc = mctp_key_add(key, msk);
+                       if (rc)
+                               kfree(key);
+
+               } else {
+                       /* existing key: start reassembly */
+                       spin_lock_irqsave(&key->reasm_lock, f);
+
+                       if (key->reasm_head || key->reasm_dead) {
+                               /* duplicate start? drop everything */
+                               __mctp_key_unlock_drop(key, net, f);
+                               rc = -EEXIST;
+                       } else {
+                               rc = mctp_frag_queue(key, skb);
+                               spin_unlock_irqrestore(&key->reasm_lock, f);
+                       }
+               }
+
+       } else if (key) {
+               /* this packet continues a previous message; reassemble
+                * using the message-specific key
+                */
+
+               spin_lock_irqsave(&key->reasm_lock, f);
+
+               /* we need to be continuing an existing reassembly... */
+               if (!key->reasm_head)
+                       rc = -EINVAL;
+               else
+                       rc = mctp_frag_queue(key, skb);
+
+               /* end of message? deliver to socket, and we're done with
+                * the reassembly/response key
+                */
+               if (!rc && flags & MCTP_HDR_FLAG_EOM) {
+                       sock_queue_rcv_skb(key->sk, key->reasm_head);
+                       key->reasm_head = NULL;
+                       __mctp_key_unlock_drop(key, net, f);
+               } else {
+                       spin_unlock_irqrestore(&key->reasm_lock, f);
+               }
+
+       } else {
+               /* not a start, no matching key */
+               rc = -ENOENT;
+       }
 
-unlock_drop:
+out_unlock:
        rcu_read_unlock();
-drop:
-       kfree_skb(skb);
-       return 0;
+out:
+       if (rc)
+               kfree_skb(skb);
+       return rc;
+}
+
+static unsigned int mctp_route_mtu(struct mctp_route *rt)
+{
+       return rt->mtu ?: READ_ONCE(rt->dev->dev->mtu);
 }
 
 static int mctp_route_output(struct mctp_route *route, struct sk_buff *skb)
 
        lockdep_assert_held(&mns->keys_lock);
 
-       key->sk = &msk->sk;
-
        /* we hold the net->key_lock here, allowing updates to both
         * then net and sk
         */
        u8 tagbits;
 
        /* be optimistic, alloc now */
-       key = kzalloc(sizeof(*key), GFP_KERNEL);
+       key = mctp_key_alloc(msk, saddr, daddr, 0, GFP_KERNEL);
        if (!key)
                return -ENOMEM;
-       key->local_addr = saddr;
-       key->peer_addr = daddr;
 
        /* 8 possible tag values */
        tagbits = 0xff;
        return rc;
 }
 
+static int mctp_do_fragment_route(struct mctp_route *rt, struct sk_buff *skb,
+                                 unsigned int mtu, u8 tag)
+{
+       const unsigned int hlen = sizeof(struct mctp_hdr);
+       struct mctp_hdr *hdr, *hdr2;
+       unsigned int pos, size;
+       struct sk_buff *skb2;
+       int rc;
+       u8 seq;
+
+       hdr = mctp_hdr(skb);
+       seq = 0;
+       rc = 0;
+
+       if (mtu < hlen + 1) {
+               kfree_skb(skb);
+               return -EMSGSIZE;
+       }
+
+       /* we've got the header */
+       skb_pull(skb, hlen);
+
+       for (pos = 0; pos < skb->len;) {
+               /* size of message payload */
+               size = min(mtu - hlen, skb->len - pos);
+
+               skb2 = alloc_skb(MCTP_HEADER_MAXLEN + hlen + size, GFP_KERNEL);
+               if (!skb2) {
+                       rc = -ENOMEM;
+                       break;
+               }
+
+               /* generic skb copy */
+               skb2->protocol = skb->protocol;
+               skb2->priority = skb->priority;
+               skb2->dev = skb->dev;
+               memcpy(skb2->cb, skb->cb, sizeof(skb2->cb));
+
+               if (skb->sk)
+                       skb_set_owner_w(skb2, skb->sk);
+
+               /* establish packet */
+               skb_reserve(skb2, MCTP_HEADER_MAXLEN);
+               skb_reset_network_header(skb2);
+               skb_put(skb2, hlen + size);
+               skb2->transport_header = skb2->network_header + hlen;
+
+               /* copy header fields, calculate SOM/EOM flags & seq */
+               hdr2 = mctp_hdr(skb2);
+               hdr2->ver = hdr->ver;
+               hdr2->dest = hdr->dest;
+               hdr2->src = hdr->src;
+               hdr2->flags_seq_tag = tag &
+                       (MCTP_HDR_TAG_MASK | MCTP_HDR_FLAG_TO);
+
+               if (pos == 0)
+                       hdr2->flags_seq_tag |= MCTP_HDR_FLAG_SOM;
+
+               if (pos + size == skb->len)
+                       hdr2->flags_seq_tag |= MCTP_HDR_FLAG_EOM;
+
+               hdr2->flags_seq_tag |= seq << MCTP_HDR_SEQ_SHIFT;
+
+               /* copy message payload */
+               skb_copy_bits(skb, pos, skb_transport_header(skb2), size);
+
+               /* do route, but don't drop the rt reference */
+               rc = rt->output(rt, skb2);
+               if (rc)
+                       break;
+
+               seq = (seq + 1) & MCTP_HDR_SEQ_MASK;
+               pos += size;
+       }
+
+       mctp_route_release(rt);
+       consume_skb(skb);
+       return rc;
+}
+
 int mctp_local_output(struct sock *sk, struct mctp_route *rt,
                      struct sk_buff *skb, mctp_eid_t daddr, u8 req_tag)
 {
        struct mctp_skb_cb *cb = mctp_cb(skb);
        struct mctp_hdr *hdr;
        unsigned long flags;
+       unsigned int mtu;
        mctp_eid_t saddr;
        int rc;
        u8 tag;
                tag = req_tag;
        }
 
-       /* TODO: we have the route MTU here; packetise */
 
+       skb->protocol = htons(ETH_P_MCTP);
+       skb->priority = 0;
        skb_reset_transport_header(skb);
        skb_push(skb, sizeof(struct mctp_hdr));
        skb_reset_network_header(skb);
+       skb->dev = rt->dev->dev;
+
+       /* cb->net will have been set on initial ingress */
+       cb->src = saddr;
+
+       /* set up common header fields */
        hdr = mctp_hdr(skb);
        hdr->ver = 1;
        hdr->dest = daddr;
        hdr->src = saddr;
-       hdr->flags_seq_tag = MCTP_HDR_FLAG_SOM | MCTP_HDR_FLAG_EOM | /* TODO */
-               tag;
 
-       skb->dev = rt->dev->dev;
-       skb->protocol = htons(ETH_P_MCTP);
-       skb->priority = 0;
+       mtu = mctp_route_mtu(rt);
 
-       /* cb->net will have been set on initial ingress */
-       cb->src = saddr;
-
-       return mctp_do_route(rt, skb);
+       if (skb->len + sizeof(struct mctp_hdr) <= mtu) {
+               hdr->flags_seq_tag = MCTP_HDR_FLAG_SOM | MCTP_HDR_FLAG_EOM |
+                       tag;
+               return mctp_do_route(rt, skb);
+       } else {
+               return mctp_do_fragment_route(rt, skb, mtu, tag);
+       }
 }
 
 /* route management */