#include <linux/bpf.h>
 #include <linux/bpf_trace.h>
 #include <linux/bpf_lirc.h>
+#include <linux/bpf_verifier.h>
 #include <linux/btf.h>
 #include <linux/syscalls.h>
 #include <linux/slab.h>
        .fill_link_info = bpf_tracing_link_fill_link_info,
 };
 
-static int bpf_tracing_prog_attach(struct bpf_prog *prog)
+static int bpf_tracing_prog_attach(struct bpf_prog *prog,
+                                  int tgt_prog_fd,
+                                  u32 btf_id)
 {
        struct bpf_link_primer link_primer;
        struct bpf_prog *tgt_prog = NULL;
+       struct bpf_trampoline *tr = NULL;
        struct bpf_tracing_link *link;
-       struct bpf_trampoline *tr;
+       u64 key = 0;
        int err;
 
        switch (prog->type) {
                goto out_put_prog;
        }
 
+       if (!!tgt_prog_fd != !!btf_id) {
+               err = -EINVAL;
+               goto out_put_prog;
+       }
+
+       if (tgt_prog_fd) {
+               /* For now we only allow new targets for BPF_PROG_TYPE_EXT */
+               if (prog->type != BPF_PROG_TYPE_EXT) {
+                       err = -EINVAL;
+                       goto out_put_prog;
+               }
+
+               tgt_prog = bpf_prog_get(tgt_prog_fd);
+               if (IS_ERR(tgt_prog)) {
+                       err = PTR_ERR(tgt_prog);
+                       tgt_prog = NULL;
+                       goto out_put_prog;
+               }
+
+               key = bpf_trampoline_compute_key(tgt_prog, btf_id);
+       }
+
        link = kzalloc(sizeof(*link), GFP_USER);
        if (!link) {
                err = -ENOMEM;
 
        mutex_lock(&prog->aux->dst_mutex);
 
-       if (!prog->aux->dst_trampoline) {
+       /* There are a few possible cases here:
+        *
+        * - if prog->aux->dst_trampoline is set, the program was just loaded
+        *   and not yet attached to anything, so we can use the values stored
+        *   in prog->aux
+        *
+        * - if prog->aux->dst_trampoline is NULL, the program has already been
+         *   attached to a target and its initial target was cleared (below)
+        *
+        * - if tgt_prog != NULL, the caller specified tgt_prog_fd +
+        *   target_btf_id using the link_create API.
+        *
+        * - if tgt_prog == NULL when this function was called using the old
+         *   raw_tracepoint_open API, and we need a target from prog->aux
+         *
+         * The combination of no saved target in prog->aux, and no target
+         * specified on load is illegal, and we reject that here.
+        */
+       if (!prog->aux->dst_trampoline && !tgt_prog) {
                err = -ENOENT;
                goto out_unlock;
        }
-       tr = prog->aux->dst_trampoline;
-       tgt_prog = prog->aux->dst_prog;
+
+       if (!prog->aux->dst_trampoline ||
+           (key && key != prog->aux->dst_trampoline->key)) {
+               /* If there is no saved target, or the specified target is
+                * different from the destination specified at load time, we
+                * need a new trampoline and a check for compatibility
+                */
+               struct bpf_attach_target_info tgt_info = {};
+
+               err = bpf_check_attach_target(NULL, prog, tgt_prog, btf_id,
+                                             &tgt_info);
+               if (err)
+                       goto out_unlock;
+
+               tr = bpf_trampoline_get(key, &tgt_info);
+               if (!tr) {
+                       err = -ENOMEM;
+                       goto out_unlock;
+               }
+       } else {
+               /* The caller didn't specify a target, or the target was the
+                * same as the destination supplied during program load. This
+                * means we can reuse the trampoline and reference from program
+                * load time, and there is no need to allocate a new one. This
+                * can only happen once for any program, as the saved values in
+                * prog->aux are cleared below.
+                */
+               tr = prog->aux->dst_trampoline;
+               tgt_prog = prog->aux->dst_prog;
+       }
 
        err = bpf_link_prime(&link->link, &link_primer);
        if (err)
        link->tgt_prog = tgt_prog;
        link->trampoline = tr;
 
+       /* Always clear the trampoline and target prog from prog->aux to make
+        * sure the original attach destination is not kept alive after a
+        * program is (re-)attached to another target.
+        */
+       if (prog->aux->dst_prog &&
+           (tgt_prog_fd || tr != prog->aux->dst_trampoline))
+               /* got extra prog ref from syscall, or attaching to different prog */
+               bpf_prog_put(prog->aux->dst_prog);
+       if (prog->aux->dst_trampoline && tr != prog->aux->dst_trampoline)
+               /* we allocated a new trampoline, so free the old one */
+               bpf_trampoline_put(prog->aux->dst_trampoline);
+
        prog->aux->dst_prog = NULL;
        prog->aux->dst_trampoline = NULL;
        mutex_unlock(&prog->aux->dst_mutex);
 
        return bpf_link_settle(&link_primer);
 out_unlock:
+       if (tr && tr != prog->aux->dst_trampoline)
+               bpf_trampoline_put(tr);
        mutex_unlock(&prog->aux->dst_mutex);
        kfree(link);
 out_put_prog:
+       if (tgt_prog_fd && tgt_prog)
+               bpf_prog_put(tgt_prog);
        bpf_prog_put(prog);
        return err;
 }
                        tp_name = prog->aux->attach_func_name;
                        break;
                }
-               return bpf_tracing_prog_attach(prog);
+               return bpf_tracing_prog_attach(prog, 0, 0);
        case BPF_PROG_TYPE_RAW_TRACEPOINT:
        case BPF_PROG_TYPE_RAW_TRACEPOINT_WRITABLE:
                if (strncpy_from_user(buf,
 
 static int tracing_bpf_link_attach(const union bpf_attr *attr, struct bpf_prog *prog)
 {
-       if (attr->link_create.attach_type == BPF_TRACE_ITER &&
-           prog->expected_attach_type == BPF_TRACE_ITER)
-               return bpf_iter_link_attach(attr, prog);
+       if (attr->link_create.attach_type != prog->expected_attach_type)
+               return -EINVAL;
 
+       if (prog->expected_attach_type == BPF_TRACE_ITER)
+               return bpf_iter_link_attach(attr, prog);
+       else if (prog->type == BPF_PROG_TYPE_EXT)
+               return bpf_tracing_prog_attach(prog,
+                                              attr->link_create.target_fd,
+                                              attr->link_create.target_btf_id);
        return -EINVAL;
 }
 
        if (CHECK_ATTR(BPF_LINK_CREATE))
                return -EINVAL;
 
-       ptype = attach_type_to_prog_type(attr->link_create.attach_type);
-       if (ptype == BPF_PROG_TYPE_UNSPEC)
-               return -EINVAL;
-
-       prog = bpf_prog_get_type(attr->link_create.prog_fd, ptype);
+       prog = bpf_prog_get(attr->link_create.prog_fd);
        if (IS_ERR(prog))
                return PTR_ERR(prog);
 
        ret = bpf_prog_attach_check_attach_type(prog,
                                                attr->link_create.attach_type);
        if (ret)
-               goto err_out;
+               goto out;
+
+       if (prog->type == BPF_PROG_TYPE_EXT) {
+               ret = tracing_bpf_link_attach(attr, prog);
+               goto out;
+       }
+
+       ptype = attach_type_to_prog_type(attr->link_create.attach_type);
+       if (ptype == BPF_PROG_TYPE_UNSPEC || ptype != prog->type) {
+               ret = -EINVAL;
+               goto out;
+       }
 
        switch (ptype) {
        case BPF_PROG_TYPE_CGROUP_SKB:
                ret = -EINVAL;
        }
 
-err_out:
+out:
        if (ret < 0)
                bpf_prog_put(prog);
        return ret;