lockdep_assert_rq_held(rq);
        rq->scx.flags |= SCX_RQ_IN_BALANCE;
+       rq->scx.flags &= ~SCX_RQ_BAL_KEEP;
 
        if (static_branch_unlikely(&scx_ops_cpu_preempt) &&
            unlikely(rq->scx.cpu_released)) {
        }
 
        if (prev_on_scx) {
-               WARN_ON_ONCE(local && (prev->scx.flags & SCX_TASK_BAL_KEEP));
                update_curr_scx(rq);
 
                /*
                 *
                 * When balancing a remote CPU for core-sched, there won't be a
                 * following put_prev_task_scx() call and we don't own
-                * %SCX_TASK_BAL_KEEP. Instead, pick_task_scx() will test the
-                * same conditions later and pick @rq->curr accordingly.
+                * %SCX_RQ_BAL_KEEP. Instead, pick_task_scx() will test the same
+                * conditions later and pick @rq->curr accordingly.
                 */
                if ((prev->scx.flags & SCX_TASK_QUEUED) &&
                    prev->scx.slice && !scx_ops_bypassing()) {
                        if (local)
-                               prev->scx.flags |= SCX_TASK_BAL_KEEP;
+                               rq->scx.flags |= SCX_RQ_BAL_KEEP;
                        goto has_tasks;
                }
        }
        if ((prev->scx.flags & SCX_TASK_QUEUED) &&
            (!static_branch_unlikely(&scx_ops_enq_last) || scx_ops_bypassing())) {
                if (local)
-                       prev->scx.flags |= SCX_TASK_BAL_KEEP;
+                       rq->scx.flags |= SCX_RQ_BAL_KEEP;
                goto has_tasks;
        }
        rq->scx.flags &= ~SCX_RQ_IN_BALANCE;
                SCX_CALL_OP_TASK(SCX_KF_REST, stopping, p, true);
 
        if (p->scx.flags & SCX_TASK_QUEUED) {
-               p->scx.flags &= ~SCX_TASK_BAL_KEEP;
-
                set_task_runnable(rq, p);
 
                /*
         * if necessary and keep running @prev. Otherwise, pop the first one
         * from the local DSQ.
         */
-       if (prev->scx.flags & SCX_TASK_BAL_KEEP) {
-               prev->scx.flags &= ~SCX_TASK_BAL_KEEP;
+       if ((rq->scx.flags & SCX_RQ_BAL_KEEP) &&
+           !WARN_ON_ONCE(prev->sched_class != &ext_sched_class)) {
                p = prev;
                if (!p->scx.slice)
                        p->scx.slice = SCX_SLICE_DFL;
  *
  * As put_prev_task_scx() hasn't been called on remote CPUs, we can't just look
  * at the first task in the local dsq. @rq->curr has to be considered explicitly
- * to mimic %SCX_TASK_BAL_KEEP.
+ * to mimic %SCX_RQ_BAL_KEEP.
  */
 static struct task_struct *pick_task_scx(struct rq *rq)
 {
  *
  * b. ops.dispatch() is ignored.
  *
- * c. balance_scx() does not set %SCX_TASK_BAL_KEEP on non-zero slice as slice
+ * c. balance_scx() does not set %SCX_RQ_BAL_KEEP on non-zero slice as slice
  *    can't be trusted. Whenever a tick triggers, the running task is rotated to
  *    the tail of the queue with core_sched_at touched.
  *