extern int perf_num_counters(void);
 extern const char *perf_pmu_name(void);
-extern void __perf_event_task_sched(struct task_struct *prev,
-                                   struct task_struct *next);
+extern void __perf_event_task_sched_in(struct task_struct *prev,
+                                      struct task_struct *task);
+extern void __perf_event_task_sched_out(struct task_struct *prev,
+                                       struct task_struct *next);
 extern int perf_event_init_task(struct task_struct *child);
 extern void perf_event_exit_task(struct task_struct *child);
 extern void perf_event_free_task(struct task_struct *task);
 
 extern struct static_key_deferred perf_sched_events;
 
-static inline void perf_event_task_sched(struct task_struct *prev,
+static inline void perf_event_task_sched_in(struct task_struct *prev,
                                            struct task_struct *task)
+{
+       if (static_key_false(&perf_sched_events.key))
+               __perf_event_task_sched_in(prev, task);
+}
+
+static inline void perf_event_task_sched_out(struct task_struct *prev,
+                                            struct task_struct *next)
 {
        perf_sw_event(PERF_COUNT_SW_CONTEXT_SWITCHES, 1, NULL, 0);
 
        if (static_key_false(&perf_sched_events.key))
-               __perf_event_task_sched(prev, task);
+               __perf_event_task_sched_out(prev, next);
 }
 
 extern void perf_event_mmap(struct vm_area_struct *vma);
 extern void perf_event_task_tick(void);
 #else
 static inline void
-perf_event_task_sched(struct task_struct *prev,
-                     struct task_struct *task)                         { }
+perf_event_task_sched_in(struct task_struct *prev,
+                        struct task_struct *task)                      { }
+static inline void
+perf_event_task_sched_out(struct task_struct *prev,
+                         struct task_struct *next)                     { }
 static inline int perf_event_init_task(struct task_struct *child)      { return 0; }
 static inline void perf_event_exit_task(struct task_struct *child)     { }
 static inline void perf_event_free_task(struct task_struct *task)      { }
 
  * accessing the event control register. If a NMI hits, then it will
  * not restart the event.
  */
-static void __perf_event_task_sched_out(struct task_struct *task,
-                                       struct task_struct *next)
+void __perf_event_task_sched_out(struct task_struct *task,
+                                struct task_struct *next)
 {
        int ctxn;
 
  * accessing the event control register. If a NMI hits, then it will
  * keep the event running.
  */
-static void __perf_event_task_sched_in(struct task_struct *prev,
-                                      struct task_struct *task)
+void __perf_event_task_sched_in(struct task_struct *prev,
+                               struct task_struct *task)
 {
        struct perf_event_context *ctx;
        int ctxn;
                perf_branch_stack_sched_in(prev, task);
 }
 
-void __perf_event_task_sched(struct task_struct *prev, struct task_struct *next)
-{
-       __perf_event_task_sched_out(prev, next);
-       __perf_event_task_sched_in(prev, next);
-}
-
 static u64 perf_calculate_period(struct perf_event *event, u64 nsec, u64 count)
 {
        u64 frequency = event->attr.sample_freq;
 
                    struct task_struct *next)
 {
        sched_info_switch(prev, next);
-       perf_event_task_sched(prev, next);
+       perf_event_task_sched_out(prev, next);
        fire_sched_out_preempt_notifiers(prev, next);
        prepare_lock_switch(rq, next);
        prepare_arch_switch(next);
         */
        prev_state = prev->state;
        finish_arch_switch(prev);
+#ifdef __ARCH_WANT_INTERRUPTS_ON_CTXSW
+       local_irq_disable();
+#endif /* __ARCH_WANT_INTERRUPTS_ON_CTXSW */
+       perf_event_task_sched_in(prev, current);
+#ifdef __ARCH_WANT_INTERRUPTS_ON_CTXSW
+       local_irq_enable();
+#endif /* __ARCH_WANT_INTERRUPTS_ON_CTXSW */
        finish_lock_switch(rq, prev);
        finish_arch_post_lock_switch();