static void store_regs(struct kvm_vcpu *vcpu);
 static int sync_regs(struct kvm_vcpu *vcpu);
 
+static int __set_sregs2(struct kvm_vcpu *vcpu, struct kvm_sregs2 *sregs2);
+static void __get_sregs2(struct kvm_vcpu *vcpu, struct kvm_sregs2 *sregs2);
+
 struct kvm_x86_ops kvm_x86_ops __read_mostly;
 EXPORT_SYMBOL_GPL(kvm_x86_ops);
 
 
        memcpy(mmu->pdptrs, pdpte, sizeof(mmu->pdptrs));
        kvm_register_mark_dirty(vcpu, VCPU_EXREG_PDPTR);
-
 out:
 
        return ret;
        case KVM_CAP_SGX_ATTRIBUTE:
 #endif
        case KVM_CAP_VM_COPY_ENC_CONTEXT_FROM:
+       case KVM_CAP_SREGS2:
                r = 1;
                break;
        case KVM_CAP_SET_GUEST_DEBUG2:
        void __user *argp = (void __user *)arg;
        int r;
        union {
+               struct kvm_sregs2 *sregs2;
                struct kvm_lapic_state *lapic;
                struct kvm_xsave *xsave;
                struct kvm_xcrs *xcrs;
                break;
        }
 #endif
+       case KVM_GET_SREGS2: {
+               u.sregs2 = kzalloc(sizeof(struct kvm_sregs2), GFP_KERNEL);
+               r = -ENOMEM;
+               if (!u.sregs2)
+                       goto out;
+               __get_sregs2(vcpu, u.sregs2);
+               r = -EFAULT;
+               if (copy_to_user(argp, u.sregs2, sizeof(struct kvm_sregs2)))
+                       goto out;
+               r = 0;
+               break;
+       }
+       case KVM_SET_SREGS2: {
+               u.sregs2 = memdup_user(argp, sizeof(struct kvm_sregs2));
+               if (IS_ERR(u.sregs2)) {
+                       r = PTR_ERR(u.sregs2);
+                       u.sregs2 = NULL;
+                       goto out;
+               }
+               r = __set_sregs2(vcpu, u.sregs2);
+               break;
+       }
        default:
                r = -EINVAL;
        }
 }
 EXPORT_SYMBOL_GPL(kvm_get_cs_db_l_bits);
 
-static void __get_sregs(struct kvm_vcpu *vcpu, struct kvm_sregs *sregs)
+static void __get_sregs_common(struct kvm_vcpu *vcpu, struct kvm_sregs *sregs)
 {
        struct desc_ptr dt;
 
        sregs->cr8 = kvm_get_cr8(vcpu);
        sregs->efer = vcpu->arch.efer;
        sregs->apic_base = kvm_get_apic_base(vcpu);
+}
 
-       memset(sregs->interrupt_bitmap, 0, sizeof(sregs->interrupt_bitmap));
+static void __get_sregs(struct kvm_vcpu *vcpu, struct kvm_sregs *sregs)
+{
+       __get_sregs_common(vcpu, sregs);
+
+       if (vcpu->arch.guest_state_protected)
+               return;
 
        if (vcpu->arch.interrupt.injected && !vcpu->arch.interrupt.soft)
                set_bit(vcpu->arch.interrupt.nr,
                        (unsigned long *)sregs->interrupt_bitmap);
 }
 
+static void __get_sregs2(struct kvm_vcpu *vcpu, struct kvm_sregs2 *sregs2)
+{
+       int i;
+
+       __get_sregs_common(vcpu, (struct kvm_sregs *)sregs2);
+
+       if (vcpu->arch.guest_state_protected)
+               return;
+
+       if (is_pae_paging(vcpu)) {
+               for (i = 0 ; i < 4 ; i++)
+                       sregs2->pdptrs[i] = kvm_pdptr_read(vcpu, i);
+               sregs2->flags |= KVM_SREGS2_FLAGS_PDPTRS_VALID;
+       }
+}
+
 int kvm_arch_vcpu_ioctl_get_sregs(struct kvm_vcpu *vcpu,
                                  struct kvm_sregs *sregs)
 {
        return kvm_is_valid_cr4(vcpu, sregs->cr4);
 }
 
-static int __set_sregs(struct kvm_vcpu *vcpu, struct kvm_sregs *sregs)
+static int __set_sregs_common(struct kvm_vcpu *vcpu, struct kvm_sregs *sregs,
+               int *mmu_reset_needed, bool update_pdptrs)
 {
        struct msr_data apic_base_msr;
-       int mmu_reset_needed = 0;
-       int pending_vec, max_bits, idx;
+       int idx;
        struct desc_ptr dt;
-       int ret = -EINVAL;
 
        if (!kvm_is_valid_sregs(vcpu, sregs))
-               goto out;
+               return -EINVAL;
 
        apic_base_msr.data = sregs->apic_base;
        apic_base_msr.host_initiated = true;
        if (kvm_set_apic_base(vcpu, &apic_base_msr))
-               goto out;
+               return -EINVAL;
 
        if (vcpu->arch.guest_state_protected)
-               goto skip_protected_regs;
+               return 0;
 
        dt.size = sregs->idt.limit;
        dt.address = sregs->idt.base;
        static_call(kvm_x86_set_gdt)(vcpu, &dt);
 
        vcpu->arch.cr2 = sregs->cr2;
-       mmu_reset_needed |= kvm_read_cr3(vcpu) != sregs->cr3;
+       *mmu_reset_needed |= kvm_read_cr3(vcpu) != sregs->cr3;
        vcpu->arch.cr3 = sregs->cr3;
        kvm_register_mark_available(vcpu, VCPU_EXREG_CR3);
 
        kvm_set_cr8(vcpu, sregs->cr8);
 
-       mmu_reset_needed |= vcpu->arch.efer != sregs->efer;
+       *mmu_reset_needed |= vcpu->arch.efer != sregs->efer;
        static_call(kvm_x86_set_efer)(vcpu, sregs->efer);
 
-       mmu_reset_needed |= kvm_read_cr0(vcpu) != sregs->cr0;
+       *mmu_reset_needed |= kvm_read_cr0(vcpu) != sregs->cr0;
        static_call(kvm_x86_set_cr0)(vcpu, sregs->cr0);
        vcpu->arch.cr0 = sregs->cr0;
 
-       mmu_reset_needed |= kvm_read_cr4(vcpu) != sregs->cr4;
+       *mmu_reset_needed |= kvm_read_cr4(vcpu) != sregs->cr4;
        static_call(kvm_x86_set_cr4)(vcpu, sregs->cr4);
 
-       idx = srcu_read_lock(&vcpu->kvm->srcu);
-       if (is_pae_paging(vcpu)) {
-               load_pdptrs(vcpu, vcpu->arch.walk_mmu, kvm_read_cr3(vcpu));
-               mmu_reset_needed = 1;
+       if (update_pdptrs) {
+               idx = srcu_read_lock(&vcpu->kvm->srcu);
+               if (is_pae_paging(vcpu)) {
+                       load_pdptrs(vcpu, vcpu->arch.walk_mmu, kvm_read_cr3(vcpu));
+                       *mmu_reset_needed = 1;
+               }
+               srcu_read_unlock(&vcpu->kvm->srcu, idx);
        }
-       srcu_read_unlock(&vcpu->kvm->srcu, idx);
-
-       if (mmu_reset_needed)
-               kvm_mmu_reset_context(vcpu);
 
        kvm_set_segment(vcpu, &sregs->cs, VCPU_SREG_CS);
        kvm_set_segment(vcpu, &sregs->ds, VCPU_SREG_DS);
            !is_protmode(vcpu))
                vcpu->arch.mp_state = KVM_MP_STATE_RUNNABLE;
 
-skip_protected_regs:
+       return 0;
+}
+
+static int __set_sregs(struct kvm_vcpu *vcpu, struct kvm_sregs *sregs)
+{
+       int pending_vec, max_bits;
+       int mmu_reset_needed = 0;
+       int ret = __set_sregs_common(vcpu, sregs, &mmu_reset_needed, true);
+
+       if (ret)
+               return ret;
+
+       if (mmu_reset_needed)
+               kvm_mmu_reset_context(vcpu);
+
        max_bits = KVM_NR_INTERRUPTS;
        pending_vec = find_first_bit(
                (const unsigned long *)sregs->interrupt_bitmap, max_bits);
+
        if (pending_vec < max_bits) {
                kvm_queue_interrupt(vcpu, pending_vec, false);
                pr_debug("Set back pending irq %d\n", pending_vec);
+               kvm_make_request(KVM_REQ_EVENT, vcpu);
        }
+       return 0;
+}
 
-       kvm_make_request(KVM_REQ_EVENT, vcpu);
+static int __set_sregs2(struct kvm_vcpu *vcpu, struct kvm_sregs2 *sregs2)
+{
+       int mmu_reset_needed = 0;
+       bool valid_pdptrs = sregs2->flags & KVM_SREGS2_FLAGS_PDPTRS_VALID;
+       bool pae = (sregs2->cr0 & X86_CR0_PG) && (sregs2->cr4 & X86_CR4_PAE) &&
+               !(sregs2->efer & EFER_LMA);
+       int i, ret;
 
-       ret = 0;
-out:
-       return ret;
+       if (sregs2->flags & ~KVM_SREGS2_FLAGS_PDPTRS_VALID)
+               return -EINVAL;
+
+       if (valid_pdptrs && (!pae || vcpu->arch.guest_state_protected))
+               return -EINVAL;
+
+       ret = __set_sregs_common(vcpu, (struct kvm_sregs *)sregs2,
+                                &mmu_reset_needed, !valid_pdptrs);
+       if (ret)
+               return ret;
+
+       if (valid_pdptrs) {
+               for (i = 0; i < 4 ; i++)
+                       kvm_pdptr_write(vcpu, i, sregs2->pdptrs[i]);
+
+               kvm_register_mark_dirty(vcpu, VCPU_EXREG_PDPTR);
+               mmu_reset_needed = 1;
+       }
+       if (mmu_reset_needed)
+               kvm_mmu_reset_context(vcpu);
+       return 0;
 }
 
 int kvm_arch_vcpu_ioctl_set_sregs(struct kvm_vcpu *vcpu,