#include "smp.h"
 
+#define SMP_ALLOW_CMD(smp, code)       set_bit(code, &smp->allow_cmd)
+#define SMP_DISALLOW_CMD(smp, code)    clear_bit(code, &smp->allow_cmd)
+
 #define SMP_TIMEOUT    msecs_to_jiffies(30000)
 
 #define AUTH_REQ_MASK   0x07
 struct smp_chan {
        struct l2cap_conn       *conn;
        struct delayed_work     security_timer;
+       unsigned long           allow_cmd; /* Bitmask of allowed commands */
 
        u8              preq[7]; /* SMP Pairing Request */
        u8              prsp[7]; /* SMP Pairing Response */
 
        smp_send_cmd(smp->conn, SMP_CMD_PAIRING_CONFIRM, sizeof(cp), &cp);
 
+       if (conn->hcon->out)
+               SMP_ALLOW_CMD(smp, SMP_CMD_PAIRING_CONFIRM);
+       else
+               SMP_ALLOW_CMD(smp, SMP_CMD_PAIRING_RANDOM);
+
        return 0;
 }
 
        }
 }
 
+static void smp_allow_key_dist(struct smp_chan *smp)
+{
+       /* Allow the first expected phase 3 PDU. The rest of the PDUs
+        * will be allowed in each PDU handler to ensure we receive
+        * them in the correct order.
+        */
+       if (smp->remote_key_dist & SMP_DIST_ENC_KEY)
+               SMP_ALLOW_CMD(smp, SMP_CMD_ENCRYPT_INFO);
+       else if (smp->remote_key_dist & SMP_DIST_ID_KEY)
+               SMP_ALLOW_CMD(smp, SMP_CMD_IDENT_INFO);
+       else if (smp->remote_key_dist & SMP_DIST_SIGN)
+               SMP_ALLOW_CMD(smp, SMP_CMD_SIGN_INFO);
+}
+
 static void smp_distribute_keys(struct smp_chan *smp)
 {
        struct smp_cmd_pairing *req, *rsp;
        rsp = (void *) &smp->prsp[1];
 
        /* The responder sends its keys first */
-       if (hcon->out && (smp->remote_key_dist & KEY_DIST_MASK))
+       if (hcon->out && (smp->remote_key_dist & KEY_DIST_MASK)) {
+               smp_allow_key_dist(smp);
                return;
+       }
 
        req = (void *) &smp->preq[1];
 
        }
 
        /* If there are still keys to be received wait for them */
-       if (smp->remote_key_dist & KEY_DIST_MASK)
+       if (smp->remote_key_dist & KEY_DIST_MASK) {
+               smp_allow_key_dist(smp);
                return;
+       }
 
        set_bit(SMP_FLAG_COMPLETE, &smp->flags);
        smp_notify_keys(conn);
        smp->conn = conn;
        chan->data = smp;
 
+       SMP_ALLOW_CMD(smp, SMP_CMD_PAIRING_FAIL);
+
        INIT_DELAYED_WORK(&smp->security_timer, smp_timeout);
 
        hci_conn_hold(conn->hcon);
            (req->auth_req & SMP_AUTH_BONDING))
                return SMP_PAIRING_NOTSUPP;
 
+       SMP_DISALLOW_CMD(smp, SMP_CMD_PAIRING_REQ);
+
        smp->preq[0] = SMP_CMD_PAIRING_REQ;
        memcpy(&smp->preq[1], req, sizeof(*req));
        skb_pull(skb, sizeof(*req));
        memcpy(&smp->prsp[1], &rsp, sizeof(rsp));
 
        smp_send_cmd(conn, SMP_CMD_PAIRING_RSP, sizeof(rsp), &rsp);
+       SMP_ALLOW_CMD(smp, SMP_CMD_PAIRING_CONFIRM);
 
        /* Request setup of TK */
        ret = tk_request(conn, 0, auth, rsp.io_capability, req->io_capability);
        if (conn->hcon->role != HCI_ROLE_MASTER)
                return SMP_CMD_NOTSUPP;
 
+       SMP_DISALLOW_CMD(smp, SMP_CMD_PAIRING_RSP);
+
        skb_pull(skb, sizeof(*rsp));
 
        req = (void *) &smp->preq[1];
        if (skb->len < sizeof(smp->pcnf))
                return SMP_INVALID_PARAMS;
 
+       SMP_DISALLOW_CMD(smp, SMP_CMD_PAIRING_CONFIRM);
+
        memcpy(smp->pcnf, skb->data, sizeof(smp->pcnf));
        skb_pull(skb, sizeof(smp->pcnf));
 
-       if (conn->hcon->out)
+       if (conn->hcon->out) {
                smp_send_cmd(conn, SMP_CMD_PAIRING_RANDOM, sizeof(smp->prnd),
                             smp->prnd);
-       else if (test_bit(SMP_FLAG_TK_VALID, &smp->flags))
+               SMP_ALLOW_CMD(smp, SMP_CMD_PAIRING_RANDOM);
+               return 0;
+       }
+
+       if (test_bit(SMP_FLAG_TK_VALID, &smp->flags))
                return smp_confirm(smp);
        else
                set_bit(SMP_FLAG_CFM_PENDING, &smp->flags);
        if (skb->len < sizeof(smp->rrnd))
                return SMP_INVALID_PARAMS;
 
+       SMP_DISALLOW_CMD(smp, SMP_CMD_PAIRING_RANDOM);
+
        memcpy(smp->rrnd, skb->data, sizeof(smp->rrnd));
        skb_pull(skb, sizeof(smp->rrnd));
 
        struct smp_cmd_security_req *rp = (void *) skb->data;
        struct smp_cmd_pairing cp;
        struct hci_conn *hcon = conn->hcon;
-       struct l2cap_chan *chan = conn->smp;
        struct smp_chan *smp;
        u8 sec_level;
 
        if (smp_ltk_encrypt(conn, hcon->pending_sec_level))
                return 0;
 
-       /* If SMP is already in progress ignore this request */
-       if (chan->data)
-               return 0;
-
        smp = smp_chan_create(conn);
        if (!smp)
                return SMP_UNSPECIFIED;
        memcpy(&smp->preq[1], &cp, sizeof(cp));
 
        smp_send_cmd(conn, SMP_CMD_PAIRING_REQ, sizeof(cp), &cp);
+       SMP_ALLOW_CMD(smp, SMP_CMD_PAIRING_RSP);
 
        return 0;
 }
                memcpy(&smp->preq[1], &cp, sizeof(cp));
 
                smp_send_cmd(conn, SMP_CMD_PAIRING_REQ, sizeof(cp), &cp);
+               SMP_ALLOW_CMD(smp, SMP_CMD_PAIRING_RSP);
        } else {
                struct smp_cmd_security_req cp;
                cp.auth_req = authreq;
                smp_send_cmd(conn, SMP_CMD_SECURITY_REQ, sizeof(cp), &cp);
+               SMP_ALLOW_CMD(smp, SMP_CMD_PAIRING_REQ);
        }
 
        set_bit(SMP_FLAG_INITIATOR, &smp->flags);
        if (skb->len < sizeof(*rp))
                return SMP_INVALID_PARAMS;
 
-       /* Ignore this PDU if it wasn't requested */
-       if (!(smp->remote_key_dist & SMP_DIST_ENC_KEY))
-               return 0;
+       SMP_DISALLOW_CMD(smp, SMP_CMD_ENCRYPT_INFO);
+       SMP_ALLOW_CMD(smp, SMP_CMD_MASTER_IDENT);
 
        skb_pull(skb, sizeof(*rp));
 
        if (skb->len < sizeof(*rp))
                return SMP_INVALID_PARAMS;
 
-       /* Ignore this PDU if it wasn't requested */
-       if (!(smp->remote_key_dist & SMP_DIST_ENC_KEY))
-               return 0;
-
        /* Mark the information as received */
        smp->remote_key_dist &= ~SMP_DIST_ENC_KEY;
 
+       SMP_DISALLOW_CMD(smp, SMP_CMD_MASTER_IDENT);
+       if (smp->remote_key_dist & SMP_DIST_ID_KEY)
+               SMP_ALLOW_CMD(smp, SMP_CMD_IDENT_INFO);
+
        skb_pull(skb, sizeof(*rp));
 
        hci_dev_lock(hdev);
        if (skb->len < sizeof(*info))
                return SMP_INVALID_PARAMS;
 
-       /* Ignore this PDU if it wasn't requested */
-       if (!(smp->remote_key_dist & SMP_DIST_ID_KEY))
-               return 0;
+       SMP_DISALLOW_CMD(smp, SMP_CMD_IDENT_INFO);
+       SMP_ALLOW_CMD(smp, SMP_CMD_IDENT_ADDR_INFO);
 
        skb_pull(skb, sizeof(*info));
 
        if (skb->len < sizeof(*info))
                return SMP_INVALID_PARAMS;
 
-       /* Ignore this PDU if it wasn't requested */
-       if (!(smp->remote_key_dist & SMP_DIST_ID_KEY))
-               return 0;
-
        /* Mark the information as received */
        smp->remote_key_dist &= ~SMP_DIST_ID_KEY;
 
+       SMP_DISALLOW_CMD(smp, SMP_CMD_IDENT_ADDR_INFO);
+       if (smp->remote_key_dist & SMP_DIST_SIGN)
+               SMP_ALLOW_CMD(smp, SMP_CMD_SIGN_INFO);
+
        skb_pull(skb, sizeof(*info));
 
        hci_dev_lock(hcon->hdev);
        if (skb->len < sizeof(*rp))
                return SMP_INVALID_PARAMS;
 
-       /* Ignore this PDU if it wasn't requested */
-       if (!(smp->remote_key_dist & SMP_DIST_SIGN))
-               return 0;
-
        /* Mark the information as received */
        smp->remote_key_dist &= ~SMP_DIST_SIGN;
 
+       SMP_DISALLOW_CMD(smp, SMP_CMD_SIGN_INFO);
+
        skb_pull(skb, sizeof(*rp));
 
        hci_dev_lock(hdev);
 {
        struct l2cap_conn *conn = chan->conn;
        struct hci_conn *hcon = conn->hcon;
+       struct smp_chan *smp;
        __u8 code, reason;
        int err = 0;
 
        code = skb->data[0];
        skb_pull(skb, sizeof(code));
 
-       /*
-        * The SMP context must be initialized for all other PDUs except
-        * pairing and security requests. If we get any other PDU when
-        * not initialized simply disconnect (done if this function
-        * returns an error).
+       smp = chan->data;
+
+       if (code > SMP_CMD_MAX)
+               goto drop;
+
+       if (smp && !test_bit(code, &smp->allow_cmd))
+               goto drop;
+
+       /* If we don't have a context the only allowed commands are
+        * pairing request and security request.
         */
-       if (code != SMP_CMD_PAIRING_REQ && code != SMP_CMD_SECURITY_REQ &&
-           !chan->data) {
-               BT_ERR("Unexpected SMP command 0x%02x. Disconnecting.", code);
-               err = -EOPNOTSUPP;
-               goto done;
-       }
+       if (!smp && code != SMP_CMD_PAIRING_REQ && code != SMP_CMD_SECURITY_REQ)
+               goto drop;
 
        switch (code) {
        case SMP_CMD_PAIRING_REQ:
        }
 
        return err;
+
+drop:
+       BT_ERR("%s unexpected SMP command 0x%02x from %pMR", hcon->hdev->name,
+              code, &hcon->dst);
+       kfree_skb(skb);
+       return 0;
 }
 
 static void smp_teardown_cb(struct l2cap_chan *chan, int err)