const u8 *old, const u8 *new, int bytes);
 int kvm_mmu_unprotect_page_virt(struct kvm_vcpu *vcpu, gva_t gva);
 void kvm_mmu_free_some_pages(struct kvm_vcpu *vcpu);
+int kvm_mmu_load(struct kvm_vcpu *vcpu);
+void kvm_mmu_unload(struct kvm_vcpu *vcpu);
 
 int kvm_hypercall(struct kvm_vcpu *vcpu, struct kvm_run *run);
 
        return vcpu->mmu.page_fault(vcpu, gva, error_code);
 }
 
+static inline int kvm_mmu_reload(struct kvm_vcpu *vcpu)
+{
+       if (likely(vcpu->mmu.root_hpa != INVALID_PAGE))
+               return 0;
+
+       return kvm_mmu_load(vcpu);
+}
+
 static inline int is_long_mode(struct kvm_vcpu *vcpu)
 {
 #ifdef CONFIG_X86_64
 
        context->free = nonpaging_free;
        context->root_level = 0;
        context->shadow_root_level = PT32E_ROOT_LEVEL;
-       mmu_alloc_roots(vcpu);
-       ASSERT(VALID_PAGE(context->root_hpa));
-       kvm_arch_ops->set_cr3(vcpu, context->root_hpa);
+       context->root_hpa = INVALID_PAGE;
        return 0;
 }
 
 {
        pgprintk("%s: cr3 %lx\n", __FUNCTION__, vcpu->cr3);
        mmu_free_roots(vcpu);
-       if (unlikely(vcpu->kvm->n_free_mmu_pages < KVM_MIN_FREE_MMU_PAGES))
-               kvm_mmu_free_some_pages(vcpu);
-       mmu_alloc_roots(vcpu);
-       kvm_mmu_flush_tlb(vcpu);
-       kvm_arch_ops->set_cr3(vcpu, vcpu->mmu.root_hpa);
 }
 
 static void inject_page_fault(struct kvm_vcpu *vcpu,
        context->free = paging_free;
        context->root_level = level;
        context->shadow_root_level = level;
-       mmu_alloc_roots(vcpu);
-       ASSERT(VALID_PAGE(context->root_hpa));
-       kvm_arch_ops->set_cr3(vcpu, context->root_hpa |
-                   (vcpu->cr3 & (CR3_PCD_MASK | CR3_WPT_MASK)));
+       context->root_hpa = INVALID_PAGE;
        return 0;
 }
 
        context->free = paging_free;
        context->root_level = PT32_ROOT_LEVEL;
        context->shadow_root_level = PT32E_ROOT_LEVEL;
-       mmu_alloc_roots(vcpu);
-       ASSERT(VALID_PAGE(context->root_hpa));
-       kvm_arch_ops->set_cr3(vcpu, context->root_hpa |
-                   (vcpu->cr3 & (CR3_PCD_MASK | CR3_WPT_MASK)));
+       context->root_hpa = INVALID_PAGE;
        return 0;
 }
 
        ASSERT(vcpu);
        ASSERT(!VALID_PAGE(vcpu->mmu.root_hpa));
 
-       mmu_topup_memory_caches(vcpu);
        if (!is_paging(vcpu))
                return nonpaging_init_context(vcpu);
        else if (is_long_mode(vcpu))
 }
 
 int kvm_mmu_reset_context(struct kvm_vcpu *vcpu)
+{
+       destroy_kvm_mmu(vcpu);
+       return init_kvm_mmu(vcpu);
+}
+
+int kvm_mmu_load(struct kvm_vcpu *vcpu)
 {
        int r;
 
-       destroy_kvm_mmu(vcpu);
-       r = init_kvm_mmu(vcpu);
-       if (r < 0)
-               goto out;
+       spin_lock(&vcpu->kvm->lock);
        r = mmu_topup_memory_caches(vcpu);
+       if (r)
+               goto out;
+       mmu_alloc_roots(vcpu);
+       kvm_arch_ops->set_cr3(vcpu, vcpu->mmu.root_hpa);
+       kvm_mmu_flush_tlb(vcpu);
 out:
+       spin_unlock(&vcpu->kvm->lock);
        return r;
 }
+EXPORT_SYMBOL_GPL(kvm_mmu_load);
+
+void kvm_mmu_unload(struct kvm_vcpu *vcpu)
+{
+       mmu_free_roots(vcpu);
+}
 
 static void mmu_pte_write_zap_pte(struct kvm_vcpu *vcpu,
                                  struct kvm_mmu_page *page,