#define MLX5_U64_4K_PAGE_MASK ((~(u64)0U) << PAGE_SHIFT)
 
-static void free_fwp(struct mlx5_core_dev *dev, struct fw_page *fwp)
+static void free_fwp(struct mlx5_core_dev *dev, struct fw_page *fwp,
+                    bool in_free_list)
 {
-       int n = (fwp->addr & ~MLX5_U64_4K_PAGE_MASK) >> MLX5_ADAPTER_PAGE_SHIFT;
-
-       fwp->free_count++;
-       set_bit(n, &fwp->bitmask);
-       if (fwp->free_count == MLX5_NUM_4K_IN_PAGE) {
-               rb_erase(&fwp->rb_node, &dev->priv.page_root);
-               if (fwp->free_count != 1)
-                       list_del(&fwp->list);
-               dma_unmap_page(dev->device, fwp->addr & MLX5_U64_4K_PAGE_MASK,
-                              PAGE_SIZE, DMA_BIDIRECTIONAL);
-               __free_page(fwp->page);
-               kfree(fwp);
-       } else if (fwp->free_count == 1) {
-               list_add(&fwp->list, &dev->priv.free_list);
-       }
+       rb_erase(&fwp->rb_node, &dev->priv.page_root);
+       if (in_free_list)
+               list_del(&fwp->list);
+       dma_unmap_page(dev->device, fwp->addr & MLX5_U64_4K_PAGE_MASK,
+                      PAGE_SIZE, DMA_BIDIRECTIONAL);
+       __free_page(fwp->page);
+       kfree(fwp);
 }
 
-static void free_addr(struct mlx5_core_dev *dev, u64 addr)
+static void free_4k(struct mlx5_core_dev *dev, u64 addr)
 {
        struct fw_page *fwp;
+       int n;
 
        fwp = find_fw_page(dev, addr & MLX5_U64_4K_PAGE_MASK);
        if (!fwp) {
                mlx5_core_warn_rl(dev, "page not found\n");
                return;
        }
-       free_fwp(dev, fwp);
+       n = (addr & ~MLX5_U64_4K_PAGE_MASK) >> MLX5_ADAPTER_PAGE_SHIFT;
+       fwp->free_count++;
+       set_bit(n, &fwp->bitmask);
+       if (fwp->free_count == MLX5_NUM_4K_IN_PAGE)
+               free_fwp(dev, fwp, fwp->free_count != 1);
+       else if (fwp->free_count == 1)
+               list_add(&fwp->list, &dev->priv.free_list);
 }
 
 static int alloc_system_page(struct mlx5_core_dev *dev, u16 func_id)
 
 out_4k:
        for (i--; i >= 0; i--)
-               free_addr(dev, MLX5_GET64(manage_pages_in, in, pas[i]));
+               free_4k(dev, MLX5_GET64(manage_pages_in, in, pas[i]));
 out_free:
        kvfree(in);
        if (notify_fail)
                p = rb_next(p);
                if (fwp->func_id != func_id)
                        continue;
-               free_fwp(dev, fwp);
-               npages++;
+               npages += (MLX5_NUM_4K_IN_PAGE - fwp->free_count);
+               free_fwp(dev, fwp, fwp->free_count);
        }
 
        dev->priv.fw_pages -= npages;
        }
 
        for (i = 0; i < num_claimed; i++)
-               free_addr(dev, MLX5_GET64(manage_pages_out, out, pas[i]));
+               free_4k(dev, MLX5_GET64(manage_pages_out, out, pas[i]));
 
        if (nclaimed)
                *nclaimed = num_claimed;