atomic_t udp_memory_allocated;
 EXPORT_SYMBOL(udp_memory_allocated);
 
+#define PORTS_PER_CHAIN (65536 / UDP_HTABLE_SIZE)
+
 static int udp_lib_lport_inuse(struct net *net, __u16 num,
                               const struct udp_hslot *hslot,
+                              unsigned long *bitmap,
                               struct sock *sk,
                               int (*saddr_comp)(const struct sock *sk1,
                                                 const struct sock *sk2))
        sk_nulls_for_each(sk2, node, &hslot->head)
                if (net_eq(sock_net(sk2), net)                  &&
                    sk2 != sk                                   &&
-                   sk2->sk_hash == num                         &&
+                   (bitmap || sk2->sk_hash == num)             &&
                    (!sk2->sk_reuse || !sk->sk_reuse)           &&
                    (!sk2->sk_bound_dev_if || !sk->sk_bound_dev_if
                        || sk2->sk_bound_dev_if == sk->sk_bound_dev_if) &&
-                   (*saddr_comp)(sk, sk2))
-                       return 1;
+                   (*saddr_comp)(sk, sk2)) {
+                       if (bitmap)
+                               __set_bit(sk2->sk_hash / UDP_HTABLE_SIZE,
+                                         bitmap);
+                       else
+                               return 1;
+               }
        return 0;
 }
 
        if (!snum) {
                int low, high, remaining;
                unsigned rand;
-               unsigned short first;
+               unsigned short first, last;
+               DECLARE_BITMAP(bitmap, PORTS_PER_CHAIN);
 
                inet_get_local_port_range(&low, &high);
                remaining = (high - low) + 1;
 
                rand = net_random();
-               snum = first = rand % remaining + low;
-               rand |= 1;
-               for (;;) {
-                       hslot = &udptable->hash[udp_hashfn(net, snum)];
+               first = (((u64)rand * remaining) >> 32) + low;
+               /*
+                * force rand to be an odd multiple of UDP_HTABLE_SIZE
+                */
+               rand = (rand | 1) * UDP_HTABLE_SIZE;
+               for (last = first + UDP_HTABLE_SIZE; first != last; first++) {
+                       hslot = &udptable->hash[udp_hashfn(net, first)];
+                       bitmap_zero(bitmap, PORTS_PER_CHAIN);
                        spin_lock_bh(&hslot->lock);
-                       if (!udp_lib_lport_inuse(net, snum, hslot, sk, saddr_comp))
-                               break;
-                       spin_unlock_bh(&hslot->lock);
+                       udp_lib_lport_inuse(net, snum, hslot, bitmap, sk,
+                                           saddr_comp);
+
+                       snum = first;
+                       /*
+                        * Iterate on all possible values of snum for this hash.
+                        * Using steps of an odd multiple of UDP_HTABLE_SIZE
+                        * give us randomization and full range coverage.
+                        */
                        do {
-                               snum = snum + rand;
-                       } while (snum < low || snum > high);
-                       if (snum == first)
-                               goto fail;
+                               if (low <= snum && snum <= high &&
+                                   !test_bit(snum / UDP_HTABLE_SIZE, bitmap))
+                                       goto found;
+                               snum += rand;
+                       } while (snum != first);
+                       spin_unlock_bh(&hslot->lock);
                }
+               goto fail;
        } else {
                hslot = &udptable->hash[udp_hashfn(net, snum)];
                spin_lock_bh(&hslot->lock);
-               if (udp_lib_lport_inuse(net, snum, hslot, sk, saddr_comp))
+               if (udp_lib_lport_inuse(net, snum, hslot, NULL, sk, saddr_comp))
                        goto fail_unlock;
        }
+found:
        inet_sk(sk)->num = snum;
        sk->sk_hash = snum;
        if (sk_unhashed(sk)) {