]> www.infradead.org Git - users/jedix/linux-maple.git/commitdiff
net: fix SO_DEVMEM_DONTNEED looping too long
authorMina Almasry <almasrymina@google.com>
Thu, 7 Nov 2024 21:03:30 +0000 (21:03 +0000)
committerJakub Kicinski <kuba@kernel.org>
Tue, 12 Nov 2024 02:11:46 +0000 (18:11 -0800)
Exit early if we're freeing more than 1024 frags, to prevent
looping too long.

Also minor code cleanups:
- Flip checks to reduce indentation.
- Use sizeof(*tokens) everywhere for consistentcy.

Cc: Yi Lai <yi1.lai@linux.intel.com>
Signed-off-by: Mina Almasry <almasrymina@google.com>
Acked-by: Stanislav Fomichev <sdf@fomichev.me>
Link: https://patch.msgid.link/20241107210331.3044434-1-almasrymina@google.com
Signed-off-by: Jakub Kicinski <kuba@kernel.org>
net/core/sock.c

index 039be95c40cf6fa429d33e0f42ee606188045992..da50df485090ff4cd6a91d2f3c2a734a19123958 100644 (file)
@@ -1052,32 +1052,34 @@ static int sock_reserve_memory(struct sock *sk, int bytes)
 
 #ifdef CONFIG_PAGE_POOL
 
-/* This is the number of tokens that the user can SO_DEVMEM_DONTNEED in
- * 1 syscall. The limit exists to limit the amount of memory the kernel
- * allocates to copy these tokens.
+/* This is the number of tokens and frags that the user can SO_DEVMEM_DONTNEED
+ * in 1 syscall. The limit exists to limit the amount of memory the kernel
+ * allocates to copy these tokens, and to prevent looping over the frags for
+ * too long.
  */
 #define MAX_DONTNEED_TOKENS 128
+#define MAX_DONTNEED_FRAGS 1024
 
 static noinline_for_stack int
 sock_devmem_dontneed(struct sock *sk, sockptr_t optval, unsigned int optlen)
 {
        unsigned int num_tokens, i, j, k, netmem_num = 0;
        struct dmabuf_token *tokens;
+       int ret = 0, num_frags = 0;
        netmem_ref netmems[16];
-       int ret = 0;
 
        if (!sk_is_tcp(sk))
                return -EBADF;
 
-       if (optlen % sizeof(struct dmabuf_token) ||
+       if (optlen % sizeof(*tokens) ||
            optlen > sizeof(*tokens) * MAX_DONTNEED_TOKENS)
                return -EINVAL;
 
-       tokens = kvmalloc_array(optlen, sizeof(*tokens), GFP_KERNEL);
+       num_tokens = optlen / sizeof(*tokens);
+       tokens = kvmalloc_array(num_tokens, sizeof(*tokens), GFP_KERNEL);
        if (!tokens)
                return -ENOMEM;
 
-       num_tokens = optlen / sizeof(struct dmabuf_token);
        if (copy_from_sockptr(tokens, optval, optlen)) {
                kvfree(tokens);
                return -EFAULT;
@@ -1086,24 +1088,28 @@ sock_devmem_dontneed(struct sock *sk, sockptr_t optval, unsigned int optlen)
        xa_lock_bh(&sk->sk_user_frags);
        for (i = 0; i < num_tokens; i++) {
                for (j = 0; j < tokens[i].token_count; j++) {
+                       if (++num_frags > MAX_DONTNEED_FRAGS)
+                               goto frag_limit_reached;
+
                        netmem_ref netmem = (__force netmem_ref)__xa_erase(
                                &sk->sk_user_frags, tokens[i].token_start + j);
 
-                       if (netmem &&
-                           !WARN_ON_ONCE(!netmem_is_net_iov(netmem))) {
-                               netmems[netmem_num++] = netmem;
-                               if (netmem_num == ARRAY_SIZE(netmems)) {
-                                       xa_unlock_bh(&sk->sk_user_frags);
-                                       for (k = 0; k < netmem_num; k++)
-                                               WARN_ON_ONCE(!napi_pp_put_page(netmems[k]));
-                                       netmem_num = 0;
-                                       xa_lock_bh(&sk->sk_user_frags);
-                               }
-                               ret++;
+                       if (!netmem || WARN_ON_ONCE(!netmem_is_net_iov(netmem)))
+                               continue;
+
+                       netmems[netmem_num++] = netmem;
+                       if (netmem_num == ARRAY_SIZE(netmems)) {
+                               xa_unlock_bh(&sk->sk_user_frags);
+                               for (k = 0; k < netmem_num; k++)
+                                       WARN_ON_ONCE(!napi_pp_put_page(netmems[k]));
+                               netmem_num = 0;
+                               xa_lock_bh(&sk->sk_user_frags);
                        }
+                       ret++;
                }
        }
 
+frag_limit_reached:
        xa_unlock_bh(&sk->sk_user_frags);
        for (k = 0; k < netmem_num; k++)
                WARN_ON_ONCE(!napi_pp_put_page(netmems[k]));