}
 EXPORT_SYMBOL(dev_change_proto_down_generic);
 
-u32 __dev_xdp_query(struct net_device *dev, bpf_op_t bpf_op,
-                   enum bpf_netdev_command cmd)
+static enum bpf_xdp_mode dev_xdp_mode(u32 flags)
 {
-       struct netdev_bpf xdp;
+       if (flags & XDP_FLAGS_HW_MODE)
+               return XDP_MODE_HW;
+       if (flags & XDP_FLAGS_DRV_MODE)
+               return XDP_MODE_DRV;
+       return XDP_MODE_SKB;
+}
 
-       if (!bpf_op)
-               return 0;
+static bpf_op_t dev_xdp_bpf_op(struct net_device *dev, enum bpf_xdp_mode mode)
+{
+       switch (mode) {
+       case XDP_MODE_SKB:
+               return generic_xdp_install;
+       case XDP_MODE_DRV:
+       case XDP_MODE_HW:
+               return dev->netdev_ops->ndo_bpf;
+       default:
+               return NULL;
+       };
+}
 
-       memset(&xdp, 0, sizeof(xdp));
-       xdp.command = cmd;
+static struct bpf_prog *dev_xdp_prog(struct net_device *dev,
+                                    enum bpf_xdp_mode mode)
+{
+       return dev->xdp_state[mode].prog;
+}
+
+u32 dev_xdp_prog_id(struct net_device *dev, enum bpf_xdp_mode mode)
+{
+       struct bpf_prog *prog = dev_xdp_prog(dev, mode);
 
-       /* Query must always succeed. */
-       WARN_ON(bpf_op(dev, &xdp) < 0 && cmd == XDP_QUERY_PROG);
+       return prog ? prog->aux->id : 0;
+}
 
-       return xdp.prog_id;
+static void dev_xdp_set_prog(struct net_device *dev, enum bpf_xdp_mode mode,
+                            struct bpf_prog *prog)
+{
+       dev->xdp_state[mode].prog = prog;
 }
 
-static int dev_xdp_install(struct net_device *dev, bpf_op_t bpf_op,
-                          struct netlink_ext_ack *extack, u32 flags,
-                          struct bpf_prog *prog)
+static int dev_xdp_install(struct net_device *dev, enum bpf_xdp_mode mode,
+                          bpf_op_t bpf_op, struct netlink_ext_ack *extack,
+                          u32 flags, struct bpf_prog *prog)
 {
-       bool non_hw = !(flags & XDP_FLAGS_HW_MODE);
-       struct bpf_prog *prev_prog = NULL;
        struct netdev_bpf xdp;
        int err;
 
-       if (non_hw) {
-               prev_prog = bpf_prog_by_id(__dev_xdp_query(dev, bpf_op,
-                                                          XDP_QUERY_PROG));
-               if (IS_ERR(prev_prog))
-                       prev_prog = NULL;
-       }
-
        memset(&xdp, 0, sizeof(xdp));
-       if (flags & XDP_FLAGS_HW_MODE)
-               xdp.command = XDP_SETUP_PROG_HW;
-       else
-               xdp.command = XDP_SETUP_PROG;
+       xdp.command = mode == XDP_MODE_HW ? XDP_SETUP_PROG_HW : XDP_SETUP_PROG;
        xdp.extack = extack;
        xdp.flags = flags;
        xdp.prog = prog;
 
+       /* Drivers assume refcnt is already incremented (i.e, prog pointer is
+        * "moved" into driver), so they don't increment it on their own, but
+        * they do decrement refcnt when program is detached or replaced.
+        * Given net_device also owns link/prog, we need to bump refcnt here
+        * to prevent drivers from underflowing it.
+        */
+       if (prog)
+               bpf_prog_inc(prog);
        err = bpf_op(dev, &xdp);
-       if (!err && non_hw)
-               bpf_prog_change_xdp(prev_prog, prog);
+       if (err) {
+               if (prog)
+                       bpf_prog_put(prog);
+               return err;
+       }
 
-       if (prev_prog)
-               bpf_prog_put(prev_prog);
+       if (mode != XDP_MODE_HW)
+               bpf_prog_change_xdp(dev_xdp_prog(dev, mode), prog);
 
-       return err;
+       return 0;
 }
 
 static void dev_xdp_uninstall(struct net_device *dev)
 {
-       struct netdev_bpf xdp;
-       bpf_op_t ndo_bpf;
+       struct bpf_prog *prog;
+       enum bpf_xdp_mode mode;
+       bpf_op_t bpf_op;
 
-       /* Remove generic XDP */
-       WARN_ON(dev_xdp_install(dev, generic_xdp_install, NULL, 0, NULL));
+       ASSERT_RTNL();
 
-       /* Remove from the driver */
-       ndo_bpf = dev->netdev_ops->ndo_bpf;
-       if (!ndo_bpf)
-               return;
+       for (mode = XDP_MODE_SKB; mode < __MAX_XDP_MODE; mode++) {
+               prog = dev_xdp_prog(dev, mode);
+               if (!prog)
+                       continue;
 
-       memset(&xdp, 0, sizeof(xdp));
-       xdp.command = XDP_QUERY_PROG;
-       WARN_ON(ndo_bpf(dev, &xdp));
-       if (xdp.prog_id)
-               WARN_ON(dev_xdp_install(dev, ndo_bpf, NULL, xdp.prog_flags,
-                                       NULL));
+               bpf_op = dev_xdp_bpf_op(dev, mode);
+               if (!bpf_op)
+                       continue;
 
-       /* Remove HW offload */
-       memset(&xdp, 0, sizeof(xdp));
-       xdp.command = XDP_QUERY_PROG_HW;
-       if (!ndo_bpf(dev, &xdp) && xdp.prog_id)
-               WARN_ON(dev_xdp_install(dev, ndo_bpf, NULL, xdp.prog_flags,
-                                       NULL));
+               WARN_ON(dev_xdp_install(dev, mode, bpf_op, NULL, 0, NULL));
+
+               bpf_prog_put(prog);
+               dev_xdp_set_prog(dev, mode, NULL);
+       }
 }
 
 /**
                      int fd, int expected_fd, u32 flags)
 {
        const struct net_device_ops *ops = dev->netdev_ops;
-       enum bpf_netdev_command query;
+       enum bpf_xdp_mode mode = dev_xdp_mode(flags);
+       bool offload = mode == XDP_MODE_HW;
        u32 prog_id, expected_id = 0;
-       bpf_op_t bpf_op, bpf_chk;
        struct bpf_prog *prog;
-       bool offload;
+       bpf_op_t bpf_op;
        int err;
 
        ASSERT_RTNL();
 
-       offload = flags & XDP_FLAGS_HW_MODE;
-       query = offload ? XDP_QUERY_PROG_HW : XDP_QUERY_PROG;
-
-       bpf_op = bpf_chk = ops->ndo_bpf;
-       if (!bpf_op && (flags & (XDP_FLAGS_DRV_MODE | XDP_FLAGS_HW_MODE))) {
+       bpf_op = dev_xdp_bpf_op(dev, mode);
+       if (!bpf_op) {
                NL_SET_ERR_MSG(extack, "underlying driver does not support XDP in native mode");
                return -EOPNOTSUPP;
        }
-       if (!bpf_op || (flags & XDP_FLAGS_SKB_MODE))
-               bpf_op = generic_xdp_install;
-       if (bpf_op == bpf_chk)
-               bpf_chk = generic_xdp_install;
 
-       prog_id = __dev_xdp_query(dev, bpf_op, query);
+       prog_id = dev_xdp_prog_id(dev, mode);
        if (flags & XDP_FLAGS_REPLACE) {
                if (expected_fd >= 0) {
                        prog = bpf_prog_get_type_dev(expected_fd,
                }
        }
        if (fd >= 0) {
-               if (!offload && __dev_xdp_query(dev, bpf_chk, XDP_QUERY_PROG)) {
-                       NL_SET_ERR_MSG(extack, "native and generic XDP can't be active at the same time");
+               enum bpf_xdp_mode other_mode = mode == XDP_MODE_SKB
+                                              ? XDP_MODE_DRV : XDP_MODE_SKB;
+
+               if (!offload && dev_xdp_prog_id(dev, other_mode)) {
+                       NL_SET_ERR_MSG(extack, "Native and generic XDP can't be active at the same time");
                        return -EEXIST;
                }
 
                        return PTR_ERR(prog);
 
                if (!offload && bpf_prog_is_dev_bound(prog->aux)) {
-                       NL_SET_ERR_MSG(extack, "using device-bound program without HW_MODE flag is not supported");
+                       NL_SET_ERR_MSG(extack, "Using device-bound program without HW_MODE flag is not supported");
                        bpf_prog_put(prog);
                        return -EINVAL;
                }
                prog = NULL;
        }
 
-       err = dev_xdp_install(dev, bpf_op, extack, flags, prog);
-       if (err < 0 && prog)
+       err = dev_xdp_install(dev, mode, bpf_op, extack, flags, prog);
+       if (err < 0 && prog) {
                bpf_prog_put(prog);
+               return err;
+       }
+       dev_xdp_set_prog(dev, mode, prog);
 
-       return err;
+       return 0;
 }
 
 /**