extern void init_idle(struct task_struct *idle, int cpu);
 
 extern int sched_fork(unsigned long clone_flags, struct task_struct *p);
-extern void sched_post_fork(struct task_struct *p);
+extern void sched_post_fork(struct task_struct *p,
+                           struct kernel_clone_args *kargs);
 extern void sched_dead(struct task_struct *p);
 
 void __noreturn do_task_dead(void);
 
  */
 int sched_fork(unsigned long clone_flags, struct task_struct *p)
 {
-       unsigned long flags;
-
        __sched_fork(clone_flags, p);
        /*
         * We mark the process as NEW here. This guarantees that
 
        init_entity_runnable_average(&p->se);
 
-       /*
-        * The child is not yet in the pid-hash so no cgroup attach races,
-        * and the cgroup is pinned to this child due to cgroup_fork()
-        * is ran before sched_fork().
-        *
-        * Silence PROVE_RCU.
-        */
-       raw_spin_lock_irqsave(&p->pi_lock, flags);
-       rseq_migrate(p);
-       /*
-        * We're setting the CPU for the first time, we don't migrate,
-        * so use __set_task_cpu().
-        */
-       __set_task_cpu(p, smp_processor_id());
-       if (p->sched_class->task_fork)
-               p->sched_class->task_fork(p);
-       raw_spin_unlock_irqrestore(&p->pi_lock, flags);
-
 #ifdef CONFIG_SCHED_INFO
        if (likely(sched_info_on()))
                memset(&p->sched_info, 0, sizeof(p->sched_info));
        return 0;
 }
 
-void sched_post_fork(struct task_struct *p)
+void sched_post_fork(struct task_struct *p, struct kernel_clone_args *kargs)
 {
+       unsigned long flags;
+#ifdef CONFIG_CGROUP_SCHED
+       struct task_group *tg;
+#endif
+
+       raw_spin_lock_irqsave(&p->pi_lock, flags);
+#ifdef CONFIG_CGROUP_SCHED
+       tg = container_of(kargs->cset->subsys[cpu_cgrp_id],
+                         struct task_group, css);
+       p->sched_task_group = autogroup_task_group(p, tg);
+#endif
+       rseq_migrate(p);
+       /*
+        * We're setting the CPU for the first time, we don't migrate,
+        * so use __set_task_cpu().
+        */
+       __set_task_cpu(p, smp_processor_id());
+       if (p->sched_class->task_fork)
+               p->sched_class->task_fork(p);
+       raw_spin_unlock_irqrestore(&p->pi_lock, flags);
+
        uclamp_post_fork(p);
 }