}
 }
 
+/* A getter for the SKB protocol field which will handle VLAN tags consistently
+ * whether VLAN acceleration is enabled or not.
+ */
+static inline __be16 skb_protocol(const struct sk_buff *skb, bool skip_vlan)
+{
+       unsigned int offset = skb_mac_offset(skb) + sizeof(struct ethhdr);
+       __be16 proto = skb->protocol;
+
+       if (!skip_vlan)
+               /* VLAN acceleration strips the VLAN header from the skb and
+                * moves it to skb->vlan_proto
+                */
+               return skb_vlan_tag_present(skb) ? skb->vlan_proto : proto;
+
+       while (eth_type_vlan(proto)) {
+               struct vlan_hdr vhdr, *vh;
+
+               vh = skb_header_pointer(skb, offset, sizeof(vhdr), &vhdr);
+               if (!vh)
+                       break;
+
+               proto = vh->h_vlan_encapsulated_proto;
+               offset += sizeof(vhdr);
+       }
+
+       return proto;
+}
+
 static inline bool vlan_hw_offload_capable(netdev_features_t features,
                                           __be16 proto)
 {
 
 
 #include <linux/ip.h>
 #include <linux/skbuff.h>
+#include <linux/if_vlan.h>
 
 #include <net/inet_sock.h>
 #include <net/dsfield.h>
 
 static inline int INET_ECN_set_ce(struct sk_buff *skb)
 {
-       switch (skb->protocol) {
+       switch (skb_protocol(skb, true)) {
        case cpu_to_be16(ETH_P_IP):
                if (skb_network_header(skb) + sizeof(struct iphdr) <=
                    skb_tail_pointer(skb))
 
 static inline int INET_ECN_set_ect1(struct sk_buff *skb)
 {
-       switch (skb->protocol) {
+       switch (skb_protocol(skb, true)) {
        case cpu_to_be16(ETH_P_IP):
                if (skb_network_header(skb) + sizeof(struct iphdr) <=
                    skb_tail_pointer(skb))
 {
        __u8 inner;
 
-       if (skb->protocol == htons(ETH_P_IP))
+       switch (skb_protocol(skb, true)) {
+       case htons(ETH_P_IP):
                inner = ip_hdr(skb)->tos;
-       else if (skb->protocol == htons(ETH_P_IPV6))
+               break;
+       case htons(ETH_P_IPV6):
                inner = ipv6_get_dsfield(ipv6_hdr(skb));
-       else
+               break;
+       default:
                return 0;
+       }
 
        return INET_ECN_decapsulate(skb, oiph->tos, inner);
 }
 {
        __u8 inner;
 
-       if (skb->protocol == htons(ETH_P_IP))
+       switch (skb_protocol(skb, true)) {
+       case htons(ETH_P_IP):
                inner = ip_hdr(skb)->tos;
-       else if (skb->protocol == htons(ETH_P_IPV6))
+               break;
+       case htons(ETH_P_IPV6):
                inner = ipv6_get_dsfield(ipv6_hdr(skb));
-       else
+               break;
+       default:
                return 0;
+       }
 
        return INET_ECN_decapsulate(skb, ipv6_get_dsfield(oipv6h), inner);
 }
 
        }
 }
 
-static inline __be16 tc_skb_protocol(const struct sk_buff *skb)
-{
-       /* We need to take extra care in case the skb came via
-        * vlan accelerated path. In that case, use skb->vlan_proto
-        * as the original vlan header was already stripped.
-        */
-       if (skb_vlan_tag_present(skb))
-               return skb->vlan_proto;
-       return skb->protocol;
-}
-
 /* Calculate maximal size of packet seen by hard_start_xmit
    routine of this device.
  */
 
 {
        unsigned int iphdr_len;
 
-       if (skb->protocol == cpu_to_be16(ETH_P_IP))
+       switch (skb_protocol(skb, true)) {
+       case cpu_to_be16(ETH_P_IP):
                iphdr_len = sizeof(struct iphdr);
-       else if (skb->protocol == cpu_to_be16(ETH_P_IPV6))
+               break;
+       case cpu_to_be16(ETH_P_IPV6):
                iphdr_len = sizeof(struct ipv6hdr);
-       else
+               break;
+       default:
                return 0;
+       }
 
        if (skb_headlen(skb) < iphdr_len)
                return 0;
 
        tcf_lastuse_update(&ca->tcf_tm);
        bstats_update(&ca->tcf_bstats, skb);
 
-       if (skb->protocol == htons(ETH_P_IP)) {
+       switch (skb_protocol(skb, true)) {
+       case htons(ETH_P_IP):
                if (skb->len < sizeof(struct iphdr))
                        goto out;
 
                proto = NFPROTO_IPV4;
-       } else if (skb->protocol == htons(ETH_P_IPV6)) {
+               break;
+       case htons(ETH_P_IPV6):
                if (skb->len < sizeof(struct ipv6hdr))
                        goto out;
 
                proto = NFPROTO_IPV6;
-       } else {
+               break;
+       default:
                goto out;
        }
 
 
                goto drop;
 
        update_flags = params->update_flags;
-       protocol = tc_skb_protocol(skb);
+       protocol = skb_protocol(skb, false);
 again:
        switch (protocol) {
        case cpu_to_be16(ETH_P_IP):
 
 {
        u8 family = NFPROTO_UNSPEC;
 
-       switch (skb->protocol) {
+       switch (skb_protocol(skb, true)) {
        case htons(ETH_P_IP):
                family = NFPROTO_IPV4;
                break;
                          const struct nf_nat_range2 *range,
                          enum nf_nat_manip_type maniptype)
 {
+       __be16 proto = skb_protocol(skb, true);
        int hooknum, err = NF_ACCEPT;
 
        /* See HOOK2MANIP(). */
        switch (ctinfo) {
        case IP_CT_RELATED:
        case IP_CT_RELATED_REPLY:
-               if (skb->protocol == htons(ETH_P_IP) &&
+               if (proto == htons(ETH_P_IP) &&
                    ip_hdr(skb)->protocol == IPPROTO_ICMP) {
                        if (!nf_nat_icmp_reply_translation(skb, ct, ctinfo,
                                                           hooknum))
                                err = NF_DROP;
                        goto out;
-               } else if (IS_ENABLED(CONFIG_IPV6) &&
-                          skb->protocol == htons(ETH_P_IPV6)) {
+               } else if (IS_ENABLED(CONFIG_IPV6) && proto == htons(ETH_P_IPV6)) {
                        __be16 frag_off;
                        u8 nexthdr = ipv6_hdr(skb)->nexthdr;
                        int hdrlen = ipv6_skip_exthdr(skb,
 MODULE_AUTHOR("Marcelo Ricardo Leitner <marcelo.leitner@gmail.com>");
 MODULE_DESCRIPTION("Connection tracking action");
 MODULE_LICENSE("GPL v2");
-
 
        action = READ_ONCE(ca->tcf_action);
 
        wlen = skb_network_offset(skb);
-       if (tc_skb_protocol(skb) == htons(ETH_P_IP)) {
+       switch (skb_protocol(skb, true)) {
+       case htons(ETH_P_IP):
                wlen += sizeof(struct iphdr);
                if (!pskb_may_pull(skb, wlen))
                        goto out;
 
                proto = NFPROTO_IPV4;
-       } else if (tc_skb_protocol(skb) == htons(ETH_P_IPV6)) {
+               break;
+       case htons(ETH_P_IPV6):
                wlen += sizeof(struct ipv6hdr);
                if (!pskb_may_pull(skb, wlen))
                        goto out;
 
                proto = NFPROTO_IPV6;
-       } else {
+               break;
+       default:
                goto out;
        }
 
 
                        goto drop;
                break;
        case TCA_MPLS_ACT_PUSH:
-               new_lse = tcf_mpls_get_lse(NULL, p, !eth_p_mpls(skb->protocol));
+               new_lse = tcf_mpls_get_lse(NULL, p, !eth_p_mpls(skb_protocol(skb, true)));
                if (skb_mpls_push(skb, new_lse, p->tcfm_proto, mac_len,
                                  skb->dev && skb->dev->type == ARPHRD_ETHER))
                        goto drop;
 
        if (params->flags & SKBEDIT_F_INHERITDSFIELD) {
                int wlen = skb_network_offset(skb);
 
-               switch (tc_skb_protocol(skb)) {
+               switch (skb_protocol(skb, true)) {
                case htons(ETH_P_IP):
                        wlen += sizeof(struct iphdr);
                        if (!pskb_may_pull(skb, wlen))
 
 reclassify:
 #endif
        for (; tp; tp = rcu_dereference_bh(tp->next)) {
-               __be16 protocol = tc_skb_protocol(skb);
+               __be16 protocol = skb_protocol(skb, false);
                int err;
 
                if (tp->protocol != protocol &&
 
        if (dst)
                return ntohl(dst);
 
-       return addr_fold(skb_dst(skb)) ^ (__force u16) tc_skb_protocol(skb);
+       return addr_fold(skb_dst(skb)) ^ (__force u16)skb_protocol(skb, true);
 }
 
 static u32 flow_get_proto(const struct sk_buff *skb,
        if (flow->ports.ports)
                return ntohs(flow->ports.dst);
 
-       return addr_fold(skb_dst(skb)) ^ (__force u16) tc_skb_protocol(skb);
+       return addr_fold(skb_dst(skb)) ^ (__force u16)skb_protocol(skb, true);
 }
 
 static u32 flow_get_iif(const struct sk_buff *skb)
 static u32 flow_get_nfct_src(const struct sk_buff *skb,
                             const struct flow_keys *flow)
 {
-       switch (tc_skb_protocol(skb)) {
+       switch (skb_protocol(skb, true)) {
        case htons(ETH_P_IP):
                return ntohl(CTTUPLE(skb, src.u3.ip));
        case htons(ETH_P_IPV6):
 static u32 flow_get_nfct_dst(const struct sk_buff *skb,
                             const struct flow_keys *flow)
 {
-       switch (tc_skb_protocol(skb)) {
+       switch (skb_protocol(skb, true)) {
        case htons(ETH_P_IP):
                return ntohl(CTTUPLE(skb, dst.u3.ip));
        case htons(ETH_P_IPV6):
 
                /* skb_flow_dissect() does not set n_proto in case an unknown
                 * protocol, so do it rather here.
                 */
-               skb_key.basic.n_proto = skb->protocol;
+               skb_key.basic.n_proto = skb_protocol(skb, false);
                skb_flow_dissect_tunnel_info(skb, &mask->dissector, &skb_key);
                skb_flow_dissect_ct(skb, &mask->dissector, &skb_key,
                                    fl_ct_info_to_flower_map,
 
        };
        int ret, network_offset;
 
-       switch (tc_skb_protocol(skb)) {
+       switch (skb_protocol(skb, true)) {
        case htons(ETH_P_IP):
                state.pf = NFPROTO_IPV4;
                if (!pskb_network_may_pull(skb, sizeof(struct iphdr)))
 
        struct nf_hook_state state;
        int ret;
 
-       switch (tc_skb_protocol(skb)) {
+       switch (skb_protocol(skb, true)) {
        case htons(ETH_P_IP):
                if (!pskb_network_may_pull(skb, sizeof(struct iphdr)))
                        return 0;
 
 META_COLLECTOR(int_protocol)
 {
        /* Let userspace take care of the byte ordering */
-       dst->value = tc_skb_protocol(skb);
+       dst->value = skb_protocol(skb, false);
 }
 
 META_COLLECTOR(int_pkttype)
 
        bool rev = !skb->_nfct, upd = false;
        __be32 ip;
 
-       if (tc_skb_protocol(skb) != htons(ETH_P_IP))
+       if (skb_protocol(skb, true) != htons(ETH_P_IP))
                return false;
 
        if (!nf_ct_get_tuple_skb(&tuple, skb))
        u16 *buf, buf_;
        u8 dscp;
 
-       switch (tc_skb_protocol(skb)) {
+       switch (skb_protocol(skb, true)) {
        case htons(ETH_P_IP):
                buf = skb_header_pointer(skb, offset, sizeof(buf_), &buf_);
                if (unlikely(!buf))
 
        if (p->set_tc_index) {
                int wlen = skb_network_offset(skb);
 
-               switch (tc_skb_protocol(skb)) {
+               switch (skb_protocol(skb, true)) {
                case htons(ETH_P_IP):
                        wlen += sizeof(struct iphdr);
                        if (!pskb_may_pull(skb, wlen) ||
        index = skb->tc_index & (p->indices - 1);
        pr_debug("index %d->%d\n", skb->tc_index, index);
 
-       switch (tc_skb_protocol(skb)) {
+       switch (skb_protocol(skb, true)) {
        case htons(ETH_P_IP):
                ipv4_change_dsfield(ip_hdr(skb), p->mv[index].mask,
                                    p->mv[index].value);
                 */
                if (p->mv[index].mask != 0xff || p->mv[index].value)
                        pr_warn("%s: unsupported protocol %d\n",
-                               __func__, ntohs(tc_skb_protocol(skb)));
+                               __func__, ntohs(skb_protocol(skb, true)));
                break;
        }
 
 
                char haddr[MAX_ADDR_LEN];
 
                neigh_ha_snapshot(haddr, n, dev);
-               err = dev_hard_header(skb, dev, ntohs(tc_skb_protocol(skb)),
+               err = dev_hard_header(skb, dev, ntohs(skb_protocol(skb, false)),
                                      haddr, NULL, skb->len);
 
                if (err < 0)