extern void mem_cgroup_replace_page_cache(struct page *oldpage,
                                        struct page *newpage);
 
-/**
- * mem_cgroup_toggle_oom - toggle the memcg OOM killer for the current task
- * @new: true to enable, false to disable
- *
- * Toggle whether a failed memcg charge should invoke the OOM killer
- * or just return -ENOMEM.  Returns the previous toggle state.
- *
- * NOTE: Any path that enables the OOM killer before charging must
- *       call mem_cgroup_oom_synchronize() afterward to finalize the
- *       OOM handling and clean up.
- */
-static inline bool mem_cgroup_toggle_oom(bool new)
+static inline void mem_cgroup_oom_enable(void)
 {
-       bool old;
-
-       old = current->memcg_oom.may_oom;
-       current->memcg_oom.may_oom = new;
-
-       return old;
+       WARN_ON(current->memcg_oom.may_oom);
+       current->memcg_oom.may_oom = 1;
 }
 
-static inline void mem_cgroup_enable_oom(void)
+static inline void mem_cgroup_oom_disable(void)
 {
-       bool old = mem_cgroup_toggle_oom(true);
-
-       WARN_ON(old == true);
-}
-
-static inline void mem_cgroup_disable_oom(void)
-{
-       bool old = mem_cgroup_toggle_oom(false);
-
-       WARN_ON(old == false);
+       WARN_ON(!current->memcg_oom.may_oom);
+       current->memcg_oom.may_oom = 0;
 }
 
 static inline bool task_in_memcg_oom(struct task_struct *p)
 {
-       return p->memcg_oom.in_memcg_oom;
+       return p->memcg_oom.memcg;
 }
 
-bool mem_cgroup_oom_synchronize(void);
+bool mem_cgroup_oom_synchronize(bool wait);
 
 #ifdef CONFIG_MEMCG_SWAP
 extern int do_swap_account;
 {
 }
 
-static inline bool mem_cgroup_toggle_oom(bool new)
-{
-       return false;
-}
-
-static inline void mem_cgroup_enable_oom(void)
+static inline void mem_cgroup_oom_enable(void)
 {
 }
 
-static inline void mem_cgroup_disable_oom(void)
+static inline void mem_cgroup_oom_disable(void)
 {
 }
 
        return false;
 }
 
-static inline bool mem_cgroup_oom_synchronize(void)
+static inline bool mem_cgroup_oom_synchronize(bool wait)
 {
        return false;
 }
 
        struct inode *inode = mapping->host;
        pgoff_t offset = vmf->pgoff;
        struct page *page;
-       bool memcg_oom;
        pgoff_t size;
        int ret = 0;
 
                return VM_FAULT_SIGBUS;
 
        /*
-        * Do we have something in the page cache already?  Either
-        * way, try readahead, but disable the memcg OOM killer for it
-        * as readahead is optional and no errors are propagated up
-        * the fault stack.  The OOM killer is enabled while trying to
-        * instantiate the faulting page individually below.
+        * Do we have something in the page cache already?
         */
        page = find_get_page(mapping, offset);
        if (likely(page) && !(vmf->flags & FAULT_FLAG_TRIED)) {
                 * We found the page, so try async readahead before
                 * waiting for the lock.
                 */
-               memcg_oom = mem_cgroup_toggle_oom(false);
                do_async_mmap_readahead(vma, ra, file, page, offset);
-               mem_cgroup_toggle_oom(memcg_oom);
        } else if (!page) {
                /* No page in the page cache at all */
-               memcg_oom = mem_cgroup_toggle_oom(false);
                do_sync_mmap_readahead(vma, ra, file, offset);
-               mem_cgroup_toggle_oom(memcg_oom);
                count_vm_event(PGMAJFAULT);
                mem_cgroup_count_vm_event(vma->vm_mm, PGMAJFAULT);
                ret = VM_FAULT_MAJOR;
 
                memcg_wakeup_oom(memcg);
 }
 
-/*
- * try to call OOM killer
- */
 static void mem_cgroup_oom(struct mem_cgroup *memcg, gfp_t mask, int order)
 {
-       bool locked;
-       int wakeups;
-
        if (!current->memcg_oom.may_oom)
                return;
-
-       current->memcg_oom.in_memcg_oom = 1;
-
        /*
-        * As with any blocking lock, a contender needs to start
-        * listening for wakeups before attempting the trylock,
-        * otherwise it can miss the wakeup from the unlock and sleep
-        * indefinitely.  This is just open-coded because our locking
-        * is so particular to memcg hierarchies.
+        * We are in the middle of the charge context here, so we
+        * don't want to block when potentially sitting on a callstack
+        * that holds all kinds of filesystem and mm locks.
+        *
+        * Also, the caller may handle a failed allocation gracefully
+        * (like optional page cache readahead) and so an OOM killer
+        * invocation might not even be necessary.
+        *
+        * That's why we don't do anything here except remember the
+        * OOM context and then deal with it at the end of the page
+        * fault when the stack is unwound, the locks are released,
+        * and when we know whether the fault was overall successful.
         */
-       wakeups = atomic_read(&memcg->oom_wakeups);
-       mem_cgroup_mark_under_oom(memcg);
-
-       locked = mem_cgroup_oom_trylock(memcg);
-
-       if (locked)
-               mem_cgroup_oom_notify(memcg);
-
-       if (locked && !memcg->oom_kill_disable) {
-               mem_cgroup_unmark_under_oom(memcg);
-               mem_cgroup_out_of_memory(memcg, mask, order);
-               mem_cgroup_oom_unlock(memcg);
-               /*
-                * There is no guarantee that an OOM-lock contender
-                * sees the wakeups triggered by the OOM kill
-                * uncharges.  Wake any sleepers explicitely.
-                */
-               memcg_oom_recover(memcg);
-       } else {
-               /*
-                * A system call can just return -ENOMEM, but if this
-                * is a page fault and somebody else is handling the
-                * OOM already, we need to sleep on the OOM waitqueue
-                * for this memcg until the situation is resolved.
-                * Which can take some time because it might be
-                * handled by a userspace task.
-                *
-                * However, this is the charge context, which means
-                * that we may sit on a large call stack and hold
-                * various filesystem locks, the mmap_sem etc. and we
-                * don't want the OOM handler to deadlock on them
-                * while we sit here and wait.  Store the current OOM
-                * context in the task_struct, then return -ENOMEM.
-                * At the end of the page fault handler, with the
-                * stack unwound, pagefault_out_of_memory() will check
-                * back with us by calling
-                * mem_cgroup_oom_synchronize(), possibly putting the
-                * task to sleep.
-                */
-               current->memcg_oom.oom_locked = locked;
-               current->memcg_oom.wakeups = wakeups;
-               css_get(&memcg->css);
-               current->memcg_oom.wait_on_memcg = memcg;
-       }
+       css_get(&memcg->css);
+       current->memcg_oom.memcg = memcg;
+       current->memcg_oom.gfp_mask = mask;
+       current->memcg_oom.order = order;
 }
 
 /**
  * mem_cgroup_oom_synchronize - complete memcg OOM handling
+ * @handle: actually kill/wait or just clean up the OOM state
  *
- * This has to be called at the end of a page fault if the the memcg
- * OOM handler was enabled and the fault is returning %VM_FAULT_OOM.
+ * This has to be called at the end of a page fault if the memcg OOM
+ * handler was enabled.
  *
- * Memcg supports userspace OOM handling, so failed allocations must
+ * Memcg supports userspace OOM handling where failed allocations must
  * sleep on a waitqueue until the userspace task resolves the
  * situation.  Sleeping directly in the charge context with all kinds
  * of locks held is not a good idea, instead we remember an OOM state
  * in the task and mem_cgroup_oom_synchronize() has to be called at
- * the end of the page fault to put the task to sleep and clean up the
- * OOM state.
+ * the end of the page fault to complete the OOM handling.
  *
  * Returns %true if an ongoing memcg OOM situation was detected and
- * finalized, %false otherwise.
+ * completed, %false otherwise.
  */
-bool mem_cgroup_oom_synchronize(void)
+bool mem_cgroup_oom_synchronize(bool handle)
 {
+       struct mem_cgroup *memcg = current->memcg_oom.memcg;
        struct oom_wait_info owait;
-       struct mem_cgroup *memcg;
+       bool locked;
 
        /* OOM is global, do not handle */
-       if (!current->memcg_oom.in_memcg_oom)
-               return false;
-
-       /*
-        * We invoked the OOM killer but there is a chance that a kill
-        * did not free up any charges.  Everybody else might already
-        * be sleeping, so restart the fault and keep the rampage
-        * going until some charges are released.
-        */
-       memcg = current->memcg_oom.wait_on_memcg;
        if (!memcg)
-               goto out;
+               return false;
 
-       if (test_thread_flag(TIF_MEMDIE) || fatal_signal_pending(current))
-               goto out_memcg;
+       if (!handle)
+               goto cleanup;
 
        owait.memcg = memcg;
        owait.wait.flags = 0;
        INIT_LIST_HEAD(&owait.wait.task_list);
 
        prepare_to_wait(&memcg_oom_waitq, &owait.wait, TASK_KILLABLE);
-       /* Only sleep if we didn't miss any wakeups since OOM */
-       if (atomic_read(&memcg->oom_wakeups) == current->memcg_oom.wakeups)
+       mem_cgroup_mark_under_oom(memcg);
+
+       locked = mem_cgroup_oom_trylock(memcg);
+
+       if (locked)
+               mem_cgroup_oom_notify(memcg);
+
+       if (locked && !memcg->oom_kill_disable) {
+               mem_cgroup_unmark_under_oom(memcg);
+               finish_wait(&memcg_oom_waitq, &owait.wait);
+               mem_cgroup_out_of_memory(memcg, current->memcg_oom.gfp_mask,
+                                        current->memcg_oom.order);
+       } else {
                schedule();
-       finish_wait(&memcg_oom_waitq, &owait.wait);
-out_memcg:
-       mem_cgroup_unmark_under_oom(memcg);
-       if (current->memcg_oom.oom_locked) {
+               mem_cgroup_unmark_under_oom(memcg);
+               finish_wait(&memcg_oom_waitq, &owait.wait);
+       }
+
+       if (locked) {
                mem_cgroup_oom_unlock(memcg);
                /*
                 * There is no guarantee that an OOM-lock contender
                 */
                memcg_oom_recover(memcg);
        }
+cleanup:
+       current->memcg_oom.memcg = NULL;
        css_put(&memcg->css);
-       current->memcg_oom.wait_on_memcg = NULL;
-out:
-       current->memcg_oom.in_memcg_oom = 0;
        return true;
 }
 
                     || fatal_signal_pending(current)))
                goto bypass;
 
+       if (unlikely(task_in_memcg_oom(current)))
+               goto bypass;
+
        /*
         * We always charge the cgroup the mm_struct belongs to.
         * The mm_struct's mem_cgroup changes on task migration if the