struct blk_mq_hw_ctx *hctx,
                                          unsigned int hctx_idx)
 {
+       if (blk_mq_is_sbitmap_shared(q->tag_set->flags)) {
+               hctx->sched_tags = q->shared_sbitmap_tags;
+               return 0;
+       }
+
        hctx->sched_tags = blk_mq_alloc_map_and_rqs(q->tag_set, hctx_idx,
                                                    q->nr_requests);
 
        return 0;
 }
 
+static void blk_mq_exit_sched_shared_sbitmap(struct request_queue *queue)
+{
+       blk_mq_free_rq_map(queue->shared_sbitmap_tags);
+       queue->shared_sbitmap_tags = NULL;
+}
+
 /* called in queue's release handler, tagset has gone away */
-static void blk_mq_sched_tags_teardown(struct request_queue *q)
+static void blk_mq_sched_tags_teardown(struct request_queue *q, unsigned int flags)
 {
        struct blk_mq_hw_ctx *hctx;
        int i;
 
        queue_for_each_hw_ctx(q, hctx, i) {
                if (hctx->sched_tags) {
-                       blk_mq_free_rq_map(hctx->sched_tags, hctx->flags);
+                       if (!blk_mq_is_sbitmap_shared(q->tag_set->flags))
+                               blk_mq_free_rq_map(hctx->sched_tags);
                        hctx->sched_tags = NULL;
                }
        }
+
+       if (blk_mq_is_sbitmap_shared(flags))
+               blk_mq_exit_sched_shared_sbitmap(q);
 }
 
 static int blk_mq_init_sched_shared_sbitmap(struct request_queue *queue)
 {
        struct blk_mq_tag_set *set = queue->tag_set;
-       int alloc_policy = BLK_MQ_FLAG_TO_ALLOC_POLICY(set->flags);
-       struct blk_mq_hw_ctx *hctx;
-       int ret, i;
 
        /*
         * Set initial depth at max so that we don't need to reallocate for
         * updating nr_requests.
         */
-       ret = blk_mq_init_bitmaps(&queue->sched_bitmap_tags,
-                                 &queue->sched_breserved_tags,
-                                 MAX_SCHED_RQ, set->reserved_tags,
-                                 set->numa_node, alloc_policy);
-       if (ret)
-               return ret;
-
-       queue_for_each_hw_ctx(queue, hctx, i) {
-               hctx->sched_tags->bitmap_tags =
-                                       &queue->sched_bitmap_tags;
-               hctx->sched_tags->breserved_tags =
-                                       &queue->sched_breserved_tags;
-       }
+       queue->shared_sbitmap_tags = blk_mq_alloc_map_and_rqs(set,
+                                               BLK_MQ_NO_HCTX_IDX,
+                                               MAX_SCHED_RQ);
+       if (!queue->shared_sbitmap_tags)
+               return -ENOMEM;
 
        blk_mq_tag_update_sched_shared_sbitmap(queue);
 
        return 0;
 }
 
-static void blk_mq_exit_sched_shared_sbitmap(struct request_queue *queue)
-{
-       sbitmap_queue_free(&queue->sched_bitmap_tags);
-       sbitmap_queue_free(&queue->sched_breserved_tags);
-}
-
 int blk_mq_init_sched(struct request_queue *q, struct elevator_type *e)
 {
+       unsigned int i, flags = q->tag_set->flags;
        struct blk_mq_hw_ctx *hctx;
        struct elevator_queue *eq;
-       unsigned int i;
        int ret;
 
        if (!e) {
        q->nr_requests = 2 * min_t(unsigned int, q->tag_set->queue_depth,
                                   BLKDEV_DEFAULT_RQ);
 
-       queue_for_each_hw_ctx(q, hctx, i) {
-               ret = blk_mq_sched_alloc_map_and_rqs(q, hctx, i);
+       if (blk_mq_is_sbitmap_shared(flags)) {
+               ret = blk_mq_init_sched_shared_sbitmap(q);
                if (ret)
-                       goto err_free_map_and_rqs;
+                       return ret;
        }
 
-       if (blk_mq_is_sbitmap_shared(q->tag_set->flags)) {
-               ret = blk_mq_init_sched_shared_sbitmap(q);
+       queue_for_each_hw_ctx(q, hctx, i) {
+               ret = blk_mq_sched_alloc_map_and_rqs(q, hctx, i);
                if (ret)
                        goto err_free_map_and_rqs;
        }
 
        ret = e->ops.init_sched(q, e);
        if (ret)
-               goto err_free_sbitmap;
+               goto err_free_map_and_rqs;
 
        blk_mq_debugfs_register_sched(q);
 
 
        return 0;
 
-err_free_sbitmap:
-       if (blk_mq_is_sbitmap_shared(q->tag_set->flags))
-               blk_mq_exit_sched_shared_sbitmap(q);
 err_free_map_and_rqs:
        blk_mq_sched_free_rqs(q);
-       blk_mq_sched_tags_teardown(q);
+       blk_mq_sched_tags_teardown(q, flags);
+
        q->elevator = NULL;
        return ret;
 }
        struct blk_mq_hw_ctx *hctx;
        int i;
 
-       queue_for_each_hw_ctx(q, hctx, i) {
-               if (hctx->sched_tags)
-                       blk_mq_free_rqs(q->tag_set, hctx->sched_tags, i);
+       if (blk_mq_is_sbitmap_shared(q->tag_set->flags)) {
+               blk_mq_free_rqs(q->tag_set, q->shared_sbitmap_tags,
+                               BLK_MQ_NO_HCTX_IDX);
+       } else {
+               queue_for_each_hw_ctx(q, hctx, i) {
+                       if (hctx->sched_tags)
+                               blk_mq_free_rqs(q->tag_set,
+                                               hctx->sched_tags, i);
+               }
        }
 }
 
        blk_mq_debugfs_unregister_sched(q);
        if (e->type->ops.exit_sched)
                e->type->ops.exit_sched(e);
-       blk_mq_sched_tags_teardown(q);
-       if (blk_mq_is_sbitmap_shared(flags))
-               blk_mq_exit_sched_shared_sbitmap(q);
+       blk_mq_sched_tags_teardown(q, flags);
        q->elevator = NULL;
 }
 
 {
        if (blk_mq_is_sbitmap_shared(hctx->flags)) {
                struct request_queue *q = hctx->queue;
-               struct blk_mq_tag_set *set = q->tag_set;
 
                if (!test_bit(QUEUE_FLAG_HCTX_ACTIVE, &q->queue_flags) &&
                    !test_and_set_bit(QUEUE_FLAG_HCTX_ACTIVE, &q->queue_flags))
-                       atomic_inc(&set->active_queues_shared_sbitmap);
+                       atomic_inc(&hctx->tags->active_queues);
        } else {
                if (!test_bit(BLK_MQ_S_TAG_ACTIVE, &hctx->state) &&
                    !test_and_set_bit(BLK_MQ_S_TAG_ACTIVE, &hctx->state))
 void __blk_mq_tag_idle(struct blk_mq_hw_ctx *hctx)
 {
        struct blk_mq_tags *tags = hctx->tags;
-       struct request_queue *q = hctx->queue;
-       struct blk_mq_tag_set *set = q->tag_set;
 
        if (blk_mq_is_sbitmap_shared(hctx->flags)) {
+               struct request_queue *q = hctx->queue;
+
                if (!test_and_clear_bit(QUEUE_FLAG_HCTX_ACTIVE,
                                        &q->queue_flags))
                        return;
-               atomic_dec(&set->active_queues_shared_sbitmap);
+               atomic_dec(&tags->active_queues);
        } else {
                if (!test_and_clear_bit(BLK_MQ_S_TAG_ACTIVE, &hctx->state))
                        return;
        return 0;
 }
 
-int blk_mq_init_shared_sbitmap(struct blk_mq_tag_set *set)
-{
-       int alloc_policy = BLK_MQ_FLAG_TO_ALLOC_POLICY(set->flags);
-       int i, ret;
-
-       ret = blk_mq_init_bitmaps(&set->__bitmap_tags, &set->__breserved_tags,
-                                 set->queue_depth, set->reserved_tags,
-                                 set->numa_node, alloc_policy);
-       if (ret)
-               return ret;
-
-       for (i = 0; i < set->nr_hw_queues; i++) {
-               struct blk_mq_tags *tags = set->tags[i];
-
-               tags->bitmap_tags = &set->__bitmap_tags;
-               tags->breserved_tags = &set->__breserved_tags;
-       }
-
-       return 0;
-}
-
-void blk_mq_exit_shared_sbitmap(struct blk_mq_tag_set *set)
-{
-       sbitmap_queue_free(&set->__bitmap_tags);
-       sbitmap_queue_free(&set->__breserved_tags);
-}
-
 struct blk_mq_tags *blk_mq_init_tags(unsigned int total_tags,
                                     unsigned int reserved_tags,
-                                    int node, unsigned int flags)
+                                    int node, int alloc_policy)
 {
-       int alloc_policy = BLK_MQ_FLAG_TO_ALLOC_POLICY(flags);
        struct blk_mq_tags *tags;
 
        if (total_tags > BLK_MQ_TAG_MAX) {
        tags->nr_reserved_tags = reserved_tags;
        spin_lock_init(&tags->lock);
 
-       if (blk_mq_is_sbitmap_shared(flags))
-               return tags;
-
        if (blk_mq_init_bitmap_tags(tags, node, alloc_policy) < 0) {
                kfree(tags);
                return NULL;
        return tags;
 }
 
-void blk_mq_free_tags(struct blk_mq_tags *tags, unsigned int flags)
+void blk_mq_free_tags(struct blk_mq_tags *tags)
 {
-       if (!blk_mq_is_sbitmap_shared(flags)) {
-               sbitmap_queue_free(tags->bitmap_tags);
-               sbitmap_queue_free(tags->breserved_tags);
-       }
+       sbitmap_queue_free(tags->bitmap_tags);
+       sbitmap_queue_free(tags->breserved_tags);
        kfree(tags);
 }
 
                if (tdepth > MAX_SCHED_RQ)
                        return -EINVAL;
 
+               /*
+                * Only the sbitmap needs resizing since we allocated the max
+                * initially.
+                */
+               if (blk_mq_is_sbitmap_shared(set->flags))
+                       return 0;
+
                new = blk_mq_alloc_map_and_rqs(set, hctx->queue_num, tdepth);
                if (!new)
                        return -ENOMEM;
 
 void blk_mq_tag_resize_shared_sbitmap(struct blk_mq_tag_set *set, unsigned int size)
 {
-       sbitmap_queue_resize(&set->__bitmap_tags, size - set->reserved_tags);
+       struct blk_mq_tags *tags = set->shared_sbitmap_tags;
+
+       sbitmap_queue_resize(&tags->__bitmap_tags, size - set->reserved_tags);
 }
 
 void blk_mq_tag_update_sched_shared_sbitmap(struct request_queue *q)
 {
-       sbitmap_queue_resize(&q->sched_bitmap_tags,
+       sbitmap_queue_resize(q->shared_sbitmap_tags->bitmap_tags,
                             q->nr_requests - q->tag_set->reserved_tags);
 }
 
 
 
 extern struct blk_mq_tags *blk_mq_init_tags(unsigned int nr_tags,
                                        unsigned int reserved_tags,
-                                       int node, unsigned int flags);
-extern void blk_mq_free_tags(struct blk_mq_tags *tags, unsigned int flags);
+                                       int node, int alloc_policy);
+extern void blk_mq_free_tags(struct blk_mq_tags *tags);
 extern int blk_mq_init_bitmaps(struct sbitmap_queue *bitmap_tags,
                               struct sbitmap_queue *breserved_tags,
                               unsigned int queue_depth,
                               unsigned int reserved,
                               int node, int alloc_policy);
 
-extern int blk_mq_init_shared_sbitmap(struct blk_mq_tag_set *set);
-extern void blk_mq_exit_shared_sbitmap(struct blk_mq_tag_set *set);
 extern unsigned int blk_mq_get_tag(struct blk_mq_alloc_data *data);
 extern void blk_mq_put_tag(struct blk_mq_tags *tags, struct blk_mq_ctx *ctx,
                           unsigned int tag);
 
        struct blk_mq_tags *drv_tags;
        struct page *page;
 
-       drv_tags = set->tags[hctx_idx];
+       if (blk_mq_is_sbitmap_shared(set->flags))
+               drv_tags = set->shared_sbitmap_tags;
+       else
+               drv_tags = set->tags[hctx_idx];
 
        if (tags->static_rqs && set->ops->exit_request) {
                int i;
        }
 }
 
-void blk_mq_free_rq_map(struct blk_mq_tags *tags, unsigned int flags)
+void blk_mq_free_rq_map(struct blk_mq_tags *tags)
 {
        kfree(tags->rqs);
        tags->rqs = NULL;
        kfree(tags->static_rqs);
        tags->static_rqs = NULL;
 
-       blk_mq_free_tags(tags, flags);
+       blk_mq_free_tags(tags);
 }
 
 static struct blk_mq_tags *blk_mq_alloc_rq_map(struct blk_mq_tag_set *set,
                                               unsigned int hctx_idx,
                                               unsigned int nr_tags,
-                                              unsigned int reserved_tags,
-                                              unsigned int flags)
+                                              unsigned int reserved_tags)
 {
        struct blk_mq_tags *tags;
        int node;
        if (node == NUMA_NO_NODE)
                node = set->numa_node;
 
-       tags = blk_mq_init_tags(nr_tags, reserved_tags, node, flags);
+       tags = blk_mq_init_tags(nr_tags, reserved_tags, node,
+                               BLK_MQ_FLAG_TO_ALLOC_POLICY(set->flags));
        if (!tags)
                return NULL;
 
                                 GFP_NOIO | __GFP_NOWARN | __GFP_NORETRY,
                                 node);
        if (!tags->rqs) {
-               blk_mq_free_tags(tags, flags);
+               blk_mq_free_tags(tags);
                return NULL;
        }
 
                                        node);
        if (!tags->static_rqs) {
                kfree(tags->rqs);
-               blk_mq_free_tags(tags, flags);
+               blk_mq_free_tags(tags);
                return NULL;
        }
 
        struct blk_mq_tags *tags;
        int ret;
 
-       tags = blk_mq_alloc_rq_map(set, hctx_idx, depth, set->reserved_tags,
-                                  set->flags);
+       tags = blk_mq_alloc_rq_map(set, hctx_idx, depth, set->reserved_tags);
        if (!tags)
                return NULL;
 
        ret = blk_mq_alloc_rqs(set, tags, hctx_idx, depth);
        if (ret) {
-               blk_mq_free_rq_map(tags, set->flags);
+               blk_mq_free_rq_map(tags);
                return NULL;
        }
 
 static bool __blk_mq_alloc_map_and_rqs(struct blk_mq_tag_set *set,
                                       int hctx_idx)
 {
+       if (blk_mq_is_sbitmap_shared(set->flags)) {
+               set->tags[hctx_idx] = set->shared_sbitmap_tags;
+
+               return true;
+       }
+
        set->tags[hctx_idx] = blk_mq_alloc_map_and_rqs(set, hctx_idx,
                                                       set->queue_depth);
 
                             struct blk_mq_tags *tags,
                             unsigned int hctx_idx)
 {
-       unsigned int flags = set->flags;
-
        if (tags) {
                blk_mq_free_rqs(set, tags, hctx_idx);
-               blk_mq_free_rq_map(tags, flags);
+               blk_mq_free_rq_map(tags);
        }
 }
 
+static void __blk_mq_free_map_and_rqs(struct blk_mq_tag_set *set,
+                                     unsigned int hctx_idx)
+{
+       if (!blk_mq_is_sbitmap_shared(set->flags))
+               blk_mq_free_map_and_rqs(set, set->tags[hctx_idx], hctx_idx);
+
+       set->tags[hctx_idx] = NULL;
+}
+
 static void blk_mq_map_swqueue(struct request_queue *q)
 {
        unsigned int i, j, hctx_idx;
                         * fallback in case of a new remap fails
                         * allocation
                         */
-                       if (i && set->tags[i]) {
-                               blk_mq_free_map_and_rqs(set, set->tags[i], i);
-                               set->tags[i] = NULL;
-                       }
+                       if (i)
+                               __blk_mq_free_map_and_rqs(set, i);
 
                        hctx->tags = NULL;
                        continue;
                struct blk_mq_hw_ctx *hctx = hctxs[j];
 
                if (hctx) {
-                       blk_mq_free_map_and_rqs(set, set->tags[j], j);
-                       set->tags[j] = NULL;
+                       __blk_mq_free_map_and_rqs(set, j);
                        blk_mq_exit_hctx(q, set, hctx, j);
                        hctxs[j] = NULL;
                }
 {
        int i;
 
+       if (blk_mq_is_sbitmap_shared(set->flags)) {
+               set->shared_sbitmap_tags = blk_mq_alloc_map_and_rqs(set,
+                                               BLK_MQ_NO_HCTX_IDX,
+                                               set->queue_depth);
+               if (!set->shared_sbitmap_tags)
+                       return -ENOMEM;
+       }
+
        for (i = 0; i < set->nr_hw_queues; i++) {
                if (!__blk_mq_alloc_map_and_rqs(set, i))
                        goto out_unwind;
        return 0;
 
 out_unwind:
-       while (--i >= 0) {
-               blk_mq_free_map_and_rqs(set, set->tags[i], i);
-               set->tags[i] = NULL;
+       while (--i >= 0)
+               __blk_mq_free_map_and_rqs(set, i);
+
+       if (blk_mq_is_sbitmap_shared(set->flags)) {
+               blk_mq_free_map_and_rqs(set, set->shared_sbitmap_tags,
+                                       BLK_MQ_NO_HCTX_IDX);
        }
 
        return -ENOMEM;
        if (ret)
                goto out_free_mq_map;
 
-       if (blk_mq_is_sbitmap_shared(set->flags)) {
-               atomic_set(&set->active_queues_shared_sbitmap, 0);
-
-               if (blk_mq_init_shared_sbitmap(set)) {
-                       ret = -ENOMEM;
-                       goto out_free_mq_rq_maps;
-               }
-       }
-
        mutex_init(&set->tag_list_lock);
        INIT_LIST_HEAD(&set->tag_list);
 
        return 0;
 
-out_free_mq_rq_maps:
-       for (i = 0; i < set->nr_hw_queues; i++) {
-               blk_mq_free_map_and_rqs(set, set->tags[i], i);
-               set->tags[i] = NULL;
-       }
 out_free_mq_map:
        for (i = 0; i < set->nr_maps; i++) {
                kfree(set->map[i].mq_map);
 {
        int i, j;
 
-       for (i = 0; i < set->nr_hw_queues; i++) {
-               blk_mq_free_map_and_rqs(set, set->tags[i], i);
-               set->tags[i] = NULL;
-       }
+       for (i = 0; i < set->nr_hw_queues; i++)
+               __blk_mq_free_map_and_rqs(set, i);
 
-       if (blk_mq_is_sbitmap_shared(set->flags))
-               blk_mq_exit_shared_sbitmap(set);
+       if (blk_mq_is_sbitmap_shared(set->flags)) {
+               blk_mq_free_map_and_rqs(set, set->shared_sbitmap_tags,
+                                       BLK_MQ_NO_HCTX_IDX);
+       }
 
        for (j = 0; j < set->nr_maps; j++) {
                kfree(set->map[j].mq_map);
                if (hctx->sched_tags) {
                        ret = blk_mq_tag_update_depth(hctx, &hctx->sched_tags,
                                                      nr, true);
-                       if (blk_mq_is_sbitmap_shared(set->flags)) {
-                               hctx->sched_tags->bitmap_tags =
-                                       &q->sched_bitmap_tags;
-                               hctx->sched_tags->breserved_tags =
-                                       &q->sched_breserved_tags;
-                       }
                } else {
                        ret = blk_mq_tag_update_depth(hctx, &hctx->tags, nr,
                                                      false);
 
  */
 void blk_mq_free_rqs(struct blk_mq_tag_set *set, struct blk_mq_tags *tags,
                     unsigned int hctx_idx);
-void blk_mq_free_rq_map(struct blk_mq_tags *tags, unsigned int flags);
+void blk_mq_free_rq_map(struct blk_mq_tags *tags);
 struct blk_mq_tags *blk_mq_alloc_map_and_rqs(struct blk_mq_tag_set *set,
                                unsigned int hctx_idx, unsigned int depth);
 void blk_mq_free_map_and_rqs(struct blk_mq_tag_set *set,
 
        if (blk_mq_is_sbitmap_shared(hctx->flags)) {
                struct request_queue *q = hctx->queue;
-               struct blk_mq_tag_set *set = q->tag_set;
 
                if (!test_bit(QUEUE_FLAG_HCTX_ACTIVE, &q->queue_flags))
                        return true;
-               users = atomic_read(&set->active_queues_shared_sbitmap);
        } else {
                if (!test_bit(BLK_MQ_S_TAG_ACTIVE, &hctx->state))
                        return true;
-               users = atomic_read(&hctx->tags->active_queues);
        }
 
+       users = atomic_read(&hctx->tags->active_queues);
+
        if (!users)
                return true;
 
 
  * @flags:        Zero or more BLK_MQ_F_* flags.
  * @driver_data:   Pointer to data owned by the block driver that created this
  *                tag set.
- * @active_queues_shared_sbitmap:
- *                number of active request queues per tag set.
- * @__bitmap_tags: A shared tags sbitmap, used over all hctx's
- * @__breserved_tags:
- *                A shared reserved tags sbitmap, used over all hctx's
  * @tags:         Tag sets. One tag set per hardware queue. Has @nr_hw_queues
  *                elements.
+ * @shared_sbitmap_tags:
+ *                Shared sbitmap set of tags. Has @nr_hw_queues elements. If
+ *                set, shared by all @tags.
  * @tag_list_lock: Serializes tag_list accesses.
  * @tag_list:     List of the request queues that use this tag set. See also
  *                request_queue.tag_set_list.
        unsigned int            timeout;
        unsigned int            flags;
        void                    *driver_data;
-       atomic_t                active_queues_shared_sbitmap;
 
-       struct sbitmap_queue    __bitmap_tags;
-       struct sbitmap_queue    __breserved_tags;
        struct blk_mq_tags      **tags;
 
+       struct blk_mq_tags      *shared_sbitmap_tags;
+
        struct mutex            tag_list_lock;
        struct list_head        tag_list;
 };
        ((policy & ((1 << BLK_MQ_F_ALLOC_POLICY_BITS) - 1)) \
                << BLK_MQ_F_ALLOC_POLICY_START_BIT)
 
+#define BLK_MQ_NO_HCTX_IDX     (-1U)
+
 struct gendisk *__blk_mq_alloc_disk(struct blk_mq_tag_set *set, void *queuedata,
                struct lock_class_key *lkclass);
 #define blk_mq_alloc_disk(set, queuedata)                              \
 
 
        atomic_t                nr_active_requests_shared_sbitmap;
 
-       struct sbitmap_queue    sched_bitmap_tags;
-       struct sbitmap_queue    sched_breserved_tags;
+       struct blk_mq_tags      *shared_sbitmap_tags;
 
        struct list_head        icq_list;
 #ifdef CONFIG_BLK_CGROUP