#include "mmu.h"
 #include "cpuid.h"
 #include "lapic.h"
+#include "hyperv.h"
 
 #include <linux/kvm_host.h>
 #include <linux/module.h>
                bool guest_mode;
        } smm;
 
+       gpa_t hv_evmcs_vmptr;
+       struct page *hv_evmcs_page;
        struct hv_enlightened_vmcs *hv_evmcs;
 };
 
 static int nested_vmx_failValid(struct kvm_vcpu *vcpu,
                                u32 vm_instruction_error)
 {
+       struct vcpu_vmx *vmx = to_vmx(vcpu);
+
        /*
         * failValid writes the error number to the current VMCS, which
         * can't be done if there isn't a current VMCS.
         */
-       if (to_vmx(vcpu)->nested.current_vmptr == -1ull)
+       if (vmx->nested.current_vmptr == -1ull && !vmx->nested.hv_evmcs)
                return nested_vmx_failInvalid(vcpu);
 
        vmx_set_rflags(vcpu, (vmx_get_rflags(vcpu)
        vmcs_write64(VMCS_LINK_POINTER, -1ull);
 }
 
+static inline void nested_release_evmcs(struct kvm_vcpu *vcpu)
+{
+       struct vcpu_vmx *vmx = to_vmx(vcpu);
+
+       if (!vmx->nested.hv_evmcs)
+               return;
+
+       kunmap(vmx->nested.hv_evmcs_page);
+       kvm_release_page_dirty(vmx->nested.hv_evmcs_page);
+       vmx->nested.hv_evmcs_vmptr = -1ull;
+       vmx->nested.hv_evmcs_page = NULL;
+       vmx->nested.hv_evmcs = NULL;
+}
+
 static inline void nested_release_vmcs12(struct kvm_vcpu *vcpu)
 {
        struct vcpu_vmx *vmx = to_vmx(vcpu);
 
        kvm_mmu_free_roots(vcpu, &vcpu->arch.guest_mmu, KVM_MMU_ROOTS_ALL);
 
+       nested_release_evmcs(vcpu);
+
        free_loaded_vmcs(&vmx->nested.vmcs02);
 }
 
                return nested_vmx_failValid(vcpu,
                        VMXERR_VMCLEAR_VMXON_POINTER);
 
-       if (vmptr == vmx->nested.current_vmptr)
-               nested_release_vmcs12(vcpu);
+       if (vmx->nested.hv_evmcs_page) {
+               if (vmptr == vmx->nested.hv_evmcs_vmptr)
+                       nested_release_evmcs(vcpu);
+       } else {
+               if (vmptr == vmx->nested.current_vmptr)
+                       nested_release_vmcs12(vcpu);
 
-       kvm_vcpu_write_guest(vcpu,
-                       vmptr + offsetof(struct vmcs12, launch_state),
-                       &zero, sizeof(zero));
+               kvm_vcpu_write_guest(vcpu,
+                                    vmptr + offsetof(struct vmcs12,
+                                                     launch_state),
+                                    &zero, sizeof(zero));
+       }
 
        return nested_vmx_succeed(vcpu);
 }
        struct vmcs12 *vmcs12 = vmx->nested.cached_vmcs12;
        struct hv_enlightened_vmcs *evmcs = vmx->nested.hv_evmcs;
 
+       vmcs12->hdr.revision_id = evmcs->revision_id;
+
        /* HV_VMX_ENLIGHTENED_CLEAN_FIELD_NONE */
        vmcs12->tpr_threshold = evmcs->tpr_threshold;
        vmcs12->guest_rip = evmcs->guest_rip;
                return nested_vmx_failValid(vcpu,
                        VMXERR_VMPTRLD_VMXON_POINTER);
 
+       /* Forbid normal VMPTRLD if Enlightened version was used */
+       if (vmx->nested.hv_evmcs)
+               return 1;
+
        if (vmx->nested.current_vmptr != vmptr) {
                struct vmcs12 *new_vmcs12;
                struct page *page;
        return nested_vmx_succeed(vcpu);
 }
 
+/*
+ * This is an equivalent of the nested hypervisor executing the vmptrld
+ * instruction.
+ */
+static int nested_vmx_handle_enlightened_vmptrld(struct kvm_vcpu *vcpu)
+{
+       struct vcpu_vmx *vmx = to_vmx(vcpu);
+       struct hv_vp_assist_page assist_page;
+
+       if (likely(!vmx->nested.enlightened_vmcs_enabled))
+               return 1;
+
+       if (unlikely(!kvm_hv_get_assist_page(vcpu, &assist_page)))
+               return 1;
+
+       if (unlikely(!assist_page.enlighten_vmentry))
+               return 1;
+
+       if (unlikely(assist_page.current_nested_vmcs !=
+                    vmx->nested.hv_evmcs_vmptr)) {
+
+               if (!vmx->nested.hv_evmcs)
+                       vmx->nested.current_vmptr = -1ull;
+
+               nested_release_evmcs(vcpu);
+
+               vmx->nested.hv_evmcs_page = kvm_vcpu_gpa_to_page(
+                       vcpu, assist_page.current_nested_vmcs);
+
+               if (unlikely(is_error_page(vmx->nested.hv_evmcs_page)))
+                       return 0;
+
+               vmx->nested.hv_evmcs = kmap(vmx->nested.hv_evmcs_page);
+
+               if (vmx->nested.hv_evmcs->revision_id != VMCS12_REVISION) {
+                       nested_release_evmcs(vcpu);
+                       return 0;
+               }
+
+               vmx->nested.dirty_vmcs12 = true;
+               /*
+                * As we keep L2 state for one guest only 'hv_clean_fields' mask
+                * can't be used when we switch between them. Reset it here for
+                * simplicity.
+                */
+               vmx->nested.hv_evmcs->hv_clean_fields &=
+                       ~HV_VMX_ENLIGHTENED_CLEAN_FIELD_ALL;
+               vmx->nested.hv_evmcs_vmptr = assist_page.current_nested_vmcs;
+
+               /*
+                * Unlike normal vmcs12, enlightened vmcs12 is not fully
+                * reloaded from guest's memory (read only fields, fields not
+                * present in struct hv_enlightened_vmcs, ...). Make sure there
+                * are no leftovers.
+                */
+               memset(vmx->nested.cached_vmcs12, 0,
+                      sizeof(*vmx->nested.cached_vmcs12));
+
+       }
+       return 1;
+}
+
 /* Emulate the VMPTRST instruction */
 static int handle_vmptrst(struct kvm_vcpu *vcpu)
 {
        if (!nested_vmx_check_permission(vcpu))
                return 1;
 
+       if (unlikely(to_vmx(vcpu)->nested.hv_evmcs))
+               return 1;
+
        if (get_vmx_mem_address(vcpu, exit_qual, instr_info, true, &gva))
                return 1;
        /* *_system ok, nested_vmx_check_permission has verified cpl=0 */
        if (!nested_vmx_check_permission(vcpu))
                return 1;
 
-       if (vmx->nested.current_vmptr == -1ull)
+       if (!nested_vmx_handle_enlightened_vmptrld(vcpu))
+               return 1;
+
+       if (!vmx->nested.hv_evmcs && vmx->nested.current_vmptr == -1ull)
                return nested_vmx_failInvalid(vcpu);
 
        vmcs12 = get_vmcs12(vcpu);