mmu_free_roots(vcpu);
 }
 
-static int nonpaging_init_context(struct kvm_vcpu *vcpu)
+static int nonpaging_init_context(struct kvm_vcpu *vcpu,
+                                 struct kvm_mmu *context)
 {
-       struct kvm_mmu *context = &vcpu->arch.mmu;
-
        context->new_cr3 = nonpaging_new_cr3;
        context->page_fault = nonpaging_page_fault;
        context->gva_to_gpa = nonpaging_gva_to_gpa;
 #include "paging_tmpl.h"
 #undef PTTYPE
 
-static void reset_rsvds_bits_mask(struct kvm_vcpu *vcpu, int level)
+static void reset_rsvds_bits_mask(struct kvm_vcpu *vcpu,
+                                 struct kvm_mmu *context,
+                                 int level)
 {
-       struct kvm_mmu *context = &vcpu->arch.mmu;
        int maxphyaddr = cpuid_maxphyaddr(vcpu);
        u64 exb_bit_rsvd = 0;
 
        }
 }
 
-static int paging64_init_context_common(struct kvm_vcpu *vcpu, int level)
+static int paging64_init_context_common(struct kvm_vcpu *vcpu,
+                                       struct kvm_mmu *context,
+                                       int level)
 {
-       struct kvm_mmu *context = &vcpu->arch.mmu;
+       reset_rsvds_bits_mask(vcpu, context, level);
 
        ASSERT(is_pae(vcpu));
        context->new_cr3 = paging_new_cr3;
        return 0;
 }
 
-static int paging64_init_context(struct kvm_vcpu *vcpu)
+static int paging64_init_context(struct kvm_vcpu *vcpu,
+                                struct kvm_mmu *context)
 {
-       reset_rsvds_bits_mask(vcpu, PT64_ROOT_LEVEL);
-       return paging64_init_context_common(vcpu, PT64_ROOT_LEVEL);
+       return paging64_init_context_common(vcpu, context, PT64_ROOT_LEVEL);
 }
 
-static int paging32_init_context(struct kvm_vcpu *vcpu)
+static int paging32_init_context(struct kvm_vcpu *vcpu,
+                                struct kvm_mmu *context)
 {
-       struct kvm_mmu *context = &vcpu->arch.mmu;
+       reset_rsvds_bits_mask(vcpu, context, PT32_ROOT_LEVEL);
 
-       reset_rsvds_bits_mask(vcpu, PT32_ROOT_LEVEL);
        context->new_cr3 = paging_new_cr3;
        context->page_fault = paging32_page_fault;
        context->gva_to_gpa = paging32_gva_to_gpa;
        return 0;
 }
 
-static int paging32E_init_context(struct kvm_vcpu *vcpu)
+static int paging32E_init_context(struct kvm_vcpu *vcpu,
+                                 struct kvm_mmu *context)
 {
-       reset_rsvds_bits_mask(vcpu, PT32E_ROOT_LEVEL);
-       return paging64_init_context_common(vcpu, PT32E_ROOT_LEVEL);
+       return paging64_init_context_common(vcpu, context, PT32E_ROOT_LEVEL);
 }
 
 static int init_kvm_tdp_mmu(struct kvm_vcpu *vcpu)
                context->gva_to_gpa = nonpaging_gva_to_gpa;
                context->root_level = 0;
        } else if (is_long_mode(vcpu)) {
-               reset_rsvds_bits_mask(vcpu, PT64_ROOT_LEVEL);
+               reset_rsvds_bits_mask(vcpu, context, PT64_ROOT_LEVEL);
                context->gva_to_gpa = paging64_gva_to_gpa;
                context->root_level = PT64_ROOT_LEVEL;
        } else if (is_pae(vcpu)) {
-               reset_rsvds_bits_mask(vcpu, PT32E_ROOT_LEVEL);
+               reset_rsvds_bits_mask(vcpu, context, PT32E_ROOT_LEVEL);
                context->gva_to_gpa = paging64_gva_to_gpa;
                context->root_level = PT32E_ROOT_LEVEL;
        } else {
-               reset_rsvds_bits_mask(vcpu, PT32_ROOT_LEVEL);
+               reset_rsvds_bits_mask(vcpu, context, PT32_ROOT_LEVEL);
                context->gva_to_gpa = paging32_gva_to_gpa;
                context->root_level = PT32_ROOT_LEVEL;
        }
        return 0;
 }
 
-static int init_kvm_softmmu(struct kvm_vcpu *vcpu)
+int kvm_init_shadow_mmu(struct kvm_vcpu *vcpu, struct kvm_mmu *context)
 {
        int r;
-
        ASSERT(vcpu);
        ASSERT(!VALID_PAGE(vcpu->arch.mmu.root_hpa));
 
        if (!is_paging(vcpu))
-               r = nonpaging_init_context(vcpu);
+               r = nonpaging_init_context(vcpu, context);
        else if (is_long_mode(vcpu))
-               r = paging64_init_context(vcpu);
+               r = paging64_init_context(vcpu, context);
        else if (is_pae(vcpu))
-               r = paging32E_init_context(vcpu);
+               r = paging32E_init_context(vcpu, context);
        else
-               r = paging32_init_context(vcpu);
+               r = paging32_init_context(vcpu, context);
 
        vcpu->arch.mmu.base_role.cr4_pae = !!is_pae(vcpu);
        vcpu->arch.mmu.base_role.cr0_wp  = is_write_protection(vcpu);
+
+       return r;
+}
+EXPORT_SYMBOL_GPL(kvm_init_shadow_mmu);
+
+static int init_kvm_softmmu(struct kvm_vcpu *vcpu)
+{
+       int r = kvm_init_shadow_mmu(vcpu, &vcpu->arch.mmu);
+
        vcpu->arch.mmu.set_cr3           = kvm_x86_ops->set_cr3;
        vcpu->arch.mmu.get_cr3           = get_cr3;
        vcpu->arch.mmu.inject_page_fault = kvm_inject_page_fault;