}
 EXPORT_SYMBOL_GPL(kvm_set_cr4);
 
+static void kvm_invalidate_pcid(struct kvm_vcpu *vcpu, unsigned long pcid)
+{
+       struct kvm_mmu *mmu = vcpu->arch.mmu;
+       unsigned long roots_to_free = 0;
+       int i;
+
+       /*
+        * If neither the current CR3 nor any of the prev_roots use the given
+        * PCID, then nothing needs to be done here because a resync will
+        * happen anyway before switching to any other CR3.
+        */
+       if (kvm_get_active_pcid(vcpu) == pcid) {
+               kvm_mmu_sync_roots(vcpu);
+               kvm_make_request(KVM_REQ_TLB_FLUSH_CURRENT, vcpu);
+       }
+
+       for (i = 0; i < KVM_MMU_NUM_PREV_ROOTS; i++)
+               if (kvm_get_pcid(vcpu, mmu->prev_roots[i].pgd) == pcid)
+                       roots_to_free |= KVM_MMU_ROOT_PREVIOUS(i);
+
+       kvm_mmu_free_roots(vcpu, mmu, roots_to_free);
+}
+
 int kvm_set_cr3(struct kvm_vcpu *vcpu, unsigned long cr3)
 {
        bool skip_tlb_flush = false;
+       unsigned long pcid = 0;
 #ifdef CONFIG_X86_64
        bool pcid_enabled = kvm_read_cr4_bits(vcpu, X86_CR4_PCIDE);
 
        if (pcid_enabled) {
                skip_tlb_flush = cr3 & X86_CR3_PCID_NOFLUSH;
                cr3 &= ~X86_CR3_PCID_NOFLUSH;
+               pcid = cr3 & X86_CR3_PCID_MASK;
        }
 #endif
 
        /* PDPTRs are always reloaded for PAE paging. */
-       if (cr3 == kvm_read_cr3(vcpu) && !is_pae_paging(vcpu)) {
-               if (!skip_tlb_flush) {
-                       kvm_mmu_sync_roots(vcpu);
-                       kvm_make_request(KVM_REQ_TLB_FLUSH_CURRENT, vcpu);
-               }
-               return 0;
-       }
+       if (cr3 == kvm_read_cr3(vcpu) && !is_pae_paging(vcpu))
+               goto handle_tlb_flush;
 
        /*
         * Do not condition the GPA check on long mode, this helper is used to
        if (is_pae_paging(vcpu) && !load_pdptrs(vcpu, vcpu->arch.walk_mmu, cr3))
                return 1;
 
-       kvm_mmu_new_pgd(vcpu, cr3, skip_tlb_flush, skip_tlb_flush);
+       if (cr3 != kvm_read_cr3(vcpu))
+               kvm_mmu_new_pgd(vcpu, cr3, skip_tlb_flush, skip_tlb_flush);
+
        vcpu->arch.cr3 = cr3;
        kvm_register_mark_available(vcpu, VCPU_EXREG_CR3);
 
+handle_tlb_flush:
+       /*
+        * A load of CR3 that flushes the TLB flushes only the current PCID,
+        * even if PCID is disabled, in which case PCID=0 is flushed.  It's a
+        * moot point in the end because _disabling_ PCID will flush all PCIDs,
+        * and it's impossible to use a non-zero PCID when PCID is disabled,
+        * i.e. only PCID=0 can be relevant.
+        */
+       if (!skip_tlb_flush)
+               kvm_invalidate_pcid(vcpu, pcid);
+
        return 0;
 }
 EXPORT_SYMBOL_GPL(kvm_set_cr3);
 {
        bool pcid_enabled;
        struct x86_exception e;
-       unsigned i;
-       unsigned long roots_to_free = 0;
        struct {
                u64 pcid;
                u64 gla;
                        return 1;
                }
 
-               if (kvm_get_active_pcid(vcpu) == operand.pcid) {
-                       kvm_mmu_sync_roots(vcpu);
-                       kvm_make_request(KVM_REQ_TLB_FLUSH_CURRENT, vcpu);
-               }
-
-               for (i = 0; i < KVM_MMU_NUM_PREV_ROOTS; i++)
-                       if (kvm_get_pcid(vcpu, vcpu->arch.mmu->prev_roots[i].pgd)
-                           == operand.pcid)
-                               roots_to_free |= KVM_MMU_ROOT_PREVIOUS(i);
-
-               kvm_mmu_free_roots(vcpu, vcpu->arch.mmu, roots_to_free);
-               /*
-                * If neither the current cr3 nor any of the prev_roots use the
-                * given PCID, then nothing needs to be done here because a
-                * resync will happen anyway before switching to any other CR3.
-                */
-
+               kvm_invalidate_pcid(vcpu, operand.pcid);
                return kvm_skip_emulated_instruction(vcpu);
 
        case INVPCID_TYPE_ALL_NON_GLOBAL: