.endm
 
-.macro SAVE_ALL_NMI
+.macro SAVE_ALL_NMI cr3_reg:req
        SAVE_ALL
+
+       /*
+        * Now switch the CR3 when PTI is enabled.
+        *
+        * We can enter with either user or kernel cr3, the code will
+        * store the old cr3 in \cr3_reg and switches to the kernel cr3
+        * if necessary.
+        */
+       SWITCH_TO_KERNEL_CR3 scratch_reg=\cr3_reg
+
+.Lend_\@:
 .endm
 /*
  * This is a sneaky trick to help the unwinder find pt_regs on the stack.  The
        POP_GS_EX
 .endm
 
-.macro RESTORE_ALL_NMI pop=0
+.macro RESTORE_ALL_NMI cr3_reg:req pop=0
+       /*
+        * Now switch the CR3 when PTI is enabled.
+        *
+        * We enter with kernel cr3 and switch the cr3 to the value
+        * stored on \cr3_reg, which is either a user or a kernel cr3.
+        */
+       ALTERNATIVE "jmp .Lswitched_\@", "", X86_FEATURE_PTI
+
+       testl   $PTI_SWITCH_MASK, \cr3_reg
+       jz      .Lswitched_\@
+
+       /* User cr3 in \cr3_reg - write it to hardware cr3 */
+       movl    \cr3_reg, %cr3
+
+.Lswitched_\@:
+
        RESTORE_REGS pop=\pop
 .endm
 
 #endif
 
        pushl   %eax                            # pt_regs->orig_ax
-       SAVE_ALL_NMI
+       SAVE_ALL_NMI cr3_reg=%edi
        ENCODE_FRAME_POINTER
        xorl    %edx, %edx                      # zero error code
        movl    %esp, %eax                      # pt_regs pointer
 
 .Lnmi_return:
        CHECK_AND_APPLY_ESPFIX
-       RESTORE_ALL_NMI pop=4
+       RESTORE_ALL_NMI cr3_reg=%edi pop=4
        jmp     .Lirq_return
 
 #ifdef CONFIG_X86_ESPFIX32
        pushl   16(%esp)
        .endr
        pushl   %eax
-       SAVE_ALL_NMI
+       SAVE_ALL_NMI cr3_reg=%edi
        ENCODE_FRAME_POINTER
        FIXUP_ESPFIX_STACK                      # %eax == %esp
        xorl    %edx, %edx                      # zero error code
        call    do_nmi
-       RESTORE_ALL_NMI
+       RESTORE_ALL_NMI cr3_reg=%edi
        lss     12+4(%esp), %esp                # back to espfix stack
        jmp     .Lirq_return
 #endif