static inline void flush_slab(struct kmem_cache *s, struct kmem_cache_cpu *c)
 {
-       void *freelist = c->freelist;
-       struct page *page = c->page;
+       unsigned long flags;
+       struct page *page;
+       void *freelist;
+
+       local_irq_save(flags);
+
+       page = c->page;
+       freelist = c->freelist;
 
        c->page = NULL;
        c->freelist = NULL;
        c->tid = next_tid(c->tid);
 
-       deactivate_slab(s, page, freelist);
+       local_irq_restore(flags);
 
-       stat(s, CPUSLAB_FLUSH);
+       if (page) {
+               deactivate_slab(s, page, freelist);
+               stat(s, CPUSLAB_FLUSH);
+       }
 }
 
 static inline void __flush_cpu_slab(struct kmem_cache *s, int cpu)
        unfreeze_partials_cpu(s, c);
 }
 
+struct slub_flush_work {
+       struct work_struct work;
+       struct kmem_cache *s;
+       bool skip;
+};
+
 /*
  * Flush cpu slab.
  *
- * Called from IPI handler with interrupts disabled.
+ * Called from CPU work handler with migration disabled.
  */
-static void flush_cpu_slab(void *d)
+static void flush_cpu_slab(struct work_struct *w)
 {
-       struct kmem_cache *s = d;
-       struct kmem_cache_cpu *c = this_cpu_ptr(s->cpu_slab);
+       struct kmem_cache *s;
+       struct kmem_cache_cpu *c;
+       struct slub_flush_work *sfw;
+
+       sfw = container_of(w, struct slub_flush_work, work);
+
+       s = sfw->s;
+       c = this_cpu_ptr(s->cpu_slab);
 
        if (c->page)
                flush_slab(s, c);
        unfreeze_partials(s);
 }
 
-static bool has_cpu_slab(int cpu, void *info)
+static bool has_cpu_slab(int cpu, struct kmem_cache *s)
 {
-       struct kmem_cache *s = info;
        struct kmem_cache_cpu *c = per_cpu_ptr(s->cpu_slab, cpu);
 
        return c->page || slub_percpu_partial(c);
 }
 
+static DEFINE_MUTEX(flush_lock);
+static DEFINE_PER_CPU(struct slub_flush_work, slub_flush);
+
+static void flush_all_cpus_locked(struct kmem_cache *s)
+{
+       struct slub_flush_work *sfw;
+       unsigned int cpu;
+
+       lockdep_assert_cpus_held();
+       mutex_lock(&flush_lock);
+
+       for_each_online_cpu(cpu) {
+               sfw = &per_cpu(slub_flush, cpu);
+               if (!has_cpu_slab(cpu, s)) {
+                       sfw->skip = true;
+                       continue;
+               }
+               INIT_WORK(&sfw->work, flush_cpu_slab);
+               sfw->skip = false;
+               sfw->s = s;
+               schedule_work_on(cpu, &sfw->work);
+       }
+
+       for_each_online_cpu(cpu) {
+               sfw = &per_cpu(slub_flush, cpu);
+               if (sfw->skip)
+                       continue;
+               flush_work(&sfw->work);
+       }
+
+       mutex_unlock(&flush_lock);
+}
+
 static void flush_all(struct kmem_cache *s)
 {
-       on_each_cpu_cond(has_cpu_slab, flush_cpu_slab, s, 1);
+       cpus_read_lock();
+       flush_all_cpus_locked(s);
+       cpus_read_unlock();
 }
 
 /*
        int node;
        struct kmem_cache_node *n;
 
-       flush_all(s);
+       flush_all_cpus_locked(s);
        /* Attempt to free all objects */
        for_each_kmem_cache_node(s, node, n) {
                free_partial(s, n);
  * being allocated from last increasing the chance that the last objects
  * are freed in them.
  */
-int __kmem_cache_shrink(struct kmem_cache *s)
+static int __kmem_cache_do_shrink(struct kmem_cache *s)
 {
        int node;
        int i;
        unsigned long flags;
        int ret = 0;
 
-       flush_all(s);
        for_each_kmem_cache_node(s, node, n) {
                INIT_LIST_HEAD(&discard);
                for (i = 0; i < SHRINK_PROMOTE_MAX; i++)
        return ret;
 }
 
+int __kmem_cache_shrink(struct kmem_cache *s)
+{
+       flush_all(s);
+       return __kmem_cache_do_shrink(s);
+}
+
 static int slab_mem_going_offline_callback(void *arg)
 {
        struct kmem_cache *s;
 
        mutex_lock(&slab_mutex);
-       list_for_each_entry(s, &slab_caches, list)
-               __kmem_cache_shrink(s);
+       list_for_each_entry(s, &slab_caches, list) {
+               flush_all_cpus_locked(s);
+               __kmem_cache_do_shrink(s);
+       }
        mutex_unlock(&slab_mutex);
 
        return 0;