void fpsimd_thread_switch(struct task_struct *next)
 {
+       bool wrong_task, wrong_cpu;
+
        if (!system_supports_fpsimd())
                return;
+
+       /* Save unsaved fpsimd state, if any: */
+       fpsimd_save();
+
        /*
-        * Save the current FPSIMD state to memory, but only if whatever is in
-        * the registers is in fact the most recent userland FPSIMD state of
-        * 'current'.
+        * Fix up TIF_FOREIGN_FPSTATE to correctly describe next's
+        * state.  For kernel threads, FPSIMD registers are never loaded
+        * and wrong_task and wrong_cpu will always be true.
         */
-       if (current->mm)
-               fpsimd_save();
-
-       if (next->mm) {
-               /*
-                * If we are switching to a task whose most recent userland
-                * FPSIMD state is already in the registers of *this* cpu,
-                * we can skip loading the state from memory. Otherwise, set
-                * the TIF_FOREIGN_FPSTATE flag so the state will be loaded
-                * upon the next return to userland.
-                */
-               bool wrong_task = __this_cpu_read(fpsimd_last_state.st) !=
+       wrong_task = __this_cpu_read(fpsimd_last_state.st) !=
                                        &next->thread.uw.fpsimd_state;
-               bool wrong_cpu = next->thread.fpsimd_cpu != smp_processor_id();
+       wrong_cpu = next->thread.fpsimd_cpu != smp_processor_id();
 
-               update_tsk_thread_flag(next, TIF_FOREIGN_FPSTATE,
-                                      wrong_task || wrong_cpu);
-       }
+       update_tsk_thread_flag(next, TIF_FOREIGN_FPSTATE,
+                              wrong_task || wrong_cpu);
 }
 
 void fpsimd_flush_thread(void)
 
        __this_cpu_write(kernel_neon_busy, true);
 
-       /* Save unsaved task fpsimd state, if any: */
-       if (current->mm)
-               fpsimd_save();
+       /* Save unsaved fpsimd state, if any: */
+       fpsimd_save();
 
        /* Invalidate any task state remaining in the fpsimd regs: */
        fpsimd_flush_cpu_state();
 {
        switch (cmd) {
        case CPU_PM_ENTER:
-               if (current->mm)
-                       fpsimd_save();
+               fpsimd_save();
                fpsimd_flush_cpu_state();
                break;
        case CPU_PM_EXIT: