#ifdef CONFIG_INET
 static struct sock *sk_lookup(struct net *net, struct bpf_sock_tuple *tuple,
-                             struct sk_buff *skb, u8 family, u8 proto)
+                             int dif, int sdif, u8 family, u8 proto)
 {
        bool refcounted = false;
        struct sock *sk = NULL;
-       int dif = 0;
-
-       if (skb->dev)
-               dif = skb->dev->ifindex;
 
        if (family == AF_INET) {
                __be32 src4 = tuple->ipv4.saddr;
                __be32 dst4 = tuple->ipv4.daddr;
-               int sdif = inet_sdif(skb);
 
                if (proto == IPPROTO_TCP)
-                       sk = __inet_lookup(net, &tcp_hashinfo, skb, 0,
+                       sk = __inet_lookup(net, &tcp_hashinfo, NULL, 0,
                                           src4, tuple->ipv4.sport,
                                           dst4, tuple->ipv4.dport,
                                           dif, sdif, &refcounted);
                else
                        sk = __udp4_lib_lookup(net, src4, tuple->ipv4.sport,
                                               dst4, tuple->ipv4.dport,
-                                              dif, sdif, &udp_table, skb);
+                                              dif, sdif, &udp_table, NULL);
 #if IS_ENABLED(CONFIG_IPV6)
        } else {
                struct in6_addr *src6 = (struct in6_addr *)&tuple->ipv6.saddr;
                struct in6_addr *dst6 = (struct in6_addr *)&tuple->ipv6.daddr;
                u16 hnum = ntohs(tuple->ipv6.dport);
-               int sdif = inet6_sdif(skb);
 
                if (proto == IPPROTO_TCP)
-                       sk = __inet6_lookup(net, &tcp_hashinfo, skb, 0,
+                       sk = __inet6_lookup(net, &tcp_hashinfo, NULL, 0,
                                            src6, tuple->ipv6.sport,
                                            dst6, hnum,
                                            dif, sdif, &refcounted);
                                                            src6, tuple->ipv6.sport,
                                                            dst6, hnum,
                                                            dif, sdif,
-                                                           &udp_table, skb);
+                                                           &udp_table, NULL);
 #endif
        }
 
  * callers to satisfy BPF_CALL declarations.
  */
 static unsigned long
-bpf_sk_lookup(struct sk_buff *skb, struct bpf_sock_tuple *tuple, u32 len,
-             u8 proto, u64 netns_id, u64 flags)
+__bpf_sk_lookup(struct sk_buff *skb, struct bpf_sock_tuple *tuple, u32 len,
+               struct net *caller_net, u32 ifindex, u8 proto, u64 netns_id,
+               u64 flags)
 {
-       struct net *caller_net;
        struct sock *sk = NULL;
        u8 family = AF_UNSPEC;
        struct net *net;
+       int sdif;
 
        family = len == sizeof(tuple->ipv4) ? AF_INET : AF_INET6;
        if (unlikely(family == AF_UNSPEC || netns_id > U32_MAX || flags))
                goto out;
 
-       if (skb->dev)
-               caller_net = dev_net(skb->dev);
+       if (family == AF_INET)
+               sdif = inet_sdif(skb);
        else
-               caller_net = sock_net(skb->sk);
+               sdif = inet6_sdif(skb);
+
        if (netns_id) {
                net = get_net_ns_by_id(caller_net, netns_id);
                if (unlikely(!net))
                        goto out;
-               sk = sk_lookup(net, tuple, skb, family, proto);
+               sk = sk_lookup(net, tuple, ifindex, sdif, family, proto);
                put_net(net);
        } else {
                net = caller_net;
-               sk = sk_lookup(net, tuple, skb, family, proto);
+               sk = sk_lookup(net, tuple, ifindex, sdif, family, proto);
        }
 
        if (sk)
        return (unsigned long) sk;
 }
 
+static unsigned long
+bpf_sk_lookup(struct sk_buff *skb, struct bpf_sock_tuple *tuple, u32 len,
+             u8 proto, u64 netns_id, u64 flags)
+{
+       struct net *caller_net;
+       int ifindex;
+
+       if (skb->dev) {
+               caller_net = dev_net(skb->dev);
+               ifindex = skb->dev->ifindex;
+       } else {
+               caller_net = sock_net(skb->sk);
+               ifindex = 0;
+       }
+
+       return __bpf_sk_lookup(skb, tuple, len, caller_net, ifindex,
+                             proto, netns_id, flags);
+}
+
 BPF_CALL_5(bpf_sk_lookup_tcp, struct sk_buff *, skb,
           struct bpf_sock_tuple *, tuple, u32, len, u64, netns_id, u64, flags)
 {
        .ret_type       = RET_INTEGER,
        .arg1_type      = ARG_PTR_TO_SOCKET,
 };
+
+BPF_CALL_5(bpf_xdp_sk_lookup_udp, struct xdp_buff *, ctx,
+          struct bpf_sock_tuple *, tuple, u32, len, u32, netns_id, u64, flags)
+{
+       struct net *caller_net = dev_net(ctx->rxq->dev);
+       int ifindex = ctx->rxq->dev->ifindex;
+
+       return __bpf_sk_lookup(NULL, tuple, len, caller_net, ifindex,
+                             IPPROTO_UDP, netns_id, flags);
+}
+
+static const struct bpf_func_proto bpf_xdp_sk_lookup_udp_proto = {
+       .func           = bpf_xdp_sk_lookup_udp,
+       .gpl_only       = false,
+       .pkt_access     = true,
+       .ret_type       = RET_PTR_TO_SOCKET_OR_NULL,
+       .arg1_type      = ARG_PTR_TO_CTX,
+       .arg2_type      = ARG_PTR_TO_MEM,
+       .arg3_type      = ARG_CONST_SIZE,
+       .arg4_type      = ARG_ANYTHING,
+       .arg5_type      = ARG_ANYTHING,
+};
+
+BPF_CALL_5(bpf_xdp_sk_lookup_tcp, struct xdp_buff *, ctx,
+          struct bpf_sock_tuple *, tuple, u32, len, u32, netns_id, u64, flags)
+{
+       struct net *caller_net = dev_net(ctx->rxq->dev);
+       int ifindex = ctx->rxq->dev->ifindex;
+
+       return __bpf_sk_lookup(NULL, tuple, len, caller_net, ifindex,
+                             IPPROTO_TCP, netns_id, flags);
+}
+
+static const struct bpf_func_proto bpf_xdp_sk_lookup_tcp_proto = {
+       .func           = bpf_xdp_sk_lookup_tcp,
+       .gpl_only       = false,
+       .pkt_access     = true,
+       .ret_type       = RET_PTR_TO_SOCKET_OR_NULL,
+       .arg1_type      = ARG_PTR_TO_CTX,
+       .arg2_type      = ARG_PTR_TO_MEM,
+       .arg3_type      = ARG_CONST_SIZE,
+       .arg4_type      = ARG_ANYTHING,
+       .arg5_type      = ARG_ANYTHING,
+};
 #endif /* CONFIG_INET */
 
 bool bpf_helper_changes_pkt_data(void *func)
                return &bpf_xdp_adjust_tail_proto;
        case BPF_FUNC_fib_lookup:
                return &bpf_xdp_fib_lookup_proto;
+#ifdef CONFIG_INET
+       case BPF_FUNC_sk_lookup_udp:
+               return &bpf_xdp_sk_lookup_udp_proto;
+       case BPF_FUNC_sk_lookup_tcp:
+               return &bpf_xdp_sk_lookup_tcp_proto;
+       case BPF_FUNC_sk_release:
+               return &bpf_sk_release_proto;
+#endif
        default:
                return bpf_base_func_proto(func_id);
        }