#include <linux/sbitmap.h>
 #include <linux/seq_file.h>
 
+/*
+ * See if we have deferred clears that we can batch move
+ */
+static inline bool sbitmap_deferred_clear(struct sbitmap *sb, int index)
+{
+       unsigned long mask, val;
+       unsigned long __maybe_unused flags;
+       bool ret = false;
+
+       /* Silence bogus lockdep warning */
+#if defined(CONFIG_LOCKDEP)
+       local_irq_save(flags);
+#endif
+       spin_lock(&sb->map[index].swap_lock);
+
+       if (!sb->map[index].cleared)
+               goto out_unlock;
+
+       /*
+        * First get a stable cleared mask, setting the old mask to 0.
+        */
+       do {
+               mask = sb->map[index].cleared;
+       } while (cmpxchg(&sb->map[index].cleared, mask, 0) != mask);
+
+       /*
+        * Now clear the masked bits in our free word
+        */
+       do {
+               val = sb->map[index].word;
+       } while (cmpxchg(&sb->map[index].word, val, val & ~mask) != val);
+
+       ret = true;
+out_unlock:
+       spin_unlock(&sb->map[index].swap_lock);
+#if defined(CONFIG_LOCKDEP)
+       local_irq_restore(flags);
+#endif
+       return ret;
+}
+
 int sbitmap_init_node(struct sbitmap *sb, unsigned int depth, int shift,
                      gfp_t flags, int node)
 {
        unsigned int bits_per_word = 1U << sb->shift;
        unsigned int i;
 
+       for (i = 0; i < sb->map_nr; i++)
+               sbitmap_deferred_clear(sb, i);
+
        sb->depth = depth;
        sb->map_nr = DIV_ROUND_UP(sb->depth, bits_per_word);
 
        return nr;
 }
 
-/*
- * See if we have deferred clears that we can batch move
- */
-static inline bool sbitmap_deferred_clear(struct sbitmap *sb, int index)
-{
-       unsigned long mask, val;
-       unsigned long __maybe_unused flags;
-       bool ret = false;
-
-       /* Silence bogus lockdep warning */
-#if defined(CONFIG_LOCKDEP)
-       local_irq_save(flags);
-#endif
-       spin_lock(&sb->map[index].swap_lock);
-
-       if (!sb->map[index].cleared)
-               goto out_unlock;
-
-       /*
-        * First get a stable cleared mask, setting the old mask to 0.
-        */
-       do {
-               mask = sb->map[index].cleared;
-       } while (cmpxchg(&sb->map[index].cleared, mask, 0) != mask);
-
-       /*
-        * Now clear the masked bits in our free word
-        */
-       do {
-               val = sb->map[index].word;
-       } while (cmpxchg(&sb->map[index].word, val, val & ~mask) != val);
-
-       ret = true;
-out_unlock:
-       spin_unlock(&sb->map[index].swap_lock);
-#if defined(CONFIG_LOCKDEP)
-       local_irq_restore(flags);
-#endif
-       return ret;
-}
-
 static int sbitmap_find_bit_in_index(struct sbitmap *sb, int index,
                                     unsigned int alloc_hint, bool round_robin)
 {
        index = SB_NR_TO_INDEX(sb, alloc_hint);
 
        for (i = 0; i < sb->map_nr; i++) {
+again:
                nr = __sbitmap_get_word(&sb->map[index].word,
                                        min(sb->map[index].depth, shallow_depth),
                                        SB_NR_TO_BIT(sb, alloc_hint), true);
                        break;
                }
 
+               if (sbitmap_deferred_clear(sb, index))
+                       goto again;
+
                /* Jump to next index. */
                index++;
                alloc_hint = index << sb->shift;
        unsigned int i;
 
        for (i = 0; i < sb->map_nr; i++) {
-               if (sb->map[i].word)
+               if (sb->map[i].word & ~sb->map[i].cleared)
                        return true;
        }
        return false;
 
        for (i = 0; i < sb->map_nr; i++) {
                const struct sbitmap_word *word = &sb->map[i];
+               unsigned long mask = word->word & ~word->cleared;
                unsigned long ret;
 
-               ret = find_first_zero_bit(&word->word, word->depth);
+               ret = find_first_zero_bit(&mask, word->depth);
                if (ret < word->depth)
                        return true;
        }