#include "posted_intr.h"
 #include "tdx.h"
 
+static __init int vt_hardware_setup(void)
+{
+       int ret;
+
+       ret = vmx_hardware_setup();
+       if (ret)
+               return ret;
+
+       /*
+        * Update vt_x86_ops::vm_size here so it is ready before
+        * kvm_ops_update() is called in kvm_x86_vendor_init().
+        *
+        * Note, the actual bringing up of TDX must be done after
+        * kvm_ops_update() because enabling TDX requires enabling
+        * hardware virtualization first, i.e., all online CPUs must
+        * be in post-VMXON state.  This means the @vm_size here
+        * may be updated to TDX's size but TDX may fail to enable
+        * at later time.
+        *
+        * The VMX/VT code could update kvm_x86_ops::vm_size again
+        * after bringing up TDX, but this would require exporting
+        * either kvm_x86_ops or kvm_ops_update() from the base KVM
+        * module, which looks overkill.  Anyway, the worst case here
+        * is KVM may allocate couple of more bytes than needed for
+        * each VM.
+        */
+       if (enable_tdx)
+               vt_x86_ops.vm_size = max_t(unsigned int, vt_x86_ops.vm_size,
+                               sizeof(struct kvm_tdx));
+
+       return 0;
+}
+
 #define VMX_REQUIRED_APICV_INHIBITS                            \
        (BIT(APICV_INHIBIT_REASON_DISABLED) |                   \
         BIT(APICV_INHIBIT_REASON_ABSENT) |                     \
 };
 
 struct kvm_x86_init_ops vt_init_ops __initdata = {
-       .hardware_setup = vmx_hardware_setup,
+       .hardware_setup = vt_hardware_setup,
        .handle_intel_pt_intr = NULL,
 
        .runtime_ops = &vt_x86_ops,
 
 static int __init vt_init(void)
 {
+       unsigned vcpu_size, vcpu_align;
        int r;
 
        r = vmx_init();
        if (r)
                goto err_tdx_bringup;
 
+       /*
+        * TDX and VMX have different vCPU structures.  Calculate the
+        * maximum size/align so that kvm_init() can use the larger
+        * values to create the kmem_vcpu_cache.
+        */
+       vcpu_size = sizeof(struct vcpu_vmx);
+       vcpu_align = __alignof__(struct vcpu_vmx);
+       if (enable_tdx) {
+               vcpu_size = max_t(unsigned, vcpu_size,
+                               sizeof(struct vcpu_tdx));
+               vcpu_align = max_t(unsigned, vcpu_align,
+                               __alignof__(struct vcpu_tdx));
+       }
+
        /*
         * Common KVM initialization _must_ come last, after this, /dev/kvm is
         * exposed to userspace!
         */
-       r = kvm_init(sizeof(struct vcpu_vmx), __alignof__(struct vcpu_vmx),
-                    THIS_MODULE);
+       r = kvm_init(vcpu_size, vcpu_align, THIS_MODULE);
        if (r)
                goto err_kvm_init;
 
 
 #include "capabilities.h"
 #include "tdx.h"
 
+#pragma GCC poison to_vmx
+
 #undef pr_fmt
 #define pr_fmt(fmt) KBUILD_MODNAME ": " fmt
 
-static bool enable_tdx __ro_after_init;
+bool enable_tdx __ro_after_init;
 module_param_named(tdx, enable_tdx, bool, 0444);
 
 static enum cpuhp_state tdx_cpuhp_state;
 
 static const struct tdx_sys_info *tdx_sysinfo;
 
+static __always_inline struct kvm_tdx *to_kvm_tdx(struct kvm *kvm)
+{
+       return container_of(kvm, struct kvm_tdx, kvm);
+}
+
+static __always_inline struct vcpu_tdx *to_tdx(struct kvm_vcpu *vcpu)
+{
+       return container_of(vcpu, struct vcpu_tdx, vcpu);
+}
+
 static int tdx_online_cpu(unsigned int cpu)
 {
        unsigned long flags;
 
 #ifdef CONFIG_KVM_INTEL_TDX
 int tdx_bringup(void);
 void tdx_cleanup(void);
+
+extern bool enable_tdx;
+
+struct kvm_tdx {
+       struct kvm kvm;
+       /* TDX specific members follow. */
+};
+
+struct vcpu_tdx {
+       struct kvm_vcpu vcpu;
+       /* TDX specific members follow. */
+};
+
+static inline bool is_td(struct kvm *kvm)
+{
+       return kvm->arch.vm_type == KVM_X86_TDX_VM;
+}
+
+static inline bool is_td_vcpu(struct kvm_vcpu *vcpu)
+{
+       return is_td(vcpu->kvm);
+}
+
 #else
 static inline int tdx_bringup(void) { return 0; }
 static inline void tdx_cleanup(void) {}
+
+#define enable_tdx     0
+
+struct kvm_tdx {
+       struct kvm kvm;
+};
+
+struct vcpu_tdx {
+       struct kvm_vcpu vcpu;
+};
+
+static inline bool is_td(struct kvm *kvm) { return false; }
+static inline bool is_td_vcpu(struct kvm_vcpu *vcpu) { return false; }
+
 #endif
 
 #endif