return percpu_counter_sum_positive(prot->sockets_allocated);
 }
 
+static inline void sk_update_clone(const struct sock *sk, struct sock *newsk)
+{
+       if (mem_cgroup_sockets_enabled && sk->sk_cgrp)
+               sock_update_memcg(newsk);
+}
+
 static inline int
 proto_sockets_allocated_sum_positive(struct proto *prot)
 {
 
 static bool mem_cgroup_is_root(struct mem_cgroup *memcg);
 void sock_update_memcg(struct sock *sk)
 {
-       /* A socket spends its whole life in the same cgroup */
-       if (sk->sk_cgrp) {
-               WARN_ON(1);
-               return;
-       }
        if (static_branch(&memcg_socket_limit_enabled)) {
                struct mem_cgroup *memcg;
 
                BUG_ON(!sk->sk_prot->proto_cgroup);
 
+               /* Socket cloning can throw us here with sk_cgrp already
+                * filled. It won't however, necessarily happen from
+                * process context. So the test for root memcg given
+                * the current task's memcg won't help us in this case.
+                *
+                * Respecting the original socket's memcg is a better
+                * decision in this case.
+                */
+               if (sk->sk_cgrp) {
+                       BUG_ON(mem_cgroup_is_root(sk->sk_cgrp->memcg));
+                       mem_cgroup_get(sk->sk_cgrp->memcg);
+                       return;
+               }
+
                rcu_read_lock();
                memcg = mem_cgroup_from_task(current);
                if (!mem_cgroup_is_root(memcg)) {
 
                sk_set_socket(newsk, NULL);
                newsk->sk_wq = NULL;
 
+               sk_update_clone(sk, newsk);
+
                if (newsk->sk_prot->sockets_allocated)
                        sk_sockets_allocated_inc(newsk);