tm_recheckpoint_new_task(new);
 
-       last = _switch(old_thread, new_thread);
-
-       /* Need to recalculate these after calling _switch() */
-       old_thread = &last->thread;
-       new_thread = ¤t->thread;
-
+       /*
+        * Call restore_sprs() before calling _switch(). If we move it after
+        * _switch() then we miss out on calling it for new tasks. The reason
+        * for this is we manually create a stack frame for new tasks that
+        * directly returns through ret_from_fork() or
+        * ret_from_kernel_thread(). See copy_thread() for details.
+        */
        restore_sprs(old_thread, new_thread);
 
+       last = _switch(old_thread, new_thread);
+
 #ifdef CONFIG_PPC_BOOK3S_64
        if (current_thread_info()->local_flags & _TLF_LAZY_MMU) {
                current_thread_info()->local_flags &= ~_TLF_LAZY_MMU;