*next = &next_p->thread;
        int cpu = smp_processor_id();
        struct tss_struct *tss = &per_cpu(init_tss, cpu);
+       bool preload_fpu;
 
        /* never put a printk in __switch_to... printk() calls wake_up*() indirectly */
 
-       __unlazy_fpu(prev_p);
+       /*
+        * If the task has used fpu the last 5 timeslices, just do a full
+        * restore of the math state immediately to avoid the trap; the
+        * chances of needing FPU soon are obviously high now
+        */
+       preload_fpu = tsk_used_math(next_p) && next_p->fpu_counter > 5;
 
+       __unlazy_fpu(prev_p);
 
        /* we're going to use this soon, after a few expensive things */
-       if (next_p->fpu_counter > 5)
+       if (preload_fpu)
                prefetch(next->xstate);
 
        /*
                     task_thread_info(next_p)->flags & _TIF_WORK_CTXSW_NEXT))
                __switch_to_xtra(prev_p, next_p, tss);
 
+       /* If we're going to preload the fpu context, make sure clts
+          is run while we're batching the cpu state updates. */
+       if (preload_fpu)
+               clts();
+
        /*
         * Leave lazy mode, flushing any hypercalls made here.
         * This must be done before restoring TLS segments so
         */
        arch_end_context_switch(next_p);
 
-       /* If the task has used fpu the last 5 timeslices, just do a full
-        * restore of the math state immediately to avoid the trap; the
-        * chances of needing FPU soon are obviously high now
-        *
-        * tsk_used_math() checks prevent calling math_state_restore(),
-        * which can sleep in the case of !tsk_used_math()
-        */
-       if (tsk_used_math(next_p) && next_p->fpu_counter > 5)
-               math_state_restore();
+       if (preload_fpu)
+               __math_state_restore();
 
        /*
         * Restore %gs if needed (which is common)