}
 
 /*
- * The cluster ci decreases one usage. If the usage counter becomes 0,
+ * The cluster ci decreases @nr_pages usage. If the usage counter becomes 0,
  * which means no page in the cluster is in use, we can optionally discard
  * the cluster and add it to free cluster list.
  */
-static void dec_cluster_info_page(struct swap_info_struct *p, struct swap_cluster_info *ci)
+static void dec_cluster_info_page(struct swap_info_struct *p,
+                                 struct swap_cluster_info *ci, int nr_pages)
 {
        if (!p->cluster_info)
                return;
 
-       VM_BUG_ON(ci->count == 0);
+       VM_BUG_ON(ci->count < nr_pages);
        VM_BUG_ON(cluster_is_free(ci));
        lockdep_assert_held(&p->lock);
        lockdep_assert_held(&ci->lock);
-       ci->count--;
+       ci->count -= nr_pages;
 
        if (!ci->count) {
                free_cluster(p, ci);
        return n_ret;
 }
 
-static void swap_free_cluster(struct swap_info_struct *si, unsigned long idx)
-{
-       unsigned long offset = idx * SWAPFILE_CLUSTER;
-       struct swap_cluster_info *ci;
-
-       ci = lock_cluster(si, offset);
-       memset(si->swap_map + offset, 0, SWAPFILE_CLUSTER);
-       ci->count = 0;
-       free_cluster(si, ci);
-       unlock_cluster(ci);
-       swap_range_free(si, offset, SWAPFILE_CLUSTER);
-}
-
 int get_swap_pages(int n_goal, swp_entry_t swp_entries[], int entry_order)
 {
        int order = swap_entry_order(entry_order);
        return usage;
 }
 
-static void swap_entry_free(struct swap_info_struct *p, swp_entry_t entry)
+/*
+ * Drop the last HAS_CACHE flag of swap entries, caller have to
+ * ensure all entries belong to the same cgroup.
+ */
+static void swap_entry_range_free(struct swap_info_struct *p, swp_entry_t entry,
+                                 unsigned int nr_pages)
 {
-       struct swap_cluster_info *ci;
        unsigned long offset = swp_offset(entry);
-       unsigned char count;
+       unsigned char *map = p->swap_map + offset;
+       unsigned char *map_end = map + nr_pages;
+       struct swap_cluster_info *ci;
 
        ci = lock_cluster(p, offset);
-       count = p->swap_map[offset];
-       VM_BUG_ON(count != SWAP_HAS_CACHE);
-       p->swap_map[offset] = 0;
-       dec_cluster_info_page(p, ci);
+       do {
+               VM_BUG_ON(*map != SWAP_HAS_CACHE);
+               *map = 0;
+       } while (++map < map_end);
+       dec_cluster_info_page(p, ci, nr_pages);
        unlock_cluster(ci);
 
-       mem_cgroup_uncharge_swap(entry, 1);
-       swap_range_free(p, offset, 1);
+       mem_cgroup_uncharge_swap(entry, nr_pages);
+       swap_range_free(p, offset, nr_pages);
 }
 
 static void cluster_swap_free_nr(struct swap_info_struct *sis,
 void put_swap_folio(struct folio *folio, swp_entry_t entry)
 {
        unsigned long offset = swp_offset(entry);
-       unsigned long idx = offset / SWAPFILE_CLUSTER;
        struct swap_cluster_info *ci;
        struct swap_info_struct *si;
        unsigned char *map;
                return;
 
        ci = lock_cluster_or_swap_info(si, offset);
-       if (size == SWAPFILE_CLUSTER) {
+       if (size > 1) {
                map = si->swap_map + offset;
-               for (i = 0; i < SWAPFILE_CLUSTER; i++) {
+               for (i = 0; i < size; i++) {
                        val = map[i];
                        VM_BUG_ON(!(val & SWAP_HAS_CACHE));
                        if (val == SWAP_HAS_CACHE)
                                free_entries++;
                }
-               if (free_entries == SWAPFILE_CLUSTER) {
+               if (free_entries == size) {
                        unlock_cluster_or_swap_info(si, ci);
                        spin_lock(&si->lock);
-                       mem_cgroup_uncharge_swap(entry, SWAPFILE_CLUSTER);
-                       swap_free_cluster(si, idx);
+                       swap_entry_range_free(si, entry, size);
                        spin_unlock(&si->lock);
                        return;
                }
        for (i = 0; i < n; ++i) {
                p = swap_info_get_cont(entries[i], prev);
                if (p)
-                       swap_entry_free(p, entries[i]);
+                       swap_entry_range_free(p, entries[i], 1);
                prev = p;
        }
        if (p)