* along with this program.  If not, see <http://www.gnu.org/licenses/>.
  */
 
+#include <asm/kvm_asm.h>
+
 #include "hyp.h"
 
 static bool __hyp_text __fpsimd_enabled_nvhe(void)
        return __fpsimd_is_enabled()();
 }
 
+static void __hyp_text __activate_traps_vhe(void)
+{
+       u64 val;
+
+       val = read_sysreg(cpacr_el1);
+       val |= CPACR_EL1_TTA;
+       val &= ~CPACR_EL1_FPEN;
+       write_sysreg(val, cpacr_el1);
+
+       write_sysreg(__kvm_hyp_vector, vbar_el1);
+}
+
+static void __hyp_text __activate_traps_nvhe(void)
+{
+       u64 val;
+
+       val = CPTR_EL2_DEFAULT;
+       val |= CPTR_EL2_TTA | CPTR_EL2_TFP;
+       write_sysreg(val, cptr_el2);
+}
+
+static hyp_alternate_select(__activate_traps_arch,
+                           __activate_traps_nvhe, __activate_traps_vhe,
+                           ARM64_HAS_VIRT_HOST_EXTN);
+
 static void __hyp_text __activate_traps(struct kvm_vcpu *vcpu)
 {
        u64 val;
        write_sysreg(val, hcr_el2);
        /* Trap on AArch32 cp15 c15 accesses (EL1 or EL0) */
        write_sysreg(1 << 15, hstr_el2);
+       write_sysreg(vcpu->arch.mdcr_el2, mdcr_el2);
+       __activate_traps_arch()();
+}
 
-       val = CPTR_EL2_DEFAULT;
-       val |= CPTR_EL2_TTA | CPTR_EL2_TFP;
-       write_sysreg(val, cptr_el2);
+static void __hyp_text __deactivate_traps_vhe(void)
+{
+       extern char vectors[];  /* kernel exception vectors */
 
-       write_sysreg(vcpu->arch.mdcr_el2, mdcr_el2);
+       write_sysreg(HCR_HOST_VHE_FLAGS, hcr_el2);
+       write_sysreg(CPACR_EL1_FPEN, cpacr_el1);
+       write_sysreg(vectors, vbar_el1);
 }
 
-static void __hyp_text __deactivate_traps(struct kvm_vcpu *vcpu)
+static void __hyp_text __deactivate_traps_nvhe(void)
 {
        write_sysreg(HCR_RW, hcr_el2);
+       write_sysreg(CPTR_EL2_DEFAULT, cptr_el2);
+}
+
+static hyp_alternate_select(__deactivate_traps_arch,
+                           __deactivate_traps_nvhe, __deactivate_traps_vhe,
+                           ARM64_HAS_VIRT_HOST_EXTN);
+
+static void __hyp_text __deactivate_traps(struct kvm_vcpu *vcpu)
+{
+       __deactivate_traps_arch()();
        write_sysreg(0, hstr_el2);
        write_sysreg(read_sysreg(mdcr_el2) & MDCR_EL2_HPMN_MASK, mdcr_el2);
-       write_sysreg(CPTR_EL2_DEFAULT, cptr_el2);
 }
 
 static void __hyp_text __activate_vm(struct kvm_vcpu *vcpu)