vmcb12 = map.hva;
 
+       if (WARN_ON_ONCE(!svm->nested.initialized))
+               return -EINVAL;
+
        if (!nested_vmcb_checks(svm, vmcb12)) {
                vmcb12->control.exit_code    = SVM_EXIT_ERR;
                vmcb12->control.exit_code_hi = 0;
        return 0;
 }
 
+int svm_allocate_nested(struct vcpu_svm *svm)
+{
+       struct page *hsave_page;
+
+       if (svm->nested.initialized)
+               return 0;
+
+       hsave_page = alloc_page(GFP_KERNEL_ACCOUNT | __GFP_ZERO);
+       if (!hsave_page)
+               return -ENOMEM;
+       svm->nested.hsave = page_address(hsave_page);
+
+       svm->nested.msrpm = svm_vcpu_alloc_msrpm();
+       if (!svm->nested.msrpm)
+               goto err_free_hsave;
+       svm_vcpu_init_msrpm(&svm->vcpu, svm->nested.msrpm);
+
+       svm->nested.initialized = true;
+       return 0;
+
+err_free_hsave:
+       __free_page(hsave_page);
+       return -ENOMEM;
+}
+
+void svm_free_nested(struct vcpu_svm *svm)
+{
+       if (!svm->nested.initialized)
+               return;
+
+       svm_vcpu_free_msrpm(svm->nested.msrpm);
+       svm->nested.msrpm = NULL;
+
+       __free_page(virt_to_page(svm->nested.hsave));
+       svm->nested.hsave = NULL;
+
+       svm->nested.initialized = false;
+}
+
 /*
  * Forcibly leave nested mode in order to be able to reset the VCPU later on.
  */
 
 int svm_set_efer(struct kvm_vcpu *vcpu, u64 efer)
 {
        struct vcpu_svm *svm = to_svm(vcpu);
+       u64 old_efer = vcpu->arch.efer;
        vcpu->arch.efer = efer;
 
        if (!npt_enabled) {
                        efer &= ~EFER_LME;
        }
 
-       if (!(efer & EFER_SVME)) {
-               svm_leave_nested(svm);
-               svm_set_gif(svm, true);
+       if ((old_efer & EFER_SVME) != (efer & EFER_SVME)) {
+               if (!(efer & EFER_SVME)) {
+                       svm_leave_nested(svm);
+                       svm_set_gif(svm, true);
+
+                       /*
+                        * Free the nested guest state, unless we are in SMM.
+                        * In this case we will return to the nested guest
+                        * as soon as we leave SMM.
+                        */
+                       if (!is_smm(&svm->vcpu))
+                               svm_free_nested(svm);
+
+               } else {
+                       int ret = svm_allocate_nested(svm);
+
+                       if (ret) {
+                               vcpu->arch.efer = old_efer;
+                               return ret;
+                       }
+               }
        }
 
        svm->vmcb->save.efer = efer | EFER_SVME;
        set_msr_interception_bitmap(vcpu, msrpm, msr, read, write);
 }
 
-static u32 *svm_vcpu_alloc_msrpm(void)
+u32 *svm_vcpu_alloc_msrpm(void)
 {
        struct page *pages = alloc_pages(GFP_KERNEL_ACCOUNT, MSRPM_ALLOC_ORDER);
        u32 *msrpm;
        return msrpm;
 }
 
-static void svm_vcpu_init_msrpm(struct kvm_vcpu *vcpu, u32 *msrpm)
+void svm_vcpu_init_msrpm(struct kvm_vcpu *vcpu, u32 *msrpm)
 {
        int i;
 
        }
 }
 
-static void svm_vcpu_free_msrpm(u32 *msrpm)
+
+void svm_vcpu_free_msrpm(u32 *msrpm)
 {
        __free_pages(virt_to_page(msrpm), MSRPM_ALLOC_ORDER);
 }
 {
        struct vcpu_svm *svm;
        struct page *vmcb_page;
-       struct page *hsave_page;
        int err;
 
        BUILD_BUG_ON(offsetof(struct vcpu_svm, vcpu) != 0);
        if (!vmcb_page)
                goto out;
 
-       hsave_page = alloc_page(GFP_KERNEL_ACCOUNT | __GFP_ZERO);
-       if (!hsave_page)
-               goto error_free_vmcb_page;
-
        err = avic_init_vcpu(svm);
        if (err)
-               goto error_free_hsave_page;
+               goto error_free_vmcb_page;
 
        /* We initialize this flag to true to make sure that the is_running
         * bit would be set the first time the vcpu is loaded.
        if (irqchip_in_kernel(vcpu->kvm) && kvm_apicv_activated(vcpu->kvm))
                svm->avic_is_running = true;
 
-       svm->nested.hsave = page_address(hsave_page);
-
        svm->msrpm = svm_vcpu_alloc_msrpm();
        if (!svm->msrpm)
-               goto error_free_hsave_page;
+               goto error_free_vmcb_page;
 
        svm_vcpu_init_msrpm(vcpu, svm->msrpm);
 
-       svm->nested.msrpm = svm_vcpu_alloc_msrpm();
-       if (!svm->nested.msrpm)
-               goto error_free_msrpm;
-
-       /* We only need the L1 pass-through MSR state, so leave vcpu as NULL */
-       svm_vcpu_init_msrpm(vcpu, svm->nested.msrpm);
-
        svm->vmcb = page_address(vmcb_page);
        svm->vmcb_pa = __sme_set(page_to_pfn(vmcb_page) << PAGE_SHIFT);
        svm->asid_generation = 0;
 
        return 0;
 
-error_free_msrpm:
-       svm_vcpu_free_msrpm(svm->msrpm);
-error_free_hsave_page:
-       __free_page(hsave_page);
 error_free_vmcb_page:
        __free_page(vmcb_page);
 out:
         */
        svm_clear_current_vmcb(svm->vmcb);
 
+       svm_free_nested(svm);
+
        __free_page(pfn_to_page(__sme_clr(svm->vmcb_pa) >> PAGE_SHIFT));
        __free_pages(virt_to_page(svm->msrpm), MSRPM_ALLOC_ORDER);
-       __free_page(virt_to_page(svm->nested.hsave));
-       __free_pages(virt_to_page(svm->nested.msrpm), MSRPM_ALLOC_ORDER);
 }
 
 static void svm_vcpu_load(struct kvm_vcpu *vcpu, int cpu)
                                         gpa_to_gfn(vmcb12_gpa), &map) == -EINVAL)
                                return 1;
 
+                       if (svm_allocate_nested(svm))
+                               return 1;
+
                        ret = enter_svm_guest_mode(svm, vmcb12_gpa, map.hva);
                        kvm_vcpu_unmap(&svm->vcpu, &map, true);
                }
 
 
        /* cache for control fields of the guest */
        struct vmcb_control_area ctl;
+
+       bool initialized;
 };
 
 struct vcpu_svm {
 #define MSR_INVALID                            0xffffffffU
 
 u32 svm_msrpm_offset(u32 msr);
+u32 *svm_vcpu_alloc_msrpm(void);
+void svm_vcpu_init_msrpm(struct kvm_vcpu *vcpu, u32 *msrpm);
+void svm_vcpu_free_msrpm(u32 *msrpm);
+
 int svm_set_efer(struct kvm_vcpu *vcpu, u64 efer);
 void svm_set_cr0(struct kvm_vcpu *vcpu, unsigned long cr0);
 int svm_set_cr4(struct kvm_vcpu *vcpu, unsigned long cr4);
 int enter_svm_guest_mode(struct vcpu_svm *svm, u64 vmcb_gpa,
                         struct vmcb *nested_vmcb);
 void svm_leave_nested(struct vcpu_svm *svm);
+void svm_free_nested(struct vcpu_svm *svm);
+int svm_allocate_nested(struct vcpu_svm *svm);
 int nested_svm_vmrun(struct vcpu_svm *svm);
 void nested_svm_vmloadsave(struct vmcb *from_vmcb, struct vmcb *to_vmcb);
 int nested_svm_vmexit(struct vcpu_svm *svm);