]> www.infradead.org Git - users/willy/xarray.git/commitdiff
sched_ext: Move dsq_hash into scx_sched
authorTejun Heo <tj@kernel.org>
Tue, 29 Apr 2025 18:40:10 +0000 (08:40 -1000)
committerTejun Heo <tj@kernel.org>
Tue, 29 Apr 2025 18:40:10 +0000 (08:40 -1000)
User DSQs are going to become per scheduler instance. Move dsq_hash into
scx_sched. This shifts the code that assumes scx_root to be the only
scx_sched instance up the call stack but doesn't remove them yet.

v2: Add missing rcu_read_lock() in scx_bpf_destroy_dsq() as per Andrea.

Signed-off-by: Tejun Heo <tj@kernel.org>
Reviewed-by: Andrea Righi <arighi@nvidia.com>
Acked-by: Changwoo Min <changwoo@igalia.com>
kernel/sched/ext.c

index 91b89cff52bf0af49bfd27010438eb274a2189c0..62ec2d1b195e1ea1f0dc5440b89023510f1c036d 100644 (file)
@@ -770,6 +770,7 @@ struct scx_sched {
        struct sched_ext_ops    ops;
        DECLARE_BITMAP(has_op, SCX_OPI_END);
 
+       struct rhashtable       dsq_hash;
        bool                    warned_zero_slice;
 
        atomic_t                exit_kind;
@@ -1016,7 +1017,6 @@ static const struct rhashtable_params dsq_hash_params = {
        .head_offset            = offsetof(struct scx_dispatch_q, hash_node),
 };
 
-static struct rhashtable dsq_hash;
 static LLIST_HEAD(dsqs_to_free);
 
 /* dispatch buf */
@@ -1114,9 +1114,9 @@ static struct scx_dispatch_q *find_global_dsq(struct task_struct *p)
        return global_dsqs[cpu_to_node(task_cpu(p))];
 }
 
-static struct scx_dispatch_q *find_user_dsq(u64 dsq_id)
+static struct scx_dispatch_q *find_user_dsq(struct scx_sched *sch, u64 dsq_id)
 {
-       return rhashtable_lookup_fast(&dsq_hash, &dsq_id, dsq_hash_params);
+       return rhashtable_lookup_fast(&sch->dsq_hash, &dsq_id, dsq_hash_params);
 }
 
 /*
@@ -2059,7 +2059,8 @@ static void dispatch_dequeue(struct rq *rq, struct task_struct *p)
                raw_spin_unlock(&dsq->lock);
 }
 
-static struct scx_dispatch_q *find_dsq_for_dispatch(struct rq *rq, u64 dsq_id,
+static struct scx_dispatch_q *find_dsq_for_dispatch(struct scx_sched *sch,
+                                                   struct rq *rq, u64 dsq_id,
                                                    struct task_struct *p)
 {
        struct scx_dispatch_q *dsq;
@@ -2079,7 +2080,7 @@ static struct scx_dispatch_q *find_dsq_for_dispatch(struct rq *rq, u64 dsq_id,
        if (dsq_id == SCX_DSQ_GLOBAL)
                dsq = find_global_dsq(p);
        else
-               dsq = find_user_dsq(dsq_id);
+               dsq = find_user_dsq(sch, dsq_id);
 
        if (unlikely(!dsq)) {
                scx_error("non-existent DSQ 0x%llx for %s[%d]",
@@ -2120,11 +2121,12 @@ static void mark_direct_dispatch(struct task_struct *ddsp_task,
        p->scx.ddsp_enq_flags = enq_flags;
 }
 
-static void direct_dispatch(struct task_struct *p, u64 enq_flags)
+static void direct_dispatch(struct scx_sched *sch, struct task_struct *p,
+                           u64 enq_flags)
 {
        struct rq *rq = task_rq(p);
        struct scx_dispatch_q *dsq =
-               find_dsq_for_dispatch(rq, p->scx.ddsp_dsq_id, p);
+               find_dsq_for_dispatch(sch, rq, p->scx.ddsp_dsq_id, p);
 
        touch_core_sched_dispatch(rq, p);
 
@@ -2183,6 +2185,7 @@ static bool scx_rq_online(struct rq *rq)
 static void do_enqueue_task(struct rq *rq, struct task_struct *p, u64 enq_flags,
                            int sticky_cpu)
 {
+       struct scx_sched *sch = scx_root;
        struct task_struct **ddsp_taskp;
        unsigned long qseq;
 
@@ -2249,7 +2252,7 @@ static void do_enqueue_task(struct rq *rq, struct task_struct *p, u64 enq_flags,
        return;
 
 direct:
-       direct_dispatch(p, enq_flags);
+       direct_dispatch(sch, p, enq_flags);
        return;
 
 local:
@@ -2909,7 +2912,8 @@ static void dispatch_to_local_dsq(struct rq *rq, struct scx_dispatch_q *dst_dsq,
  * was valid in the first place. Make sure that the task is still owned by the
  * BPF scheduler and claim the ownership before dispatching.
  */
-static void finish_dispatch(struct rq *rq, struct task_struct *p,
+static void finish_dispatch(struct scx_sched *sch, struct rq *rq,
+                           struct task_struct *p,
                            unsigned long qseq_at_dispatch,
                            u64 dsq_id, u64 enq_flags)
 {
@@ -2962,7 +2966,7 @@ retry:
 
        BUG_ON(!(p->scx.flags & SCX_TASK_QUEUED));
 
-       dsq = find_dsq_for_dispatch(this_rq(), dsq_id, p);
+       dsq = find_dsq_for_dispatch(sch, this_rq(), dsq_id, p);
 
        if (dsq->id == SCX_DSQ_LOCAL)
                dispatch_to_local_dsq(rq, dsq, p, enq_flags);
@@ -2970,7 +2974,7 @@ retry:
                dispatch_enqueue(dsq, p, enq_flags | SCX_ENQ_CLEAR_OPSS);
 }
 
-static void flush_dispatch_buf(struct rq *rq)
+static void flush_dispatch_buf(struct scx_sched *sch, struct rq *rq)
 {
        struct scx_dsp_ctx *dspc = this_cpu_ptr(scx_dsp_ctx);
        u32 u;
@@ -2978,7 +2982,7 @@ static void flush_dispatch_buf(struct rq *rq)
        for (u = 0; u < dspc->cursor; u++) {
                struct scx_dsp_buf_ent *ent = &dspc->buf[u];
 
-               finish_dispatch(rq, ent->task, ent->qseq, ent->dsq_id,
+               finish_dispatch(sch, rq, ent->task, ent->qseq, ent->dsq_id,
                                ent->enq_flags);
        }
 
@@ -2988,6 +2992,7 @@ static void flush_dispatch_buf(struct rq *rq)
 
 static int balance_one(struct rq *rq, struct task_struct *prev)
 {
+       struct scx_sched *sch = scx_root;
        struct scx_dsp_ctx *dspc = this_cpu_ptr(scx_dsp_ctx);
        bool prev_on_scx = prev->sched_class == &ext_sched_class;
        bool prev_on_rq = prev->scx.flags & SCX_TASK_QUEUED;
@@ -3055,7 +3060,7 @@ static int balance_one(struct rq *rq, struct task_struct *prev)
                SCX_CALL_OP(SCX_KF_DISPATCH, dispatch, rq, cpu_of(rq),
                            prev_on_scx ? prev : NULL);
 
-               flush_dispatch_buf(rq);
+               flush_dispatch_buf(sch, rq);
 
                if (prev_on_rq && prev->scx.slice) {
                        rq->scx.flags |= SCX_RQ_BAL_KEEP;
@@ -3149,11 +3154,12 @@ static void process_ddsp_deferred_locals(struct rq *rq)
         */
        while ((p = list_first_entry_or_null(&rq->scx.ddsp_deferred_locals,
                                struct task_struct, scx.dsq_list.node))) {
+               struct scx_sched *sch = scx_root;
                struct scx_dispatch_q *dsq;
 
                list_del_init(&p->scx.dsq_list.node);
 
-               dsq = find_dsq_for_dispatch(rq, p->scx.ddsp_dsq_id, p);
+               dsq = find_dsq_for_dispatch(sch, rq, p->scx.ddsp_dsq_id, p);
                if (!WARN_ON_ONCE(dsq->id != SCX_DSQ_LOCAL))
                        dispatch_to_local_dsq(rq, dsq, p, p->scx.ddsp_enq_flags);
        }
@@ -4207,14 +4213,14 @@ static void free_dsq_irq_workfn(struct irq_work *irq_work)
 
 static DEFINE_IRQ_WORK(free_dsq_irq_work, free_dsq_irq_workfn);
 
-static void destroy_dsq(u64 dsq_id)
+static void destroy_dsq(struct scx_sched *sch, u64 dsq_id)
 {
        struct scx_dispatch_q *dsq;
        unsigned long flags;
 
        rcu_read_lock();
 
-       dsq = find_user_dsq(dsq_id);
+       dsq = find_user_dsq(sch, dsq_id);
        if (!dsq)
                goto out_unlock_rcu;
 
@@ -4226,7 +4232,8 @@ static void destroy_dsq(u64 dsq_id)
                goto out_unlock_dsq;
        }
 
-       if (rhashtable_remove_fast(&dsq_hash, &dsq->hash_node, dsq_hash_params))
+       if (rhashtable_remove_fast(&sch->dsq_hash, &dsq->hash_node,
+                                  dsq_hash_params))
                goto out_unlock_dsq;
 
        /*
@@ -4400,7 +4407,21 @@ static void scx_sched_free_rcu_work(struct work_struct *work)
 {
        struct rcu_work *rcu_work = to_rcu_work(work);
        struct scx_sched *sch = container_of(rcu_work, struct scx_sched, rcu_work);
+       struct rhashtable_iter rht_iter;
+       struct scx_dispatch_q *dsq;
+
+       rhashtable_walk_enter(&sch->dsq_hash, &rht_iter);
+       do {
+               rhashtable_walk_start(&rht_iter);
+
+               while ((dsq = rhashtable_walk_next(&rht_iter)) && !IS_ERR(dsq))
+                       destroy_dsq(sch, dsq->id);
+
+               rhashtable_walk_stop(&rht_iter);
+       } while (dsq == ERR_PTR(-EAGAIN));
+       rhashtable_walk_exit(&rht_iter);
 
+       rhashtable_free_and_destroy(&sch->dsq_hash, NULL, NULL);
        free_exit_info(sch->exit_info);
        kfree(sch);
 }
@@ -4704,8 +4725,6 @@ static void scx_disable_workfn(struct kthread_work *work)
        struct scx_exit_info *ei = sch->exit_info;
        struct scx_task_iter sti;
        struct task_struct *p;
-       struct rhashtable_iter rht_iter;
-       struct scx_dispatch_q *dsq;
        int kind, cpu;
 
        kind = atomic_read(&sch->exit_kind);
@@ -4837,17 +4856,6 @@ static void scx_disable_workfn(struct kthread_work *work)
         */
        kobject_del(&sch->kobj);
 
-       rhashtable_walk_enter(&dsq_hash, &rht_iter);
-       do {
-               rhashtable_walk_start(&rht_iter);
-
-               while ((dsq = rhashtable_walk_next(&rht_iter)) && !IS_ERR(dsq))
-                       destroy_dsq(dsq->id);
-
-               rhashtable_walk_stop(&rht_iter);
-       } while (dsq == ERR_PTR(-EAGAIN));
-       rhashtable_walk_exit(&rht_iter);
-
        free_percpu(scx_dsp_ctx);
        scx_dsp_ctx = NULL;
        scx_dsp_max_batch = 0;
@@ -5251,6 +5259,10 @@ static struct scx_sched *scx_alloc_and_add_sched(struct sched_ext_ops *ops)
                goto err_free_sch;
        }
 
+       ret = rhashtable_init(&sch->dsq_hash, &dsq_hash_params);
+       if (ret < 0)
+               goto err_free_ei;
+
        atomic_set(&sch->exit_kind, SCX_EXIT_NONE);
        sch->ops = *ops;
        ops->priv = sch;
@@ -5258,10 +5270,12 @@ static struct scx_sched *scx_alloc_and_add_sched(struct sched_ext_ops *ops)
        sch->kobj.kset = scx_kset;
        ret = kobject_init_and_add(&sch->kobj, &scx_ktype, NULL, "root");
        if (ret < 0)
-               goto err_free_ei;
+               goto err_free_hash;
 
        return sch;
 
+err_free_hash:
+       rhashtable_free_and_destroy(&sch->dsq_hash, NULL, NULL);
 err_free_ei:
        free_exit_info(sch->exit_info);
 err_free_sch:
@@ -6102,7 +6116,6 @@ void __init init_sched_ext_class(void)
        WRITE_ONCE(v, SCX_ENQ_WAKEUP | SCX_DEQ_SLEEP | SCX_KICK_PREEMPT |
                   SCX_TG_ONLINE);
 
-       BUG_ON(rhashtable_init(&dsq_hash, &dsq_hash_params));
        scx_idle_init_masks();
 
        scx_kick_cpus_pnt_seqs =
@@ -6303,6 +6316,7 @@ static const struct btf_kfunc_id_set scx_kfunc_set_enqueue_dispatch = {
 static bool scx_dsq_move(struct bpf_iter_scx_dsq_kern *kit,
                         struct task_struct *p, u64 dsq_id, u64 enq_flags)
 {
+       struct scx_sched *sch = scx_root;
        struct scx_dispatch_q *src_dsq = kit->dsq, *dst_dsq;
        struct rq *this_rq, *src_rq, *locked_rq;
        bool dispatched = false;
@@ -6355,7 +6369,7 @@ static bool scx_dsq_move(struct bpf_iter_scx_dsq_kern *kit,
        }
 
        /* @p is still on $src_dsq and stable, determine the destination */
-       dst_dsq = find_dsq_for_dispatch(this_rq, dsq_id, p);
+       dst_dsq = find_dsq_for_dispatch(sch, this_rq, dsq_id, p);
 
        /*
         * Apply vtime and slice updates before moving so that the new time is
@@ -6435,15 +6449,16 @@ __bpf_kfunc void scx_bpf_dispatch_cancel(void)
  */
 __bpf_kfunc bool scx_bpf_dsq_move_to_local(u64 dsq_id)
 {
+       struct scx_sched *sch = scx_root;
        struct scx_dsp_ctx *dspc = this_cpu_ptr(scx_dsp_ctx);
        struct scx_dispatch_q *dsq;
 
        if (!scx_kf_allowed(SCX_KF_DISPATCH))
                return false;
 
-       flush_dispatch_buf(dspc->rq);
+       flush_dispatch_buf(sch, dspc->rq);
 
-       dsq = find_user_dsq(dsq_id);
+       dsq = find_user_dsq(sch, dsq_id);
        if (unlikely(!dsq)) {
                scx_error("invalid DSQ ID 0x%016llx", dsq_id);
                return false;
@@ -6704,6 +6719,7 @@ __bpf_kfunc_start_defs();
 __bpf_kfunc s32 scx_bpf_create_dsq(u64 dsq_id, s32 node)
 {
        struct scx_dispatch_q *dsq;
+       struct scx_sched *sch;
        s32 ret;
 
        if (unlikely(node >= (int)nr_node_ids ||
@@ -6719,8 +6735,16 @@ __bpf_kfunc s32 scx_bpf_create_dsq(u64 dsq_id, s32 node)
 
        init_dsq(dsq, dsq_id);
 
-       ret = rhashtable_lookup_insert_fast(&dsq_hash, &dsq->hash_node,
-                                           dsq_hash_params);
+       rcu_read_lock();
+
+       sch = rcu_dereference(scx_root);
+       if (sch)
+               ret = rhashtable_lookup_insert_fast(&sch->dsq_hash, &dsq->hash_node,
+                                                   dsq_hash_params);
+       else
+               ret = -ENODEV;
+
+       rcu_read_unlock();
        if (ret)
                kfree(dsq);
        return ret;
@@ -6819,11 +6843,18 @@ out:
  */
 __bpf_kfunc s32 scx_bpf_dsq_nr_queued(u64 dsq_id)
 {
+       struct scx_sched *sch;
        struct scx_dispatch_q *dsq;
        s32 ret;
 
        preempt_disable();
 
+       sch = rcu_dereference_sched(scx_root);
+       if (!sch) {
+               ret = -ENODEV;
+               goto out;
+       }
+
        if (dsq_id == SCX_DSQ_LOCAL) {
                ret = READ_ONCE(this_rq()->scx.local_dsq.nr);
                goto out;
@@ -6835,7 +6866,7 @@ __bpf_kfunc s32 scx_bpf_dsq_nr_queued(u64 dsq_id)
                        goto out;
                }
        } else {
-               dsq = find_user_dsq(dsq_id);
+               dsq = find_user_dsq(sch, dsq_id);
                if (dsq) {
                        ret = READ_ONCE(dsq->nr);
                        goto out;
@@ -6858,7 +6889,13 @@ out:
  */
 __bpf_kfunc void scx_bpf_destroy_dsq(u64 dsq_id)
 {
-       destroy_dsq(dsq_id);
+       struct scx_sched *sch;
+
+       rcu_read_lock();
+       sch = rcu_dereference(scx_root);
+       if (sch)
+               destroy_dsq(sch, dsq_id);
+       rcu_read_unlock();
 }
 
 /**
@@ -6875,16 +6912,21 @@ __bpf_kfunc int bpf_iter_scx_dsq_new(struct bpf_iter_scx_dsq *it, u64 dsq_id,
                                     u64 flags)
 {
        struct bpf_iter_scx_dsq_kern *kit = (void *)it;
+       struct scx_sched *sch;
 
        BUILD_BUG_ON(sizeof(struct bpf_iter_scx_dsq_kern) >
                     sizeof(struct bpf_iter_scx_dsq));
        BUILD_BUG_ON(__alignof__(struct bpf_iter_scx_dsq_kern) !=
                     __alignof__(struct bpf_iter_scx_dsq));
 
+       sch = rcu_dereference(scx_root);
+       if (!sch)
+               return -ENODEV;
+
        if (flags & ~__SCX_DSQ_ITER_USER_FLAGS)
                return -EINVAL;
 
-       kit->dsq = find_user_dsq(dsq_id);
+       kit->dsq = find_user_dsq(sch, dsq_id);
        if (!kit->dsq)
                return -ENOENT;