/* If the task is part of a group prevent parallel updates to group stats */
        if (p->numa_group) {
                group_lock = &p->numa_group->lock;
-               spin_lock(group_lock);
+               spin_lock_irq(group_lock);
        }
 
        /* Find the node with the highest number of faults */
                        }
                }
 
-               spin_unlock(group_lock);
+               spin_unlock_irq(group_lock);
        }
 
        /* Preferred node as the node with the most faults */
        if (!join)
                return;
 
-       double_lock(&my_grp->lock, &grp->lock);
+       BUG_ON(irqs_disabled());
+       double_lock_irq(&my_grp->lock, &grp->lock);
 
        for (i = 0; i < NR_NUMA_HINT_FAULT_STATS * nr_node_ids; i++) {
                my_grp->faults[i] -= p->numa_faults_memory[i];
        grp->nr_tasks++;
 
        spin_unlock(&my_grp->lock);
-       spin_unlock(&grp->lock);
+       spin_unlock_irq(&grp->lock);
 
        rcu_assign_pointer(p->numa_group, grp);
 
        void *numa_faults = p->numa_faults_memory;
 
        if (grp) {
-               spin_lock(&grp->lock);
+               spin_lock_irq(&grp->lock);
                for (i = 0; i < NR_NUMA_HINT_FAULT_STATS * nr_node_ids; i++)
                        grp->faults[i] -= p->numa_faults_memory[i];
                grp->total_faults -= p->total_numa_faults;
 
                list_del(&p->numa_entry);
                grp->nr_tasks--;
-               spin_unlock(&grp->lock);
+               spin_unlock_irq(&grp->lock);
                rcu_assign_pointer(p->numa_group, NULL);
                put_numa_group(grp);
        }
 
        spin_lock_nested(l2, SINGLE_DEPTH_NESTING);
 }
 
+static inline void double_lock_irq(spinlock_t *l1, spinlock_t *l2)
+{
+       if (l1 > l2)
+               swap(l1, l2);
+
+       spin_lock_irq(l1);
+       spin_lock_nested(l2, SINGLE_DEPTH_NESTING);
+}
+
 static inline void double_raw_lock(raw_spinlock_t *l1, raw_spinlock_t *l2)
 {
        if (l1 > l2)