src = &prev->host_state;
        dest = &vmx->loaded_vmcs->host_state;
 
-       vmx_set_vmcs_host_state(dest, src->cr3, src->fs_sel, src->gs_sel,
-                               src->fs_base, src->gs_base);
+       vmx_set_host_fs_gs(dest, src->fs_sel, src->gs_sel, src->fs_base, src->gs_base);
        dest->ldt_sel = src->ldt_sel;
 #ifdef CONFIG_X86_64
        dest->ds_sel = src->ds_sel;
 
                wrmsrl(MSR_IA32_RTIT_CTL, vmx->pt_desc.host.ctl);
 }
 
-void vmx_set_vmcs_host_state(struct vmcs_host_state *host, unsigned long cr3,
-                            u16 fs_sel, u16 gs_sel,
-                            unsigned long fs_base, unsigned long gs_base)
+void vmx_set_host_fs_gs(struct vmcs_host_state *host, u16 fs_sel, u16 gs_sel,
+                       unsigned long fs_base, unsigned long gs_base)
 {
-       if (unlikely(cr3 != host->cr3)) {
-               vmcs_writel(HOST_CR3, cr3);
-               host->cr3 = cr3;
-       }
        if (unlikely(fs_sel != host->fs_sel)) {
                if (!(fs_sel & 7))
                        vmcs_write16(HOST_FS_SELECTOR, fs_sel);
 #ifdef CONFIG_X86_64
        int cpu = raw_smp_processor_id();
 #endif
+       unsigned long cr3;
        unsigned long fs_base, gs_base;
        u16 fs_sel, gs_sel;
        int i;
        gs_base = segment_base(gs_sel);
 #endif
 
-       vmx_set_vmcs_host_state(host_state, __get_current_cr3_fast(),
-                               fs_sel, gs_sel, fs_base, gs_base);
+       vmx_set_host_fs_gs(host_state, fs_sel, gs_sel, fs_base, gs_base);
+
+       /* Host CR3 including its PCID is stable when guest state is loaded. */
+       cr3 = __get_current_cr3_fast();
+       if (unlikely(cr3 != host_state->cr3)) {
+               vmcs_writel(HOST_CR3, cr3);
+               host_state->cr3 = cr3;
+       }
 
        vmx->guest_state_loaded = true;
 }
 
 void free_vpid(int vpid);
 void vmx_set_constant_host_state(struct vcpu_vmx *vmx);
 void vmx_prepare_switch_to_guest(struct kvm_vcpu *vcpu);
-void vmx_set_vmcs_host_state(struct vmcs_host_state *host, unsigned long cr3,
-                            u16 fs_sel, u16 gs_sel,
-                            unsigned long fs_base, unsigned long gs_base);
+void vmx_set_host_fs_gs(struct vmcs_host_state *host, u16 fs_sel, u16 gs_sel,
+                       unsigned long fs_base, unsigned long gs_base);
 int vmx_get_cpl(struct kvm_vcpu *vcpu);
 bool vmx_emulation_required(struct kvm_vcpu *vcpu);
 unsigned long vmx_get_rflags(struct kvm_vcpu *vcpu);