*/
 static struct kmem_cache *iocontext_cachep;
 
+/**
+ * get_io_context - increment reference count to io_context
+ * @ioc: io_context to get
+ *
+ * Increment reference count to @ioc.
+ */
+void get_io_context(struct io_context *ioc)
+{
+       BUG_ON(atomic_long_read(&ioc->refcount) <= 0);
+       atomic_long_inc(&ioc->refcount);
+}
+EXPORT_SYMBOL(get_io_context);
+
 static void cfq_dtor(struct io_context *ioc)
 {
        if (!hlist_empty(&ioc->cic_list)) {
 {
        struct io_context *ioc;
 
+       /* PF_EXITING prevents new io_context from being attached to @task */
+       WARN_ON_ONCE(!(current->flags & PF_EXITING));
+
        task_lock(task);
        ioc = task->io_context;
        task->io_context = NULL;
        put_io_context(ioc);
 }
 
-struct io_context *alloc_io_context(gfp_t gfp_flags, int node)
+static struct io_context *create_task_io_context(struct task_struct *task,
+                                                gfp_t gfp_flags, int node,
+                                                bool take_ref)
 {
        struct io_context *ioc;
 
        INIT_RADIX_TREE(&ioc->radix_root, GFP_ATOMIC | __GFP_HIGH);
        INIT_HLIST_HEAD(&ioc->cic_list);
 
+       /* try to install, somebody might already have beaten us to it */
+       task_lock(task);
+
+       if (!task->io_context && !(task->flags & PF_EXITING)) {
+               task->io_context = ioc;
+       } else {
+               kmem_cache_free(iocontext_cachep, ioc);
+               ioc = task->io_context;
+       }
+
+       if (ioc && take_ref)
+               get_io_context(ioc);
+
+       task_unlock(task);
        return ioc;
 }
 
  */
 struct io_context *current_io_context(gfp_t gfp_flags, int node)
 {
-       struct task_struct *tsk = current;
-       struct io_context *ret;
-
-       ret = tsk->io_context;
-       if (likely(ret))
-               return ret;
-
-       ret = alloc_io_context(gfp_flags, node);
-       if (ret) {
-               /* make sure set_task_ioprio() sees the settings above */
-               smp_wmb();
-               tsk->io_context = ret;
-       }
+       might_sleep_if(gfp_flags & __GFP_WAIT);
 
-       return ret;
+       if (current->io_context)
+               return current->io_context;
+
+       return create_task_io_context(current, gfp_flags, node, false);
 }
+EXPORT_SYMBOL(current_io_context);
 
-/*
- * If the current task has no IO context then create one and initialise it.
- * If it does have a context, take a ref on it.
+/**
+ * get_task_io_context - get io_context of a task
+ * @task: task of interest
+ * @gfp_flags: allocation flags, used if allocation is necessary
+ * @node: allocation node, used if allocation is necessary
+ *
+ * Return io_context of @task.  If it doesn't exist, it is created with
+ * @gfp_flags and @node.  The returned io_context has its reference count
+ * incremented.
  *
- * This is always called in the context of the task which submitted the I/O.
+ * This function always goes through task_lock() and it's better to use
+ * current_io_context() + get_io_context() for %current.
  */
-struct io_context *get_io_context(gfp_t gfp_flags, int node)
+struct io_context *get_task_io_context(struct task_struct *task,
+                                      gfp_t gfp_flags, int node)
 {
-       struct io_context *ioc = NULL;
-
-       /*
-        * Check for unlikely race with exiting task. ioc ref count is
-        * zero when ioc is being detached.
-        */
-       do {
-               ioc = current_io_context(gfp_flags, node);
-               if (unlikely(!ioc))
-                       break;
-       } while (!atomic_long_inc_not_zero(&ioc->refcount));
+       struct io_context *ioc;
 
-       return ioc;
+       might_sleep_if(gfp_flags & __GFP_WAIT);
+
+       task_lock(task);
+       ioc = task->io_context;
+       if (likely(ioc)) {
+               get_io_context(ioc);
+               task_unlock(task);
+               return ioc;
+       }
+       task_unlock(task);
+
+       return create_task_io_context(task, gfp_flags, node, true);
 }
-EXPORT_SYMBOL(get_io_context);
+EXPORT_SYMBOL(get_task_io_context);
 
 static int __init blk_ioc_init(void)
 {
 
 #include <linux/rbtree.h>
 #include <linux/ioprio.h>
 #include <linux/blktrace_api.h>
+#include "blk.h"
 #include "cfq.h"
 
 /*
 cfq_get_io_context(struct cfq_data *cfqd, gfp_t gfp_mask)
 {
        struct io_context *ioc = NULL;
-       struct cfq_io_context *cic;
+       struct cfq_io_context *cic = NULL;
 
        might_sleep_if(gfp_mask & __GFP_WAIT);
 
-       ioc = get_io_context(gfp_mask, cfqd->queue->node);
+       ioc = current_io_context(gfp_mask, cfqd->queue->node);
        if (!ioc)
-               return NULL;
+               goto err;
 
        cic = cfq_cic_lookup(cfqd, ioc);
        if (cic)
                goto err;
 
        if (cfq_cic_link(cfqd, ioc, cic, gfp_mask))
-               goto err_free;
-
+               goto err;
 out:
-       smp_read_barrier_depends();
+       get_io_context(ioc);
+
        if (unlikely(ioc->ioprio_changed))
                cfq_ioc_set_ioprio(ioc);
 
                cfq_ioc_set_cgroup(ioc);
 #endif
        return cic;
-err_free:
-       cfq_cic_free(cic);
 err:
-       put_io_context(ioc);
+       if (cic)
+               cfq_cic_free(cic);
        return NULL;
 }