#include <linux/sched/mm.h>
 #include <linux/proc_ns.h>
 #include <linux/mount.h>
+#include <linux/min_heap.h>
 
 #include "internal.h"
 
        ctx_sched_out(&cpuctx->ctx, cpuctx, event_type);
 }
 
-static int visit_groups_merge(struct perf_event_groups *groups, int cpu,
-                             int (*func)(struct perf_event *, void *), void *data)
+static bool perf_less_group_idx(const void *l, const void *r)
 {
-       struct perf_event **evt, *evt1, *evt2;
+       const struct perf_event *le = l, *re = r;
+
+       return le->group_index < re->group_index;
+}
+
+static void swap_ptr(void *l, void *r)
+{
+       void **lp = l, **rp = r;
+
+       swap(*lp, *rp);
+}
+
+static const struct min_heap_callbacks perf_min_heap = {
+       .elem_size = sizeof(struct perf_event *),
+       .less = perf_less_group_idx,
+       .swp = swap_ptr,
+};
+
+static void __heap_add(struct min_heap *heap, struct perf_event *event)
+{
+       struct perf_event **itrs = heap->data;
+
+       if (event) {
+               itrs[heap->nr] = event;
+               heap->nr++;
+       }
+}
+
+static noinline int visit_groups_merge(struct perf_event_groups *groups,
+                               int cpu,
+                               int (*func)(struct perf_event *, void *),
+                               void *data)
+{
+       /* Space for per CPU and/or any CPU event iterators. */
+       struct perf_event *itrs[2];
+       struct min_heap event_heap = {
+               .data = itrs,
+               .nr = 0,
+               .size = ARRAY_SIZE(itrs),
+       };
+       struct perf_event **evt = event_heap.data;
        int ret;
 
-       evt1 = perf_event_groups_first(groups, -1);
-       evt2 = perf_event_groups_first(groups, cpu);
+       __heap_add(&event_heap, perf_event_groups_first(groups, -1));
+       __heap_add(&event_heap, perf_event_groups_first(groups, cpu));
 
-       while (evt1 || evt2) {
-               if (evt1 && evt2) {
-                       if (evt1->group_index < evt2->group_index)
-                               evt = &evt1;
-                       else
-                               evt = &evt2;
-               } else if (evt1) {
-                       evt = &evt1;
-               } else {
-                       evt = &evt2;
-               }
+       min_heapify_all(&event_heap, &perf_min_heap);
 
+       while (event_heap.nr) {
                ret = func(*evt, data);
                if (ret)
                        return ret;
 
                *evt = perf_event_groups_next(*evt);
+               if (*evt)
+                       min_heapify(&event_heap, 0, &perf_min_heap);
+               else
+                       min_heap_pop(&event_heap, &perf_min_heap);
        }
 
        return 0;