*need_flush = true;
 }
 
+static void load_new_mm_cr3(pgd_t *pgdir, u16 new_asid, bool need_flush)
+{
+       unsigned long new_mm_cr3;
+
+       if (need_flush) {
+               new_mm_cr3 = build_cr3(pgdir, new_asid);
+       } else {
+               new_mm_cr3 = build_cr3_noflush(pgdir, new_asid);
+       }
+
+       /*
+        * Caution: many callers of this function expect
+        * that load_cr3() is serializing and orders TLB
+        * fills with respect to the mm_cpumask writes.
+        */
+       write_cr3(new_mm_cr3);
+}
+
 void leave_mm(int cpu)
 {
        struct mm_struct *loaded_mm = this_cpu_read(cpu_tlbstate.loaded_mm);
                if (need_flush) {
                        this_cpu_write(cpu_tlbstate.ctxs[new_asid].ctx_id, next->context.ctx_id);
                        this_cpu_write(cpu_tlbstate.ctxs[new_asid].tlb_gen, next_tlb_gen);
-                       write_cr3(build_cr3(next->pgd, new_asid));
+                       load_new_mm_cr3(next->pgd, new_asid, true);
 
                        /*
                         * NB: This gets called via leave_mm() in the idle path
                        trace_tlb_flush_rcuidle(TLB_FLUSH_ON_TASK_SWITCH, TLB_FLUSH_ALL);
                } else {
                        /* The new ASID is already up to date. */
-                       write_cr3(build_cr3_noflush(next->pgd, new_asid));
+                       load_new_mm_cr3(next->pgd, new_asid, false);
 
                        /* See above wrt _rcuidle. */
                        trace_tlb_flush_rcuidle(TLB_FLUSH_ON_TASK_SWITCH, 0);