}
 }
 
+static bool idset_contains(struct bpf_idset *s, u32 id)
+{
+       u32 i;
+
+       for (i = 0; i < s->count; ++i)
+               if (s->ids[i] == id)
+                       return true;
+
+       return false;
+}
+
+static int idset_push(struct bpf_idset *s, u32 id)
+{
+       if (WARN_ON_ONCE(s->count >= ARRAY_SIZE(s->ids)))
+               return -EFAULT;
+       s->ids[s->count++] = id;
+       return 0;
+}
+
+static void idset_reset(struct bpf_idset *s)
+{
+       s->count = 0;
+}
+
+/* Collect a set of IDs for all registers currently marked as precise in env->bt.
+ * Mark all registers with these IDs as precise.
+ */
+static int mark_precise_scalar_ids(struct bpf_verifier_env *env, struct bpf_verifier_state *st)
+{
+       struct bpf_idset *precise_ids = &env->idset_scratch;
+       struct backtrack_state *bt = &env->bt;
+       struct bpf_func_state *func;
+       struct bpf_reg_state *reg;
+       DECLARE_BITMAP(mask, 64);
+       int i, fr;
+
+       idset_reset(precise_ids);
+
+       for (fr = bt->frame; fr >= 0; fr--) {
+               func = st->frame[fr];
+
+               bitmap_from_u64(mask, bt_frame_reg_mask(bt, fr));
+               for_each_set_bit(i, mask, 32) {
+                       reg = &func->regs[i];
+                       if (!reg->id || reg->type != SCALAR_VALUE)
+                               continue;
+                       if (idset_push(precise_ids, reg->id))
+                               return -EFAULT;
+               }
+
+               bitmap_from_u64(mask, bt_frame_stack_mask(bt, fr));
+               for_each_set_bit(i, mask, 64) {
+                       if (i >= func->allocated_stack / BPF_REG_SIZE)
+                               break;
+                       if (!is_spilled_scalar_reg(&func->stack[i]))
+                               continue;
+                       reg = &func->stack[i].spilled_ptr;
+                       if (!reg->id)
+                               continue;
+                       if (idset_push(precise_ids, reg->id))
+                               return -EFAULT;
+               }
+       }
+
+       for (fr = 0; fr <= st->curframe; ++fr) {
+               func = st->frame[fr];
+
+               for (i = BPF_REG_0; i < BPF_REG_10; ++i) {
+                       reg = &func->regs[i];
+                       if (!reg->id)
+                               continue;
+                       if (!idset_contains(precise_ids, reg->id))
+                               continue;
+                       bt_set_frame_reg(bt, fr, i);
+               }
+               for (i = 0; i < func->allocated_stack / BPF_REG_SIZE; ++i) {
+                       if (!is_spilled_scalar_reg(&func->stack[i]))
+                               continue;
+                       reg = &func->stack[i].spilled_ptr;
+                       if (!reg->id)
+                               continue;
+                       if (!idset_contains(precise_ids, reg->id))
+                               continue;
+                       bt_set_frame_slot(bt, fr, i);
+               }
+       }
+
+       return 0;
+}
+
 /*
  * __mark_chain_precision() backtracks BPF program instruction sequence and
  * chain of verifier states making sure that register *regno* (if regno >= 0)
                                bt->frame, last_idx, first_idx, subseq_idx);
                }
 
+               /* If some register with scalar ID is marked as precise,
+                * make sure that all registers sharing this ID are also precise.
+                * This is needed to estimate effect of find_equal_scalars().
+                * Do this at the last instruction of each state,
+                * bpf_reg_state::id fields are valid for these instructions.
+                *
+                * Allows to track precision in situation like below:
+                *
+                *     r2 = unknown value
+                *     ...
+                *   --- state #0 ---
+                *     ...
+                *     r1 = r2                 // r1 and r2 now share the same ID
+                *     ...
+                *   --- state #1 {r1.id = A, r2.id = A} ---
+                *     ...
+                *     if (r2 > 10) goto exit; // find_equal_scalars() assigns range to r1
+                *     ...
+                *   --- state #2 {r1.id = A, r2.id = A} ---
+                *     r3 = r10
+                *     r3 += r1                // need to mark both r1 and r2
+                */
+               if (mark_precise_scalar_ids(env, st))
+                       return -EFAULT;
+
                if (last_idx < 0) {
                        /* we are at the entry into subprog, which
                         * is expected for global funcs, but only if
 
        mark_precise: frame0: regs=r2 stack= before 20\
        mark_precise: frame0: parent state regs=r2 stack=:\
        mark_precise: frame0: last_idx 19 first_idx 10\
-       mark_precise: frame0: regs=r2 stack= before 19\
+       mark_precise: frame0: regs=r2,r9 stack= before 19\
        mark_precise: frame0: regs=r9 stack= before 18\
        mark_precise: frame0: regs=r8,r9 stack= before 17\
        mark_precise: frame0: regs=r0,r9 stack= before 15\
        mark_precise: frame0: regs=r2 stack= before 22\
        mark_precise: frame0: parent state regs=r2 stack=:\
        mark_precise: frame0: last_idx 20 first_idx 20\
-       mark_precise: frame0: regs=r2 stack= before 20\
-       mark_precise: frame0: parent state regs=r2 stack=:\
+       mark_precise: frame0: regs=r2,r9 stack= before 20\
+       mark_precise: frame0: parent state regs=r2,r9 stack=:\
        mark_precise: frame0: last_idx 19 first_idx 17\
-       mark_precise: frame0: regs=r2 stack= before 19\
+       mark_precise: frame0: regs=r2,r9 stack= before 19\
        mark_precise: frame0: regs=r9 stack= before 18\
        mark_precise: frame0: regs=r8,r9 stack= before 17\
        mark_precise: frame0: parent state regs= stack=:",