context->update_pte = nonpaging_update_pte;
        context->root_level = 0;
        context->shadow_root_level = PT32E_ROOT_LEVEL;
-       context->root_hpa = INVALID_PAGE;
-       context->prev_root = KVM_MMU_ROOT_INFO_INVALID;
        context->direct_map = true;
        context->nx = false;
 }
        context->invlpg = paging64_invlpg;
        context->update_pte = paging64_update_pte;
        context->shadow_root_level = level;
-       context->root_hpa = INVALID_PAGE;
-       context->prev_root = KVM_MMU_ROOT_INFO_INVALID;
        context->direct_map = false;
 }
 
        context->invlpg = paging32_invlpg;
        context->update_pte = paging32_update_pte;
        context->shadow_root_level = PT32E_ROOT_LEVEL;
-       context->root_hpa = INVALID_PAGE;
-       context->prev_root = KVM_MMU_ROOT_INFO_INVALID;
        context->direct_map = false;
 }
 
        context->invlpg = nonpaging_invlpg;
        context->update_pte = nonpaging_update_pte;
        context->shadow_root_level = kvm_x86_ops->get_tdp_level(vcpu);
-       context->root_hpa = INVALID_PAGE;
-       context->prev_root = KVM_MMU_ROOT_INFO_INVALID;
        context->direct_map = true;
        context->set_cr3 = kvm_x86_ops->set_tdp_cr3;
        context->get_cr3 = get_cr3;
 {
        struct kvm_mmu *context = &vcpu->arch.mmu;
 
-       MMU_WARN_ON(VALID_PAGE(context->root_hpa));
-
        if (!is_paging(vcpu))
                nonpaging_init_context(vcpu, context);
        else if (is_long_mode(vcpu))
        union kvm_mmu_page_role root_page_role =
                kvm_calc_shadow_ept_root_page_role(vcpu, accessed_dirty);
 
-       MMU_WARN_ON(VALID_PAGE(context->root_hpa));
-
        context->shadow_root_level = PT64_ROOT_4LEVEL;
 
        context->nx = true;
        context->invlpg = ept_invlpg;
        context->update_pte = ept_update_pte;
        context->root_level = PT64_ROOT_4LEVEL;
-       context->root_hpa = INVALID_PAGE;
-       context->prev_root = KVM_MMU_ROOT_INFO_INVALID;
        context->direct_map = false;
        context->base_role.word = root_page_role.word & mmu_base_role_mask.word;
        update_permission_bitmask(vcpu, context, true);
        update_last_nonleaf_level(vcpu, g_context);
 }
 
-static void init_kvm_mmu(struct kvm_vcpu *vcpu)
+void kvm_init_mmu(struct kvm_vcpu *vcpu, bool reset_roots)
 {
+       if (reset_roots) {
+               vcpu->arch.mmu.root_hpa = INVALID_PAGE;
+               vcpu->arch.mmu.prev_root = KVM_MMU_ROOT_INFO_INVALID;
+       }
+
        if (mmu_is_nested(vcpu))
                init_kvm_nested_mmu(vcpu);
        else if (tdp_enabled)
        else
                init_kvm_softmmu(vcpu);
 }
+EXPORT_SYMBOL_GPL(kvm_init_mmu);
 
 static union kvm_mmu_page_role
 kvm_mmu_calc_root_page_role(struct kvm_vcpu *vcpu)
 void kvm_mmu_reset_context(struct kvm_vcpu *vcpu)
 {
        kvm_mmu_unload(vcpu);
-       init_kvm_mmu(vcpu);
+       kvm_init_mmu(vcpu, true);
 }
 EXPORT_SYMBOL_GPL(kvm_mmu_reset_context);
 
 {
        MMU_WARN_ON(VALID_PAGE(vcpu->arch.mmu.root_hpa));
 
-       init_kvm_mmu(vcpu);
+       kvm_init_mmu(vcpu, true);
 }
 
 static void kvm_mmu_invalidate_zap_pages_in_memslot(struct kvm *kvm,