return 0;
 }
 
-static int sev_cpuid_hv(struct cpuid_leaf *leaf)
+static int __sev_cpuid_hv_msr(struct cpuid_leaf *leaf)
 {
        int ret;
 
        return ret;
 }
 
+static int __sev_cpuid_hv_ghcb(struct ghcb *ghcb, struct es_em_ctxt *ctxt, struct cpuid_leaf *leaf)
+{
+       u32 cr4 = native_read_cr4();
+       int ret;
+
+       ghcb_set_rax(ghcb, leaf->fn);
+       ghcb_set_rcx(ghcb, leaf->subfn);
+
+       if (cr4 & X86_CR4_OSXSAVE)
+               /* Safe to read xcr0 */
+               ghcb_set_xcr0(ghcb, xgetbv(XCR_XFEATURE_ENABLED_MASK));
+       else
+               /* xgetbv will cause #UD - use reset value for xcr0 */
+               ghcb_set_xcr0(ghcb, 1);
+
+       ret = sev_es_ghcb_hv_call(ghcb, ctxt, SVM_EXIT_CPUID, 0, 0);
+       if (ret != ES_OK)
+               return ret;
+
+       if (!(ghcb_rax_is_valid(ghcb) &&
+             ghcb_rbx_is_valid(ghcb) &&
+             ghcb_rcx_is_valid(ghcb) &&
+             ghcb_rdx_is_valid(ghcb)))
+               return ES_VMM_ERROR;
+
+       leaf->eax = ghcb->save.rax;
+       leaf->ebx = ghcb->save.rbx;
+       leaf->ecx = ghcb->save.rcx;
+       leaf->edx = ghcb->save.rdx;
+
+       return ES_OK;
+}
+
+static int sev_cpuid_hv(struct ghcb *ghcb, struct es_em_ctxt *ctxt, struct cpuid_leaf *leaf)
+{
+       return ghcb ? __sev_cpuid_hv_ghcb(ghcb, ctxt, leaf)
+                   : __sev_cpuid_hv_msr(leaf);
+}
+
 /*
  * This may be called early while still running on the initial identity
  * mapping. Use RIP-relative addressing to obtain the correct address
        return false;
 }
 
-static void snp_cpuid_hv(struct cpuid_leaf *leaf)
+static void snp_cpuid_hv(struct ghcb *ghcb, struct es_em_ctxt *ctxt, struct cpuid_leaf *leaf)
 {
-       if (sev_cpuid_hv(leaf))
+       if (sev_cpuid_hv(ghcb, ctxt, leaf))
                sev_es_terminate(SEV_TERM_SET_LINUX, GHCB_TERM_CPUID_HV);
 }
 
-static int snp_cpuid_postprocess(struct cpuid_leaf *leaf)
+static int snp_cpuid_postprocess(struct ghcb *ghcb, struct es_em_ctxt *ctxt,
+                                struct cpuid_leaf *leaf)
 {
        struct cpuid_leaf leaf_hv = *leaf;
 
        switch (leaf->fn) {
        case 0x1:
-               snp_cpuid_hv(&leaf_hv);
+               snp_cpuid_hv(ghcb, ctxt, &leaf_hv);
 
                /* initial APIC ID */
                leaf->ebx = (leaf_hv.ebx & GENMASK(31, 24)) | (leaf->ebx & GENMASK(23, 0));
                break;
        case 0xB:
                leaf_hv.subfn = 0;
-               snp_cpuid_hv(&leaf_hv);
+               snp_cpuid_hv(ghcb, ctxt, &leaf_hv);
 
                /* extended APIC ID */
                leaf->edx = leaf_hv.edx;
                }
                break;
        case 0x8000001E:
-               snp_cpuid_hv(&leaf_hv);
+               snp_cpuid_hv(ghcb, ctxt, &leaf_hv);
 
                /* extended APIC ID */
                leaf->eax = leaf_hv.eax;
  * Returns -EOPNOTSUPP if feature not enabled. Any other non-zero return value
  * should be treated as fatal by caller.
  */
-static int snp_cpuid(struct cpuid_leaf *leaf)
+static int snp_cpuid(struct ghcb *ghcb, struct es_em_ctxt *ctxt, struct cpuid_leaf *leaf)
 {
        const struct snp_cpuid_table *cpuid_table = snp_cpuid_get_table();
 
                        return 0;
        }
 
-       return snp_cpuid_postprocess(leaf);
+       return snp_cpuid_postprocess(ghcb, ctxt, leaf);
 }
 
 /*
        leaf.fn = fn;
        leaf.subfn = subfn;
 
-       ret = snp_cpuid(&leaf);
+       ret = snp_cpuid(NULL, NULL, &leaf);
        if (!ret)
                goto cpuid_done;
 
        if (ret != -EOPNOTSUPP)
                goto fail;
 
-       if (sev_cpuid_hv(&leaf))
+       if (__sev_cpuid_hv_msr(&leaf))
                goto fail;
 
 cpuid_done:
        return ret;
 }
 
-static int vc_handle_cpuid_snp(struct pt_regs *regs)
+static int vc_handle_cpuid_snp(struct ghcb *ghcb, struct es_em_ctxt *ctxt)
 {
+       struct pt_regs *regs = ctxt->regs;
        struct cpuid_leaf leaf;
        int ret;
 
        leaf.fn = regs->ax;
        leaf.subfn = regs->cx;
-       ret = snp_cpuid(&leaf);
+       ret = snp_cpuid(ghcb, ctxt, &leaf);
        if (!ret) {
                regs->ax = leaf.eax;
                regs->bx = leaf.ebx;
        enum es_result ret;
        int snp_cpuid_ret;
 
-       snp_cpuid_ret = vc_handle_cpuid_snp(regs);
+       snp_cpuid_ret = vc_handle_cpuid_snp(ghcb, ctxt);
        if (!snp_cpuid_ret)
                return ES_OK;
        if (snp_cpuid_ret != -EOPNOTSUPP)