Currently we traverse all symbols of all modules to find the specified
function for the specified module. But in reality, we just need to find
the given module and then traverse all the symbols in it.
Let's add a new parameter 'const char *modname' to function
module_kallsyms_on_each_symbol(), then we can compare the module names
directly in this function and call hook 'fn' after matching. If 'modname'
is NULL, the symbols of all modules are still traversed for compatibility
with other usage cases.
Phase1: mod1-->mod2..(subsequent modules do not need to be compared)
                |
Phase2:          -->f1-->f2-->f3
Assuming that there are m modules, each module has n symbols on average,
then the time complexity is reduced from O(m * n) to O(m) + O(n).
Reviewed-by: Petr Mladek <pmladek@suse.com>
Acked-by: Song Liu <song@kernel.org>
Signed-off-by: Zhen Lei <thunder.leizhen@huawei.com>
Signed-off-by: Jiri Olsa <jolsa@kernel.org>
Acked-by: Miroslav Benes <mbenes@suse.cz>
Reviewed-by: Luis Chamberlain <mcgrof@kernel.org>
Link: https://lore.kernel.org/r/20230116101009.23694-2-jolsa@kernel.org
Signed-off-by: Alexei Starovoitov <ast@kernel.org>
 #endif /* CONFIG_MODULE_SIG */
 
 #if defined(CONFIG_MODULES) && defined(CONFIG_KALLSYMS)
-int module_kallsyms_on_each_symbol(int (*fn)(void *, const char *,
+int module_kallsyms_on_each_symbol(const char *modname,
+                                  int (*fn)(void *, const char *,
                                             struct module *, unsigned long),
                                   void *data);
 #else
-static inline int module_kallsyms_on_each_symbol(int (*fn)(void *, const char *,
+static inline int module_kallsyms_on_each_symbol(const char *modname,
+                                                int (*fn)(void *, const char *,
                                                 struct module *, unsigned long),
                                                 void *data)
 {
 
 }
 
 struct klp_find_arg {
-       const char *objname;
        const char *name;
        unsigned long addr;
        unsigned long count;
 {
        struct klp_find_arg *args = data;
 
-       if ((mod && !args->objname) || (!mod && args->objname))
-               return 0;
-
        if (strcmp(args->name, name))
                return 0;
 
-       if (args->objname && strcmp(args->objname, mod->name))
-               return 0;
-
        return klp_match_callback(data, addr);
 }
 
                                  unsigned long sympos, unsigned long *addr)
 {
        struct klp_find_arg args = {
-               .objname = objname,
                .name = name,
                .addr = 0,
                .count = 0,
        };
 
        if (objname)
-               module_kallsyms_on_each_symbol(klp_find_callback, &args);
+               module_kallsyms_on_each_symbol(objname, klp_find_callback, &args);
        else
                kallsyms_on_each_match_symbol(klp_match_callback, name, &args);
 
 
        return ret;
 }
 
-int module_kallsyms_on_each_symbol(int (*fn)(void *, const char *,
+int module_kallsyms_on_each_symbol(const char *modname,
+                                  int (*fn)(void *, const char *,
                                             struct module *, unsigned long),
                                   void *data)
 {
                if (mod->state == MODULE_STATE_UNFORMED)
                        continue;
 
+               if (modname && strcmp(modname, mod->name))
+                       continue;
+
                /* Use rcu_dereference_sched() to remain compliant with the sparse tool */
                preempt_disable();
                kallsyms = rcu_dereference_sched(mod->kallsyms);
                        if (ret != 0)
                                goto out;
                }
+
+               /*
+                * The given module is found, the subsequent modules do not
+                * need to be compared.
+                */
+               if (modname)
+                       break;
        }
 out:
        mutex_unlock(&module_mutex);
 
        int err;
 
        /* We return either err < 0 in case of error, ... */
-       err = module_kallsyms_on_each_symbol(module_callback, &args);
+       err = module_kallsyms_on_each_symbol(NULL, module_callback, &args);
        if (err) {
                kprobe_multi_put_modules(args.mods, args.mods_cnt);
                kfree(args.mods);
 
        found_all = kallsyms_on_each_symbol(kallsyms_callback, &args);
        if (found_all)
                return 0;
-       found_all = module_kallsyms_on_each_symbol(kallsyms_callback, &args);
+       found_all = module_kallsyms_on_each_symbol(NULL, kallsyms_callback, &args);
        return found_all ? 0 : -ESRCH;
 }