return memcmp(&a1, &a2, sizeof(a1));
 }
 
-static int __tcp_ao_key_cmp(const struct tcp_ao_key *key,
+static int __tcp_ao_key_cmp(const struct tcp_ao_key *key, int l3index,
                            const union tcp_ao_addr *addr, u8 prefixlen,
                            int family, int sndid, int rcvid)
 {
                return (key->sndid > sndid) ? 1 : -1;
        if (rcvid >= 0 && key->rcvid != rcvid)
                return (key->rcvid > rcvid) ? 1 : -1;
+       if (l3index >= 0 && (key->keyflags & TCP_AO_KEYF_IFINDEX)) {
+               if (key->l3index != l3index)
+                       return (key->l3index > l3index) ? 1 : -1;
+       }
 
        if (family == AF_UNSPEC)
                return 0;
        return -1;
 }
 
-static int tcp_ao_key_cmp(const struct tcp_ao_key *key,
+static int tcp_ao_key_cmp(const struct tcp_ao_key *key, int l3index,
                          const union tcp_ao_addr *addr, u8 prefixlen,
                          int family, int sndid, int rcvid)
 {
        if (family == AF_INET6 && ipv6_addr_v4mapped(&addr->a6)) {
                __be32 addr4 = addr->a6.s6_addr32[3];
 
-               return __tcp_ao_key_cmp(key, (union tcp_ao_addr *)&addr4,
+               return __tcp_ao_key_cmp(key, l3index,
+                                       (union tcp_ao_addr *)&addr4,
                                        prefixlen, AF_INET, sndid, rcvid);
        }
 #endif
-       return __tcp_ao_key_cmp(key, addr, prefixlen, family, sndid, rcvid);
+       return __tcp_ao_key_cmp(key, l3index, addr,
+                               prefixlen, family, sndid, rcvid);
 }
 
-static struct tcp_ao_key *__tcp_ao_do_lookup(const struct sock *sk,
+static struct tcp_ao_key *__tcp_ao_do_lookup(const struct sock *sk, int l3index,
                const union tcp_ao_addr *addr, int family, u8 prefix,
                int sndid, int rcvid)
 {
        hlist_for_each_entry_rcu(key, &ao->head, node) {
                u8 prefixlen = min(prefix, key->prefixlen);
 
-               if (!tcp_ao_key_cmp(key, addr, prefixlen, family, sndid, rcvid))
+               if (!tcp_ao_key_cmp(key, l3index, addr, prefixlen,
+                                   family, sndid, rcvid))
                        return key;
        }
        return NULL;
 }
 
-struct tcp_ao_key *tcp_ao_do_lookup(const struct sock *sk,
+struct tcp_ao_key *tcp_ao_do_lookup(const struct sock *sk, int l3index,
                                    const union tcp_ao_addr *addr,
                                    int family, int sndid, int rcvid)
 {
-       return __tcp_ao_do_lookup(sk, addr, family, U8_MAX, sndid, rcvid);
+       return __tcp_ao_do_lookup(sk, l3index, addr, family, U8_MAX, sndid, rcvid);
 }
 
 static struct tcp_ao_info *tcp_ao_alloc_info(gfp_t flags)
                                        struct request_sock *req,
                                        int sndid, int rcvid)
 {
-       union tcp_ao_addr *addr =
-                       (union tcp_ao_addr *)&inet_rsk(req)->ir_rmt_addr;
+       struct inet_request_sock *ireq = inet_rsk(req);
+       union tcp_ao_addr *addr = (union tcp_ao_addr *)&ireq->ir_rmt_addr;
+       int l3index;
 
-       return tcp_ao_do_lookup(sk, addr, AF_INET, sndid, rcvid);
+       l3index = l3mdev_master_ifindex_by_index(sock_net(sk), ireq->ir_iif);
+       return tcp_ao_do_lookup(sk, l3index, addr, AF_INET, sndid, rcvid);
 }
 
 struct tcp_ao_key *tcp_v4_ao_lookup(const struct sock *sk, struct sock *addr_sk,
                                    int sndid, int rcvid)
 {
+       int l3index = l3mdev_master_ifindex_by_index(sock_net(sk),
+                                                    addr_sk->sk_bound_dev_if);
        union tcp_ao_addr *addr = (union tcp_ao_addr *)&addr_sk->sk_daddr;
 
-       return tcp_ao_do_lookup(sk, addr, AF_INET, sndid, rcvid);
+       return tcp_ao_do_lookup(sk, l3index, addr, AF_INET, sndid, rcvid);
 }
 
 int tcp_ao_prepare_reset(const struct sock *sk, struct sk_buff *skb,
                ao_info = rcu_dereference(tcp_sk(sk)->ao_info);
                if (!ao_info)
                        return -ENOENT;
-               *key = tcp_ao_do_lookup(sk, addr, family, -1, aoh->rnext_keyid);
+               *key = tcp_ao_do_lookup(sk, l3index, addr, family,
+                                       -1, aoh->rnext_keyid);
                if (!*key)
                        return -ENOENT;
                *traffic_key = kmalloc(tcp_ao_digest_size(*key), GFP_ATOMIC);
 
 static struct tcp_ao_key *tcp_ao_inbound_lookup(unsigned short int family,
                const struct sock *sk, const struct sk_buff *skb,
-               int sndid, int rcvid)
+               int sndid, int rcvid, int l3index)
 {
        if (family == AF_INET) {
                const struct iphdr *iph = ip_hdr(skb);
 
-               return tcp_ao_do_lookup(sk, (union tcp_ao_addr *)&iph->saddr,
-                               AF_INET, sndid, rcvid);
+               return tcp_ao_do_lookup(sk, l3index,
+                                       (union tcp_ao_addr *)&iph->saddr,
+                                       AF_INET, sndid, rcvid);
        } else {
                const struct ipv6hdr *iph = ipv6_hdr(skb);
 
-               return tcp_ao_do_lookup(sk, (union tcp_ao_addr *)&iph->saddr,
-                               AF_INET6, sndid, rcvid);
+               return tcp_ao_do_lookup(sk, l3index,
+                                       (union tcp_ao_addr *)&iph->saddr,
+                                       AF_INET6, sndid, rcvid);
        }
 }
 
 void tcp_ao_syncookie(struct sock *sk, const struct sk_buff *skb,
                      struct tcp_request_sock *treq,
-                     unsigned short int family)
+                     unsigned short int family, int l3index)
 {
        const struct tcphdr *th = tcp_hdr(skb);
        const struct tcp_ao_hdr *aoh;
        if (tcp_parse_auth_options(th, NULL, &aoh) || !aoh)
                return;
 
-       key = tcp_ao_inbound_lookup(family, sk, skb, -1, aoh->keyid);
+       key = tcp_ao_inbound_lookup(family, sk, skb, -1, aoh->keyid, l3index);
        if (!key)
                /* Key not found, continue without TCP-AO */
                return;
 tcp_ao_verify_hash(const struct sock *sk, const struct sk_buff *skb,
                   unsigned short int family, struct tcp_ao_info *info,
                   const struct tcp_ao_hdr *aoh, struct tcp_ao_key *key,
-                  u8 *traffic_key, u8 *phash, u32 sne)
+                  u8 *traffic_key, u8 *phash, u32 sne, int l3index)
 {
        u8 maclen = aoh->length - sizeof(struct tcp_ao_hdr);
        const struct tcphdr *th = tcp_hdr(skb);
                atomic64_inc(&info->counters.pkt_bad);
                atomic64_inc(&key->pkt_bad);
                tcp_hash_fail("AO hash wrong length", family, skb,
-                             "%u != %d", maclen, tcp_ao_maclen(key));
+                             "%u != %d L3index: %d", maclen,
+                             tcp_ao_maclen(key), l3index);
                return SKB_DROP_REASON_TCP_AOFAILURE;
        }
 
                NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPAOBAD);
                atomic64_inc(&info->counters.pkt_bad);
                atomic64_inc(&key->pkt_bad);
-               tcp_hash_fail("AO hash mismatch", family, skb, "");
+               tcp_hash_fail("AO hash mismatch", family, skb,
+                             "L3index: %d", l3index);
                kfree(hash_buf);
                return SKB_DROP_REASON_TCP_AOFAILURE;
        }
 enum skb_drop_reason
 tcp_inbound_ao_hash(struct sock *sk, const struct sk_buff *skb,
                    unsigned short int family, const struct request_sock *req,
-                   const struct tcp_ao_hdr *aoh)
+                   int l3index, const struct tcp_ao_hdr *aoh)
 {
        const struct tcphdr *th = tcp_hdr(skb);
        u8 *phash = (u8 *)(aoh + 1); /* hash goes just after the header */
        if (!info) {
                NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPAOKEYNOTFOUND);
                tcp_hash_fail("AO key not found", family, skb,
-                             "keyid: %u", aoh->keyid);
+                             "keyid: %u L3index: %d", aoh->keyid, l3index);
                return SKB_DROP_REASON_TCP_AOUNEXPECTED;
        }
 
                /* Established socket, traffic key are cached */
                traffic_key = rcv_other_key(key);
                err = tcp_ao_verify_hash(sk, skb, family, info, aoh, key,
-                                        traffic_key, phash, sne);
+                                        traffic_key, phash, sne, l3index);
                if (err)
                        return err;
                current_key = READ_ONCE(info->current_key);
         * - request sockets would race on those key pointers
         * - tcp_ao_del_cmd() allows async key removal
         */
-       key = tcp_ao_inbound_lookup(family, sk, skb, -1, aoh->keyid);
+       key = tcp_ao_inbound_lookup(family, sk, skb, -1, aoh->keyid, l3index);
        if (!key)
                goto key_not_found;
 
                return SKB_DROP_REASON_NOT_SPECIFIED;
        tcp_ao_calc_key_skb(key, traffic_key, skb, sisn, disn, family);
        ret = tcp_ao_verify_hash(sk, skb, family, info, aoh, key,
-                                traffic_key, phash, sne);
+                                traffic_key, phash, sne, l3index);
        kfree(traffic_key);
        return ret;
 
        NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPAOKEYNOTFOUND);
        atomic64_inc(&info->counters.key_not_found);
        tcp_hash_fail("Requested by the peer AO key id not found",
-                     family, skb, "");
+                     family, skb, "L3index: %d", l3index);
        return SKB_DROP_REASON_TCP_AOKEYNOTFOUND;
 }
 
        struct tcp_ao_info *ao_info;
        union tcp_ao_addr *addr;
        struct tcp_ao_key *key;
-       int family;
+       int family, l3index;
 
        ao_info = rcu_dereference_protected(tp->ao_info,
                                            lockdep_sock_is_held(sk));
 #endif
        else
                return;
+       l3index = l3mdev_master_ifindex_by_index(sock_net(sk),
+                                                sk->sk_bound_dev_if);
 
        hlist_for_each_entry_rcu(key, &ao_info->head, node) {
-               if (!tcp_ao_key_cmp(key, addr, key->prefixlen, family, -1, -1))
+               if (!tcp_ao_key_cmp(key, l3index, addr, key->prefixlen, family, -1, -1))
                        continue;
 
                if (key == ao_info->current_key)
        struct tcp_ao_key *key, *new_key, *first_key;
        struct tcp_ao_info *new_ao, *ao;
        struct hlist_node *key_head;
+       int l3index, ret = -ENOMEM;
        union tcp_ao_addr *addr;
        bool match = false;
-       int ret = -ENOMEM;
 
        ao = rcu_dereference(tcp_sk(sk)->ao_info);
        if (!ao)
                ret = -EAFNOSUPPORT;
                goto free_ao;
        }
+       l3index = l3mdev_master_ifindex_by_index(sock_net(newsk),
+                                                newsk->sk_bound_dev_if);
 
        hlist_for_each_entry_rcu(key, &ao->head, node) {
-               if (tcp_ao_key_cmp(key, addr, key->prefixlen, family, -1, -1))
+               if (tcp_ao_key_cmp(key, l3index, addr, key->prefixlen, family, -1, -1))
                        continue;
 
                new_key = tcp_ao_copy_key(newsk, key);
        return ERR_PTR(-ESOCKTNOSUPPORT);
 }
 
-#define TCP_AO_KEYF_ALL                (TCP_AO_KEYF_EXCLUDE_OPT)
+#define TCP_AO_KEYF_ALL (TCP_AO_KEYF_IFINDEX | TCP_AO_KEYF_EXCLUDE_OPT)
+#define TCP_AO_GET_KEYF_VALID  (TCP_AO_KEYF_IFINDEX)
 
 static struct tcp_ao_key *tcp_ao_key_alloc(struct sock *sk,
                                           struct tcp_ao_add *cmd)
        union tcp_ao_addr *addr;
        struct tcp_ao_key *key;
        struct tcp_ao_add cmd;
+       int ret, l3index = 0;
        bool first = false;
-       int ret;
 
        if (optlen < sizeof(cmd))
                return -EINVAL;
                        return -EINVAL;
        }
 
+       if (cmd.ifindex && !(cmd.keyflags & TCP_AO_KEYF_IFINDEX))
+               return -EINVAL;
+
+       /* For cmd.tcp_ifindex = 0 the key will apply to the default VRF */
+       if (cmd.keyflags & TCP_AO_KEYF_IFINDEX && cmd.ifindex) {
+               int bound_dev_if = READ_ONCE(sk->sk_bound_dev_if);
+               struct net_device *dev;
+
+               rcu_read_lock();
+               dev = dev_get_by_index_rcu(sock_net(sk), cmd.ifindex);
+               if (dev && netif_is_l3_master(dev))
+                       l3index = dev->ifindex;
+               rcu_read_unlock();
+
+               if (!dev || !l3index)
+                       return -EINVAL;
+
+               /* It's still possible to bind after adding keys or even
+                * re-bind to a different dev (with CAP_NET_RAW).
+                * So, no reason to return error here, rather try to be
+                * nice and warn the user.
+                */
+               if (bound_dev_if && bound_dev_if != cmd.ifindex)
+                       net_warn_ratelimited("AO key ifindex %d != sk bound ifindex %d\n",
+                                            cmd.ifindex, bound_dev_if);
+       }
+
        /* Don't allow keys for peers that have a matching TCP-MD5 key */
-       if (tcp_md5_do_lookup_any_l3index(sk, addr, family))
-               return -EKEYREJECTED;
+       if (cmd.keyflags & TCP_AO_KEYF_IFINDEX) {
+               /* Non-_exact version of tcp_md5_do_lookup() will
+                * as well match keys that aren't bound to a specific VRF
+                * (that will make them match AO key with
+                * sysctl_tcp_l3dev_accept = 1
+                */
+               if (tcp_md5_do_lookup(sk, l3index, addr, family))
+                       return -EKEYREJECTED;
+       } else {
+               if (tcp_md5_do_lookup_any_l3index(sk, addr, family))
+                       return -EKEYREJECTED;
+       }
 
        ao_info = setsockopt_ao_info(sk);
        if (IS_ERR(ao_info))
                 * > The IDs of MKTs MUST NOT overlap where their
                 * > TCP connection identifiers overlap.
                 */
-               if (__tcp_ao_do_lookup(sk, addr, family,
-                                      cmd.prefix, -1, cmd.rcvid))
+               if (__tcp_ao_do_lookup(sk, l3index, addr, family, cmd.prefix, -1, cmd.rcvid))
                        return -EEXIST;
-               if (__tcp_ao_do_lookup(sk, addr, family,
+               if (__tcp_ao_do_lookup(sk, l3index, addr, family,
                                       cmd.prefix, cmd.sndid, -1))
                        return -EEXIST;
        }
        key->keyflags   = cmd.keyflags;
        key->sndid      = cmd.sndid;
        key->rcvid      = cmd.rcvid;
+       key->l3index    = l3index;
        atomic64_set(&key->pkt_good, 0);
        atomic64_set(&key->pkt_bad, 0);
 
        return err;
 }
 
+#define TCP_AO_DEL_KEYF_ALL (TCP_AO_KEYF_IFINDEX)
 static int tcp_ao_del_cmd(struct sock *sk, unsigned short int family,
                          sockptr_t optval, int optlen)
 {
        struct tcp_ao_key *key, *new_current = NULL, *new_rnext = NULL;
+       int err, addr_len, l3index = 0;
        struct tcp_ao_info *ao_info;
        union tcp_ao_addr *addr;
        struct tcp_ao_del cmd;
-       int addr_len;
        __u8 prefix;
        u16 port;
-       int err;
 
        if (optlen < sizeof(cmd))
                return -EINVAL;
                        return -EINVAL;
        }
 
+       if (cmd.keyflags & ~TCP_AO_DEL_KEYF_ALL)
+               return -EINVAL;
+
+       /* No sanity check for TCP_AO_KEYF_IFINDEX as if a VRF
+        * was destroyed, there still should be a way to delete keys,
+        * that were bound to that l3intf. So, fail late at lookup stage
+        * if there is no key for that ifindex.
+        */
+       if (cmd.ifindex && !(cmd.keyflags & TCP_AO_KEYF_IFINDEX))
+               return -EINVAL;
+
        ao_info = setsockopt_ao_info(sk);
        if (IS_ERR(ao_info))
                return PTR_ERR(ao_info);
                    memcmp(addr, &key->addr, addr_len))
                        continue;
 
+               if ((cmd.keyflags & TCP_AO_KEYF_IFINDEX) !=
+                   (key->keyflags & TCP_AO_KEYF_IFINDEX))
+                       continue;
+
+               if (key->l3index != l3index)
+                       continue;
+
                if (key == new_current || key == new_rnext)
                        continue;
 
        struct tcp_ao_key *key, *current_key;
        bool do_address_matching = true;
        union tcp_ao_addr *addr = NULL;
+       int err, l3index, user_len;
        unsigned int max_keys;  /* maximum number of keys to copy to user */
        size_t out_offset = 0;
        size_t bytes_to_write;  /* number of bytes to write to user level */
-       int err, user_len;
        u32 matched_keys;       /* keys from ao_info matched so far */
        int optlen_out;
        __be16 port = 0;
 
        if (opt_in.pkt_good || opt_in.pkt_bad)
                return -EINVAL;
+       if (opt_in.keyflags & ~TCP_AO_GET_KEYF_VALID)
+               return -EINVAL;
+       if (opt_in.ifindex && !(opt_in.keyflags & TCP_AO_KEYF_IFINDEX))
+               return -EINVAL;
 
        if (opt_in.reserved != 0)
                return -EINVAL;
 
        max_keys = opt_in.nkeys;
+       l3index = (opt_in.keyflags & TCP_AO_KEYF_IFINDEX) ? opt_in.ifindex : -1;
 
        if (opt_in.get_all || opt_in.is_current || opt_in.is_rnext) {
                if (opt_in.get_all && (opt_in.is_current || opt_in.is_rnext))
                        continue;
                }
 
-               if (tcp_ao_key_cmp(key, addr, opt_in.prefix,
+               if (tcp_ao_key_cmp(key, l3index, addr, opt_in.prefix,
                                   opt_in.addr.ss_family,
                                   opt_in.sndid, opt_in.rcvid) != 0)
                        continue;
                opt_out.nkeys = 0;
                opt_out.maclen = key->maclen;
                opt_out.keylen = key->keylen;
+               opt_out.ifindex = key->l3index;
                opt_out.pkt_good = atomic64_read(&key->pkt_good);
                opt_out.pkt_bad = atomic64_read(&key->pkt_bad);
                memcpy(&opt_out.key, key->key, key->keylen);