void bpf_warn_invalid_xdp_action(u32 act);
 void bpf_warn_invalid_xdp_redirect(u32 ifindex);
 
-struct sock *do_sk_redirect_map(void);
+struct sock *do_sk_redirect_map(struct sk_buff *skb);
 
 #ifdef CONFIG_BPF_JIT
 extern int bpf_jit_enable;
 
                        struct inet6_skb_parm   h6;
 #endif
                } header;       /* For incoming skbs */
+               struct {
+                       __u32 key;
+                       __u32 flags;
+                       struct bpf_map *map;
+               } bpf;
        };
 };
 
 
 #include <linux/workqueue.h>
 #include <linux/list.h>
 #include <net/strparser.h>
+#include <net/tcp.h>
 
 struct bpf_stab {
        struct bpf_map map;
                return SK_DROP;
 
        skb_orphan(skb);
+       /* We need to ensure that BPF metadata for maps is also cleared
+        * when we orphan the skb so that we don't have the possibility
+        * to reference a stale map.
+        */
+       TCP_SKB_CB(skb)->bpf.map = NULL;
        skb->sk = psock->sock;
        bpf_compute_data_end(skb);
+       preempt_disable();
        rc = (*prog->bpf_func)(skb, prog->insnsi);
+       preempt_enable();
        skb->sk = NULL;
 
        return rc;
        struct sock *sk;
        int rc;
 
-       /* Because we use per cpu values to feed input from sock redirect
-        * in BPF program to do_sk_redirect_map() call we need to ensure we
-        * are not preempted. RCU read lock is not sufficient in this case
-        * with CONFIG_PREEMPT_RCU enabled so we must be explicit here.
-        */
-       preempt_disable();
        rc = smap_verdict_func(psock, skb);
        switch (rc) {
        case SK_REDIRECT:
-               sk = do_sk_redirect_map();
-               preempt_enable();
+               sk = do_sk_redirect_map(skb);
                if (likely(sk)) {
                        struct smap_psock *peer = smap_psock_sk(sk);
 
        /* Fall through and free skb otherwise */
        case SK_DROP:
        default:
-               if (rc != SK_REDIRECT)
-                       preempt_enable();
                kfree_skb(skb);
        }
 }
 
        .arg2_type      = ARG_ANYTHING,
 };
 
-BPF_CALL_3(bpf_sk_redirect_map, struct bpf_map *, map, u32, key, u64, flags)
+BPF_CALL_4(bpf_sk_redirect_map, struct sk_buff *, skb,
+          struct bpf_map *, map, u32, key, u64, flags)
 {
-       struct redirect_info *ri = this_cpu_ptr(&redirect_info);
+       struct tcp_skb_cb *tcb = TCP_SKB_CB(skb);
 
        if (unlikely(flags))
                return SK_ABORTED;
 
-       ri->ifindex = key;
-       ri->flags = flags;
-       ri->map = map;
+       tcb->bpf.key = key;
+       tcb->bpf.flags = flags;
+       tcb->bpf.map = map;
 
        return SK_REDIRECT;
 }
 
-struct sock *do_sk_redirect_map(void)
+struct sock *do_sk_redirect_map(struct sk_buff *skb)
 {
-       struct redirect_info *ri = this_cpu_ptr(&redirect_info);
+       struct tcp_skb_cb *tcb = TCP_SKB_CB(skb);
        struct sock *sk = NULL;
 
-       if (ri->map) {
-               sk = __sock_map_lookup_elem(ri->map, ri->ifindex);
+       if (tcb->bpf.map) {
+               sk = __sock_map_lookup_elem(tcb->bpf.map, tcb->bpf.key);
 
-               ri->ifindex = 0;
-               ri->map = NULL;
-               /* we do not clear flags for future lookup */
+               tcb->bpf.key = 0;
+               tcb->bpf.map = NULL;
        }
 
        return sk;
        .func           = bpf_sk_redirect_map,
        .gpl_only       = false,
        .ret_type       = RET_INTEGER,
-       .arg1_type      = ARG_CONST_MAP_PTR,
-       .arg2_type      = ARG_ANYTHING,
+       .arg1_type      = ARG_PTR_TO_CTX,
+       .arg2_type      = ARG_CONST_MAP_PTR,
        .arg3_type      = ARG_ANYTHING,
+       .arg4_type      = ARG_ANYTHING,
 };
 
 BPF_CALL_1(bpf_get_cgroup_classid, const struct sk_buff *, skb)
 
                ret = 1;
 
        bpf_printk("sockmap: %d -> %d @ %d\n", lport, bpf_ntohl(rport), ret);
-       return bpf_sk_redirect_map(&sock_map, ret, 0);
+       return bpf_sk_redirect_map(skb, &sock_map, ret, 0);
 }
 
 SEC("sockops")
 
  *     @flags: reserved for future use
  *     Return: 0 on success or negative error code
  *
- * int bpf_sk_redirect_map(map, key, flags)
+ * int bpf_sk_redirect_map(skb, map, key, flags)
  *     Redirect skb to a sock in map using key as a lookup key for the
  *     sock in map.
+ *     @skb: pointer to skb
  *     @map: pointer to sockmap
  *     @key: key to lookup sock in map
  *     @flags: reserved for future use
 
 static int (*bpf_setsockopt)(void *ctx, int level, int optname, void *optval,
                             int optlen) =
        (void *) BPF_FUNC_setsockopt;
-static int (*bpf_sk_redirect_map)(void *map, int key, int flags) =
+static int (*bpf_sk_redirect_map)(void *ctx, void *map, int key, int flags) =
        (void *) BPF_FUNC_sk_redirect_map;
 static int (*bpf_sock_map_update)(void *map, void *key, void *value,
                                  unsigned long long flags) =
 
        bpf_printk("verdict: data[0] = redir(%u:%u)\n", map, sk);
 
        if (!map)
-               return bpf_sk_redirect_map(&sock_map_rx, sk, 0);
-       return bpf_sk_redirect_map(&sock_map_tx, sk, 0);
+               return bpf_sk_redirect_map(skb, &sock_map_rx, sk, 0);
+       return bpf_sk_redirect_map(skb, &sock_map_tx, sk, 0);
 }
 
 char _license[] SEC("license") = "GPL";