return 0;
 }
 
-static void zswap_rb_erase(struct rb_root *root, struct zswap_entry *entry)
+static bool zswap_rb_erase(struct rb_root *root, struct zswap_entry *entry)
 {
        if (!RB_EMPTY_NODE(&entry->rbnode)) {
                rb_erase(&entry->rbnode, root);
                RB_CLEAR_NODE(&entry->rbnode);
+               return true;
        }
+       return false;
 }
 
 /*
        return NULL;
 }
 
+/*
+ * If the entry is still valid in the tree, drop the initial ref and remove it
+ * from the tree. This function must be called with an additional ref held,
+ * otherwise it may race with another invalidation freeing the entry.
+ */
 static void zswap_invalidate_entry(struct zswap_tree *tree,
                                   struct zswap_entry *entry)
 {
-       /* remove from rbtree */
-       zswap_rb_erase(&tree->rbroot, entry);
-
-       /* drop the initial reference from entry creation */
-       zswap_entry_put(tree, entry);
+       if (zswap_rb_erase(&tree->rbroot, entry))
+               zswap_entry_put(tree, entry);
 }
 
 static int zswap_reclaim_entry(struct zswap_pool *pool)
         * swapcache. Drop the entry from zswap - unless invalidate already
         * took it out while we had the tree->lock released for IO.
         */
-       if (entry == zswap_rb_search(&tree->rbroot, swpoffset))
-               zswap_invalidate_entry(tree, entry);
+       zswap_invalidate_entry(tree, entry);
 
 put_unlock:
        /* Drop local reference */
                count_objcg_event(entry->objcg, ZSWPIN);
 freeentry:
        spin_lock(&tree->lock);
-       zswap_entry_put(tree, entry);
        if (!ret && zswap_exclusive_loads_enabled) {
                zswap_invalidate_entry(tree, entry);
                *exclusive = true;
                list_move(&entry->lru, &entry->pool->lru);
                spin_unlock(&entry->pool->lru_lock);
        }
+       zswap_entry_put(tree, entry);
        spin_unlock(&tree->lock);
 
        return ret;