struct fpu *fpu = ¤t->thread.fpu;
        int cpu = smp_processor_id();
 
-       if (WARN_ON_ONCE(current->mm == NULL))
+       if (WARN_ON_ONCE(current->flags & PF_KTHREAD))
                return;
 
        if (!fpregs_state_valid(fpu, cpu)) {
  * otherwise.
  *
  * The FPU context is only stored/restored for a user task and
- * ->mm is used to distinguish between kernel and user threads.
+ * PF_KTHREAD is used to distinguish between kernel and user threads.
  */
 static inline void switch_fpu_prepare(struct fpu *old_fpu, int cpu)
 {
-       if (static_cpu_has(X86_FEATURE_FPU) && current->mm) {
+       if (static_cpu_has(X86_FEATURE_FPU) && !(current->flags & PF_KTHREAD)) {
                if (!copy_fpregs_to_fpstate(old_fpu))
                        old_fpu->last_cpu = -1;
                else