OOM_SCAN_SELECT,        /* always select this thread first */
 };
 
-extern void compare_swap_oom_score_adj(short old_val, short new_val);
-extern short test_set_oom_score_adj(short new_val);
+/* Thread is the potential origin of an oom condition; kill first on oom */
+#define OOM_FLAG_ORIGIN                ((__force oom_flags_t)0x1)
+
+static inline void set_current_oom_origin(void)
+{
+       current->signal->oom_flags |= OOM_FLAG_ORIGIN;
+}
+
+static inline void clear_current_oom_origin(void)
+{
+       current->signal->oom_flags &= ~OOM_FLAG_ORIGIN;
+}
+
+static inline bool oom_task_origin(const struct task_struct *p)
+{
+       return !!(p->signal->oom_flags & OOM_FLAG_ORIGIN);
+}
 
 extern unsigned long oom_badness(struct task_struct *p,
                struct mem_cgroup *memcg, const nodemask_t *nodemask,
 
        if (ksm_run != flags) {
                ksm_run = flags;
                if (flags & KSM_RUN_UNMERGE) {
-                       short oom_score_adj;
-
-                       oom_score_adj = test_set_oom_score_adj(OOM_SCORE_ADJ_MAX);
+                       set_current_oom_origin();
                        err = unmerge_and_remove_all_rmap_items();
-                       compare_swap_oom_score_adj(OOM_SCORE_ADJ_MAX,
-                                                               oom_score_adj);
+                       clear_current_oom_origin();
                        if (err) {
                                ksm_run = KSM_RUN_STOP;
                                count = err;
 
 int sysctl_oom_dump_tasks = 1;
 static DEFINE_SPINLOCK(zone_scan_lock);
 
-/*
- * compare_swap_oom_score_adj() - compare and swap current's oom_score_adj
- * @old_val: old oom_score_adj for compare
- * @new_val: new oom_score_adj for swap
- *
- * Sets the oom_score_adj value for current to @new_val iff its present value is
- * @old_val.  Usually used to reinstate a previous value to prevent racing with
- * userspacing tuning the value in the interim.
- */
-void compare_swap_oom_score_adj(short old_val, short new_val)
-{
-       struct sighand_struct *sighand = current->sighand;
-
-       spin_lock_irq(&sighand->siglock);
-       if (current->signal->oom_score_adj == old_val)
-               current->signal->oom_score_adj = new_val;
-       trace_oom_score_adj_update(current);
-       spin_unlock_irq(&sighand->siglock);
-}
-
-/**
- * test_set_oom_score_adj() - set current's oom_score_adj and return old value
- * @new_val: new oom_score_adj value
- *
- * Sets the oom_score_adj value for current to @new_val with proper
- * synchronization and returns the old value.  Usually used to temporarily
- * set a value, save the old value in the caller, and then reinstate it later.
- */
-short test_set_oom_score_adj(short new_val)
-{
-       struct sighand_struct *sighand = current->sighand;
-       int old_val;
-
-       spin_lock_irq(&sighand->siglock);
-       old_val = current->signal->oom_score_adj;
-       current->signal->oom_score_adj = new_val;
-       trace_oom_score_adj_update(current);
-       spin_unlock_irq(&sighand->siglock);
-
-       return old_val;
-}
-
 #ifdef CONFIG_NUMA
 /**
  * has_intersects_mems_allowed() - check task eligiblity for kill
        if (!task->mm)
                return OOM_SCAN_CONTINUE;
 
+       /*
+        * If task is allocating a lot of memory and has been marked to be
+        * killed first if it triggers an oom, then select it.
+        */
+       if (oom_task_origin(task))
+               return OOM_SCAN_SELECT;
+
        if (task->flags & PF_EXITING && !force_kill) {
                /*
                 * If this task is not being ptraced on exit, then wait for it
 
        struct address_space *mapping;
        struct inode *inode;
        struct filename *pathname;
-       short oom_score_adj;
        int i, type, prev;
        int err;
 
        p->flags &= ~SWP_WRITEOK;
        spin_unlock(&swap_lock);
 
-       oom_score_adj = test_set_oom_score_adj(OOM_SCORE_ADJ_MAX);
+       set_current_oom_origin();
        err = try_to_unuse(type, false, 0); /* force all pages to be unused */
-       compare_swap_oom_score_adj(OOM_SCORE_ADJ_MAX, oom_score_adj);
+       clear_current_oom_origin();
 
        if (err) {
                /* re-insert swap space back into swap_list */