#include <asm/cpu.h>
 #include <asm/traps.h>
 #include <asm/desc.h>
+#include <asm/tlbflush.h>
 
 #define MMU_QUEUE_SIZE 1024
 
        wait_queue_head_t wq;
        u32 token;
        int cpu;
+       bool halted;
+       struct mm_struct *mm;
 };
 
 static struct kvm_task_sleep_head {
        struct kvm_task_sleep_head *b = &async_pf_sleepers[key];
        struct kvm_task_sleep_node n, *e;
        DEFINE_WAIT(wait);
+       int cpu, idle;
+
+       cpu = get_cpu();
+       idle = idle_cpu(cpu);
+       put_cpu();
 
        spin_lock(&b->lock);
        e = _find_apf_task(b, token);
 
        n.token = token;
        n.cpu = smp_processor_id();
+       n.mm = current->active_mm;
+       n.halted = idle || preempt_count() > 1;
+       atomic_inc(&n.mm->mm_count);
        init_waitqueue_head(&n.wq);
        hlist_add_head(&n.link, &b->list);
        spin_unlock(&b->lock);
 
        for (;;) {
-               prepare_to_wait(&n.wq, &wait, TASK_UNINTERRUPTIBLE);
+               if (!n.halted)
+                       prepare_to_wait(&n.wq, &wait, TASK_UNINTERRUPTIBLE);
                if (hlist_unhashed(&n.link))
                        break;
-               local_irq_enable();
-               schedule();
-               local_irq_disable();
+
+               if (!n.halted) {
+                       local_irq_enable();
+                       schedule();
+                       local_irq_disable();
+               } else {
+                       /*
+                        * We cannot reschedule. So halt.
+                        */
+                       native_safe_halt();
+                       local_irq_disable();
+               }
        }
-       finish_wait(&n.wq, &wait);
+       if (!n.halted)
+               finish_wait(&n.wq, &wait);
 
        return;
 }
 static void apf_task_wake_one(struct kvm_task_sleep_node *n)
 {
        hlist_del_init(&n->link);
-       if (waitqueue_active(&n->wq))
+       if (!n->mm)
+               return;
+       mmdrop(n->mm);
+       if (n->halted)
+               smp_send_reschedule(n->cpu);
+       else if (waitqueue_active(&n->wq))
                wake_up(&n->wq);
 }
 
                }
                n->token = token;
                n->cpu = smp_processor_id();
+               n->mm = NULL;
                init_waitqueue_head(&n->wq);
                hlist_add_head(&n->link, &b->list);
        } else