vmx_vcpu_reset(vcpu, init_event);
 }
 
+static void vt_flush_tlb_all(struct kvm_vcpu *vcpu)
+{
+       if (is_td_vcpu(vcpu)) {
+               tdx_flush_tlb_all(vcpu);
+               return;
+       }
+
+       vmx_flush_tlb_all(vcpu);
+}
+
+static void vt_flush_tlb_current(struct kvm_vcpu *vcpu)
+{
+       if (is_td_vcpu(vcpu)) {
+               tdx_flush_tlb_current(vcpu);
+               return;
+       }
+
+       vmx_flush_tlb_current(vcpu);
+}
+
+static void vt_flush_tlb_gva(struct kvm_vcpu *vcpu, gva_t addr)
+{
+       if (is_td_vcpu(vcpu))
+               return;
+
+       vmx_flush_tlb_gva(vcpu, addr);
+}
+
+static void vt_flush_tlb_guest(struct kvm_vcpu *vcpu)
+{
+       if (is_td_vcpu(vcpu))
+               return;
+
+       vmx_flush_tlb_guest(vcpu);
+}
+
 static void vt_load_mmu_pgd(struct kvm_vcpu *vcpu, hpa_t root_hpa,
                            int pgd_level)
 {
        .set_rflags = vmx_set_rflags,
        .get_if_flag = vmx_get_if_flag,
 
-       .flush_tlb_all = vmx_flush_tlb_all,
-       .flush_tlb_current = vmx_flush_tlb_current,
-       .flush_tlb_gva = vmx_flush_tlb_gva,
-       .flush_tlb_guest = vmx_flush_tlb_guest,
+       .flush_tlb_all = vt_flush_tlb_all,
+       .flush_tlb_current = vt_flush_tlb_current,
+       .flush_tlb_gva = vt_flush_tlb_gva,
+       .flush_tlb_guest = vt_flush_tlb_guest,
 
        .vcpu_pre_run = vmx_vcpu_pre_run,
        .vcpu_run = vmx_vcpu_run,
 
 #include "x86_ops.h"
 #include "lapic.h"
 #include "tdx.h"
+#include "vmx.h"
 #include "mmu/spte.h"
 
 #pragma GCC poison to_vmx
        td_vmcs_write64(to_tdx(vcpu), SHARED_EPT_POINTER, root_hpa);
 }
 
+/*
+ * Ensure shared and private EPTs to be flushed on all vCPUs.
+ * tdh_mem_track() is the only caller that increases TD epoch. An increase in
+ * the TD epoch (e.g., to value "N + 1") is successful only if no vCPUs are
+ * running in guest mode with the value "N - 1".
+ *
+ * A successful execution of tdh_mem_track() ensures that vCPUs can only run in
+ * guest mode with TD epoch value "N" if no TD exit occurs after the TD epoch
+ * being increased to "N + 1".
+ *
+ * Kicking off all vCPUs after that further results in no vCPUs can run in guest
+ * mode with TD epoch value "N", which unblocks the next tdh_mem_track() (e.g.
+ * to increase TD epoch to "N + 2").
+ *
+ * TDX module will flush EPT on the next TD enter and make vCPUs to run in
+ * guest mode with TD epoch value "N + 1".
+ *
+ * kvm_make_all_cpus_request() guarantees all vCPUs are out of guest mode by
+ * waiting empty IPI handler ack_kick().
+ *
+ * No action is required to the vCPUs being kicked off since the kicking off
+ * occurs certainly after TD epoch increment and before the next
+ * tdh_mem_track().
+ */
+static void __always_unused tdx_track(struct kvm *kvm)
+{
+       struct kvm_tdx *kvm_tdx = to_kvm_tdx(kvm);
+       u64 err;
+
+       /* If TD isn't finalized, it's before any vcpu running. */
+       if (unlikely(kvm_tdx->state != TD_STATE_RUNNABLE))
+               return;
+
+       lockdep_assert_held_write(&kvm->mmu_lock);
+
+       do {
+               err = tdh_mem_track(&kvm_tdx->td);
+       } while (unlikely((err & TDX_SEAMCALL_STATUS_MASK) == TDX_OPERAND_BUSY));
+
+       if (KVM_BUG_ON(err, kvm))
+               pr_tdx_error(TDH_MEM_TRACK, err);
+
+       kvm_make_all_cpus_request(kvm, KVM_REQ_OUTSIDE_GUEST_MODE);
+}
+
 static int tdx_get_capabilities(struct kvm_tdx_cmd *cmd)
 {
        const struct tdx_sys_info_td_conf *td_conf = &tdx_sysinfo->td_conf;
        return ret;
 }
 
+void tdx_flush_tlb_current(struct kvm_vcpu *vcpu)
+{
+       /*
+        * flush_tlb_current() is invoked when the first time for the vcpu to
+        * run or when root of shared EPT is invalidated.
+        * KVM only needs to flush shared EPT because the TDX module handles TLB
+        * invalidation for private EPT in tdh_vp_enter();
+        *
+        * A single context invalidation for shared EPT can be performed here.
+        * However, this single context invalidation requires the private EPTP
+        * rather than the shared EPTP to flush shared EPT, as shared EPT uses
+        * private EPTP as its ASID for TLB invalidation.
+        *
+        * To avoid reading back private EPTP, perform a global invalidation for
+        * shared EPT instead to keep this function simple.
+        */
+       ept_sync_global();
+}
+
+void tdx_flush_tlb_all(struct kvm_vcpu *vcpu)
+{
+       /*
+        * TDX has called tdx_track() in tdx_sept_remove_private_spte() to
+        * ensure that private EPT will be flushed on the next TD enter. No need
+        * to call tdx_track() here again even when this callback is a result of
+        * zapping private EPT.
+        *
+        * Due to the lack of the context to determine which EPT has been
+        * affected by zapping, invoke invept() directly here for both shared
+        * EPT and private EPT for simplicity, though it's not necessary for
+        * private EPT.
+        */
+       ept_sync_global();
+}
+
 int tdx_vm_ioctl(struct kvm *kvm, void __user *argp)
 {
        struct kvm_tdx_cmd tdx_cmd;
 
 
 int tdx_vcpu_ioctl(struct kvm_vcpu *vcpu, void __user *argp);
 
+void tdx_flush_tlb_current(struct kvm_vcpu *vcpu);
+void tdx_flush_tlb_all(struct kvm_vcpu *vcpu);
 void tdx_load_mmu_pgd(struct kvm_vcpu *vcpu, hpa_t root_hpa, int root_level);
 #else
 static inline int tdx_vm_init(struct kvm *kvm) { return -EOPNOTSUPP; }
 
 static inline int tdx_vcpu_ioctl(struct kvm_vcpu *vcpu, void __user *argp) { return -EOPNOTSUPP; }
 
+static inline void tdx_flush_tlb_current(struct kvm_vcpu *vcpu) {}
+static inline void tdx_flush_tlb_all(struct kvm_vcpu *vcpu) {}
 static inline void tdx_load_mmu_pgd(struct kvm_vcpu *vcpu, hpa_t root_hpa, int root_level) {}
 #endif