wake_up_interruptible(&group->poll_wait);
 }
 
-static void record_times(struct psi_group_cpu *groupc, int cpu,
-                        bool memstall_tick)
+static void record_times(struct psi_group_cpu *groupc, int cpu)
 {
        u32 delta;
        u64 now;
                groupc->times[PSI_MEM_SOME] += delta;
                if (groupc->state_mask & (1 << PSI_MEM_FULL))
                        groupc->times[PSI_MEM_FULL] += delta;
-               else if (memstall_tick) {
-                       u32 sample;
-                       /*
-                        * Since we care about lost potential, a
-                        * memstall is FULL when there are no other
-                        * working tasks, but also when the CPU is
-                        * actively reclaiming and nothing productive
-                        * could run even if it were runnable.
-                        *
-                        * When the timer tick sees a reclaiming CPU,
-                        * regardless of runnable tasks, sample a FULL
-                        * tick (or less if it hasn't been a full tick
-                        * since the last state change).
-                        */
-                       sample = min(delta, (u32)jiffies_to_nsecs(1));
-                       groupc->times[PSI_MEM_FULL] += sample;
-               }
        }
 
        if (groupc->state_mask & (1 << PSI_CPU_SOME)) {
         */
        write_seqcount_begin(&groupc->seq);
 
-       record_times(groupc, cpu, false);
+       record_times(groupc, cpu);
 
        for (t = 0, m = clear; m; m &= ~(1 << t), t++) {
                if (!(m & (1 << t)))
                if (test_state(groupc->tasks, s))
                        state_mask |= (1 << s);
        }
+
+       /*
+        * Since we care about lost potential, a memstall is FULL
+        * when there are no other working tasks, but also when
+        * the CPU is actively reclaiming and nothing productive
+        * could run even if it were runnable. So when the current
+        * task in a cgroup is in_memstall, the corresponding groupc
+        * on that cpu is in PSI_MEM_FULL state.
+        */
+       if (groupc->tasks[NR_ONCPU] && cpu_curr(cpu)->in_memstall)
+               state_mask |= (1 << PSI_MEM_FULL);
+
        groupc->state_mask = state_mask;
 
        write_seqcount_end(&groupc->seq);
        void *iter;
 
        if (next->pid) {
+               bool identical_state;
+
                psi_flags_change(next, 0, TSK_ONCPU);
                /*
-                * When moving state between tasks, the group that
-                * contains them both does not change: we can stop
-                * updating the tree once we reach the first common
-                * ancestor. Iterate @next's ancestors until we
-                * encounter @prev's state.
+                * When switching between tasks that have an identical
+                * runtime state, the cgroup that contains both tasks
+                * runtime state, the cgroup that contains both tasks
+                * we reach the first common ancestor. Iterate @next's
+                * ancestors only until we encounter @prev's ONCPU.
                 */
+               identical_state = prev->psi_flags == next->psi_flags;
                iter = NULL;
                while ((group = iterate_groups(next, &iter))) {
-                       if (per_cpu_ptr(group->pcpu, cpu)->tasks[NR_ONCPU]) {
+                       if (identical_state &&
+                           per_cpu_ptr(group->pcpu, cpu)->tasks[NR_ONCPU]) {
                                common = group;
                                break;
                        }
        }
 }
 
-void psi_memstall_tick(struct task_struct *task, int cpu)
-{
-       struct psi_group *group;
-       void *iter = NULL;
-
-       while ((group = iterate_groups(task, &iter))) {
-               struct psi_group_cpu *groupc;
-
-               groupc = per_cpu_ptr(group->pcpu, cpu);
-               write_seqcount_begin(&groupc->seq);
-               record_times(groupc, cpu, true);
-               write_seqcount_end(&groupc->seq);
-       }
-}
-
 /**
  * psi_memstall_enter - mark the beginning of a memory stall section
  * @flags: flags to handle nested sections