#endif /* CONFIG_CGROUP_WRITEBACK */
 
 struct sock;
-bool mem_cgroup_charge_skmem(struct mem_cgroup *memcg, unsigned int nr_pages,
-                            gfp_t gfp_mask);
-void mem_cgroup_uncharge_skmem(struct mem_cgroup *memcg, unsigned int nr_pages);
 #ifdef CONFIG_MEMCG
 extern struct static_key_false memcg_sockets_enabled_key;
 #define mem_cgroup_sockets_enabled static_branch_unlikely(&memcg_sockets_enabled_key)
+
 void mem_cgroup_sk_alloc(struct sock *sk);
 void mem_cgroup_sk_free(struct sock *sk);
 void mem_cgroup_sk_inherit(const struct sock *sk, struct sock *newsk);
+bool mem_cgroup_sk_charge(const struct sock *sk, unsigned int nr_pages,
+                         gfp_t gfp_mask);
+void mem_cgroup_sk_uncharge(const struct sock *sk, unsigned int nr_pages);
 
 #if BITS_PER_LONG < 64
 static inline void mem_cgroup_set_socket_pressure(struct mem_cgroup *memcg)
 void reparent_shrinker_deferred(struct mem_cgroup *memcg);
 #else
 #define mem_cgroup_sockets_enabled 0
-static inline void mem_cgroup_sk_alloc(struct sock *sk) { };
-static inline void mem_cgroup_sk_free(struct sock *sk) { };
+
+static inline void mem_cgroup_sk_alloc(struct sock *sk)
+{
+}
+
+static inline void mem_cgroup_sk_free(struct sock *sk)
+{
+}
 
 static inline void mem_cgroup_sk_inherit(const struct sock *sk, struct sock *newsk)
 {
 }
 
+static inline bool mem_cgroup_sk_charge(const struct sock *sk,
+                                       unsigned int nr_pages,
+                                       gfp_t gfp_mask)
+{
+       return false;
+}
+
+static inline void mem_cgroup_sk_uncharge(const struct sock *sk,
+                                         unsigned int nr_pages)
+{
+}
+
 static inline bool mem_cgroup_under_socket_pressure(struct mem_cgroup *memcg)
 {
        return false;
 
 }
 
 /**
- * mem_cgroup_charge_skmem - charge socket memory
- * @memcg: memcg to charge
+ * mem_cgroup_sk_charge - charge socket memory
+ * @sk: socket in memcg to charge
  * @nr_pages: number of pages to charge
  * @gfp_mask: reclaim mode
  *
  * Charges @nr_pages to @memcg. Returns %true if the charge fit within
  * @memcg's configured limit, %false if it doesn't.
  */
-bool mem_cgroup_charge_skmem(struct mem_cgroup *memcg, unsigned int nr_pages,
-                            gfp_t gfp_mask)
+bool mem_cgroup_sk_charge(const struct sock *sk, unsigned int nr_pages,
+                         gfp_t gfp_mask)
 {
+       struct mem_cgroup *memcg = mem_cgroup_from_sk(sk);
+
        if (!cgroup_subsys_on_dfl(memory_cgrp_subsys))
                return memcg1_charge_skmem(memcg, nr_pages, gfp_mask);
 
 }
 
 /**
- * mem_cgroup_uncharge_skmem - uncharge socket memory
- * @memcg: memcg to uncharge
+ * mem_cgroup_sk_uncharge - uncharge socket memory
+ * @sk: socket in memcg to uncharge
  * @nr_pages: number of pages to uncharge
  */
-void mem_cgroup_uncharge_skmem(struct mem_cgroup *memcg, unsigned int nr_pages)
+void mem_cgroup_sk_uncharge(const struct sock *sk, unsigned int nr_pages)
 {
+       struct mem_cgroup *memcg = mem_cgroup_from_sk(sk);
+
        if (!cgroup_subsys_on_dfl(memory_cgrp_subsys)) {
                memcg1_uncharge_skmem(memcg, nr_pages);
                return;
 
        pages = sk_mem_pages(bytes);
 
        /* pre-charge to memcg */
-       charged = mem_cgroup_charge_skmem(sk->sk_memcg, pages,
-                                         GFP_KERNEL | __GFP_RETRY_MAYFAIL);
+       charged = mem_cgroup_sk_charge(sk, pages,
+                                      GFP_KERNEL | __GFP_RETRY_MAYFAIL);
        if (!charged)
                return -ENOMEM;
 
         */
        if (allocated > sk_prot_mem_limits(sk, 1)) {
                sk_memory_allocated_sub(sk, pages);
-               mem_cgroup_uncharge_skmem(sk->sk_memcg, pages);
+               mem_cgroup_sk_uncharge(sk, pages);
                return -ENOMEM;
        }
        sk_forward_alloc_add(sk, pages << PAGE_SHIFT);
  */
 int __sk_mem_raise_allocated(struct sock *sk, int size, int amt, int kind)
 {
+       bool memcg_enabled = false, charged = false;
        struct proto *prot = sk->sk_prot;
-       struct mem_cgroup *memcg = NULL;
-       bool charged = false;
        long allocated;
 
        sk_memory_allocated_add(sk, amt);
        allocated = sk_memory_allocated(sk);
 
        if (mem_cgroup_sk_enabled(sk)) {
-               memcg = sk->sk_memcg;
-               charged = mem_cgroup_charge_skmem(memcg, amt, gfp_memcg_charge());
+               memcg_enabled = true;
+               charged = mem_cgroup_sk_charge(sk, amt, gfp_memcg_charge());
                if (!charged)
                        goto suppress_allocation;
        }
                 */
                if (sk->sk_wmem_queued + size >= sk->sk_sndbuf) {
                        /* Force charge with __GFP_NOFAIL */
-                       if (memcg && !charged) {
-                               mem_cgroup_charge_skmem(memcg, amt,
-                                       gfp_memcg_charge() | __GFP_NOFAIL);
-                       }
+                       if (memcg_enabled && !charged)
+                               mem_cgroup_sk_charge(sk, amt,
+                                                    gfp_memcg_charge() | __GFP_NOFAIL);
                        return 1;
                }
        }
        sk_memory_allocated_sub(sk, amt);
 
        if (charged)
-               mem_cgroup_uncharge_skmem(memcg, amt);
+               mem_cgroup_sk_uncharge(sk, amt);
 
        return 0;
 }
        sk_memory_allocated_sub(sk, amount);
 
        if (mem_cgroup_sk_enabled(sk))
-               mem_cgroup_uncharge_skmem(sk->sk_memcg, amount);
+               mem_cgroup_sk_uncharge(sk, amount);
 
        if (sk_under_global_memory_pressure(sk) &&
            (sk_memory_allocated(sk) < sk_prot_mem_limits(sk, 0)))
 
                }
 
                if (amt)
-                       mem_cgroup_charge_skmem(newsk->sk_memcg, amt, gfp);
+                       mem_cgroup_sk_charge(newsk, amt, gfp);
                kmem_cache_charge(newsk, gfp);
 
                release_sock(newsk);
 
        sk_memory_allocated_add(sk, amt);
 
        if (mem_cgroup_sk_enabled(sk))
-               mem_cgroup_charge_skmem(sk->sk_memcg, amt,
-                                       gfp_memcg_charge() | __GFP_NOFAIL);
+               mem_cgroup_sk_charge(sk, amt, gfp_memcg_charge() | __GFP_NOFAIL);
 }
 
 /* Send a FIN. The caller locks the socket for us.