static u8 cake_handle_diffserv(struct sk_buff *skb, u16 wash)
 {
-       int wlen = skb_network_offset(skb);
+       const int offset = skb_network_offset(skb);
+       u16 *buf, buf_;
        u8 dscp;
 
        switch (tc_skb_protocol(skb)) {
        case htons(ETH_P_IP):
-               wlen += sizeof(struct iphdr);
-               if (!pskb_may_pull(skb, wlen) ||
-                   skb_try_make_writable(skb, wlen))
+               buf = skb_header_pointer(skb, offset, sizeof(buf_), &buf_);
+               if (unlikely(!buf))
                        return 0;
 
-               dscp = ipv4_get_dsfield(ip_hdr(skb)) >> 2;
-               if (wash && dscp)
+               /* ToS is in the second byte of iphdr */
+               dscp = ipv4_get_dsfield((struct iphdr *)buf) >> 2;
+
+               if (wash && dscp) {
+                       const int wlen = offset + sizeof(struct iphdr);
+
+                       if (!pskb_may_pull(skb, wlen) ||
+                           skb_try_make_writable(skb, wlen))
+                               return 0;
+
                        ipv4_change_dsfield(ip_hdr(skb), INET_ECN_MASK, 0);
+               }
+
                return dscp;
 
        case htons(ETH_P_IPV6):
-               wlen += sizeof(struct ipv6hdr);
-               if (!pskb_may_pull(skb, wlen) ||
-                   skb_try_make_writable(skb, wlen))
+               buf = skb_header_pointer(skb, offset, sizeof(buf_), &buf_);
+               if (unlikely(!buf))
                        return 0;
 
-               dscp = ipv6_get_dsfield(ipv6_hdr(skb)) >> 2;
-               if (wash && dscp)
+               /* Traffic class is in the first and second bytes of ipv6hdr */
+               dscp = ipv6_get_dsfield((struct ipv6hdr *)buf) >> 2;
+
+               if (wash && dscp) {
+                       const int wlen = offset + sizeof(struct ipv6hdr);
+
+                       if (!pskb_may_pull(skb, wlen) ||
+                           skb_try_make_writable(skb, wlen))
+                               return 0;
+
                        ipv6_change_dsfield(ipv6_hdr(skb), INET_ECN_MASK, 0);
+               }
+
                return dscp;
 
        case htons(ETH_P_ARP):