struct bpf_prog *bpf_prog_get(u32 ufd);
 struct bpf_prog *bpf_prog_get_type(u32 ufd, enum bpf_prog_type type);
 struct bpf_prog *bpf_prog_get_type_dev(u32 ufd, enum bpf_prog_type type,
-                                      struct net_device *netdev);
+                                      bool attach_drv);
 struct bpf_prog * __must_check bpf_prog_add(struct bpf_prog *prog, int i);
 void bpf_prog_sub(struct bpf_prog *prog, int i);
 struct bpf_prog * __must_check bpf_prog_inc(struct bpf_prog *prog);
 
 static inline struct bpf_prog *bpf_prog_get_type_dev(u32 ufd,
                                                     enum bpf_prog_type type,
-                                                    struct net_device *netdev)
+                                                    bool attach_drv)
 {
        return ERR_PTR(-EOPNOTSUPP);
 }
 
 }
 EXPORT_SYMBOL_GPL(bpf_prog_inc_not_zero);
 
-static bool bpf_prog_can_attach(struct bpf_prog *prog,
-                               enum bpf_prog_type *attach_type,
-                               struct net_device *netdev)
+static bool bpf_prog_get_ok(struct bpf_prog *prog,
+                           enum bpf_prog_type *attach_type, bool attach_drv)
 {
-       struct bpf_dev_offload *offload = prog->aux->offload;
+       /* not an attachment, just a refcount inc, always allow */
+       if (!attach_type)
+               return true;
 
        if (prog->type != *attach_type)
                return false;
-       if (offload && offload->netdev != netdev)
+       if (bpf_prog_is_dev_bound(prog->aux) && !attach_drv)
                return false;
 
        return true;
 }
 
 static struct bpf_prog *__bpf_prog_get(u32 ufd, enum bpf_prog_type *attach_type,
-                                      struct net_device *netdev)
+                                      bool attach_drv)
 {
        struct fd f = fdget(ufd);
        struct bpf_prog *prog;
        prog = ____bpf_prog_get(f);
        if (IS_ERR(prog))
                return prog;
-       if (attach_type && !bpf_prog_can_attach(prog, attach_type, netdev)) {
+       if (!bpf_prog_get_ok(prog, attach_type, attach_drv)) {
                prog = ERR_PTR(-EINVAL);
                goto out;
        }
 
 struct bpf_prog *bpf_prog_get(u32 ufd)
 {
-       return __bpf_prog_get(ufd, NULL, NULL);
+       return __bpf_prog_get(ufd, NULL, false);
 }
 
 struct bpf_prog *bpf_prog_get_type(u32 ufd, enum bpf_prog_type type)
 {
-       struct bpf_prog *prog = __bpf_prog_get(ufd, &type, NULL);
+       struct bpf_prog *prog = __bpf_prog_get(ufd, &type, false);
 
        if (!IS_ERR(prog))
                trace_bpf_prog_get_type(prog);
 EXPORT_SYMBOL_GPL(bpf_prog_get_type);
 
 struct bpf_prog *bpf_prog_get_type_dev(u32 ufd, enum bpf_prog_type type,
-                                      struct net_device *netdev)
+                                      bool attach_drv)
 {
-       struct bpf_prog *prog = __bpf_prog_get(ufd, &type, netdev);
+       struct bpf_prog *prog = __bpf_prog_get(ufd, &type, attach_drv);
 
        if (!IS_ERR(prog))
                trace_bpf_prog_get_type(prog);
 
                    __dev_xdp_attached(dev, bpf_op, NULL))
                        return -EBUSY;
 
-               if (bpf_op == ops->ndo_bpf)
-                       prog = bpf_prog_get_type_dev(fd, BPF_PROG_TYPE_XDP,
-                                                    dev);
-               else
-                       prog = bpf_prog_get_type(fd, BPF_PROG_TYPE_XDP);
+               prog = bpf_prog_get_type_dev(fd, BPF_PROG_TYPE_XDP,
+                                            bpf_op == ops->ndo_bpf);
                if (IS_ERR(prog))
                        return PTR_ERR(prog);
        }
 
 {
        struct bpf_prog *fp;
        char *name = NULL;
+       bool skip_sw;
        u32 bpf_fd;
 
        bpf_fd = nla_get_u32(tb[TCA_BPF_FD]);
+       skip_sw = gen_flags & TCA_CLS_FLAGS_SKIP_SW;
 
-       if (gen_flags & TCA_CLS_FLAGS_SKIP_SW)
-               fp = bpf_prog_get_type_dev(bpf_fd, BPF_PROG_TYPE_SCHED_CLS,
-                                          qdisc_dev(tp->q));
-       else
-               fp = bpf_prog_get_type(bpf_fd, BPF_PROG_TYPE_SCHED_CLS);
+       fp = bpf_prog_get_type_dev(bpf_fd, BPF_PROG_TYPE_SCHED_CLS, skip_sw);
        if (IS_ERR(fp))
                return PTR_ERR(fp);