u64 bad_mt_xwr;
 };
 
+struct kvm_mmu_root_info {
+       gpa_t cr3;
+       hpa_t hpa;
+};
+
+#define KVM_MMU_ROOT_INFO_INVALID \
+       ((struct kvm_mmu_root_info) { .cr3 = INVALID_PAGE, .hpa = INVALID_PAGE })
+
 /*
  * x86 supports 4 paging modes (5-level 64-bit, 4-level 64-bit, 3-level 32-bit,
  * and 2-level 32-bit).  The kvm_mmu structure abstracts the details of the
        u8 shadow_root_level;
        u8 ept_ad;
        bool direct_map;
+       struct kvm_mmu_root_info prev_root;
 
        /*
         * Bitmap; bit set = permission fault
 int kvm_mmu_load(struct kvm_vcpu *vcpu);
 void kvm_mmu_unload(struct kvm_vcpu *vcpu);
 void kvm_mmu_sync_roots(struct kvm_vcpu *vcpu);
-void kvm_mmu_free_roots(struct kvm_vcpu *vcpu);
+void kvm_mmu_free_roots(struct kvm_vcpu *vcpu, bool free_prev_root);
 gpa_t translate_nested_gpa(struct kvm_vcpu *vcpu, gpa_t gpa, u32 access,
                           struct x86_exception *exception);
 gpa_t kvm_mmu_gva_to_gpa_read(struct kvm_vcpu *vcpu, gva_t gva,
 int kvm_mmu_page_fault(struct kvm_vcpu *vcpu, gva_t gva, u64 error_code,
                       void *insn, int insn_len);
 void kvm_mmu_invlpg(struct kvm_vcpu *vcpu, gva_t gva);
-void kvm_mmu_new_cr3(struct kvm_vcpu *vcpu);
+void kvm_mmu_new_cr3(struct kvm_vcpu *vcpu, gpa_t new_cr3);
 
 void kvm_enable_tdp(void);
 void kvm_disable_tdp(void);
 
        *root_hpa = INVALID_PAGE;
 }
 
-void kvm_mmu_free_roots(struct kvm_vcpu *vcpu)
+void kvm_mmu_free_roots(struct kvm_vcpu *vcpu, bool free_prev_root)
 {
        int i;
        LIST_HEAD(invalid_list);
        struct kvm_mmu *mmu = &vcpu->arch.mmu;
 
-       if (!VALID_PAGE(mmu->root_hpa))
+       if (!VALID_PAGE(mmu->root_hpa) &&
+           (!VALID_PAGE(mmu->prev_root.hpa) || !free_prev_root))
                return;
 
        spin_lock(&vcpu->kvm->mmu_lock);
 
+       if (free_prev_root)
+               mmu_free_root_page(vcpu->kvm, &mmu->prev_root.hpa,
+                                  &invalid_list);
+
        if (mmu->shadow_root_level >= PT64_ROOT_4LEVEL &&
            (mmu->root_level >= PT64_ROOT_4LEVEL || mmu->direct_map)) {
                mmu_free_root_page(vcpu->kvm, &mmu->root_hpa, &invalid_list);
        context->root_level = 0;
        context->shadow_root_level = PT32E_ROOT_LEVEL;
        context->root_hpa = INVALID_PAGE;
+       context->prev_root = KVM_MMU_ROOT_INFO_INVALID;
        context->direct_map = true;
        context->nx = false;
 }
 
-void kvm_mmu_new_cr3(struct kvm_vcpu *vcpu)
+static bool fast_cr3_switch(struct kvm_vcpu *vcpu, gpa_t new_cr3)
+{
+       struct kvm_mmu *mmu = &vcpu->arch.mmu;
+
+       /*
+        * For now, limit the fast switch to 64-bit hosts+VMs in order to avoid
+        * having to deal with PDPTEs. We may add support for 32-bit hosts/VMs
+        * later if necessary.
+        */
+       if (mmu->shadow_root_level >= PT64_ROOT_4LEVEL &&
+           mmu->root_level >= PT64_ROOT_4LEVEL) {
+               gpa_t prev_cr3 = mmu->prev_root.cr3;
+
+               if (mmu_check_root(vcpu, new_cr3 >> PAGE_SHIFT))
+                       return false;
+
+               swap(mmu->root_hpa, mmu->prev_root.hpa);
+               mmu->prev_root.cr3 = kvm_read_cr3(vcpu);
+
+               if (new_cr3 == prev_cr3 && VALID_PAGE(mmu->root_hpa)) {
+                       /*
+                        * It is possible that the cached previous root page is
+                        * obsolete because of a change in the MMU
+                        * generation number. However, that is accompanied by
+                        * KVM_REQ_MMU_RELOAD, which will free the root that we
+                        * have set here and allocate a new one.
+                        */
+
+                       kvm_make_request(KVM_REQ_MMU_SYNC, vcpu);
+                       __clear_sp_write_flooding_count(
+                               page_header(mmu->root_hpa));
+
+                       mmu->set_cr3(vcpu, mmu->root_hpa);
+
+                       return true;
+               }
+       }
+
+       return false;
+}
+
+void kvm_mmu_new_cr3(struct kvm_vcpu *vcpu, gpa_t new_cr3)
 {
-       kvm_mmu_free_roots(vcpu);
+       if (!fast_cr3_switch(vcpu, new_cr3))
+               kvm_mmu_free_roots(vcpu, false);
 }
 
 static unsigned long get_cr3(struct kvm_vcpu *vcpu)
        context->update_pte = paging64_update_pte;
        context->shadow_root_level = level;
        context->root_hpa = INVALID_PAGE;
+       context->prev_root = KVM_MMU_ROOT_INFO_INVALID;
        context->direct_map = false;
 }
 
        context->update_pte = paging32_update_pte;
        context->shadow_root_level = PT32E_ROOT_LEVEL;
        context->root_hpa = INVALID_PAGE;
+       context->prev_root = KVM_MMU_ROOT_INFO_INVALID;
        context->direct_map = false;
 }
 
        context->update_pte = nonpaging_update_pte;
        context->shadow_root_level = kvm_x86_ops->get_tdp_level(vcpu);
        context->root_hpa = INVALID_PAGE;
+       context->prev_root = KVM_MMU_ROOT_INFO_INVALID;
        context->direct_map = true;
        context->set_cr3 = kvm_x86_ops->set_tdp_cr3;
        context->get_cr3 = get_cr3;
        context->update_pte = ept_update_pte;
        context->root_level = PT64_ROOT_4LEVEL;
        context->root_hpa = INVALID_PAGE;
+       context->prev_root = KVM_MMU_ROOT_INFO_INVALID;
        context->direct_map = false;
        context->base_role.ad_disabled = !accessed_dirty;
        context->base_role.guest_mode = 1;
 
 void kvm_mmu_unload(struct kvm_vcpu *vcpu)
 {
-       kvm_mmu_free_roots(vcpu);
+       kvm_mmu_free_roots(vcpu, true);
        WARN_ON(VALID_PAGE(vcpu->arch.mmu.root_hpa));
 }
 EXPORT_SYMBOL_GPL(kvm_mmu_unload);
 {
        vcpu->arch.walk_mmu = &vcpu->arch.mmu;
        vcpu->arch.mmu.root_hpa = INVALID_PAGE;
+       vcpu->arch.mmu.prev_root = KVM_MMU_ROOT_INFO_INVALID;
        vcpu->arch.mmu.translate_gpa = translate_gpa;
        vcpu->arch.nested_mmu.translate_gpa = translate_nested_gpa;