tl;dr: There is a window in the mm switching code where the new CR3 is
set and the CPU should be getting TLB flushes for the new mm.  But
should_flush_tlb() has a bug and suppresses the flush.  Fix it by
widening the window where should_flush_tlb() sends an IPI.
Long Version:
=== History ===
There were a few things leading up to this.
First, updating mm_cpumask() was observed to be too expensive, so it was
made lazier.  But being lazy caused too many unnecessary IPIs to CPUs
due to the now-lazy mm_cpumask().  So code was added to cull
mm_cpumask() periodically[2].  But that culling was a bit too aggressive
and skipped sending TLB flushes to CPUs that need them.  So here we are
again.
=== Problem ===
The too-aggressive code in should_flush_tlb() strikes in this window:
	// Turn on IPIs for this CPU/mm combination, but only
	// if should_flush_tlb() agrees:
	cpumask_set_cpu(cpu, mm_cpumask(next));
	next_tlb_gen = atomic64_read(&next->context.tlb_gen);
	choose_new_asid(next, next_tlb_gen, &new_asid, &need_flush);
	load_new_mm_cr3(need_flush);
	// ^ After 'need_flush' is set to false, IPIs *MUST*
	// be sent to this CPU and not be ignored.
        this_cpu_write(cpu_tlbstate.loaded_mm, next);
	// ^ Not until this point does should_flush_tlb()
	// become true!
should_flush_tlb() will suppress TLB flushes between load_new_mm_cr3()
and writing to 'loaded_mm', which is a window where they should not be
suppressed.  Whoops.
=== Solution ===
Thankfully, the fuzzy "just about to write CR3" window is already marked
with loaded_mm==LOADED_MM_SWITCHING.  Simply checking for that state in
should_flush_tlb() is sufficient to ensure that the CPU is targeted with
an IPI.
This will cause more TLB flush IPIs.  But the window is relatively small
and I do not expect this to cause any kind of measurable performance
impact.
Update the comment where LOADED_MM_SWITCHING is written since it grew
yet another user.
Peter Z also raised a concern that should_flush_tlb() might not observe
'loaded_mm' and 'is_lazy' in the same order that switch_mm_irqs_off()
writes them.  Add a barrier to ensure that they are observed in the
order they are written.
Signed-off-by: Dave Hansen <dave.hansen@linux.intel.com>
Acked-by: Rik van Riel <riel@surriel.com>
Link: https://lore.kernel.org/oe-lkp/202411282207.6bd28eae-lkp@intel.com/
Fixes: 6db2526c1d69 ("x86/mm/tlb: Only trim the mm_cpumask once a second") [2]
Reported-by: Stephen Dolan <sdolan@janestreet.com>
Cc: stable@vger.kernel.org
Acked-by: Ingo Molnar <mingo@kernel.org>
Acked-by: Peter Zijlstra (Intel) <peterz@infradead.org>
Signed-off-by: Linus Torvalds <torvalds@linux-foundation.org>
                cond_mitigation(tsk);
 
                /*
-                * Let nmi_uaccess_okay() and finish_asid_transition()
-                * know that CR3 is changing.
+                * Indicate that CR3 is about to change. nmi_uaccess_okay()
+                * and others are sensitive to the window where mm_cpumask(),
+                * CR3 and cpu_tlbstate.loaded_mm are not all in sync.
                 */
                this_cpu_write(cpu_tlbstate.loaded_mm, LOADED_MM_SWITCHING);
                barrier();
 
 static bool should_flush_tlb(int cpu, void *data)
 {
+       struct mm_struct *loaded_mm = per_cpu(cpu_tlbstate.loaded_mm, cpu);
        struct flush_tlb_info *info = data;
 
+       /*
+        * Order the 'loaded_mm' and 'is_lazy' against their
+        * write ordering in switch_mm_irqs_off(). Ensure
+        * 'is_lazy' is at least as new as 'loaded_mm'.
+        */
+       smp_rmb();
+
        /* Lazy TLB will get flushed at the next context switch. */
        if (per_cpu(cpu_tlbstate_shared.is_lazy, cpu))
                return false;
        if (!info->mm)
                return true;
 
+       /*
+        * While switching, the remote CPU could have state from
+        * either the prev or next mm. Assume the worst and flush.
+        */
+       if (loaded_mm == LOADED_MM_SWITCHING)
+               return true;
+
        /* The target mm is loaded, and the CPU is not lazy. */
-       if (per_cpu(cpu_tlbstate.loaded_mm, cpu) == info->mm)
+       if (loaded_mm == info->mm)
                return true;
 
        /* In cpumask, but not the loaded mm? Periodically remove by flushing. */