* current mmu mode.
  */
 struct kvm_mmu {
-       void (*set_cr3)(struct kvm_vcpu *vcpu, unsigned long root);
        unsigned long (*get_guest_pgd)(struct kvm_vcpu *vcpu);
        u64 (*get_pdptr)(struct kvm_vcpu *vcpu, int index);
        int (*page_fault)(struct kvm_vcpu *vcpu, gpa_t cr2_or_gpa, u32 err,
        int (*get_tdp_level)(struct kvm_vcpu *vcpu);
        u64 (*get_mt_mask)(struct kvm_vcpu *vcpu, gfn_t gfn, bool is_mmio);
 
-       void (*set_tdp_cr3)(struct kvm_vcpu *vcpu, unsigned long cr3);
-
        bool (*has_wbinvd_exit)(void);
 
        u64 (*read_l1_tsc_offset)(struct kvm_vcpu *vcpu);
 
        return kvm_get_pcid(vcpu, kvm_read_cr3(vcpu));
 }
 
-static inline void kvm_mmu_load_cr3(struct kvm_vcpu *vcpu)
+static inline void kvm_mmu_load_pgd(struct kvm_vcpu *vcpu)
 {
        if (VALID_PAGE(vcpu->arch.mmu->root_hpa))
-               vcpu->arch.mmu->set_cr3(vcpu, vcpu->arch.mmu->root_hpa |
-                                             kvm_get_active_pcid(vcpu));
+               kvm_x86_ops->set_cr3(vcpu, vcpu->arch.mmu->root_hpa |
+                                    kvm_get_active_pcid(vcpu));
 }
 
 int kvm_tdp_page_fault(struct kvm_vcpu *vcpu, gpa_t gpa, u32 error_code,
 
        context->update_pte = nonpaging_update_pte;
        context->shadow_root_level = kvm_x86_ops->get_tdp_level(vcpu);
        context->direct_map = true;
-       context->set_cr3 = kvm_x86_ops->set_tdp_cr3;
        context->get_guest_pgd = get_cr3;
        context->get_pdptr = kvm_pdptr_read;
        context->inject_page_fault = kvm_inject_page_fault;
        struct kvm_mmu *context = vcpu->arch.mmu;
 
        kvm_init_shadow_mmu(vcpu);
-       context->set_cr3           = kvm_x86_ops->set_cr3;
        context->get_guest_pgd     = get_cr3;
        context->get_pdptr         = kvm_pdptr_read;
        context->inject_page_fault = kvm_inject_page_fault;
 
        return pdpte;
 }
 
-static void nested_svm_set_tdp_cr3(struct kvm_vcpu *vcpu,
-                                  unsigned long root)
-{
-       struct vcpu_svm *svm = to_svm(vcpu);
-
-       svm->vmcb->control.nested_cr3 = __sme_set(root);
-       mark_dirty(svm->vmcb, VMCB_NPT);
-}
-
 static void nested_svm_inject_npf_exit(struct kvm_vcpu *vcpu,
                                       struct x86_exception *fault)
 {
 
        vcpu->arch.mmu = &vcpu->arch.guest_mmu;
        kvm_init_shadow_mmu(vcpu);
-       vcpu->arch.mmu->set_cr3           = nested_svm_set_tdp_cr3;
        vcpu->arch.mmu->get_guest_pgd     = nested_svm_get_tdp_cr3;
        vcpu->arch.mmu->get_pdptr         = nested_svm_get_tdp_pdptr;
        vcpu->arch.mmu->inject_page_fault = nested_svm_inject_npf_exit;
 static void svm_set_cr3(struct kvm_vcpu *vcpu, unsigned long root)
 {
        struct vcpu_svm *svm = to_svm(vcpu);
+       bool update_guest_cr3 = true;
+       unsigned long cr3;
 
-       svm->vmcb->save.cr3 = __sme_set(root);
-       mark_dirty(svm->vmcb, VMCB_CR);
-}
-
-static void set_tdp_cr3(struct kvm_vcpu *vcpu, unsigned long root)
-{
-       struct vcpu_svm *svm = to_svm(vcpu);
+       cr3 = __sme_set(root);
+       if (npt_enabled) {
+               svm->vmcb->control.nested_cr3 = cr3;
+               mark_dirty(svm->vmcb, VMCB_NPT);
 
-       svm->vmcb->control.nested_cr3 = __sme_set(root);
-       mark_dirty(svm->vmcb, VMCB_NPT);
+               /* Loading L2's CR3 is handled by enter_svm_guest_mode.  */
+               if (is_guest_mode(vcpu))
+                       update_guest_cr3 = false;
+               else if (test_bit(VCPU_EXREG_CR3, (ulong *)&vcpu->arch.regs_avail))
+                       cr3 = vcpu->arch.cr3;
+               else /* CR3 is already up-to-date.  */
+                       update_guest_cr3 = false;
+       }
 
-       /* Also sync guest cr3 here in case we live migrate */
-       svm->vmcb->save.cr3 = kvm_read_cr3(vcpu);
-       mark_dirty(svm->vmcb, VMCB_CR);
+       if (update_guest_cr3) {
+               svm->vmcb->save.cr3 = cr3;
+               mark_dirty(svm->vmcb, VMCB_CR);
+       }
 }
 
 static int is_disabled(void)
        .read_l1_tsc_offset = svm_read_l1_tsc_offset,
        .write_l1_tsc_offset = svm_write_l1_tsc_offset,
 
-       .set_tdp_cr3 = set_tdp_cr3,
-
        .check_intercept = svm_check_intercept,
        .handle_exit_irqoff = svm_handle_exit_irqoff,