#include <linux/netdevice.h>
 #include <linux/printk.h>
 #include <linux/rtnetlink.h>
+#include <linux/rwsem.h>
 
-/* protected by RTNL */
+/* Protects bpf_prog_offload_devs and offload members of all progs.
+ * RTNL lock cannot be taken when holding this lock.
+ */
+static DECLARE_RWSEM(bpf_devs_lock);
 static LIST_HEAD(bpf_prog_offload_devs);
 
 int bpf_prog_offload_init(struct bpf_prog *prog, union bpf_attr *attr)
 {
-       struct net *net = current->nsproxy->net_ns;
        struct bpf_dev_offload *offload;
 
        if (attr->prog_type != BPF_PROG_TYPE_SCHED_CLS &&
        offload->prog = prog;
        init_waitqueue_head(&offload->verifier_done);
 
-       rtnl_lock();
-       offload->netdev = __dev_get_by_index(net, attr->prog_ifindex);
-       if (!offload->netdev) {
-               rtnl_unlock();
-               kfree(offload);
-               return -EINVAL;
-       }
+       offload->netdev = dev_get_by_index(current->nsproxy->net_ns,
+                                          attr->prog_ifindex);
+       if (!offload->netdev)
+               goto err_free;
 
+       down_write(&bpf_devs_lock);
+       if (offload->netdev->reg_state != NETREG_REGISTERED)
+               goto err_unlock;
        prog->aux->offload = offload;
        list_add_tail(&offload->offloads, &bpf_prog_offload_devs);
-       rtnl_unlock();
+       dev_put(offload->netdev);
+       up_write(&bpf_devs_lock);
 
        return 0;
+err_unlock:
+       up_write(&bpf_devs_lock);
+       dev_put(offload->netdev);
+err_free:
+       kfree(offload);
+       return -EINVAL;
 }
 
 static int __bpf_offload_ndo(struct bpf_prog *prog, enum bpf_netdev_command cmd,
        wake_up(&offload->verifier_done);
 
        rtnl_lock();
+       down_write(&bpf_devs_lock);
        __bpf_prog_offload_destroy(prog);
+       up_write(&bpf_devs_lock);
        rtnl_unlock();
 
        kfree(offload);
                if (netdev->reg_state != NETREG_UNREGISTERING)
                        break;
 
+               down_write(&bpf_devs_lock);
                list_for_each_entry_safe(offload, tmp, &bpf_prog_offload_devs,
                                         offloads) {
                        if (offload->netdev == netdev)
                                __bpf_prog_offload_destroy(offload->prog);
                }
+               up_write(&bpf_devs_lock);
                break;
        default:
                break;