#ifdef CONFIG_THP_SWAP
 #define SWAPFILE_CLUSTER       HPAGE_PMD_NR
+
+#define swap_entry_size(size)  (size)
 #else
 #define SWAPFILE_CLUSTER       256
+
+/*
+ * Define swap_entry_size() as constant to let compiler to optimize
+ * out some code if !CONFIG_THP_SWAP
+ */
+#define swap_entry_size(size)  1
 #endif
 #define LATENCY_LIMIT          256
 
 /*
  * Called after dropping swapcache to decrease refcnt to swap entries.
  */
-static void swapcache_free(swp_entry_t entry)
-{
-       struct swap_info_struct *p;
-
-       p = _swap_info_get(entry);
-       if (p) {
-               if (!__swap_entry_free(p, entry, SWAP_HAS_CACHE))
-                       free_swap_slot(entry);
-       }
-}
-
-static void swapcache_free_cluster(swp_entry_t entry)
+void put_swap_page(struct page *page, swp_entry_t entry)
 {
        unsigned long offset = swp_offset(entry);
        unsigned long idx = offset / SWAPFILE_CLUSTER;
        unsigned char *map;
        unsigned int i, free_entries = 0;
        unsigned char val;
-
-       if (!IS_ENABLED(CONFIG_THP_SWAP))
-               return;
+       int size = swap_entry_size(hpage_nr_pages(page));
 
        si = _swap_info_get(entry);
        if (!si)
                return;
 
-       ci = lock_cluster(si, offset);
-       VM_BUG_ON(!cluster_is_huge(ci));
-       map = si->swap_map + offset;
-       for (i = 0; i < SWAPFILE_CLUSTER; i++) {
-               val = map[i];
-               VM_BUG_ON(!(val & SWAP_HAS_CACHE));
-               if (val == SWAP_HAS_CACHE)
-                       free_entries++;
-       }
-       if (!free_entries) {
-               for (i = 0; i < SWAPFILE_CLUSTER; i++)
-                       map[i] &= ~SWAP_HAS_CACHE;
-       }
-       cluster_clear_huge(ci);
-       unlock_cluster(ci);
-       if (free_entries == SWAPFILE_CLUSTER) {
-               spin_lock(&si->lock);
+       if (size == SWAPFILE_CLUSTER) {
                ci = lock_cluster(si, offset);
-               memset(map, 0, SWAPFILE_CLUSTER);
+               VM_BUG_ON(!cluster_is_huge(ci));
+               map = si->swap_map + offset;
+               for (i = 0; i < SWAPFILE_CLUSTER; i++) {
+                       val = map[i];
+                       VM_BUG_ON(!(val & SWAP_HAS_CACHE));
+                       if (val == SWAP_HAS_CACHE)
+                               free_entries++;
+               }
+               if (!free_entries) {
+                       for (i = 0; i < SWAPFILE_CLUSTER; i++)
+                               map[i] &= ~SWAP_HAS_CACHE;
+               }
+               cluster_clear_huge(ci);
                unlock_cluster(ci);
-               mem_cgroup_uncharge_swap(entry, SWAPFILE_CLUSTER);
-               swap_free_cluster(si, idx);
-               spin_unlock(&si->lock);
-       } else if (free_entries) {
-               for (i = 0; i < SWAPFILE_CLUSTER; i++, entry.val++) {
+               if (free_entries == SWAPFILE_CLUSTER) {
+                       spin_lock(&si->lock);
+                       ci = lock_cluster(si, offset);
+                       memset(map, 0, SWAPFILE_CLUSTER);
+                       unlock_cluster(ci);
+                       mem_cgroup_uncharge_swap(entry, SWAPFILE_CLUSTER);
+                       swap_free_cluster(si, idx);
+                       spin_unlock(&si->lock);
+                       return;
+               }
+       }
+       if (size == 1 || free_entries) {
+               for (i = 0; i < size; i++, entry.val++) {
                        if (!__swap_entry_free(si, entry, SWAP_HAS_CACHE))
                                free_swap_slot(entry);
                }
 }
 #endif
 
-void put_swap_page(struct page *page, swp_entry_t entry)
-{
-       if (!PageTransHuge(page))
-               swapcache_free(entry);
-       else
-               swapcache_free_cluster(entry);
-}
-
 static int swp_entry_cmp(const void *ent1, const void *ent2)
 {
        const swp_entry_t *e1 = ent1, *e2 = ent2;