extern int use_cop(unsigned long acop, struct mm_struct *mm);
 extern void drop_cop(unsigned long acop, struct mm_struct *mm);
 
+#ifdef CONFIG_PPC_BOOK3S_64
+static inline void inc_mm_active_cpus(struct mm_struct *mm)
+{
+       atomic_inc(&mm->context.active_cpus);
+}
+
+static inline void dec_mm_active_cpus(struct mm_struct *mm)
+{
+       atomic_dec(&mm->context.active_cpus);
+}
+
+static inline void mm_context_add_copro(struct mm_struct *mm)
+{
+       /*
+        * On hash, should only be called once over the lifetime of
+        * the context, as we can't decrement the active cpus count
+        * and flush properly for the time being.
+        */
+       inc_mm_active_cpus(mm);
+}
+
+static inline void mm_context_remove_copro(struct mm_struct *mm)
+{
+       /*
+        * Need to broadcast a global flush of the full mm before
+        * decrementing active_cpus count, as the next TLBI may be
+        * local and the nMMU and/or PSL need to be cleaned up.
+        * Should be rare enough so that it's acceptable.
+        *
+        * Skip on hash, as we don't know how to do the proper flush
+        * for the time being. Invalidations will remain global if
+        * used on hash.
+        */
+       if (radix_enabled()) {
+               flush_all_mm(mm);
+               dec_mm_active_cpus(mm);
+       }
+}
+#else
+static inline void inc_mm_active_cpus(struct mm_struct *mm) { }
+static inline void dec_mm_active_cpus(struct mm_struct *mm) { }
+static inline void mm_context_add_copro(struct mm_struct *mm) { }
+static inline void mm_context_remove_copro(struct mm_struct *mm) { }
+#endif
+
+
 extern void switch_mm_irqs_off(struct mm_struct *prev, struct mm_struct *next,
                               struct task_struct *tsk);
 
 
                                   struct mm_struct *mm) { }
 #endif
 
-#ifdef CONFIG_PPC_BOOK3S_64
-static inline void inc_mm_active_cpus(struct mm_struct *mm)
-{
-       atomic_inc(&mm->context.active_cpus);
-}
-#else
-static inline void inc_mm_active_cpus(struct mm_struct *mm) { }
-#endif
-
 void switch_mm_irqs_off(struct mm_struct *prev, struct mm_struct *next,
                        struct task_struct *tsk)
 {
 
 #include <linux/module.h>
 #include <linux/mount.h>
 #include <linux/sched/mm.h>
+#include <linux/mmu_context.h>
 
 #include "cxl.h"
 
                /* ensure this mm_struct can't be freed */
                cxl_context_mm_count_get(ctx);
 
-               /* decrement the use count */
-               if (ctx->mm)
+               if (ctx->mm) {
+                       /* decrement the use count from above */
                        mmput(ctx->mm);
+                       /* make TLBIs for this context global */
+                       mm_context_add_copro(ctx->mm);
+               }
        }
 
        /*
         */
        cxl_ctx_get();
 
+       /* See the comment in afu_ioctl_start_work() */
+       smp_mb();
+
        if ((rc = cxl_ops->attach_process(ctx, kernel, wed, 0))) {
                put_pid(ctx->pid);
                ctx->pid = NULL;
                cxl_adapter_context_put(ctx->afu->adapter);
                cxl_ctx_put();
-               if (task)
+               if (task) {
                        cxl_context_mm_count_put(ctx);
+                       if (ctx->mm)
+                               mm_context_remove_copro(ctx->mm);
+               }
                goto out;
        }
 
 
 #include <linux/slab.h>
 #include <linux/idr.h>
 #include <linux/sched/mm.h>
+#include <linux/mmu_context.h>
 #include <asm/cputable.h>
 #include <asm/current.h>
 #include <asm/copro.h>
 
        /* Decrease the mm count on the context */
        cxl_context_mm_count_put(ctx);
+       if (ctx->mm)
+               mm_context_remove_copro(ctx->mm);
        ctx->mm = NULL;
 
        return 0;
 
 #include <linux/mm.h>
 #include <linux/slab.h>
 #include <linux/sched/mm.h>
+#include <linux/mmu_context.h>
 #include <asm/cputable.h>
 #include <asm/current.h>
 #include <asm/copro.h>
        /* ensure this mm_struct can't be freed */
        cxl_context_mm_count_get(ctx);
 
-       /* decrement the use count */
-       if (ctx->mm)
+       if (ctx->mm) {
+               /* decrement the use count from above */
                mmput(ctx->mm);
+               /* make TLBIs for this context global */
+               mm_context_add_copro(ctx->mm);
+       }
 
        /*
         * Increment driver use count. Enables global TLBIs for hash
         */
        cxl_ctx_get();
 
+       /*
+        * A barrier is needed to make sure all TLBIs are global
+        * before we attach and the context starts being used by the
+        * adapter.
+        *
+        * Needed after mm_context_add_copro() for radix and
+        * cxl_ctx_get() for hash/p8.
+        *
+        * The barrier should really be mb(), since it involves a
+        * device. However, it's only useful when we have local
+        * vs. global TLBIs, i.e SMP=y. So keep smp_mb().
+        */
+       smp_mb();
+
        trace_cxl_attach(ctx, work.work_element_descriptor, work.num_interrupts, amr);
 
        if ((rc = cxl_ops->attach_process(ctx, false, work.work_element_descriptor,
                ctx->pid = NULL;
                cxl_ctx_put();
                cxl_context_mm_count_put(ctx);
+               if (ctx->mm)
+                       mm_context_remove_copro(ctx->mm);
                goto out;
        }