#include "tdx.h"
 #include "tdx_arch.h"
 
+static void vt_disable_virtualization_cpu(void)
+{
+       /* Note, TDX *and* VMX need to be disabled if TDX is enabled. */
+       if (enable_tdx)
+               tdx_disable_virtualization_cpu();
+       vmx_disable_virtualization_cpu();
+}
+
 static __init int vt_hardware_setup(void)
 {
        int ret;
        vmx_vcpu_reset(vcpu, init_event);
 }
 
+static void vt_vcpu_load(struct kvm_vcpu *vcpu, int cpu)
+{
+       if (is_td_vcpu(vcpu)) {
+               tdx_vcpu_load(vcpu, cpu);
+               return;
+       }
+
+       vmx_vcpu_load(vcpu, cpu);
+}
+
 static void vt_flush_tlb_all(struct kvm_vcpu *vcpu)
 {
        if (is_td_vcpu(vcpu)) {
        .hardware_unsetup = vmx_hardware_unsetup,
 
        .enable_virtualization_cpu = vmx_enable_virtualization_cpu,
-       .disable_virtualization_cpu = vmx_disable_virtualization_cpu,
+       .disable_virtualization_cpu = vt_disable_virtualization_cpu,
        .emergency_disable_virtualization_cpu = vmx_emergency_disable_virtualization_cpu,
 
        .has_emulated_msr = vmx_has_emulated_msr,
        .vcpu_reset = vt_vcpu_reset,
 
        .prepare_switch_to_guest = vmx_prepare_switch_to_guest,
-       .vcpu_load = vmx_vcpu_load,
+       .vcpu_load = vt_vcpu_load,
        .vcpu_put = vmx_vcpu_put,
 
        .update_exception_bitmap = vmx_update_exception_bitmap,
 
 }
 
 
+/*
+ * A per-CPU list of TD vCPUs associated with a given CPU.
+ * Protected by interrupt mask. Only manipulated by the CPU owning this per-CPU
+ * list.
+ * - When a vCPU is loaded onto a CPU, it is removed from the per-CPU list of
+ *   the old CPU during the IPI callback running on the old CPU, and then added
+ *   to the per-CPU list of the new CPU.
+ * - When a TD is tearing down, all vCPUs are disassociated from their current
+ *   running CPUs and removed from the per-CPU list during the IPI callback
+ *   running on those CPUs.
+ * - When a CPU is brought down, traverse the per-CPU list to disassociate all
+ *   associated TD vCPUs and remove them from the per-CPU list.
+ */
+static DEFINE_PER_CPU(struct list_head, associated_tdvcpus);
+
 static inline void tdx_hkid_free(struct kvm_tdx *kvm_tdx)
 {
        tdx_guest_keyid_free(kvm_tdx->hkid);
        return kvm_tdx->hkid > 0;
 }
 
+static inline void tdx_disassociate_vp(struct kvm_vcpu *vcpu)
+{
+       lockdep_assert_irqs_disabled();
+
+       list_del(&to_tdx(vcpu)->cpu_list);
+
+       /*
+        * Ensure tdx->cpu_list is updated before setting vcpu->cpu to -1,
+        * otherwise, a different CPU can see vcpu->cpu = -1 and add the vCPU
+        * to its list before it's deleted from this CPU's list.
+        */
+       smp_wmb();
+
+       vcpu->cpu = -1;
+}
+
 static void tdx_clear_page(struct page *page)
 {
        const void *zero_page = (const void *) page_to_virt(ZERO_PAGE(0));
        __free_page(ctrl_page);
 }
 
+struct tdx_flush_vp_arg {
+       struct kvm_vcpu *vcpu;
+       u64 err;
+};
+
+static void tdx_flush_vp(void *_arg)
+{
+       struct tdx_flush_vp_arg *arg = _arg;
+       struct kvm_vcpu *vcpu = arg->vcpu;
+       u64 err;
+
+       arg->err = 0;
+       lockdep_assert_irqs_disabled();
+
+       /* Task migration can race with CPU offlining. */
+       if (unlikely(vcpu->cpu != raw_smp_processor_id()))
+               return;
+
+       /*
+        * No need to do TDH_VP_FLUSH if the vCPU hasn't been initialized.  The
+        * list tracking still needs to be updated so that it's correct if/when
+        * the vCPU does get initialized.
+        */
+       if (to_tdx(vcpu)->state != VCPU_TD_STATE_UNINITIALIZED) {
+               /*
+                * No need to retry.  TDX Resources needed for TDH.VP.FLUSH are:
+                * TDVPR as exclusive, TDR as shared, and TDCS as shared.  This
+                * vp flush function is called when destructing vCPU/TD or vCPU
+                * migration.  No other thread uses TDVPR in those cases.
+                */
+               err = tdh_vp_flush(&to_tdx(vcpu)->vp);
+               if (unlikely(err && err != TDX_VCPU_NOT_ASSOCIATED)) {
+                       /*
+                        * This function is called in IPI context. Do not use
+                        * printk to avoid console semaphore.
+                        * The caller prints out the error message, instead.
+                        */
+                       if (err)
+                               arg->err = err;
+               }
+       }
+
+       tdx_disassociate_vp(vcpu);
+}
+
+static void tdx_flush_vp_on_cpu(struct kvm_vcpu *vcpu)
+{
+       struct tdx_flush_vp_arg arg = {
+               .vcpu = vcpu,
+       };
+       int cpu = vcpu->cpu;
+
+       if (unlikely(cpu == -1))
+               return;
+
+       smp_call_function_single(cpu, tdx_flush_vp, &arg, 1);
+       if (KVM_BUG_ON(arg.err, vcpu->kvm))
+               pr_tdx_error(TDH_VP_FLUSH, arg.err);
+}
+
+void tdx_disable_virtualization_cpu(void)
+{
+       int cpu = raw_smp_processor_id();
+       struct list_head *tdvcpus = &per_cpu(associated_tdvcpus, cpu);
+       struct tdx_flush_vp_arg arg;
+       struct vcpu_tdx *tdx, *tmp;
+       unsigned long flags;
+
+       local_irq_save(flags);
+       /* Safe variant needed as tdx_disassociate_vp() deletes the entry. */
+       list_for_each_entry_safe(tdx, tmp, tdvcpus, cpu_list) {
+               arg.vcpu = &tdx->vcpu;
+               tdx_flush_vp(&arg);
+       }
+       local_irq_restore(flags);
+}
+
 #define TDX_SEAMCALL_RETRIES 10000
 
 static void smp_func_do_phymem_cache_wb(void *unused)
        bool packages_allocated, targets_allocated;
        struct kvm_tdx *kvm_tdx = to_kvm_tdx(kvm);
        cpumask_var_t packages, targets;
-       u64 err;
+       struct kvm_vcpu *vcpu;
+       unsigned long j;
        int i;
+       u64 err;
 
        if (!is_hkid_assigned(kvm_tdx))
                return;
 
-       /* KeyID has been allocated but guest is not yet configured */
-       if (!kvm_tdx->td.tdr_page) {
-               tdx_hkid_free(kvm_tdx);
-               return;
-       }
-
        packages_allocated = zalloc_cpumask_var(&packages, GFP_KERNEL);
        targets_allocated = zalloc_cpumask_var(&targets, GFP_KERNEL);
        cpus_read_lock();
 
+       kvm_for_each_vcpu(j, vcpu, kvm)
+               tdx_flush_vp_on_cpu(vcpu);
+
        /*
         * TDH.PHYMEM.CACHE.WB tries to acquire the TDX module global lock
         * and can fail with TDX_OPERAND_BUSY when it fails to get the lock.
         * After the above flushing vps, there should be no more vCPU
         * associations, as all vCPU fds have been released at this stage.
         */
+       err = tdh_mng_vpflushdone(&kvm_tdx->td);
+       if (err == TDX_FLUSHVP_NOT_DONE)
+               goto out;
+       if (KVM_BUG_ON(err, kvm)) {
+               pr_tdx_error(TDH_MNG_VPFLUSHDONE, err);
+               pr_err("tdh_mng_vpflushdone() failed. HKID %d is leaked.\n",
+                      kvm_tdx->hkid);
+               goto out;
+       }
+
        for_each_online_cpu(i) {
                if (packages_allocated &&
                    cpumask_test_and_set_cpu(topology_physical_package_id(i),
                tdx_hkid_free(kvm_tdx);
        }
 
+out:
        mutex_unlock(&tdx_lock);
        cpus_read_unlock();
        free_cpumask_var(targets);
        return 0;
 }
 
+void tdx_vcpu_load(struct kvm_vcpu *vcpu, int cpu)
+{
+       struct vcpu_tdx *tdx = to_tdx(vcpu);
+
+       if (vcpu->cpu == cpu || !is_hkid_assigned(to_kvm_tdx(vcpu->kvm)))
+               return;
+
+       tdx_flush_vp_on_cpu(vcpu);
+
+       KVM_BUG_ON(cpu != raw_smp_processor_id(), vcpu->kvm);
+       local_irq_disable();
+       /*
+        * Pairs with the smp_wmb() in tdx_disassociate_vp() to ensure
+        * vcpu->cpu is read before tdx->cpu_list.
+        */
+       smp_rmb();
+
+       list_add(&tdx->cpu_list, &per_cpu(associated_tdvcpus, cpu));
+       local_irq_enable();
+}
+
 void tdx_vcpu_free(struct kvm_vcpu *vcpu)
 {
        struct kvm_tdx *kvm_tdx = to_kvm_tdx(vcpu->kvm);
 
 int __init tdx_bringup(void)
 {
-       int r;
+       int r, i;
+
+       /* tdx_disable_virtualization_cpu() uses associated_tdvcpus. */
+       for_each_possible_cpu(i)
+               INIT_LIST_HEAD(&per_cpu(associated_tdvcpus, i));
 
        if (!enable_tdx)
                return 0;
 
 void vmx_setup_mce(struct kvm_vcpu *vcpu);
 
 #ifdef CONFIG_KVM_INTEL_TDX
+void tdx_disable_virtualization_cpu(void);
 int tdx_vm_init(struct kvm *kvm);
 void tdx_mmu_release_hkid(struct kvm *kvm);
 void tdx_vm_destroy(struct kvm *kvm);
 
 int tdx_vcpu_create(struct kvm_vcpu *vcpu);
 void tdx_vcpu_free(struct kvm_vcpu *vcpu);
+void tdx_vcpu_load(struct kvm_vcpu *vcpu, int cpu);
 
 int tdx_vcpu_ioctl(struct kvm_vcpu *vcpu, void __user *argp);
 
 void tdx_load_mmu_pgd(struct kvm_vcpu *vcpu, hpa_t root_hpa, int root_level);
 int tdx_gmem_private_max_mapping_level(struct kvm *kvm, kvm_pfn_t pfn);
 #else
+static inline void tdx_disable_virtualization_cpu(void) {}
 static inline int tdx_vm_init(struct kvm *kvm) { return -EOPNOTSUPP; }
 static inline void tdx_mmu_release_hkid(struct kvm *kvm) {}
 static inline void tdx_vm_destroy(struct kvm *kvm) {}
 
 static inline int tdx_vcpu_create(struct kvm_vcpu *vcpu) { return -EOPNOTSUPP; }
 static inline void tdx_vcpu_free(struct kvm_vcpu *vcpu) {}
+static inline void tdx_vcpu_load(struct kvm_vcpu *vcpu, int cpu) {}
 
 static inline int tdx_vcpu_ioctl(struct kvm_vcpu *vcpu, void __user *argp) { return -EOPNOTSUPP; }