irq_stack_exit
        .endm
 
+#ifdef CONFIG_ARM64_PSEUDO_NMI
+       /*
+        * Set res to 0 if irqs were unmasked in interrupted context.
+        * Otherwise set res to non-0 value.
+        */
+       .macro  test_irqs_unmasked res:req, pmr:req
+alternative_if ARM64_HAS_IRQ_PRIO_MASKING
+       sub     \res, \pmr, #GIC_PRIO_IRQON
+alternative_else
+       mov     \res, xzr
+alternative_endif
+       .endm
+#endif
+
        .text
 
 /*
 el1_irq:
        kernel_entry 1
        enable_da_f
-#ifdef CONFIG_TRACE_IRQFLAGS
+
 #ifdef CONFIG_ARM64_PSEUDO_NMI
 alternative_if ARM64_HAS_IRQ_PRIO_MASKING
        ldr     x20, [sp, #S_PMR_SAVE]
-alternative_else
-       mov     x20, #GIC_PRIO_IRQON
-alternative_endif
-       cmp     x20, #GIC_PRIO_IRQOFF
-       /* Irqs were disabled, don't trace */
-       b.ls    1f
+alternative_else_nop_endif
+       test_irqs_unmasked      res=x0, pmr=x20
+       cbz     x0, 1f
+       bl      asm_nmi_enter
+1:
 #endif
+
+#ifdef CONFIG_TRACE_IRQFLAGS
        bl      trace_hardirqs_off
-1:
 #endif
 
        irq_handler
        bl      preempt_schedule_irq            // irq en/disable is done inside
 1:
 #endif
-#ifdef CONFIG_TRACE_IRQFLAGS
+
 #ifdef CONFIG_ARM64_PSEUDO_NMI
        /*
         * if IRQs were disabled when we received the interrupt, we have an NMI
         * and we are not re-enabling interrupt upon eret. Skip tracing.
         */
-       cmp     x20, #GIC_PRIO_IRQOFF
-       b.ls    1f
+       test_irqs_unmasked      res=x0, pmr=x20
+       cbz     x0, 1f
+       bl      asm_nmi_exit
+1:
+#endif
+
+#ifdef CONFIG_TRACE_IRQFLAGS
+#ifdef CONFIG_ARM64_PSEUDO_NMI
+       test_irqs_unmasked      res=x0, pmr=x20
+       cbnz    x0, 1f
 #endif
        bl      trace_hardirqs_on
 1:
 
 #include <linux/smp.h>
 #include <linux/init.h>
 #include <linux/irqchip.h>
+#include <linux/kprobes.h>
 #include <linux/seq_file.h>
 #include <linux/vmalloc.h>
+#include <asm/daifflags.h>
 #include <asm/vmap_stack.h>
 
 unsigned long irq_err_count;
        if (!handle_arch_irq)
                panic("No interrupt controller found.");
 }
+
+/*
+ * Stubs to make nmi_enter/exit() code callable from ASM
+ */
+asmlinkage void notrace asm_nmi_enter(void)
+{
+       nmi_enter();
+}
+NOKPROBE_SYMBOL(asm_nmi_enter);
+
+asmlinkage void notrace asm_nmi_exit(void)
+{
+       nmi_exit();
+}
+NOKPROBE_SYMBOL(asm_nmi_exit);
 
 
 static inline void gic_handle_nmi(u32 irqnr, struct pt_regs *regs)
 {
+       bool irqs_enabled = interrupts_enabled(regs);
        int err;
 
+       if (irqs_enabled)
+               nmi_enter();
+
        if (static_branch_likely(&supports_deactivate_key))
                gic_write_eoir(irqnr);
        /*
        err = handle_domain_nmi(gic_data.domain, irqnr, regs);
        if (err)
                gic_deactivate_unhandled(irqnr);
+
+       if (irqs_enabled)
+               nmi_exit();
 }
 
 static asmlinkage void __exception_irq_entry gic_handle_irq(struct pt_regs *regs)
 
  * @hwirq:     The HW irq number to convert to a logical one
  * @regs:      Register file coming from the low-level handling code
  *
+ *             This function must be called from an NMI context.
+ *
  * Returns:    0 on success, or -EINVAL if conversion has failed
  */
 int handle_domain_nmi(struct irq_domain *domain, unsigned int hwirq,
        unsigned int irq;
        int ret = 0;
 
-       nmi_enter();
+       /*
+        * NMI context needs to be setup earlier in order to deal with tracing.
+        */
+       WARN_ON(!in_nmi());
 
        irq = irq_find_mapping(domain, hwirq);
 
        else
                ret = -EINVAL;
 
-       nmi_exit();
        set_irq_regs(old_regs);
        return ret;
 }