#include <linux/linkage.h>
 #include <linux/arm-smccc.h>
 
+#include <asm/alternative.h>
+
 .macro hyp_ventry
        .align 7
 1:     .rept 27
        nop
        .endr
+/*
+ * The default sequence is to directly branch to the KVM vectors,
+ * using the computed offset. This applies for VHE as well as
+ * !ARM64_HARDEN_EL2_VECTORS.
+ *
+ * For ARM64_HARDEN_EL2_VECTORS configurations, this gets replaced
+ * with:
+ *
+ * stp x0, x1, [sp, #-16]!
+ * movz        x0, #(addr & 0xffff)
+ * movk        x0, #((addr >> 16) & 0xffff), lsl #16
+ * movk        x0, #((addr >> 32) & 0xffff), lsl #32
+ * br  x0
+ *
+ * Where addr = kern_hyp_va(__kvm_hyp_vector) + vector-offset + 4.
+ * See kvm_patch_vector_branch for details.
+ */
+alternative_cb kvm_patch_vector_branch
        b       __kvm_hyp_vector + (1b - 0b)
        nop
        nop
        nop
        nop
+alternative_cb_end
 .endm
 
 .macro generate_vectors
 
                updptr[i] = cpu_to_le32(insn);
        }
 }
+
+void kvm_patch_vector_branch(struct alt_instr *alt,
+                            __le32 *origptr, __le32 *updptr, int nr_inst)
+{
+       u64 addr;
+       u32 insn;
+
+       BUG_ON(nr_inst != 5);
+
+       if (has_vhe() || !cpus_have_const_cap(ARM64_HARDEN_EL2_VECTORS)) {
+               WARN_ON_ONCE(cpus_have_const_cap(ARM64_HARDEN_EL2_VECTORS));
+               return;
+       }
+
+       if (!va_mask)
+               compute_layout();
+
+       /*
+        * Compute HYP VA by using the same computation as kern_hyp_va()
+        */
+       addr = (uintptr_t)kvm_ksym_ref(__kvm_hyp_vector);
+       addr &= va_mask;
+       addr |= tag_val << tag_lsb;
+
+       /* Use PC[10:7] to branch to the same vector in KVM */
+       addr |= ((u64)origptr & GENMASK_ULL(10, 7));
+
+       /*
+        * Branch to the second instruction in the vectors in order to
+        * avoid the initial store on the stack (which we already
+        * perform in the hardening vectors).
+        */
+       addr += AARCH64_INSN_SIZE;
+
+       /* stp x0, x1, [sp, #-16]! */
+       insn = aarch64_insn_gen_load_store_pair(AARCH64_INSN_REG_0,
+                                               AARCH64_INSN_REG_1,
+                                               AARCH64_INSN_REG_SP,
+                                               -16,
+                                               AARCH64_INSN_VARIANT_64BIT,
+                                               AARCH64_INSN_LDST_STORE_PAIR_PRE_INDEX);
+       *updptr++ = cpu_to_le32(insn);
+
+       /* movz x0, #(addr & 0xffff) */
+       insn = aarch64_insn_gen_movewide(AARCH64_INSN_REG_0,
+                                        (u16)addr,
+                                        0,
+                                        AARCH64_INSN_VARIANT_64BIT,
+                                        AARCH64_INSN_MOVEWIDE_ZERO);
+       *updptr++ = cpu_to_le32(insn);
+
+       /* movk x0, #((addr >> 16) & 0xffff), lsl #16 */
+       insn = aarch64_insn_gen_movewide(AARCH64_INSN_REG_0,
+                                        (u16)(addr >> 16),
+                                        16,
+                                        AARCH64_INSN_VARIANT_64BIT,
+                                        AARCH64_INSN_MOVEWIDE_KEEP);
+       *updptr++ = cpu_to_le32(insn);
+
+       /* movk x0, #((addr >> 32) & 0xffff), lsl #32 */
+       insn = aarch64_insn_gen_movewide(AARCH64_INSN_REG_0,
+                                        (u16)(addr >> 32),
+                                        32,
+                                        AARCH64_INSN_VARIANT_64BIT,
+                                        AARCH64_INSN_MOVEWIDE_KEEP);
+       *updptr++ = cpu_to_le32(insn);
+
+       /* br x0 */
+       insn = aarch64_insn_gen_branch_reg(AARCH64_INSN_REG_0,
+                                          AARCH64_INSN_BRANCH_NOLINK);
+       *updptr++ = cpu_to_le32(insn);
+}