*/
 static int is_branch_taken(struct bpf_reg_state *reg, u64 val, u8 opcode)
 {
+       s64 sval;
+
        if (__is_pointer_value(false, reg))
                return -1;
 
+       sval = (s64)val;
+
        switch (opcode) {
        case BPF_JEQ:
                if (tnum_is_const(reg->var_off))
                        return 0;
                break;
        case BPF_JSGT:
-               if (reg->smin_value > (s64)val)
+               if (reg->smin_value > sval)
                        return 1;
-               else if (reg->smax_value < (s64)val)
+               else if (reg->smax_value < sval)
                        return 0;
                break;
        case BPF_JLT:
                        return 0;
                break;
        case BPF_JSLT:
-               if (reg->smax_value < (s64)val)
+               if (reg->smax_value < sval)
                        return 1;
-               else if (reg->smin_value >= (s64)val)
+               else if (reg->smin_value >= sval)
                        return 0;
                break;
        case BPF_JGE:
                        return 0;
                break;
        case BPF_JSGE:
-               if (reg->smin_value >= (s64)val)
+               if (reg->smin_value >= sval)
                        return 1;
-               else if (reg->smax_value < (s64)val)
+               else if (reg->smax_value < sval)
                        return 0;
                break;
        case BPF_JLE:
                        return 0;
                break;
        case BPF_JSLE:
-               if (reg->smax_value <= (s64)val)
+               if (reg->smax_value <= sval)
                        return 1;
-               else if (reg->smin_value > (s64)val)
+               else if (reg->smin_value > sval)
                        return 0;
                break;
        }
                            struct bpf_reg_state *false_reg, u64 val,
                            u8 opcode)
 {
+       s64 sval;
+
        /* If the dst_reg is a pointer, we can't learn anything about its
         * variable offset from the compare (unless src_reg were a pointer into
         * the same object, but we don't bother with that.
        if (__is_pointer_value(false, false_reg))
                return;
 
+       sval = (s64)val;
+
        switch (opcode) {
        case BPF_JEQ:
-               /* If this is false then we know nothing Jon Snow, but if it is
-                * true then we know for sure.
-                */
-               __mark_reg_known(true_reg, val);
-               break;
        case BPF_JNE:
-               /* If this is true we know nothing Jon Snow, but if it is false
-                * we know the value for sure;
+       {
+               struct bpf_reg_state *reg =
+                       opcode == BPF_JEQ ? true_reg : false_reg;
+
+               /* For BPF_JEQ, if this is false we know nothing Jon Snow, but
+                * if it is true we know the value for sure. Likewise for
+                * BPF_JNE.
                 */
-               __mark_reg_known(false_reg, val);
+               __mark_reg_known(reg, val);
                break;
+       }
        case BPF_JSET:
                false_reg->var_off = tnum_and(false_reg->var_off,
                                              tnum_const(~val));
                        true_reg->var_off = tnum_or(true_reg->var_off,
                                                    tnum_const(val));
                break;
-       case BPF_JGT:
-               false_reg->umax_value = min(false_reg->umax_value, val);
-               true_reg->umin_value = max(true_reg->umin_value, val + 1);
-               break;
-       case BPF_JSGT:
-               false_reg->smax_value = min_t(s64, false_reg->smax_value, val);
-               true_reg->smin_value = max_t(s64, true_reg->smin_value, val + 1);
-               break;
-       case BPF_JLT:
-               false_reg->umin_value = max(false_reg->umin_value, val);
-               true_reg->umax_value = min(true_reg->umax_value, val - 1);
-               break;
-       case BPF_JSLT:
-               false_reg->smin_value = max_t(s64, false_reg->smin_value, val);
-               true_reg->smax_value = min_t(s64, true_reg->smax_value, val - 1);
-               break;
        case BPF_JGE:
-               false_reg->umax_value = min(false_reg->umax_value, val - 1);
-               true_reg->umin_value = max(true_reg->umin_value, val);
+       case BPF_JGT:
+       {
+               u64 false_umax = opcode == BPF_JGT ? val    : val - 1;
+               u64 true_umin = opcode == BPF_JGT ? val + 1 : val;
+
+               false_reg->umax_value = min(false_reg->umax_value, false_umax);
+               true_reg->umin_value = max(true_reg->umin_value, true_umin);
                break;
+       }
        case BPF_JSGE:
-               false_reg->smax_value = min_t(s64, false_reg->smax_value, val - 1);
-               true_reg->smin_value = max_t(s64, true_reg->smin_value, val);
+       case BPF_JSGT:
+       {
+               s64 false_smax = opcode == BPF_JSGT ? sval    : sval - 1;
+               s64 true_smin = opcode == BPF_JSGT ? sval + 1 : sval;
+
+               false_reg->smax_value = min(false_reg->smax_value, false_smax);
+               true_reg->smin_value = max(true_reg->smin_value, true_smin);
                break;
+       }
        case BPF_JLE:
-               false_reg->umin_value = max(false_reg->umin_value, val + 1);
-               true_reg->umax_value = min(true_reg->umax_value, val);
+       case BPF_JLT:
+       {
+               u64 false_umin = opcode == BPF_JLT ? val    : val + 1;
+               u64 true_umax = opcode == BPF_JLT ? val - 1 : val;
+
+               false_reg->umin_value = max(false_reg->umin_value, false_umin);
+               true_reg->umax_value = min(true_reg->umax_value, true_umax);
                break;
+       }
        case BPF_JSLE:
-               false_reg->smin_value = max_t(s64, false_reg->smin_value, val + 1);
-               true_reg->smax_value = min_t(s64, true_reg->smax_value, val);
+       case BPF_JSLT:
+       {
+               s64 false_smin = opcode == BPF_JSLT ? sval    : sval + 1;
+               s64 true_smax = opcode == BPF_JSLT ? sval - 1 : sval;
+
+               false_reg->smin_value = max(false_reg->smin_value, false_smin);
+               true_reg->smax_value = min(true_reg->smax_value, true_smax);
                break;
+       }
        default:
                break;
        }
                                struct bpf_reg_state *false_reg, u64 val,
                                u8 opcode)
 {
+       s64 sval;
+
        if (__is_pointer_value(false, false_reg))
                return;
 
+       sval = (s64)val;
+
        switch (opcode) {
        case BPF_JEQ:
-               /* If this is false then we know nothing Jon Snow, but if it is
-                * true then we know for sure.
-                */
-               __mark_reg_known(true_reg, val);
-               break;
        case BPF_JNE:
-               /* If this is true we know nothing Jon Snow, but if it is false
-                * we know the value for sure;
-                */
-               __mark_reg_known(false_reg, val);
+       {
+               struct bpf_reg_state *reg =
+                       opcode == BPF_JEQ ? true_reg : false_reg;
+
+               __mark_reg_known(reg, val);
                break;
+       }
        case BPF_JSET:
                false_reg->var_off = tnum_and(false_reg->var_off,
                                              tnum_const(~val));
                        true_reg->var_off = tnum_or(true_reg->var_off,
                                                    tnum_const(val));
                break;
-       case BPF_JGT:
-               true_reg->umax_value = min(true_reg->umax_value, val - 1);
-               false_reg->umin_value = max(false_reg->umin_value, val);
-               break;
-       case BPF_JSGT:
-               true_reg->smax_value = min_t(s64, true_reg->smax_value, val - 1);
-               false_reg->smin_value = max_t(s64, false_reg->smin_value, val);
-               break;
-       case BPF_JLT:
-               true_reg->umin_value = max(true_reg->umin_value, val + 1);
-               false_reg->umax_value = min(false_reg->umax_value, val);
-               break;
-       case BPF_JSLT:
-               true_reg->smin_value = max_t(s64, true_reg->smin_value, val + 1);
-               false_reg->smax_value = min_t(s64, false_reg->smax_value, val);
-               break;
        case BPF_JGE:
-               true_reg->umax_value = min(true_reg->umax_value, val);
-               false_reg->umin_value = max(false_reg->umin_value, val + 1);
+       case BPF_JGT:
+       {
+               u64 false_umin = opcode == BPF_JGT ? val    : val + 1;
+               u64 true_umax = opcode == BPF_JGT ? val - 1 : val;
+
+               false_reg->umin_value = max(false_reg->umin_value, false_umin);
+               true_reg->umax_value = min(true_reg->umax_value, true_umax);
                break;
+       }
        case BPF_JSGE:
-               true_reg->smax_value = min_t(s64, true_reg->smax_value, val);
-               false_reg->smin_value = max_t(s64, false_reg->smin_value, val + 1);
+       case BPF_JSGT:
+       {
+               s64 false_smin = opcode == BPF_JSGT ? sval    : sval + 1;
+               s64 true_smax = opcode == BPF_JSGT ? sval - 1 : sval;
+
+               false_reg->smin_value = max(false_reg->smin_value, false_smin);
+               true_reg->smax_value = min(true_reg->smax_value, true_smax);
                break;
+       }
        case BPF_JLE:
-               true_reg->umin_value = max(true_reg->umin_value, val);
-               false_reg->umax_value = min(false_reg->umax_value, val - 1);
+       case BPF_JLT:
+       {
+               u64 false_umax = opcode == BPF_JLT ? val    : val - 1;
+               u64 true_umin = opcode == BPF_JLT ? val + 1 : val;
+
+               false_reg->umax_value = min(false_reg->umax_value, false_umax);
+               true_reg->umin_value = max(true_reg->umin_value, true_umin);
                break;
+       }
        case BPF_JSLE:
-               true_reg->smin_value = max_t(s64, true_reg->smin_value, val);
-               false_reg->smax_value = min_t(s64, false_reg->smax_value, val - 1);
+       case BPF_JSLT:
+       {
+               s64 false_smax = opcode == BPF_JSLT ? sval    : sval - 1;
+               s64 true_smin = opcode == BPF_JSLT ? sval + 1 : sval;
+
+               false_reg->smax_value = min(false_reg->smax_value, false_smax);
+               true_reg->smin_value = max(true_reg->smin_value, true_smin);
                break;
+       }
        default:
                break;
        }