.macro  kernel_exit, el
        .if     \el != 0
+       disable_daif
+
        /* Restore the task's original addr_limit. */
        ldr     x20, [sp, #S_ORIG_ADDR_LIMIT]
        str     x20, [tsk, #TSK_TI_ADDR_LIMIT]
        mov     x2, sp                          // struct pt_regs
        bl      do_mem_abort
 
-       // disable interrupts before pulling preserved data off the stack
-       disable_irq
        kernel_exit 1
 el1_sp_pc:
        /*
  * and this includes saving x0 back into the kernel stack.
  */
 ret_fast_syscall:
-       disable_irq                             // disable interrupts
+       disable_daif
        str     x0, [sp, #S_X0]                 // returned x0
        ldr     x1, [tsk, #TSK_TI_FLAGS]        // re-check for syscall tracing
        and     x2, x1, #_TIF_SYSCALL_WORK
        enable_step_tsk x1, x2
        kernel_exit 0
 ret_fast_syscall_trace:
-       enable_irq                              // enable interrupts
+       enable_daif
        b       __sys_trace_return_skipped      // we already saved x0
 
 /*
  * "slow" syscall return path.
  */
 ret_to_user:
-       disable_irq                             // disable interrupts
+       disable_daif
        ldr     x1, [tsk, #TSK_TI_FLAGS]
        and     x2, x1, #_TIF_WORK_MASK
        cbnz    x2, work_pending
 
 #include <linux/ratelimit.h>
 #include <linux/syscalls.h>
 
+#include <asm/daifflags.h>
 #include <asm/debug-monitors.h>
 #include <asm/elf.h>
 #include <asm/cacheflush.h>
                addr_limit_user_check();
 
                if (thread_flags & _TIF_NEED_RESCHED) {
+                       /* Unmask Debug and SError for the next task */
+                       local_daif_restore(DAIF_PROCCTX_NOIRQ);
+
                        schedule();
                } else {
-                       local_irq_enable();
+                       local_daif_restore(DAIF_PROCCTX);
 
                        if (thread_flags & _TIF_UPROBE)
                                uprobe_notify_resume(regs);
                                fpsimd_restore_current_state();
                }
 
-               local_irq_disable();
+               local_daif_mask();
                thread_flags = READ_ONCE(current_thread_info()->flags);
        } while (thread_flags & _TIF_WORK_MASK);
 }