return data;
 }
 
+static void *bpf_ctx_init(const union bpf_attr *kattr, u32 max_size)
+{
+       void __user *data_in = u64_to_user_ptr(kattr->test.ctx_in);
+       void __user *data_out = u64_to_user_ptr(kattr->test.ctx_out);
+       u32 size = kattr->test.ctx_size_in;
+       void *data;
+       int err;
+
+       if (!data_in && !data_out)
+               return NULL;
+
+       data = kzalloc(max_size, GFP_USER);
+       if (!data)
+               return ERR_PTR(-ENOMEM);
+
+       if (data_in) {
+               err = bpf_check_uarg_tail_zero(data_in, max_size, size);
+               if (err) {
+                       kfree(data);
+                       return ERR_PTR(err);
+               }
+
+               size = min_t(u32, max_size, size);
+               if (copy_from_user(data, data_in, size)) {
+                       kfree(data);
+                       return ERR_PTR(-EFAULT);
+               }
+       }
+       return data;
+}
+
+static int bpf_ctx_finish(const union bpf_attr *kattr,
+                         union bpf_attr __user *uattr, const void *data,
+                         u32 size)
+{
+       void __user *data_out = u64_to_user_ptr(kattr->test.ctx_out);
+       int err = -EFAULT;
+       u32 copy_size = size;
+
+       if (!data || !data_out)
+               return 0;
+
+       if (copy_size > kattr->test.ctx_size_out) {
+               copy_size = kattr->test.ctx_size_out;
+               err = -ENOSPC;
+       }
+
+       if (copy_to_user(data_out, data, copy_size))
+               goto out;
+       if (copy_to_user(&uattr->test.ctx_size_out, &size, sizeof(size)))
+               goto out;
+       if (err != -ENOSPC)
+               err = 0;
+out:
+       return err;
+}
+
+/**
+ * range_is_zero - test whether buffer is initialized
+ * @buf: buffer to check
+ * @from: check from this position
+ * @to: check up until (excluding) this position
+ *
+ * This function returns true if the there is a non-zero byte
+ * in the buf in the range [from,to).
+ */
+static inline bool range_is_zero(void *buf, size_t from, size_t to)
+{
+       return !memchr_inv((u8 *)buf + from, 0, to - from);
+}
+
+static int convert___skb_to_skb(struct sk_buff *skb, struct __sk_buff *__skb)
+{
+       struct qdisc_skb_cb *cb = (struct qdisc_skb_cb *)skb->cb;
+
+       if (!__skb)
+               return 0;
+
+       /* make sure the fields we don't use are zeroed */
+       if (!range_is_zero(__skb, 0, offsetof(struct __sk_buff, priority)))
+               return -EINVAL;
+
+       /* priority is allowed */
+
+       if (!range_is_zero(__skb, offsetof(struct __sk_buff, priority) +
+                          FIELD_SIZEOF(struct __sk_buff, priority),
+                          offsetof(struct __sk_buff, cb)))
+               return -EINVAL;
+
+       /* cb is allowed */
+
+       if (!range_is_zero(__skb, offsetof(struct __sk_buff, cb) +
+                          FIELD_SIZEOF(struct __sk_buff, cb),
+                          sizeof(struct __sk_buff)))
+               return -EINVAL;
+
+       skb->priority = __skb->priority;
+       memcpy(&cb->data, __skb->cb, QDISC_CB_PRIV_LEN);
+
+       return 0;
+}
+
+static void convert_skb_to___skb(struct sk_buff *skb, struct __sk_buff *__skb)
+{
+       struct qdisc_skb_cb *cb = (struct qdisc_skb_cb *)skb->cb;
+
+       if (!__skb)
+               return;
+
+       __skb->priority = skb->priority;
+       memcpy(__skb->cb, &cb->data, QDISC_CB_PRIV_LEN);
+}
+
 int bpf_prog_test_run_skb(struct bpf_prog *prog, const union bpf_attr *kattr,
                          union bpf_attr __user *uattr)
 {
        bool is_l2 = false, is_direct_pkt_access = false;
        u32 size = kattr->test.data_size_in;
        u32 repeat = kattr->test.repeat;
+       struct __sk_buff *ctx = NULL;
        u32 retval, duration;
        int hh_len = ETH_HLEN;
        struct sk_buff *skb;
        if (IS_ERR(data))
                return PTR_ERR(data);
 
+       ctx = bpf_ctx_init(kattr, sizeof(struct __sk_buff));
+       if (IS_ERR(ctx)) {
+               kfree(data);
+               return PTR_ERR(ctx);
+       }
+
        switch (prog->type) {
        case BPF_PROG_TYPE_SCHED_CLS:
        case BPF_PROG_TYPE_SCHED_ACT:
        sk = kzalloc(sizeof(struct sock), GFP_USER);
        if (!sk) {
                kfree(data);
+               kfree(ctx);
                return -ENOMEM;
        }
        sock_net_set(sk, current->nsproxy->net_ns);
        skb = build_skb(data, 0);
        if (!skb) {
                kfree(data);
+               kfree(ctx);
                kfree(sk);
                return -ENOMEM;
        }
                __skb_push(skb, hh_len);
        if (is_direct_pkt_access)
                bpf_compute_data_pointers(skb);
+       ret = convert___skb_to_skb(skb, ctx);
+       if (ret)
+               goto out;
        ret = bpf_test_run(prog, skb, repeat, &retval, &duration);
-       if (ret) {
-               kfree_skb(skb);
-               kfree(sk);
-               return ret;
-       }
+       if (ret)
+               goto out;
        if (!is_l2) {
                if (skb_headroom(skb) < hh_len) {
                        int nhead = HH_DATA_ALIGN(hh_len - skb_headroom(skb));
 
                        if (pskb_expand_head(skb, nhead, 0, GFP_USER)) {
-                               kfree_skb(skb);
-                               kfree(sk);
-                               return -ENOMEM;
+                               ret = -ENOMEM;
+                               goto out;
                        }
                }
                memset(__skb_push(skb, hh_len), 0, hh_len);
        }
+       convert_skb_to___skb(skb, ctx);
 
        size = skb->len;
        /* bpf program can never convert linear skb to non-linear */
        if (WARN_ON_ONCE(skb_is_nonlinear(skb)))
                size = skb_headlen(skb);
        ret = bpf_test_finish(kattr, uattr, skb->data, size, retval, duration);
+       if (!ret)
+               ret = bpf_ctx_finish(kattr, uattr, ctx,
+                                    sizeof(struct __sk_buff));
+out:
        kfree_skb(skb);
        kfree(sk);
+       kfree(ctx);
        return ret;
 }