The socket memcg feature is enabled by a static key and
only works for non-root cgroup.
We check both conditions in many places.
Let's factorise it as a helper function.
Signed-off-by: Kuniyuki Iwashima <kuniyu@google.com>
Reviewed-by: Eric Dumazet <edumazet@google.com>
Acked-by: Roman Gushchin <roman.gushchin@linux.dev>
Acked-by: Shakeel Butt <shakeel.butt@linux.dev>
Link: https://patch.msgid.link/20250815201712.1745332-8-kuniyu@google.com
Signed-off-by: Jakub Kicinski <kuba@kernel.org>
        if (!sk->sk_prot->memory_pressure)
                return false;
 
-       if (mem_cgroup_sockets_enabled && sk->sk_memcg &&
+       if (mem_cgroup_sk_enabled(sk) &&
            mem_cgroup_under_socket_pressure(sk->sk_memcg))
                return true;
 
 
 {
        return sk->sk_memcg;
 }
+
+static inline bool mem_cgroup_sk_enabled(const struct sock *sk)
+{
+       return mem_cgroup_sockets_enabled && mem_cgroup_from_sk(sk);
+}
 #else
 static inline struct mem_cgroup *mem_cgroup_from_sk(const struct sock *sk)
 {
        return NULL;
 }
+
+static inline bool mem_cgroup_sk_enabled(const struct sock *sk)
+{
+       return false;
+}
 #endif
 
 static inline long sock_rcvtimeo(const struct sock *sk, bool noblock)
 
 /* optimized version of sk_under_memory_pressure() for TCP sockets */
 static inline bool tcp_under_memory_pressure(const struct sock *sk)
 {
-       if (mem_cgroup_sockets_enabled && sk->sk_memcg &&
+       if (mem_cgroup_sk_enabled(sk) &&
            mem_cgroup_under_socket_pressure(sk->sk_memcg))
                return true;
 
 
        bool charged;
        int pages;
 
-       if (!mem_cgroup_sockets_enabled || !sk->sk_memcg || !sk_has_account(sk))
+       if (!mem_cgroup_sk_enabled(sk) || !sk_has_account(sk))
                return -EOPNOTSUPP;
 
        if (!bytes)
        sk_memory_allocated_add(sk, amt);
        allocated = sk_memory_allocated(sk);
 
-       if (mem_cgroup_sockets_enabled && sk->sk_memcg) {
+       if (mem_cgroup_sk_enabled(sk)) {
                memcg = sk->sk_memcg;
                charged = mem_cgroup_charge_skmem(memcg, amt, gfp_memcg_charge());
                if (!charged)
 {
        sk_memory_allocated_sub(sk, amount);
 
-       if (mem_cgroup_sockets_enabled && sk->sk_memcg)
+       if (mem_cgroup_sk_enabled(sk))
                mem_cgroup_uncharge_skmem(sk->sk_memcg, amount);
 
        if (sk_under_global_memory_pressure(sk) &&
 
        sk_forward_alloc_add(sk, amt << PAGE_SHIFT);
        sk_memory_allocated_add(sk, amt);
 
-       if (mem_cgroup_sockets_enabled && sk->sk_memcg)
+       if (mem_cgroup_sk_enabled(sk))
                mem_cgroup_charge_skmem(sk->sk_memcg, amt,
                                        gfp_memcg_charge() | __GFP_NOFAIL);
 }