void kvm_emulate_nested_eret(struct kvm_vcpu *vcpu)
 {
-       u64 spsr, elr, mode;
-       bool direct_eret;
+       u64 spsr, elr;
 
        /*
         * Forward this trap to the virtual EL2 if the virtual
        if (forward_traps(vcpu, HCR_NV))
                return;
 
-       /*
-        * Going through the whole put/load motions is a waste of time
-        * if this is a VHE guest hypervisor returning to its own
-        * userspace, or the hypervisor performing a local exception
-        * return. No need to save/restore registers, no need to
-        * switch S2 MMU. Just do the canonical ERET.
-        */
-       spsr = vcpu_read_sys_reg(vcpu, SPSR_EL2);
-       spsr = kvm_check_illegal_exception_return(vcpu, spsr);
-
-       mode = spsr & (PSR_MODE_MASK | PSR_MODE32_BIT);
-
-       direct_eret  = (mode == PSR_MODE_EL0t &&
-                       vcpu_el2_e2h_is_set(vcpu) &&
-                       vcpu_el2_tge_is_set(vcpu));
-       direct_eret |= (mode == PSR_MODE_EL2h || mode == PSR_MODE_EL2t);
-
-       if (direct_eret) {
-               *vcpu_pc(vcpu) = vcpu_read_sys_reg(vcpu, ELR_EL2);
-               *vcpu_cpsr(vcpu) = spsr;
-               trace_kvm_nested_eret(vcpu, *vcpu_pc(vcpu), spsr);
-               return;
-       }
-
        preempt_disable();
        kvm_arch_vcpu_put(vcpu);
 
+       spsr = __vcpu_sys_reg(vcpu, SPSR_EL2);
+       spsr = kvm_check_illegal_exception_return(vcpu, spsr);
        elr = __vcpu_sys_reg(vcpu, ELR_EL2);
 
        trace_kvm_nested_eret(vcpu, elr, spsr);
 
        __vcpu_put_switch_sysregs(vcpu);
 }
 
+static bool kvm_hyp_handle_eret(struct kvm_vcpu *vcpu, u64 *exit_code)
+{
+       u64 spsr, mode;
+
+       /*
+        * Going through the whole put/load motions is a waste of time
+        * if this is a VHE guest hypervisor returning to its own
+        * userspace, or the hypervisor performing a local exception
+        * return. No need to save/restore registers, no need to
+        * switch S2 MMU. Just do the canonical ERET.
+        *
+        * Unless the trap has to be forwarded further down the line,
+        * of course...
+        */
+       if (__vcpu_sys_reg(vcpu, HCR_EL2) & HCR_NV)
+               return false;
+
+       spsr = read_sysreg_el1(SYS_SPSR);
+       mode = spsr & (PSR_MODE_MASK | PSR_MODE32_BIT);
+
+       switch (mode) {
+       case PSR_MODE_EL0t:
+               if (!(vcpu_el2_e2h_is_set(vcpu) && vcpu_el2_tge_is_set(vcpu)))
+                       return false;
+               break;
+       case PSR_MODE_EL2t:
+               mode = PSR_MODE_EL1t;
+               break;
+       case PSR_MODE_EL2h:
+               mode = PSR_MODE_EL1h;
+               break;
+       default:
+               return false;
+       }
+
+       spsr = (spsr & ~(PSR_MODE_MASK | PSR_MODE32_BIT)) | mode;
+
+       write_sysreg_el2(spsr, SYS_SPSR);
+       write_sysreg_el2(read_sysreg_el1(SYS_ELR), SYS_ELR);
+
+       return true;
+}
+
 static const exit_handler_fn hyp_exit_handlers[] = {
        [0 ... ESR_ELx_EC_MAX]          = NULL,
        [ESR_ELx_EC_CP15_32]            = kvm_hyp_handle_cp15_32,
        [ESR_ELx_EC_DABT_LOW]           = kvm_hyp_handle_dabt_low,
        [ESR_ELx_EC_WATCHPT_LOW]        = kvm_hyp_handle_watchpt_low,
        [ESR_ELx_EC_PAC]                = kvm_hyp_handle_ptrauth,
+       [ESR_ELx_EC_ERET]               = kvm_hyp_handle_eret,
        [ESR_ELx_EC_MOPS]               = kvm_hyp_handle_mops,
 };