#include "journal_reclaim.h"
 #include "trace.h"
 
+#include <linux/sched/mm.h>
+
 static int bch2_btree_key_cache_cmp_fn(struct rhashtable_compare_arg *arg,
                                       const void *obj)
 {
        c->nr_keys--;
 }
 
-static void bkey_cached_free(struct btree_key_cache *c,
+static void bkey_cached_free(struct btree_key_cache *bc,
                             struct bkey_cached *ck)
 {
-       list_move(&ck->list, &c->freed);
+       struct bch_fs *c = container_of(bc, struct bch_fs, btree_key_cache);
+
+       ck->btree_trans_barrier_seq =
+               start_poll_synchronize_srcu(&c->btree_trans_barrier);
+
+       list_move(&ck->list, &bc->freed);
 
        kfree(ck->k);
        ck->k           = NULL;
        struct bkey_cached_key key;
        struct btree_trans trans;
 
+       int srcu_idx = srcu_read_lock(&c->btree_trans_barrier);
+
        six_lock_read(&ck->c.lock, NULL, NULL);
        key = ck->key;
 
        if (ck->journal.seq != seq ||
            !test_bit(BKEY_CACHED_DIRTY, &ck->flags)) {
                six_unlock_read(&ck->c.lock);
-               return;
+               goto unlock;
        }
        six_unlock_read(&ck->c.lock);
 
        bch2_trans_init(&trans, c, 0, 0);
        btree_key_cache_flush_pos(&trans, key, seq, false);
        bch2_trans_exit(&trans);
+unlock:
+       srcu_read_unlock(&c->btree_trans_barrier, srcu_idx);
 }
 
 /*
 }
 #endif
 
+static unsigned long bch2_btree_key_cache_scan(struct shrinker *shrink,
+                                          struct shrink_control *sc)
+{
+       struct bch_fs *c = container_of(shrink, struct bch_fs,
+                                       btree_key_cache.shrink);
+       struct btree_key_cache *bc = &c->btree_key_cache;
+       struct bkey_cached *ck, *t;
+       size_t scanned = 0, freed = 0, nr = sc->nr_to_scan;
+       unsigned flags;
+
+       /* Return -1 if we can't do anything right now */
+       if (sc->gfp_mask & __GFP_FS)
+               mutex_lock(&bc->lock);
+       else if (!mutex_trylock(&bc->lock))
+               return -1;
+
+       flags = memalloc_nofs_save();
+
+       list_for_each_entry_safe(ck, t, &bc->freed, list) {
+               scanned++;
+
+               if (poll_state_synchronize_srcu(&c->btree_trans_barrier,
+                                               ck->btree_trans_barrier_seq)) {
+                       list_del(&ck->list);
+                       kfree(ck);
+                       freed++;
+               }
+
+               if (scanned >= nr)
+                       goto out;
+       }
+
+       list_for_each_entry_safe(ck, t, &bc->clean, list) {
+               scanned++;
+
+               if (bkey_cached_lock_for_evict(ck)) {
+                       bkey_cached_evict(bc, ck);
+                       bkey_cached_free(bc, ck);
+               }
+
+               if (scanned >= nr) {
+                       if (&t->list != &bc->clean)
+                               list_move_tail(&bc->clean, &t->list);
+                       goto out;
+               }
+       }
+out:
+       memalloc_nofs_restore(flags);
+       mutex_unlock(&bc->lock);
+
+       return freed;
+}
+
+static unsigned long bch2_btree_key_cache_count(struct shrinker *shrink,
+                                           struct shrink_control *sc)
+{
+       struct bch_fs *c = container_of(shrink, struct bch_fs,
+                                       btree_key_cache.shrink);
+       struct btree_key_cache *bc = &c->btree_key_cache;
+
+       return bc->nr_keys;
+}
+
 void bch2_fs_btree_key_cache_exit(struct btree_key_cache *bc)
 {
        struct bch_fs *c = container_of(bc, struct bch_fs, btree_key_cache);
        struct bkey_cached *ck, *n;
 
+       if (bc->shrink.list.next)
+               unregister_shrinker(&bc->shrink);
+
        mutex_lock(&bc->lock);
        list_splice(&bc->dirty, &bc->clean);
 
        INIT_LIST_HEAD(&c->dirty);
 }
 
-int bch2_fs_btree_key_cache_init(struct btree_key_cache *c)
+int bch2_fs_btree_key_cache_init(struct btree_key_cache *bc)
 {
-       return rhashtable_init(&c->table, &bch2_btree_key_cache_params);
+       struct bch_fs *c = container_of(bc, struct bch_fs, btree_key_cache);
+
+       bc->shrink.count_objects        = bch2_btree_key_cache_count;
+       bc->shrink.scan_objects         = bch2_btree_key_cache_scan;
+
+       return  register_shrinker(&bc->shrink, "%s/btree_key_cache", c->name) ?:
+               rhashtable_init(&bc->table, &bch2_btree_key_cache_params);
 }
 
 void bch2_btree_key_cache_to_text(struct printbuf *out, struct btree_key_cache *c)