* a pointer, but small integers make for the smallest compare
  * instructions.
  */
-#define SWAP_WORDS_64 (swap_func_t)0
-#define SWAP_WORDS_32 (swap_func_t)1
-#define SWAP_BYTES    (swap_func_t)2
+#define SWAP_WORDS_64 (swap_r_func_t)0
+#define SWAP_WORDS_32 (swap_r_func_t)1
+#define SWAP_BYTES    (swap_r_func_t)2
+#define SWAP_WRAPPER  (swap_r_func_t)3
+
+struct wrapper {
+       cmp_func_t cmp;
+       swap_func_t swap;
+};
 
 /*
  * The function pointer is last to make tail calls most efficient if the
  * compiler decides not to inline this function.
  */
-static void do_swap(void *a, void *b, size_t size, swap_func_t swap_func)
+static void do_swap(void *a, void *b, size_t size, swap_r_func_t swap_func, const void *priv)
 {
+       if (swap_func == SWAP_WRAPPER) {
+               ((const struct wrapper *)priv)->swap(a, b, (int)size);
+               return;
+       }
+
        if (swap_func == SWAP_WORDS_64)
                swap_words_64(a, b, size);
        else if (swap_func == SWAP_WORDS_32)
        else if (swap_func == SWAP_BYTES)
                swap_bytes(a, b, size);
        else
-               swap_func(a, b, (int)size);
+               swap_func(a, b, (int)size, priv);
 }
 
 #define _CMP_WRAPPER ((cmp_r_func_t)0L)
 static int do_cmp(const void *a, const void *b, cmp_r_func_t cmp, const void *priv)
 {
        if (cmp == _CMP_WRAPPER)
-               return ((cmp_func_t)(priv))(a, b);
+               return ((const struct wrapper *)priv)->cmp(a, b);
        return cmp(a, b, priv);
 }
 
  */
 void sort_r(void *base, size_t num, size_t size,
            cmp_r_func_t cmp_func,
-           swap_func_t swap_func,
+           swap_r_func_t swap_func,
            const void *priv)
 {
        /* pre-scale counters for performance */
        if (!a)         /* num < 2 || size == 0 */
                return;
 
+       /* called from 'sort' without swap function, let's pick the default */
+       if (swap_func == SWAP_WRAPPER && !((struct wrapper *)priv)->swap)
+               swap_func = NULL;
+
        if (!swap_func) {
                if (is_aligned(base, size, 8))
                        swap_func = SWAP_WORDS_64;
                if (a)                  /* Building heap: sift down --a */
                        a -= size;
                else if (n -= size)     /* Sorting: Extract root to --n */
-                       do_swap(base, base + n, size, swap_func);
+                       do_swap(base, base + n, size, swap_func, priv);
                else                    /* Sort complete */
                        break;
 
                c = b;                  /* Where "a" belongs */
                while (b != a) {        /* Shift it into place */
                        b = parent(b, lsbit, size);
-                       do_swap(base + b, base + c, size, swap_func);
+                       do_swap(base + b, base + c, size, swap_func, priv);
                }
        }
 }
          cmp_func_t cmp_func,
          swap_func_t swap_func)
 {
-       return sort_r(base, num, size, _CMP_WRAPPER, swap_func, cmp_func);
+       struct wrapper w = {
+               .cmp  = cmp_func,
+               .swap = swap_func,
+       };
+
+       return sort_r(base, num, size, _CMP_WRAPPER, SWAP_WRAPPER, &w);
 }
 EXPORT_SYMBOL(sort);