#define SEEN_FUNC      16      /* calls C functions */
 #define SEEN_TAIL_CALL 32      /* code uses tail calls */
 #define SEEN_SKB_CHANGE        64      /* code changes skb data */
+#define SEEN_REG_AX    128     /* code uses constant blinding */
 #define SEEN_STACK     (SEEN_FUNC | SEEN_MEM | SEEN_SKB)
 
 /*
  * s390 registers
  */
-#define REG_W0         (__MAX_BPF_REG+0)       /* Work register 1 (even) */
-#define REG_W1         (__MAX_BPF_REG+1)       /* Work register 2 (odd) */
-#define REG_SKB_DATA   (__MAX_BPF_REG+2)       /* SKB data register */
-#define REG_L          (__MAX_BPF_REG+3)       /* Literal pool register */
-#define REG_15         (__MAX_BPF_REG+4)       /* Register 15 */
+#define REG_W0         (MAX_BPF_JIT_REG + 0)   /* Work register 1 (even) */
+#define REG_W1         (MAX_BPF_JIT_REG + 1)   /* Work register 2 (odd) */
+#define REG_SKB_DATA   (MAX_BPF_JIT_REG + 2)   /* SKB data register */
+#define REG_L          (MAX_BPF_JIT_REG + 3)   /* Literal pool register */
+#define REG_15         (MAX_BPF_JIT_REG + 4)   /* Register 15 */
 #define REG_0          REG_W0                  /* Register 0 */
 #define REG_1          REG_W1                  /* Register 1 */
 #define REG_2          BPF_REG_1               /* Register 2 */
        [BPF_REG_9]     = 10,
        /* BPF stack pointer */
        [BPF_REG_FP]    = 13,
+       /* Register for blinding (shared with REG_SKB_DATA) */
+       [BPF_REG_AX]    = 12,
        /* SKB data pointer */
        [REG_SKB_DATA]  = 12,
        /* Work registers for s390x backend */
 /*
  * For SKB access %b1 contains the SKB pointer. For "bpf_jit.S"
  * we store the SKB header length on the stack and the SKB data
- * pointer in REG_SKB_DATA.
+ * pointer in REG_SKB_DATA if BPF_REG_AX is not used.
  */
 static void emit_load_skb_data_hlen(struct bpf_jit *jit)
 {
                   offsetof(struct sk_buff, data_len));
        /* stg %w1,ST_OFF_HLEN(%r0,%r15) */
        EMIT6_DISP_LH(0xe3000000, 0x0024, REG_W1, REG_0, REG_15, STK_OFF_HLEN);
-       /* lg %skb_data,data_off(%b1) */
-       EMIT6_DISP_LH(0xe3000000, 0x0004, REG_SKB_DATA, REG_0,
-                     BPF_REG_1, offsetof(struct sk_buff, data));
+       if (!(jit->seen & SEEN_REG_AX))
+               /* lg %skb_data,data_off(%b1) */
+               EMIT6_DISP_LH(0xe3000000, 0x0004, REG_SKB_DATA, REG_0,
+                             BPF_REG_1, offsetof(struct sk_buff, data));
 }
 
 /*
        s32 imm = insn->imm;
        s16 off = insn->off;
 
+       if (dst_reg == BPF_REG_AX || src_reg == BPF_REG_AX)
+               jit->seen |= SEEN_REG_AX;
        switch (insn->code) {
        /*
         * BPF_MOV
                /*
                 * Implicit input:
                 *  BPF_REG_6    (R7) : skb pointer
-                *  REG_SKB_DATA (R12): skb data pointer
+                *  REG_SKB_DATA (R12): skb data pointer (if no BPF_REG_AX)
                 *
                 * Calculated input:
                 *  BPF_REG_2    (R3) : offset of byte(s) to fetch in skb
                        /* agfr %b2,%src (%src is s32 here) */
                        EMIT4(0xb9180000, BPF_REG_2, src_reg);
 
+               /* Reload REG_SKB_DATA if BPF_REG_AX is used */
+               if (jit->seen & SEEN_REG_AX)
+                       /* lg %skb_data,data_off(%b6) */
+                       EMIT6_DISP_LH(0xe3000000, 0x0004, REG_SKB_DATA, REG_0,
+                                     BPF_REG_6, offsetof(struct sk_buff, data));
                /* basr %b5,%w1 (%b5 is call saved) */
                EMIT2(0x0d00, BPF_REG_5, REG_W1);
 
  */
 struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *fp)
 {
+       struct bpf_prog *tmp, *orig_fp = fp;
        struct bpf_binary_header *header;
+       bool tmp_blinded = false;
        struct bpf_jit jit;
        int pass;
 
        if (!bpf_jit_enable)
-               return fp;
+               return orig_fp;
+
+       tmp = bpf_jit_blind_constants(fp);
+       /*
+        * If blinding was requested and we failed during blinding,
+        * we must fall back to the interpreter.
+        */
+       if (IS_ERR(tmp))
+               return orig_fp;
+       if (tmp != fp) {
+               tmp_blinded = true;
+               fp = tmp;
+       }
 
        memset(&jit, 0, sizeof(jit));
        jit.addrs = kcalloc(fp->len + 1, sizeof(*jit.addrs), GFP_KERNEL);
-       if (jit.addrs == NULL)
-               return fp;
+       if (jit.addrs == NULL) {
+               fp = orig_fp;
+               goto out;
+       }
        /*
         * Three initial passes:
         *   - 1/2: Determine clobbered registers
         *   - 3:   Calculate program size and addrs arrray
         */
        for (pass = 1; pass <= 3; pass++) {
-               if (bpf_jit_prog(&jit, fp))
+               if (bpf_jit_prog(&jit, fp)) {
+                       fp = orig_fp;
                        goto free_addrs;
+               }
        }
        /*
         * Final pass: Allocate and generate program
         */
-       if (jit.size >= BPF_SIZE_MAX)
+       if (jit.size >= BPF_SIZE_MAX) {
+               fp = orig_fp;
                goto free_addrs;
+       }
        header = bpf_jit_binary_alloc(jit.size, &jit.prg_buf, 2, jit_fill_hole);
-       if (!header)
+       if (!header) {
+               fp = orig_fp;
                goto free_addrs;
-       if (bpf_jit_prog(&jit, fp))
+       }
+       if (bpf_jit_prog(&jit, fp)) {
+               fp = orig_fp;
                goto free_addrs;
+       }
        if (bpf_jit_enable > 1) {
                bpf_jit_dump(fp->len, jit.size, pass, jit.prg_buf);
                if (jit.prg_buf)
        }
 free_addrs:
        kfree(jit.addrs);
+out:
+       if (tmp_blinded)
+               bpf_jit_prog_release_other(fp, fp == orig_fp ?
+                                          tmp : orig_fp);
        return fp;
 }