#include <net/sock.h>
 #include <net/tcp.h>
 
-static __always_inline u32 bpf_test_run_one(struct bpf_prog *prog, void *ctx,
-               struct bpf_cgroup_storage *storage[MAX_BPF_CGROUP_STORAGE_TYPE])
-{
-       u32 ret;
-
-       preempt_disable();
-       rcu_read_lock();
-       bpf_cgroup_storage_set(storage);
-       ret = BPF_PROG_RUN(prog, ctx);
-       rcu_read_unlock();
-       preempt_enable();
-
-       return ret;
-}
-
-static int bpf_test_run(struct bpf_prog *prog, void *ctx, u32 repeat, u32 *ret,
-                       u32 *time)
+static int bpf_test_run(struct bpf_prog *prog, void *ctx, u32 repeat,
+                       u32 *retval, u32 *time)
 {
        struct bpf_cgroup_storage *storage[MAX_BPF_CGROUP_STORAGE_TYPE] = { 0 };
        enum bpf_cgroup_storage_type stype;
        u64 time_start, time_spent = 0;
+       int ret = 0;
        u32 i;
 
        for_each_cgroup_storage_type(stype) {
 
        if (!repeat)
                repeat = 1;
+
+       rcu_read_lock();
+       preempt_disable();
        time_start = ktime_get_ns();
        for (i = 0; i < repeat; i++) {
-               *ret = bpf_test_run_one(prog, ctx, storage);
+               bpf_cgroup_storage_set(storage);
+               *retval = BPF_PROG_RUN(prog, ctx);
+
+               if (signal_pending(current)) {
+                       ret = -EINTR;
+                       break;
+               }
+
                if (need_resched()) {
-                       if (signal_pending(current))
-                               break;
                        time_spent += ktime_get_ns() - time_start;
+                       preempt_enable();
+                       rcu_read_unlock();
+
                        cond_resched();
+
+                       rcu_read_lock();
+                       preempt_disable();
                        time_start = ktime_get_ns();
                }
        }
        time_spent += ktime_get_ns() - time_start;
+       preempt_enable();
+       rcu_read_unlock();
+
        do_div(time_spent, repeat);
        *time = time_spent > U32_MAX ? U32_MAX : (u32)time_spent;
 
        for_each_cgroup_storage_type(stype)
                bpf_cgroup_storage_free(storage[stype]);
 
-       return 0;
+       return ret;
 }
 
 static int bpf_test_finish(const union bpf_attr *kattr,