In the current code, msg->data is set as sg_virt(&sg[i]) + start - offset
and msg->data_end relative to it as msg->data + bytes. Using iterator i
to point to the updated starting scatterlist element holds true for some
cases, however not for all where we'd end up pointing out of bounds. It
is /correct/ for these ones:
1) When first finding the starting scatterlist element (sge) where we
   find that the page is already privately owned by the msg and where
   the requested bytes and headroom fit into the sge's length.
However, it's /incorrect/ for the following ones:
2) After we made the requested area private and updated the newly allocated
   page into first_sg slot of the scatterlist ring; when we find that no
   shift repair of the ring is needed where we bail out updating msg->data
   and msg->data_end. At that point i will point to last_sg, which in this
   case is the next elem of first_sg in the ring. The sge at that point
   might as well be invalid (e.g. i == msg->sg_end), which we use for
   setting the range of sg_virt(&sg[i]). The correct one would have been
   first_sg.
3) Similar as in 2) but when we find that a shift repair of the ring is
   needed. In this case we fix up all sges and stop once we've reached the
   end. In this case i will point to will point to the new msg->sg_end,
   and the sge at that point will be invalid. Again here the requested
   range sits in first_sg.
Fixes: 015632bb30da ("bpf: sk_msg program helper bpf_sk_msg_pull_data")
Signed-off-by: Daniel Borkmann <daniel@iogearbox.net>
Acked-by: John Fastabend <john.fastabend@gmail.com>
Signed-off-by: Alexei Starovoitov <ast@kernel.org>
        if (unlikely(start >= offset + len))
                return -EINVAL;
 
+       first_sg = i;
        /* The start may point into the sg element so we need to also
         * account for the headroom.
         */
        if (!msg->sg_copy[i] && bytes_sg_total <= len)
                goto out;
 
-       first_sg = i;
-
        /* At this point we need to linearize multiple scatterlist
         * elements or a single shared page. Either way we need to
         * copy into a linear buffer exclusively owned by BPF. Then
        if (msg->sg_end < 0)
                msg->sg_end += MAX_SKB_FRAGS;
 out:
-       msg->data = sg_virt(&sg[i]) + start - offset;
+       msg->data = sg_virt(&sg[first_sg]) + start - offset;
        msg->data_end = msg->data + bytes;
 
        return 0;