#include <linux/list_nulls.h>
 #include <linux/timer.h>
 #include <linux/cache.h>
+#include <linux/bitops.h>
 #include <linux/lockdep.h>
 #include <linux/netdevice.h>
 #include <linux/skbuff.h>      /* struct sk_buff */
 #endif
 };
 
+/*
+ * Bits in struct cg_proto.flags
+ */
+enum cg_proto_flags {
+       /* Currently active and new sockets should be assigned to cgroups */
+       MEMCG_SOCK_ACTIVE,
+       /* It was ever activated; we must disarm static keys on destruction */
+       MEMCG_SOCK_ACTIVATED,
+};
+
 struct cg_proto {
        void                    (*enter_memory_pressure)(struct sock *sk);
        struct res_counter      *memory_allocated;      /* Current allocated memory. */
        struct percpu_counter   *sockets_allocated;     /* Current number of sockets. */
        int                     *memory_pressure;
        long                    *sysctl_mem;
+       unsigned long           flags;
        /*
         * memcg field is used to find which memcg we belong directly
         * Each memcg struct can hold more than one cg_proto, so container_of
 extern int proto_register(struct proto *prot, int alloc_slab);
 extern void proto_unregister(struct proto *prot);
 
+static inline bool memcg_proto_active(struct cg_proto *cg_proto)
+{
+       return test_bit(MEMCG_SOCK_ACTIVE, &cg_proto->flags);
+}
+
+static inline bool memcg_proto_activated(struct cg_proto *cg_proto)
+{
+       return test_bit(MEMCG_SOCK_ACTIVATED, &cg_proto->flags);
+}
+
 #ifdef SOCK_REFCNT_DEBUG
 static inline void sk_refcnt_debug_inc(struct sock *sk)
 {
 
 {
        if (mem_cgroup_sockets_enabled) {
                struct mem_cgroup *memcg;
+               struct cg_proto *cg_proto;
 
                BUG_ON(!sk->sk_prot->proto_cgroup);
 
 
                rcu_read_lock();
                memcg = mem_cgroup_from_task(current);
-               if (!mem_cgroup_is_root(memcg)) {
+               cg_proto = sk->sk_prot->proto_cgroup(memcg);
+               if (!mem_cgroup_is_root(memcg) && memcg_proto_active(cg_proto)) {
                        mem_cgroup_get(memcg);
-                       sk->sk_cgrp = sk->sk_prot->proto_cgroup(memcg);
+                       sk->sk_cgrp = cg_proto;
                }
                rcu_read_unlock();
        }
 #endif /* CONFIG_INET */
 #endif /* CONFIG_CGROUP_MEM_RES_CTLR_KMEM */
 
+#if defined(CONFIG_INET) && defined(CONFIG_CGROUP_MEM_RES_CTLR_KMEM)
+static void disarm_sock_keys(struct mem_cgroup *memcg)
+{
+       if (!memcg_proto_activated(&memcg->tcp_mem.cg_proto))
+               return;
+       static_key_slow_dec(&memcg_socket_limit_enabled);
+}
+#else
+static void disarm_sock_keys(struct mem_cgroup *memcg)
+{
+}
+#endif
+
 static void drain_all_stock_async(struct mem_cgroup *memcg);
 
 static struct mem_cgroup_per_zone *
        int size = sizeof(struct mem_cgroup);
 
        memcg = container_of(work, struct mem_cgroup, work_freeing);
+       /*
+        * We need to make sure that (at least for now), the jump label
+        * destruction code runs outside of the cgroup lock. This is because
+        * get_online_cpus(), which is called from the static_branch update,
+        * can't be called inside the cgroup_lock. cpusets are the ones
+        * enforcing this dependency, so if they ever change, we might as well.
+        *
+        * schedule_work() will guarantee this happens. Be careful if you need
+        * to move this code around, and make sure it is outside
+        * the cgroup_lock.
+        */
+       disarm_sock_keys(memcg);
        if (size < PAGE_SIZE)
                kfree(memcg);
        else
 
        percpu_counter_destroy(&tcp->tcp_sockets_allocated);
 
        val = res_counter_read_u64(&tcp->tcp_memory_allocated, RES_LIMIT);
-
-       if (val != RESOURCE_MAX)
-               static_key_slow_dec(&memcg_socket_limit_enabled);
 }
 EXPORT_SYMBOL(tcp_destroy_cgroup);
 
                tcp->tcp_prot_mem[i] = min_t(long, val >> PAGE_SHIFT,
                                             net->ipv4.sysctl_tcp_mem[i]);
 
-       if (val == RESOURCE_MAX && old_lim != RESOURCE_MAX)
-               static_key_slow_dec(&memcg_socket_limit_enabled);
-       else if (old_lim == RESOURCE_MAX && val != RESOURCE_MAX)
-               static_key_slow_inc(&memcg_socket_limit_enabled);
+       if (val == RESOURCE_MAX)
+               clear_bit(MEMCG_SOCK_ACTIVE, &cg_proto->flags);
+       else if (val != RESOURCE_MAX) {
+               /*
+                * The active bit needs to be written after the static_key
+                * update. This is what guarantees that the socket activation
+                * function is the last one to run. See sock_update_memcg() for
+                * details, and note that we don't mark any socket as belonging
+                * to this memcg until that flag is up.
+                *
+                * We need to do this, because static_keys will span multiple
+                * sites, but we can't control their order. If we mark a socket
+                * as accounted, but the accounting functions are not patched in
+                * yet, we'll lose accounting.
+                *
+                * We never race with the readers in sock_update_memcg(),
+                * because when this value change, the code to process it is not
+                * patched in yet.
+                *
+                * The activated bit is used to guarantee that no two writers
+                * will do the update in the same memcg. Without that, we can't
+                * properly shutdown the static key.
+                */
+               if (!test_and_set_bit(MEMCG_SOCK_ACTIVATED, &cg_proto->flags))
+                       static_key_slow_inc(&memcg_socket_limit_enabled);
+               set_bit(MEMCG_SOCK_ACTIVE, &cg_proto->flags);
+       }
 
        return 0;
 }