]> www.infradead.org Git - linux-platform-drivers-x86.git/commitdiff
skmsg: Pass psock pointer to ->psock_update_sk_prot()
authorCong Wang <cong.wang@bytedance.com>
Wed, 7 Apr 2021 03:21:11 +0000 (20:21 -0700)
committerDaniel Borkmann <daniel@iogearbox.net>
Mon, 12 Apr 2021 15:34:27 +0000 (17:34 +0200)
Using sk_psock() to retrieve psock pointer from sock requires
RCU read lock, but we already get psock pointer before calling
->psock_update_sk_prot() in both cases, so we can just pass it
without bothering sk_psock().

Fixes: 8a59f9d1e3d4 ("sock: Introduce sk->sk_prot->psock_update_sk_prot()")
Reported-by: syzbot+320a3bc8d80f478c37e4@syzkaller.appspotmail.com
Signed-off-by: Cong Wang <cong.wang@bytedance.com>
Signed-off-by: Daniel Borkmann <daniel@iogearbox.net>
Tested-by: syzbot+320a3bc8d80f478c37e4@syzkaller.appspotmail.com
Reviewed-by: Jakub Sitnicki <jakub@cloudflare.com>
Acked-by: John Fastabend <john.fastabend@gmail.com>
Link: https://lore.kernel.org/bpf/20210407032111.33398-1-xiyou.wangcong@gmail.com
include/linux/skmsg.h
include/net/sock.h
include/net/tcp.h
include/net/udp.h
net/core/sock_map.c
net/ipv4/tcp_bpf.c
net/ipv4/udp_bpf.c

index f78e90a04a69b733707fb94fcd157d398c46b143..e2fb0a5a101e983f7f47ef00aa9dc40947c90772 100644 (file)
@@ -99,7 +99,8 @@ struct sk_psock {
        void (*saved_close)(struct sock *sk, long timeout);
        void (*saved_write_space)(struct sock *sk);
        void (*saved_data_ready)(struct sock *sk);
-       int  (*psock_update_sk_prot)(struct sock *sk, bool restore);
+       int  (*psock_update_sk_prot)(struct sock *sk, struct sk_psock *psock,
+                                    bool restore);
        struct proto                    *sk_proto;
        struct mutex                    work_mutex;
        struct sk_psock_work_state      work_state;
@@ -405,7 +406,7 @@ static inline void sk_psock_restore_proto(struct sock *sk,
 {
        sk->sk_prot->unhash = psock->saved_unhash;
        if (psock->psock_update_sk_prot)
-               psock->psock_update_sk_prot(sk, true);
+               psock->psock_update_sk_prot(sk, psock, true);
 }
 
 static inline void sk_psock_set_state(struct sk_psock *psock,
index 8b4155e756c20320bc0ea5f427eb00a84fa4ff64..c4bbdcd83f4d8da24e75f3ae4b08cca1ca75b880 100644 (file)
@@ -1114,6 +1114,7 @@ struct inet_hashinfo;
 struct raw_hashinfo;
 struct smc_hashinfo;
 struct module;
+struct sk_psock;
 
 /*
  * caches using SLAB_TYPESAFE_BY_RCU should let .next pointer from nulls nodes
@@ -1185,7 +1186,9 @@ struct proto {
        void                    (*rehash)(struct sock *sk);
        int                     (*get_port)(struct sock *sk, unsigned short snum);
 #ifdef CONFIG_BPF_SYSCALL
-       int                     (*psock_update_sk_prot)(struct sock *sk, bool restore);
+       int                     (*psock_update_sk_prot)(struct sock *sk,
+                                                       struct sk_psock *psock,
+                                                       bool restore);
 #endif
 
        /* Keeping track of sockets in use */
index eaea43afcc97ba4399022f250dadb830d799e4fd..d05193cb0d990adaffa97d274ff99d45f734db56 100644 (file)
@@ -2215,7 +2215,7 @@ struct sk_psock;
 
 #ifdef CONFIG_BPF_SYSCALL
 struct proto *tcp_bpf_get_proto(struct sock *sk, struct sk_psock *psock);
-int tcp_bpf_update_proto(struct sock *sk, bool restore);
+int tcp_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore);
 void tcp_bpf_clone(const struct sock *sk, struct sock *newsk);
 #endif /* CONFIG_BPF_SYSCALL */
 
index f55aaeef7e915b9845300ef7a3e956b938e55ef2..360df454356cbd956c9dd19655d1f667a0e621a9 100644 (file)
@@ -543,7 +543,7 @@ static inline void udp_post_segment_fix_csum(struct sk_buff *skb)
 #ifdef CONFIG_BPF_SYSCALL
 struct sk_psock;
 struct proto *udp_bpf_get_proto(struct sock *sk, struct sk_psock *psock);
-int udp_bpf_update_proto(struct sock *sk, bool restore);
+int udp_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore);
 #endif
 
 #endif /* _UDP_H */
index 3d190d22b0d8b78ca221390ce2263829bf41107b..f473c51cbc4b76d49e6462f350b33c1edeac9c53 100644 (file)
@@ -188,7 +188,7 @@ static int sock_map_init_proto(struct sock *sk, struct sk_psock *psock)
        if (!sk->sk_prot->psock_update_sk_prot)
                return -EINVAL;
        psock->psock_update_sk_prot = sk->sk_prot->psock_update_sk_prot;
-       return sk->sk_prot->psock_update_sk_prot(sk, false);
+       return sk->sk_prot->psock_update_sk_prot(sk, psock, false);
 }
 
 static struct sk_psock *sock_map_psock_get_checked(struct sock *sk)
index 3d622a0d075339d05090110b2673fa447a97f2cd..4930bc8ab47ef5ed46ebc840518f2cceb08afc61 100644 (file)
@@ -499,9 +499,8 @@ static int tcp_bpf_assert_proto_ops(struct proto *ops)
               ops->sendpage == tcp_sendpage ? 0 : -ENOTSUPP;
 }
 
-int tcp_bpf_update_proto(struct sock *sk, bool restore)
+int tcp_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore)
 {
-       struct sk_psock *psock = sk_psock(sk);
        int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
        int config = psock->progs.msg_parser   ? TCP_BPF_TX   : TCP_BPF_BASE;
 
index 4a7e38c5d842008bbf0a3bbe12227e8de376c16e..954c4591a6fd690166fd5b9ae86123f13b26c05f 100644 (file)
@@ -103,10 +103,9 @@ static int __init udp_bpf_v4_build_proto(void)
 }
 core_initcall(udp_bpf_v4_build_proto);
 
-int udp_bpf_update_proto(struct sock *sk, bool restore)
+int udp_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore)
 {
        int family = sk->sk_family == AF_INET ? UDP_BPF_IPV4 : UDP_BPF_IPV6;
-       struct sk_psock *psock = sk_psock(sk);
 
        if (restore) {
                sk->sk_write_space = psock->saved_write_space;