#include "blk-mq-sched.h"
 #include "blk-mq-tag.h"
 
+/*
+ * Recalculate wakeup batch when tag is shared by hctx.
+ */
+static void blk_mq_update_wake_batch(struct blk_mq_tags *tags,
+               unsigned int users)
+{
+       if (!users)
+               return;
+
+       sbitmap_queue_recalculate_wake_batch(&tags->bitmap_tags,
+                       users);
+       sbitmap_queue_recalculate_wake_batch(&tags->breserved_tags,
+                       users);
+}
+
 /*
  * If a previously inactive queue goes active, bump the active user count.
  * We need to do this before try to allocate driver tag, then even if fail
  */
 bool __blk_mq_tag_busy(struct blk_mq_hw_ctx *hctx)
 {
+       unsigned int users;
+
        if (blk_mq_is_shared_tags(hctx->flags)) {
                struct request_queue *q = hctx->queue;
 
-               if (!test_bit(QUEUE_FLAG_HCTX_ACTIVE, &q->queue_flags) &&
-                   !test_and_set_bit(QUEUE_FLAG_HCTX_ACTIVE, &q->queue_flags))
-                       atomic_inc(&hctx->tags->active_queues);
+               if (test_bit(QUEUE_FLAG_HCTX_ACTIVE, &q->queue_flags) ||
+                   test_and_set_bit(QUEUE_FLAG_HCTX_ACTIVE, &q->queue_flags)) {
+                       return true;
+               }
        } else {
-               if (!test_bit(BLK_MQ_S_TAG_ACTIVE, &hctx->state) &&
-                   !test_and_set_bit(BLK_MQ_S_TAG_ACTIVE, &hctx->state))
-                       atomic_inc(&hctx->tags->active_queues);
+               if (test_bit(BLK_MQ_S_TAG_ACTIVE, &hctx->state) ||
+                   test_and_set_bit(BLK_MQ_S_TAG_ACTIVE, &hctx->state)) {
+                       return true;
+               }
        }
 
+       users = atomic_inc_return(&hctx->tags->active_queues);
+
+       blk_mq_update_wake_batch(hctx->tags, users);
+
        return true;
 }
 
 void __blk_mq_tag_idle(struct blk_mq_hw_ctx *hctx)
 {
        struct blk_mq_tags *tags = hctx->tags;
+       unsigned int users;
 
        if (blk_mq_is_shared_tags(hctx->flags)) {
                struct request_queue *q = hctx->queue;
                        return;
        }
 
-       atomic_dec(&tags->active_queues);
+       users = atomic_dec_return(&tags->active_queues);
+
+       blk_mq_update_wake_batch(tags, users);
 
        blk_mq_tag_wakeup_all(tags, false);
 }
 
        sbitmap_free(&sbq->sb);
 }
 
+/**
+ * sbitmap_queue_recalculate_wake_batch() - Recalculate wake batch
+ * @sbq: Bitmap queue to recalculate wake batch.
+ * @users: Number of shares.
+ *
+ * Like sbitmap_queue_update_wake_batch(), this will calculate wake batch
+ * by depth. This interface is for HCTX shared tags or queue shared tags.
+ */
+void sbitmap_queue_recalculate_wake_batch(struct sbitmap_queue *sbq,
+                                           unsigned int users);
+
 /**
  * sbitmap_queue_resize() - Resize a &struct sbitmap_queue.
  * @sbq: Bitmap queue to resize.
 
 }
 EXPORT_SYMBOL_GPL(sbitmap_queue_init_node);
 
-static void sbitmap_queue_update_wake_batch(struct sbitmap_queue *sbq,
-                                           unsigned int depth)
+static inline void __sbitmap_queue_update_wake_batch(struct sbitmap_queue *sbq,
+                                           unsigned int wake_batch)
 {
-       unsigned int wake_batch = sbq_calc_wake_batch(sbq, depth);
        int i;
 
        if (sbq->wake_batch != wake_batch) {
        }
 }
 
+static void sbitmap_queue_update_wake_batch(struct sbitmap_queue *sbq,
+                                           unsigned int depth)
+{
+       unsigned int wake_batch;
+
+       wake_batch = sbq_calc_wake_batch(sbq, depth);
+       __sbitmap_queue_update_wake_batch(sbq, wake_batch);
+}
+
+void sbitmap_queue_recalculate_wake_batch(struct sbitmap_queue *sbq,
+                                           unsigned int users)
+{
+       unsigned int wake_batch;
+
+       wake_batch = clamp_val((sbq->sb.depth + users - 1) /
+                       users, 4, SBQ_WAKE_BATCH);
+       __sbitmap_queue_update_wake_batch(sbq, wake_batch);
+}
+EXPORT_SYMBOL_GPL(sbitmap_queue_recalculate_wake_batch);
+
 void sbitmap_queue_resize(struct sbitmap_queue *sbq, unsigned int depth)
 {
        sbitmap_queue_update_wake_batch(sbq, depth);