struct net_device *dev = NULL;
        struct neighbour *neigh;
        void *dst, *lladdr;
+       u8 protocol = 0;
        int err;
 
        ASSERT_RTNL();
        dst = nla_data(tb[NDA_DST]);
        lladdr = tb[NDA_LLADDR] ? nla_data(tb[NDA_LLADDR]) : NULL;
 
+       if (tb[NDA_PROTOCOL]) {
+               if (nla_len(tb[NDA_PROTOCOL]) != sizeof(u8)) {
+                       NL_SET_ERR_MSG(extack, "Invalid protocol attribute");
+                       goto out;
+               }
+               protocol = nla_get_u8(tb[NDA_PROTOCOL]);
+       }
+
        if (ndm->ndm_flags & NTF_PROXY) {
                struct pneigh_entry *pn;
 
                pn = pneigh_lookup(tbl, net, dst, dev, 1);
                if (pn) {
                        pn->flags = ndm->ndm_flags;
+                       if (protocol)
+                               pn->protocol = protocol;
                        err = 0;
                }
                goto out;
        } else
                err = __neigh_update(neigh, lladdr, ndm->ndm_state, flags,
                                     NETLINK_CB(skb).portid, extack);
+
+       if (protocol)
+               neigh->protocol = protocol;
+
        neigh_release(neigh);
 
 out:
            nla_put(skb, NDA_CACHEINFO, sizeof(ci), &ci))
                goto nla_put_failure;
 
+       if (neigh->protocol && nla_put_u8(skb, NDA_PROTOCOL, neigh->protocol))
+               goto nla_put_failure;
+
        nlmsg_end(skb, nlh);
        return 0;
 
        if (nla_put(skb, NDA_DST, tbl->key_len, pn->key))
                goto nla_put_failure;
 
+       if (pn->protocol && nla_put_u8(skb, NDA_PROTOCOL, pn->protocol))
+               goto nla_put_failure;
+
        nlmsg_end(skb, nlh);
        return 0;
 
               + nla_total_size(MAX_ADDR_LEN) /* NDA_DST */
               + nla_total_size(MAX_ADDR_LEN) /* NDA_LLADDR */
               + nla_total_size(sizeof(struct nda_cacheinfo))
-              + nla_total_size(4); /* NDA_PROBES */
+              + nla_total_size(4)  /* NDA_PROBES */
+              + nla_total_size(1); /* NDA_PROTOCOL */
 }
 
 static void __neigh_notify(struct neighbour *n, int type, int flags,