#include "mmu.h"
 #include "trace.h"
 
-static u32 xstate_required_size(u64 xstate_bv)
+static u32 xstate_required_size(u64 xstate_bv, bool compacted)
 {
        int feature_bit = 0;
        u32 ret = XSAVE_HDR_SIZE + XSAVE_HDR_OFFSET;
        xstate_bv &= XSTATE_EXTEND_MASK;
        while (xstate_bv) {
                if (xstate_bv & 0x1) {
-                       u32 eax, ebx, ecx, edx;
+                       u32 eax, ebx, ecx, edx, offset;
                        cpuid_count(0xD, feature_bit, &eax, &ebx, &ecx, &edx);
-                       ret = max(ret, eax + ebx);
+                       offset = compacted ? ret : ebx;
+                       ret = max(ret, offset + eax);
                }
 
                xstate_bv >>= 1;
                        (best->eax | ((u64)best->edx << 32)) &
                        kvm_supported_xcr0();
                vcpu->arch.guest_xstate_size = best->ebx =
-                       xstate_required_size(vcpu->arch.xcr0);
+                       xstate_required_size(vcpu->arch.xcr0, false);
        }
 
+       best = kvm_find_cpuid_entry(vcpu, 0xD, 1);
+       if (best && (best->eax & (F(XSAVES) | F(XSAVEC))))
+               best->ebx = xstate_required_size(vcpu->arch.xcr0, true);
+
        /*
         * The existing code assumes virtual address is 48-bit in the canonical
         * address checks; exit if it is ever changed.
                                goto out;
 
                        do_cpuid_1_ent(&entry[i], function, idx);
-                       if (idx == 1)
+                       if (idx == 1) {
                                entry[i].eax &= kvm_supported_word10_x86_features;
-                       else if (entry[i].eax == 0 || !(supported & mask))
+                               entry[i].ebx = 0;
+                               if (entry[i].eax & (F(XSAVES)|F(XSAVEC)))
+                                       entry[i].ebx =
+                                               xstate_required_size(supported,
+                                                                    true);
+                       } else if (entry[i].eax == 0 || !(supported & mask))
                                continue;
                        entry[i].flags |=
                               KVM_CPUID_FLAG_SIGNIFCANT_INDEX;