.pcpu = &system_group_pcpu,
 };
 
+static DEFINE_PER_CPU(seqcount_t, psi_seq);
+
+static inline void psi_write_begin(int cpu)
+{
+       write_seqcount_begin(per_cpu_ptr(&psi_seq, cpu));
+}
+
+static inline void psi_write_end(int cpu)
+{
+       write_seqcount_end(per_cpu_ptr(&psi_seq, cpu));
+}
+
+static inline u32 psi_read_begin(int cpu)
+{
+       return read_seqcount_begin(per_cpu_ptr(&psi_seq, cpu));
+}
+
+static inline bool psi_read_retry(int cpu, u32 seq)
+{
+       return read_seqcount_retry(per_cpu_ptr(&psi_seq, cpu), seq);
+}
+
 static void psi_avgs_work(struct work_struct *work);
 
 static void poll_timer_fn(struct timer_list *t);
 
        group->enabled = true;
        for_each_possible_cpu(cpu)
-               seqcount_init(&per_cpu_ptr(group->pcpu, cpu)->seq);
+               seqcount_init(per_cpu_ptr(&psi_seq, cpu));
        group->avg_last_update = sched_clock();
        group->avg_next_update = group->avg_last_update + psi_period;
        mutex_init(&group->avgs_lock);
 
        /* Snapshot a coherent view of the CPU state */
        do {
-               seq = read_seqcount_begin(&groupc->seq);
+               seq = psi_read_begin(cpu);
                now = cpu_clock(cpu);
                memcpy(times, groupc->times, sizeof(groupc->times));
                state_mask = groupc->state_mask;
                state_start = groupc->state_start;
                if (cpu == current_cpu)
                        memcpy(tasks, groupc->tasks, sizeof(groupc->tasks));
-       } while (read_seqcount_retry(&groupc->seq, seq));
+       } while (psi_read_retry(cpu, seq));
 
        /* Calculate state time deltas against the previous snapshot */
        for (s = 0; s < NR_PSI_STATES; s++) {
                groupc->times[PSI_NONIDLE] += delta;
 }
 
+#define for_each_group(iter, group) \
+       for (typeof(group) iter = group; iter; iter = iter->parent)
+
 static void psi_group_change(struct psi_group *group, int cpu,
                             unsigned int clear, unsigned int set,
-                            bool wake_clock)
+                            u64 now, bool wake_clock)
 {
        struct psi_group_cpu *groupc;
        unsigned int t, m;
        u32 state_mask;
-       u64 now;
 
        lockdep_assert_rq_held(cpu_rq(cpu));
        groupc = per_cpu_ptr(group->pcpu, cpu);
 
-       /*
-        * First we update the task counts according to the state
-        * change requested through the @clear and @set bits.
-        *
-        * Then if the cgroup PSI stats accounting enabled, we
-        * assess the aggregate resource states this CPU's tasks
-        * have been in since the last change, and account any
-        * SOME and FULL time these may have resulted in.
-        */
-       write_seqcount_begin(&groupc->seq);
-       now = cpu_clock(cpu);
-
        /*
         * Start with TSK_ONCPU, which doesn't have a corresponding
         * task count - it's just a boolean flag directly encoded in
 
                groupc->state_mask = state_mask;
 
-               write_seqcount_end(&groupc->seq);
                return;
        }
 
 
        groupc->state_mask = state_mask;
 
-       write_seqcount_end(&groupc->seq);
-
        if (state_mask & group->rtpoll_states)
                psi_schedule_rtpoll_work(group, 1, false);
 
 void psi_task_change(struct task_struct *task, int clear, int set)
 {
        int cpu = task_cpu(task);
-       struct psi_group *group;
+       u64 now;
 
        if (!task->pid)
                return;
 
        psi_flags_change(task, clear, set);
 
-       group = task_psi_group(task);
-       do {
-               psi_group_change(group, cpu, clear, set, true);
-       } while ((group = group->parent));
+       psi_write_begin(cpu);
+       now = cpu_clock(cpu);
+       for_each_group(group, task_psi_group(task))
+               psi_group_change(group, cpu, clear, set, now, true);
+       psi_write_end(cpu);
 }
 
 void psi_task_switch(struct task_struct *prev, struct task_struct *next,
                     bool sleep)
 {
-       struct psi_group *group, *common = NULL;
+       struct psi_group *common = NULL;
        int cpu = task_cpu(prev);
+       u64 now;
+
+       psi_write_begin(cpu);
+       now = cpu_clock(cpu);
 
        if (next->pid) {
                psi_flags_change(next, 0, TSK_ONCPU);
                 * ancestors with @prev, those will already have @prev's
                 * TSK_ONCPU bit set, and we can stop the iteration there.
                 */
-               group = task_psi_group(next);
-               do {
-                       if (per_cpu_ptr(group->pcpu, cpu)->state_mask &
-                           PSI_ONCPU) {
+               for_each_group(group, task_psi_group(next)) {
+                       struct psi_group_cpu *groupc = per_cpu_ptr(group->pcpu, cpu);
+
+                       if (groupc->state_mask & PSI_ONCPU) {
                                common = group;
                                break;
                        }
-
-                       psi_group_change(group, cpu, 0, TSK_ONCPU, true);
-               } while ((group = group->parent));
+                       psi_group_change(group, cpu, 0, TSK_ONCPU, now, true);
+               }
        }
 
        if (prev->pid) {
 
                psi_flags_change(prev, clear, set);
 
-               group = task_psi_group(prev);
-               do {
+               for_each_group(group, task_psi_group(prev)) {
                        if (group == common)
                                break;
-                       psi_group_change(group, cpu, clear, set, wake_clock);
-               } while ((group = group->parent));
+                       psi_group_change(group, cpu, clear, set, now, wake_clock);
+               }
 
                /*
                 * TSK_ONCPU is handled up to the common ancestor. If there are
                 */
                if ((prev->psi_flags ^ next->psi_flags) & ~TSK_ONCPU) {
                        clear &= ~TSK_ONCPU;
-                       for (; group; group = group->parent)
-                               psi_group_change(group, cpu, clear, set, wake_clock);
+                       for_each_group(group, common)
+                               psi_group_change(group, cpu, clear, set, now, wake_clock);
                }
        }
+       psi_write_end(cpu);
 }
 
 #ifdef CONFIG_IRQ_TIME_ACCOUNTING
 void psi_account_irqtime(struct rq *rq, struct task_struct *curr, struct task_struct *prev)
 {
        int cpu = task_cpu(curr);
-       struct psi_group *group;
        struct psi_group_cpu *groupc;
        s64 delta;
        u64 irq;
+       u64 now;
 
        if (static_branch_likely(&psi_disabled) || !irqtime_enabled())
                return;
                return;
 
        lockdep_assert_rq_held(rq);
-       group = task_psi_group(curr);
-       if (prev && task_psi_group(prev) == group)
+       if (prev && task_psi_group(prev) == task_psi_group(curr))
                return;
 
        irq = irq_time_read(cpu);
                return;
        rq->psi_irq_time = irq;
 
-       do {
-               u64 now;
+       psi_write_begin(cpu);
+       now = cpu_clock(cpu);
 
+       for_each_group(group, task_psi_group(curr)) {
                if (!group->enabled)
                        continue;
 
                groupc = per_cpu_ptr(group->pcpu, cpu);
 
-               write_seqcount_begin(&groupc->seq);
-               now = cpu_clock(cpu);
-
                record_times(groupc, now);
                groupc->times[PSI_IRQ_FULL] += delta;
 
-               write_seqcount_end(&groupc->seq);
-
                if (group->rtpoll_states & (1 << PSI_IRQ_FULL))
                        psi_schedule_rtpoll_work(group, 1, false);
-       } while ((group = group->parent));
+       }
+       psi_write_end(cpu);
 }
 #endif /* CONFIG_IRQ_TIME_ACCOUNTING */
 
                return;
 
        for_each_possible_cpu(cpu) {
-               struct rq *rq = cpu_rq(cpu);
-               struct rq_flags rf;
+               u64 now;
 
-               rq_lock_irq(rq, &rf);
-               psi_group_change(group, cpu, 0, 0, true);
-               rq_unlock_irq(rq, &rf);
+               guard(rq_lock_irq)(cpu_rq(cpu));
+
+               psi_write_begin(cpu);
+               now = cpu_clock(cpu);
+               psi_group_change(group, cpu, 0, 0, now, true);
+               psi_write_end(cpu);
        }
 }
 #endif /* CONFIG_CGROUPS */