#include <linux/export.h>
 
 #include <asm/processor.h>
+#include <asm/traps.h>
 #include <asm/mce.h>
 #include <asm/msr.h>
 
 {
        struct mca_config *cfg = &mca_cfg;
        struct mce m, *final;
+       enum ctx_state prev_state;
        int i;
        int worst = 0;
        int severity;
        DECLARE_BITMAP(valid_banks, MAX_NR_BANKS);
        char *msg = "Unknown";
 
+       prev_state = ist_enter(regs);
+
        this_cpu_inc(mce_exception_count);
 
        if (!cfg->banks)
        mce_wrmsrl(MSR_IA32_MCG_STATUS, 0);
 out:
        sync_core();
+       ist_exit(regs, prev_state);
 }
 EXPORT_SYMBOL_GPL(do_machine_check);
 
 
 #include <linux/smp.h>
 
 #include <asm/processor.h>
+#include <asm/traps.h>
 #include <asm/mce.h>
 #include <asm/msr.h>
 
 /* Machine check handler for Pentium class Intel CPUs: */
 static void pentium_machine_check(struct pt_regs *regs, long error_code)
 {
+       enum ctx_state prev_state;
        u32 loaddr, hi, lotype;
 
+       prev_state = ist_enter(regs);
+
        rdmsr(MSR_IA32_P5_MC_ADDR, loaddr, hi);
        rdmsr(MSR_IA32_P5_MC_TYPE, lotype, hi);
 
        }
 
        add_taint(TAINT_MACHINE_CHECK, LOCKDEP_NOW_UNRELIABLE);
+
+       ist_exit(regs, prev_state);
 }
 
 /* Set up machine check reporting for processors with Intel style MCE: */
 
        preempt_count_dec();
 }
 
+enum ctx_state ist_enter(struct pt_regs *regs)
+{
+       /*
+        * We are atomic because we're on the IST stack (or we're on x86_32,
+        * in which case we still shouldn't schedule.
+        */
+       preempt_count_add(HARDIRQ_OFFSET);
+
+       if (user_mode_vm(regs)) {
+               /* Other than that, we're just an exception. */
+               return exception_enter();
+       } else {
+               /*
+                * We might have interrupted pretty much anything.  In
+                * fact, if we're a machine check, we can even interrupt
+                * NMI processing.  We don't want in_nmi() to return true,
+                * but we need to notify RCU.
+                */
+               rcu_nmi_enter();
+               return IN_KERNEL;  /* the value is irrelevant. */
+       }
+}
+
+void ist_exit(struct pt_regs *regs, enum ctx_state prev_state)
+{
+       preempt_count_sub(HARDIRQ_OFFSET);
+
+       if (user_mode_vm(regs))
+               return exception_exit(prev_state);
+       else
+               rcu_nmi_exit();
+}
+
 static nokprobe_inline int
 do_trap_no_signal(struct task_struct *tsk, int trapnr, char *str,
                  struct pt_regs *regs, long error_code)
         * end up promoting it to a doublefault.  In that case, modify
         * the stack to make it look like we just entered the #GP
         * handler from user space, similar to bad_iret.
+        *
+        * No need for ist_enter here because we don't use RCU.
         */
        if (((long)regs->sp >> PGDIR_SHIFT) == ESPFIX_PGD_ENTRY &&
                regs->cs == __KERNEL_CS &&
                normal_regs->orig_ax = 0;  /* Missing (lost) #GP error code */
                regs->ip = (unsigned long)general_protection;
                regs->sp = (unsigned long)&normal_regs->orig_ax;
+
                return;
        }
 #endif
 
-       exception_enter();
-       /* Return not checked because double check cannot be ignored */
+       ist_enter(regs);  /* Discard prev_state because we won't return. */
        notify_die(DIE_TRAP, str, regs, error_code, X86_TRAP_DF, SIGSEGV);
 
        tsk->thread.error_code = error_code;
        if (poke_int3_handler(regs))
                return;
 
-       prev_state = exception_enter();
+       prev_state = ist_enter(regs);
 #ifdef CONFIG_KGDB_LOW_LEVEL_TRAP
        if (kgdb_ll_trap(DIE_INT3, "int3", regs, error_code, X86_TRAP_BP,
                                SIGTRAP) == NOTIFY_STOP)
        preempt_conditional_cli(regs);
        debug_stack_usage_dec();
 exit:
-       exception_exit(prev_state);
+       ist_exit(regs, prev_state);
 }
 NOKPROBE_SYMBOL(do_int3);
 
        unsigned long dr6;
        int si_code;
 
-       prev_state = exception_enter();
+       prev_state = ist_enter(regs);
 
        get_debugreg(dr6, 6);
 
        debug_stack_usage_dec();
 
 exit:
-       exception_exit(prev_state);
+       ist_exit(regs, prev_state);
 }
 NOKPROBE_SYMBOL(do_debug);