#include <linux/printk.h>
 #include <linux/workqueue.h>
 #include <linux/sched.h>
+#include <linux/capability.h>
+
 #include <net/sch_generic.h>
 
 #include <asm/cacheflush.h>
 #define BPF_REG_X      BPF_REG_7
 #define BPF_REG_TMP    BPF_REG_8
 
+/* Kernel hidden auxiliary/helper register for hardening step.
+ * Only used by eBPF JITs. It's nothing more than a temporary
+ * register that JITs use internally, only that here it's part
+ * of eBPF instructions that have been rewritten for blinding
+ * constants. See JIT pre-step in bpf_jit_blind_constants().
+ */
+#define BPF_REG_AX             MAX_BPF_REG
+#define MAX_BPF_JIT_REG                (MAX_BPF_REG + 1)
+
 /* BPF program can access up to 512 bytes of stack space. */
 #define MAX_BPF_STACK  512
 
 
 #ifdef CONFIG_BPF_JIT
 extern int bpf_jit_enable;
+extern int bpf_jit_harden;
 
 typedef void (*bpf_jit_fill_hole_t)(void *area, unsigned int size);
 
 void bpf_jit_compile(struct bpf_prog *fp);
 void bpf_jit_free(struct bpf_prog *fp);
 
+struct bpf_prog *bpf_jit_blind_constants(struct bpf_prog *fp);
+void bpf_jit_prog_release_other(struct bpf_prog *fp, struct bpf_prog *fp_other);
+
 static inline void bpf_jit_dump(unsigned int flen, unsigned int proglen,
                                u32 pass, void *image)
 {
                print_hex_dump(KERN_ERR, "JIT code: ", DUMP_PREFIX_OFFSET,
                               16, 1, image, proglen, false);
 }
+
+static inline bool bpf_jit_is_ebpf(void)
+{
+# ifdef CONFIG_HAVE_EBPF_JIT
+       return true;
+# else
+       return false;
+# endif
+}
+
+static inline bool bpf_jit_blinding_enabled(void)
+{
+       /* These are the prerequisites, should someone ever have the
+        * idea to call blinding outside of them, we make sure to
+        * bail out.
+        */
+       if (!bpf_jit_is_ebpf())
+               return false;
+       if (!bpf_jit_enable)
+               return false;
+       if (!bpf_jit_harden)
+               return false;
+       if (bpf_jit_harden == 1 && capable(CAP_SYS_ADMIN))
+               return false;
+
+       return true;
+}
 #else
 static inline void bpf_jit_compile(struct bpf_prog *fp)
 {
 
 {
        module_memfree(hdr);
 }
+
+int bpf_jit_harden __read_mostly;
+
+static int bpf_jit_blind_insn(const struct bpf_insn *from,
+                             const struct bpf_insn *aux,
+                             struct bpf_insn *to_buff)
+{
+       struct bpf_insn *to = to_buff;
+       u32 imm_rnd = prandom_u32();
+       s16 off;
+
+       BUILD_BUG_ON(BPF_REG_AX  + 1 != MAX_BPF_JIT_REG);
+       BUILD_BUG_ON(MAX_BPF_REG + 1 != MAX_BPF_JIT_REG);
+
+       if (from->imm == 0 &&
+           (from->code == (BPF_ALU   | BPF_MOV | BPF_K) ||
+            from->code == (BPF_ALU64 | BPF_MOV | BPF_K))) {
+               *to++ = BPF_ALU64_REG(BPF_XOR, from->dst_reg, from->dst_reg);
+               goto out;
+       }
+
+       switch (from->code) {
+       case BPF_ALU | BPF_ADD | BPF_K:
+       case BPF_ALU | BPF_SUB | BPF_K:
+       case BPF_ALU | BPF_AND | BPF_K:
+       case BPF_ALU | BPF_OR  | BPF_K:
+       case BPF_ALU | BPF_XOR | BPF_K:
+       case BPF_ALU | BPF_MUL | BPF_K:
+       case BPF_ALU | BPF_MOV | BPF_K:
+       case BPF_ALU | BPF_DIV | BPF_K:
+       case BPF_ALU | BPF_MOD | BPF_K:
+               *to++ = BPF_ALU32_IMM(BPF_MOV, BPF_REG_AX, imm_rnd ^ from->imm);
+               *to++ = BPF_ALU32_IMM(BPF_XOR, BPF_REG_AX, imm_rnd);
+               *to++ = BPF_ALU32_REG(from->code, from->dst_reg, BPF_REG_AX);
+               break;
+
+       case BPF_ALU64 | BPF_ADD | BPF_K:
+       case BPF_ALU64 | BPF_SUB | BPF_K:
+       case BPF_ALU64 | BPF_AND | BPF_K:
+       case BPF_ALU64 | BPF_OR  | BPF_K:
+       case BPF_ALU64 | BPF_XOR | BPF_K:
+       case BPF_ALU64 | BPF_MUL | BPF_K:
+       case BPF_ALU64 | BPF_MOV | BPF_K:
+       case BPF_ALU64 | BPF_DIV | BPF_K:
+       case BPF_ALU64 | BPF_MOD | BPF_K:
+               *to++ = BPF_ALU64_IMM(BPF_MOV, BPF_REG_AX, imm_rnd ^ from->imm);
+               *to++ = BPF_ALU64_IMM(BPF_XOR, BPF_REG_AX, imm_rnd);
+               *to++ = BPF_ALU64_REG(from->code, from->dst_reg, BPF_REG_AX);
+               break;
+
+       case BPF_JMP | BPF_JEQ  | BPF_K:
+       case BPF_JMP | BPF_JNE  | BPF_K:
+       case BPF_JMP | BPF_JGT  | BPF_K:
+       case BPF_JMP | BPF_JGE  | BPF_K:
+       case BPF_JMP | BPF_JSGT | BPF_K:
+       case BPF_JMP | BPF_JSGE | BPF_K:
+       case BPF_JMP | BPF_JSET | BPF_K:
+               /* Accommodate for extra offset in case of a backjump. */
+               off = from->off;
+               if (off < 0)
+                       off -= 2;
+               *to++ = BPF_ALU64_IMM(BPF_MOV, BPF_REG_AX, imm_rnd ^ from->imm);
+               *to++ = BPF_ALU64_IMM(BPF_XOR, BPF_REG_AX, imm_rnd);
+               *to++ = BPF_JMP_REG(from->code, from->dst_reg, BPF_REG_AX, off);
+               break;
+
+       case BPF_LD | BPF_ABS | BPF_W:
+       case BPF_LD | BPF_ABS | BPF_H:
+       case BPF_LD | BPF_ABS | BPF_B:
+               *to++ = BPF_ALU64_IMM(BPF_MOV, BPF_REG_AX, imm_rnd ^ from->imm);
+               *to++ = BPF_ALU64_IMM(BPF_XOR, BPF_REG_AX, imm_rnd);
+               *to++ = BPF_LD_IND(from->code, BPF_REG_AX, 0);
+               break;
+
+       case BPF_LD | BPF_IND | BPF_W:
+       case BPF_LD | BPF_IND | BPF_H:
+       case BPF_LD | BPF_IND | BPF_B:
+               *to++ = BPF_ALU64_IMM(BPF_MOV, BPF_REG_AX, imm_rnd ^ from->imm);
+               *to++ = BPF_ALU64_IMM(BPF_XOR, BPF_REG_AX, imm_rnd);
+               *to++ = BPF_ALU32_REG(BPF_ADD, BPF_REG_AX, from->src_reg);
+               *to++ = BPF_LD_IND(from->code, BPF_REG_AX, 0);
+               break;
+
+       case BPF_LD | BPF_IMM | BPF_DW:
+               *to++ = BPF_ALU64_IMM(BPF_MOV, BPF_REG_AX, imm_rnd ^ aux[1].imm);
+               *to++ = BPF_ALU64_IMM(BPF_XOR, BPF_REG_AX, imm_rnd);
+               *to++ = BPF_ALU64_IMM(BPF_LSH, BPF_REG_AX, 32);
+               *to++ = BPF_ALU64_REG(BPF_MOV, aux[0].dst_reg, BPF_REG_AX);
+               break;
+       case 0: /* Part 2 of BPF_LD | BPF_IMM | BPF_DW. */
+               *to++ = BPF_ALU32_IMM(BPF_MOV, BPF_REG_AX, imm_rnd ^ aux[0].imm);
+               *to++ = BPF_ALU32_IMM(BPF_XOR, BPF_REG_AX, imm_rnd);
+               *to++ = BPF_ALU64_REG(BPF_OR,  aux[0].dst_reg, BPF_REG_AX);
+               break;
+
+       case BPF_ST | BPF_MEM | BPF_DW:
+       case BPF_ST | BPF_MEM | BPF_W:
+       case BPF_ST | BPF_MEM | BPF_H:
+       case BPF_ST | BPF_MEM | BPF_B:
+               *to++ = BPF_ALU64_IMM(BPF_MOV, BPF_REG_AX, imm_rnd ^ from->imm);
+               *to++ = BPF_ALU64_IMM(BPF_XOR, BPF_REG_AX, imm_rnd);
+               *to++ = BPF_STX_MEM(from->code, from->dst_reg, BPF_REG_AX, from->off);
+               break;
+       }
+out:
+       return to - to_buff;
+}
+
+static struct bpf_prog *bpf_prog_clone_create(struct bpf_prog *fp_other,
+                                             gfp_t gfp_extra_flags)
+{
+       gfp_t gfp_flags = GFP_KERNEL | __GFP_HIGHMEM | __GFP_ZERO |
+                         gfp_extra_flags;
+       struct bpf_prog *fp;
+
+       fp = __vmalloc(fp_other->pages * PAGE_SIZE, gfp_flags, PAGE_KERNEL);
+       if (fp != NULL) {
+               kmemcheck_annotate_bitfield(fp, meta);
+
+               /* aux->prog still points to the fp_other one, so
+                * when promoting the clone to the real program,
+                * this still needs to be adapted.
+                */
+               memcpy(fp, fp_other, fp_other->pages * PAGE_SIZE);
+       }
+
+       return fp;
+}
+
+static void bpf_prog_clone_free(struct bpf_prog *fp)
+{
+       /* aux was stolen by the other clone, so we cannot free
+        * it from this path! It will be freed eventually by the
+        * other program on release.
+        *
+        * At this point, we don't need a deferred release since
+        * clone is guaranteed to not be locked.
+        */
+       fp->aux = NULL;
+       __bpf_prog_free(fp);
+}
+
+void bpf_jit_prog_release_other(struct bpf_prog *fp, struct bpf_prog *fp_other)
+{
+       /* We have to repoint aux->prog to self, as we don't
+        * know whether fp here is the clone or the original.
+        */
+       fp->aux->prog = fp;
+       bpf_prog_clone_free(fp_other);
+}
+
+struct bpf_prog *bpf_jit_blind_constants(struct bpf_prog *prog)
+{
+       struct bpf_insn insn_buff[16], aux[2];
+       struct bpf_prog *clone, *tmp;
+       int insn_delta, insn_cnt;
+       struct bpf_insn *insn;
+       int i, rewritten;
+
+       if (!bpf_jit_blinding_enabled())
+               return prog;
+
+       clone = bpf_prog_clone_create(prog, GFP_USER);
+       if (!clone)
+               return ERR_PTR(-ENOMEM);
+
+       insn_cnt = clone->len;
+       insn = clone->insnsi;
+
+       for (i = 0; i < insn_cnt; i++, insn++) {
+               /* We temporarily need to hold the original ld64 insn
+                * so that we can still access the first part in the
+                * second blinding run.
+                */
+               if (insn[0].code == (BPF_LD | BPF_IMM | BPF_DW) &&
+                   insn[1].code == 0)
+                       memcpy(aux, insn, sizeof(aux));
+
+               rewritten = bpf_jit_blind_insn(insn, aux, insn_buff);
+               if (!rewritten)
+                       continue;
+
+               tmp = bpf_patch_insn_single(clone, i, insn_buff, rewritten);
+               if (!tmp) {
+                       /* Patching may have repointed aux->prog during
+                        * realloc from the original one, so we need to
+                        * fix it up here on error.
+                        */
+                       bpf_jit_prog_release_other(prog, clone);
+                       return ERR_PTR(-ENOMEM);
+               }
+
+               clone = tmp;
+               insn_delta = rewritten - 1;
+
+               /* Walk new program and skip insns we just inserted. */
+               insn = clone->insnsi + i + insn_delta;
+               insn_cnt += insn_delta;
+               i        += insn_delta;
+       }
+
+       return clone;
+}
 #endif /* CONFIG_BPF_JIT */
 
 /* Base function for offset calculation. Needs to go into .text section,