}
 
 #endif /* __ASSEMBLY__ */
+
+/*
+ * Below is used in the eBPF JIT compiler and emits the byte sequence
+ * for the following assembly:
+ *
+ * With retpolines configured:
+ *
+ *    callq do_rop
+ *  spec_trap:
+ *    pause
+ *    lfence
+ *    jmp spec_trap
+ *  do_rop:
+ *    mov %rax,(%rsp)
+ *    retq
+ *
+ * Without retpolines configured:
+ *
+ *    jmp *%rax
+ */
+#ifdef CONFIG_RETPOLINE
+# define RETPOLINE_RAX_BPF_JIT_SIZE    17
+# define RETPOLINE_RAX_BPF_JIT()                               \
+       EMIT1_off32(0xE8, 7);    /* callq do_rop */             \
+       /* spec_trap: */                                        \
+       EMIT2(0xF3, 0x90);       /* pause */                    \
+       EMIT3(0x0F, 0xAE, 0xE8); /* lfence */                   \
+       EMIT2(0xEB, 0xF9);       /* jmp spec_trap */            \
+       /* do_rop: */                                           \
+       EMIT4(0x48, 0x89, 0x04, 0x24); /* mov %rax,(%rsp) */    \
+       EMIT1(0xC3);             /* retq */
+#else
+# define RETPOLINE_RAX_BPF_JIT_SIZE    2
+# define RETPOLINE_RAX_BPF_JIT()                               \
+       EMIT2(0xFF, 0xE0);       /* jmp *%rax */
+#endif
+
 #endif /* _ASM_X86_NOSPEC_BRANCH_H_ */
 
 #include <linux/if_vlan.h>
 #include <asm/cacheflush.h>
 #include <asm/set_memory.h>
+#include <asm/nospec-branch.h>
 #include <linux/bpf.h>
 
 /*
        EMIT2(0x89, 0xD2);                        /* mov edx, edx */
        EMIT3(0x39, 0x56,                         /* cmp dword ptr [rsi + 16], edx */
              offsetof(struct bpf_array, map.max_entries));
-#define OFFSET1 43 /* number of bytes to jump */
+#define OFFSET1 (41 + RETPOLINE_RAX_BPF_JIT_SIZE) /* number of bytes to jump */
        EMIT2(X86_JBE, OFFSET1);                  /* jbe out */
        label1 = cnt;
 
         */
        EMIT2_off32(0x8B, 0x85, 36);              /* mov eax, dword ptr [rbp + 36] */
        EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT);     /* cmp eax, MAX_TAIL_CALL_CNT */
-#define OFFSET2 32
+#define OFFSET2 (30 + RETPOLINE_RAX_BPF_JIT_SIZE)
        EMIT2(X86_JA, OFFSET2);                   /* ja out */
        label2 = cnt;
        EMIT3(0x83, 0xC0, 0x01);                  /* add eax, 1 */
         *   goto out;
         */
        EMIT3(0x48, 0x85, 0xC0);                  /* test rax,rax */
-#define OFFSET3 10
+#define OFFSET3 (8 + RETPOLINE_RAX_BPF_JIT_SIZE)
        EMIT2(X86_JE, OFFSET3);                   /* je out */
        label3 = cnt;
 
         * rdi == ctx (1st arg)
         * rax == prog->bpf_func + prologue_size
         */
-       EMIT2(0xFF, 0xE0);                        /* jmp rax */
+       RETPOLINE_RAX_BPF_JIT();
 
        /* out: */
        BUILD_BUG_ON(cnt - label1 != OFFSET1);