u32 used_map_cnt;               /* number of used maps */
        u32 id_gen;                     /* used to generate unique reg IDs */
        bool allow_ptr_leaks;
+       bool seen_direct_write;
 };
 
 #define BPF_COMPLEXITY_LIMIT_INSNS     65536
 struct bpf_call_arg_meta {
        struct bpf_map *map_ptr;
        bool raw_mode;
+       bool pkt_access;
        int regno;
        int access_size;
 };
 
 #define MAX_PACKET_OFF 0xffff
 
-static bool may_write_pkt_data(enum bpf_prog_type type)
+static bool may_access_direct_pkt_data(struct verifier_env *env,
+                                      const struct bpf_call_arg_meta *meta)
 {
-       switch (type) {
+       switch (env->prog->type) {
+       case BPF_PROG_TYPE_SCHED_CLS:
+       case BPF_PROG_TYPE_SCHED_ACT:
        case BPF_PROG_TYPE_XDP:
+               if (meta)
+                       return meta->pkt_access;
+
+               env->seen_direct_write = true;
                return true;
        default:
                return false;
                        err = check_stack_read(state, off, size, value_regno);
                }
        } else if (state->regs[regno].type == PTR_TO_PACKET) {
-               if (t == BPF_WRITE && !may_write_pkt_data(env->prog->type)) {
+               if (t == BPF_WRITE && !may_access_direct_pkt_data(env, NULL)) {
                        verbose("cannot write into packet\n");
                        return -EACCES;
                }
                return 0;
        }
 
-       if (type == PTR_TO_PACKET && !may_write_pkt_data(env->prog->type)) {
-               verbose("helper access to the packet is not allowed for clsact\n");
+       if (type == PTR_TO_PACKET && !may_access_direct_pkt_data(env, meta)) {
+               verbose("helper access to the packet is not allowed\n");
                return -EACCES;
        }
 
        changes_data = bpf_helper_changes_skb_data(fn->func);
 
        memset(&meta, 0, sizeof(meta));
+       meta.pkt_access = fn->pkt_access;
 
        /* We only support one arg being in raw mode at the moment, which
         * is sufficient for the helper functions we have right now.
  */
 static int convert_ctx_accesses(struct verifier_env *env)
 {
-       struct bpf_insn *insn = env->prog->insnsi;
-       int insn_cnt = env->prog->len;
-       struct bpf_insn insn_buf[16];
+       const struct bpf_verifier_ops *ops = env->prog->aux->ops;
+       struct bpf_insn insn_buf[16], *insn;
        struct bpf_prog *new_prog;
        enum bpf_access_type type;
-       int i;
+       int i, insn_cnt, cnt;
 
-       if (!env->prog->aux->ops->convert_ctx_access)
+       if (ops->gen_prologue) {
+               cnt = ops->gen_prologue(insn_buf, env->seen_direct_write,
+                                       env->prog);
+               if (cnt >= ARRAY_SIZE(insn_buf)) {
+                       verbose("bpf verifier is misconfigured\n");
+                       return -EINVAL;
+               } else if (cnt) {
+                       new_prog = bpf_patch_insn_single(env->prog, 0,
+                                                        insn_buf, cnt);
+                       if (!new_prog)
+                               return -ENOMEM;
+                       env->prog = new_prog;
+               }
+       }
+
+       if (!ops->convert_ctx_access)
                return 0;
 
+       insn_cnt = env->prog->len;
+       insn = env->prog->insnsi;
+
        for (i = 0; i < insn_cnt; i++, insn++) {
-               u32 insn_delta, cnt;
+               u32 insn_delta;
 
                if (insn->code == (BPF_LDX | BPF_MEM | BPF_W) ||
                    insn->code == (BPF_LDX | BPF_MEM | BPF_DW))
                        continue;
                }
 
-               cnt = env->prog->aux->ops->
-                       convert_ctx_access(type, insn->dst_reg, insn->src_reg,
-                                          insn->off, insn_buf, env->prog);
+               cnt = ops->convert_ctx_access(type, insn->dst_reg, insn->src_reg,
+                                             insn->off, insn_buf, env->prog);
                if (cnt == 0 || cnt >= ARRAY_SIZE(insn_buf)) {
                        verbose("bpf verifier is misconfigured\n");
                        return -EINVAL;
 
        return err;
 }
 
+static int bpf_try_make_head_writable(struct sk_buff *skb)
+{
+       return bpf_try_make_writable(skb, skb_headlen(skb));
+}
+
 static inline void bpf_push_mac_rcsum(struct sk_buff *skb)
 {
        if (skb_at_tc_ingress(skb))
        .arg4_type      = ARG_CONST_STACK_SIZE,
 };
 
+BPF_CALL_2(bpf_skb_pull_data, struct sk_buff *, skb, u32, len)
+{
+       /* Idea is the following: should the needed direct read/write
+        * test fail during runtime, we can pull in more data and redo
+        * again, since implicitly, we invalidate previous checks here.
+        *
+        * Or, since we know how much we need to make read/writeable,
+        * this can be done once at the program beginning for direct
+        * access case. By this we overcome limitations of only current
+        * headroom being accessible.
+        */
+       return bpf_try_make_writable(skb, len ? : skb_headlen(skb));
+}
+
+static const struct bpf_func_proto bpf_skb_pull_data_proto = {
+       .func           = bpf_skb_pull_data,
+       .gpl_only       = false,
+       .ret_type       = RET_INTEGER,
+       .arg1_type      = ARG_PTR_TO_CTX,
+       .arg2_type      = ARG_ANYTHING,
+};
+
 BPF_CALL_5(bpf_l3_csum_replace, struct sk_buff *, skb, u32, offset,
           u64, from, u64, to, u64, flags)
 {
 static const struct bpf_func_proto bpf_csum_diff_proto = {
        .func           = bpf_csum_diff,
        .gpl_only       = false,
+       .pkt_access     = true,
        .ret_type       = RET_INTEGER,
        .arg1_type      = ARG_PTR_TO_STACK,
        .arg2_type      = ARG_CONST_STACK_SIZE_OR_ZERO,
        .arg5_type      = ARG_ANYTHING,
 };
 
+BPF_CALL_2(bpf_csum_update, struct sk_buff *, skb, __wsum, csum)
+{
+       /* The interface is to be used in combination with bpf_csum_diff()
+        * for direct packet writes. csum rotation for alignment as well
+        * as emulating csum_sub() can be done from the eBPF program.
+        */
+       if (skb->ip_summed == CHECKSUM_COMPLETE)
+               return (skb->csum = csum_add(skb->csum, csum));
+
+       return -ENOTSUPP;
+}
+
+static const struct bpf_func_proto bpf_csum_update_proto = {
+       .func           = bpf_csum_update,
+       .gpl_only       = false,
+       .ret_type       = RET_INTEGER,
+       .arg1_type      = ARG_PTR_TO_CTX,
+       .arg2_type      = ARG_ANYTHING,
+};
+
 static inline int __bpf_rx_skb(struct net_device *dev, struct sk_buff *skb)
 {
        return dev_forward_skb(dev, skb);
 BPF_CALL_3(bpf_clone_redirect, struct sk_buff *, skb, u32, ifindex, u64, flags)
 {
        struct net_device *dev;
+       struct sk_buff *clone;
+       int ret;
 
        if (unlikely(flags & ~(BPF_F_INGRESS)))
                return -EINVAL;
        if (unlikely(!dev))
                return -EINVAL;
 
-       skb = skb_clone(skb, GFP_ATOMIC);
-       if (unlikely(!skb))
+       clone = skb_clone(skb, GFP_ATOMIC);
+       if (unlikely(!clone))
                return -ENOMEM;
 
-       bpf_push_mac_rcsum(skb);
+       /* For direct write, we need to keep the invariant that the skbs
+        * we're dealing with need to be uncloned. Should uncloning fail
+        * here, we need to free the just generated clone to unclone once
+        * again.
+        */
+       ret = bpf_try_make_head_writable(skb);
+       if (unlikely(ret)) {
+               kfree_skb(clone);
+               return -ENOMEM;
+       }
+
+       bpf_push_mac_rcsum(clone);
 
        return flags & BPF_F_INGRESS ?
-              __bpf_rx_skb(dev, skb) : __bpf_tx_skb(dev, skb);
+              __bpf_rx_skb(dev, clone) : __bpf_tx_skb(dev, clone);
 }
 
 static const struct bpf_func_proto bpf_clone_redirect_proto = {
 
 bool bpf_helper_changes_skb_data(void *func)
 {
-       if (func == bpf_skb_vlan_push)
-               return true;
-       if (func == bpf_skb_vlan_pop)
-               return true;
-       if (func == bpf_skb_store_bytes)
-               return true;
-       if (func == bpf_skb_change_proto)
-               return true;
-       if (func == bpf_skb_change_tail)
-               return true;
-       if (func == bpf_l3_csum_replace)
-               return true;
-       if (func == bpf_l4_csum_replace)
+       if (func == bpf_skb_vlan_push ||
+           func == bpf_skb_vlan_pop ||
+           func == bpf_skb_store_bytes ||
+           func == bpf_skb_change_proto ||
+           func == bpf_skb_change_tail ||
+           func == bpf_skb_pull_data ||
+           func == bpf_l3_csum_replace ||
+           func == bpf_l4_csum_replace)
                return true;
 
        return false;
                return &bpf_skb_store_bytes_proto;
        case BPF_FUNC_skb_load_bytes:
                return &bpf_skb_load_bytes_proto;
+       case BPF_FUNC_skb_pull_data:
+               return &bpf_skb_pull_data_proto;
        case BPF_FUNC_csum_diff:
                return &bpf_csum_diff_proto;
+       case BPF_FUNC_csum_update:
+               return &bpf_csum_update_proto;
        case BPF_FUNC_l3_csum_replace:
                return &bpf_l3_csum_replace_proto;
        case BPF_FUNC_l4_csum_replace:
        return __is_valid_access(off, size, type);
 }
 
+static int tc_cls_act_prologue(struct bpf_insn *insn_buf, bool direct_write,
+                              const struct bpf_prog *prog)
+{
+       struct bpf_insn *insn = insn_buf;
+
+       if (!direct_write)
+               return 0;
+
+       /* if (!skb->cloned)
+        *       goto start;
+        *
+        * (Fast-path, otherwise approximation that we might be
+        *  a clone, do the rest in helper.)
+        */
+       *insn++ = BPF_LDX_MEM(BPF_B, BPF_REG_6, BPF_REG_1, CLONED_OFFSET());
+       *insn++ = BPF_ALU32_IMM(BPF_AND, BPF_REG_6, CLONED_MASK);
+       *insn++ = BPF_JMP_IMM(BPF_JEQ, BPF_REG_6, 0, 7);
+
+       /* ret = bpf_skb_pull_data(skb, 0); */
+       *insn++ = BPF_MOV64_REG(BPF_REG_6, BPF_REG_1);
+       *insn++ = BPF_ALU64_REG(BPF_XOR, BPF_REG_2, BPF_REG_2);
+       *insn++ = BPF_RAW_INSN(BPF_JMP | BPF_CALL, 0, 0, 0,
+                              BPF_FUNC_skb_pull_data);
+       /* if (!ret)
+        *      goto restore;
+        * return TC_ACT_SHOT;
+        */
+       *insn++ = BPF_JMP_IMM(BPF_JEQ, BPF_REG_0, 0, 2);
+       *insn++ = BPF_ALU32_IMM(BPF_MOV, BPF_REG_0, TC_ACT_SHOT);
+       *insn++ = BPF_EXIT_INSN();
+
+       /* restore: */
+       *insn++ = BPF_MOV64_REG(BPF_REG_1, BPF_REG_6);
+       /* start: */
+       *insn++ = prog->insnsi[0];
+
+       return insn - insn_buf;
+}
+
 static bool tc_cls_act_is_valid_access(int off, int size,
                                       enum bpf_access_type type,
                                       enum bpf_reg_type *reg_type)
        .get_func_proto         = tc_cls_act_func_proto,
        .is_valid_access        = tc_cls_act_is_valid_access,
        .convert_ctx_access     = tc_cls_act_convert_ctx_access,
+       .gen_prologue           = tc_cls_act_prologue,
 };
 
 static const struct bpf_verifier_ops xdp_ops = {