{
        unsigned long tmp;
        unsigned long m;
-       int i, k;
-       u64 base = 0;
-       int p = 0;
-       int skip;
-       int mask;
-       u64 len;
-       u64 pfn;
+       u64 base = ~0, p = 0;
+       u64 len, pfn;
+       int i = 0;
        struct scatterlist *sg;
        int entry;
        unsigned long page_shift = umem->page_shift;
        m = find_first_bit(&tmp, BITS_PER_LONG);
        if (max_page_shift)
                m = min_t(unsigned long, max_page_shift - page_shift, m);
-       skip = 1 << m;
-       mask = skip - 1;
-       i = 0;
+
        for_each_sg(umem->sg_head.sgl, sg, umem->nmap, entry) {
                len = sg_dma_len(sg) >> page_shift;
                pfn = sg_dma_address(sg) >> page_shift;
-               for (k = 0; k < len; k++) {
-                       if (!(i & mask)) {
-                               tmp = (unsigned long)pfn;
-                               m = min_t(unsigned long, m, find_first_bit(&tmp, BITS_PER_LONG));
-                               skip = 1 << m;
-                               mask = skip - 1;
-                               base = pfn;
-                               p = 0;
-                       } else {
-                               if (base + p != pfn) {
-                                       tmp = (unsigned long)p;
-                                       m = find_first_bit(&tmp, BITS_PER_LONG);
-                                       skip = 1 << m;
-                                       mask = skip - 1;
-                                       base = pfn;
-                                       p = 0;
-                               }
-                       }
-                       p++;
-                       i++;
+               if (base + p != pfn) {
+                       /* If either the offset or the new
+                        * base are unaligned update m
+                        */
+                       tmp = (unsigned long)(pfn | p);
+                       if (!IS_ALIGNED(tmp, 1 << m))
+                               m = find_first_bit(&tmp, BITS_PER_LONG);
+
+                       base = pfn;
+                       p = 0;
                }
+
+               p += len;
+               i += len;
        }
 
        if (i) {