}
 
 /* You must call memblock_analyze() after this. */
-void __init memblock_enforce_memory_limit(phys_addr_t memory_limit)
+void __init memblock_enforce_memory_limit(phys_addr_t limit)
 {
        unsigned long i;
-       phys_addr_t limit;
-       struct memblock_region *p;
+       phys_addr_t max_addr = (phys_addr_t)ULLONG_MAX;
 
-       if (!memory_limit)
+       if (!limit)
                return;
 
-       /* Truncate the memblock regions to satisfy the memory limit. */
-       limit = memory_limit;
+       /* find out max address */
        for (i = 0; i < memblock.memory.cnt; i++) {
-               if (limit > memblock.memory.regions[i].size) {
-                       limit -= memblock.memory.regions[i].size;
-                       continue;
-               }
-
-               memblock.memory.regions[i].size = limit;
-               memblock.memory.cnt = i + 1;
-               break;
-       }
-
-       memory_limit = memblock_end_of_DRAM();
+               struct memblock_region *r = &memblock.memory.regions[i];
 
-       /* And truncate any reserves above the limit also. */
-       for (i = 0; i < memblock.reserved.cnt; i++) {
-               p = &memblock.reserved.regions[i];
-
-               if (p->base > memory_limit)
-                       p->size = 0;
-               else if ((p->base + p->size) > memory_limit)
-                       p->size = memory_limit - p->base;
-
-               if (p->size == 0) {
-                       memblock_remove_region(&memblock.reserved, i);
-                       i--;
+               if (limit <= r->size) {
+                       max_addr = r->base + limit;
+                       break;
                }
+               limit -= r->size;
        }
+
+       /* truncate both memory and reserved regions */
+       __memblock_remove(&memblock.memory, max_addr, (phys_addr_t)ULLONG_MAX);
+       __memblock_remove(&memblock.reserved, max_addr, (phys_addr_t)ULLONG_MAX);
 }
 
 static int __init_memblock memblock_search(struct memblock_type *type, phys_addr_t addr)