int xfrm_register_type_offload(const struct xfrm_type_offload *type, unsigned short family);
 void xfrm_unregister_type_offload(const struct xfrm_type_offload *type, unsigned short family);
+void xfrm_set_type_offload(struct xfrm_state *x);
+static inline void xfrm_unset_type_offload(struct xfrm_state *x)
+{
+       if (!x->type_offload)
+               return;
+
+       module_put(x->type_offload->owner);
+       x->type_offload = NULL;
+}
 
 /**
  * struct xfrm_mode_cbs - XFRM mode callbacks
 u32 xfrm_replay_seqhi(struct xfrm_state *x, __be32 net_seq);
 int xfrm_init_replay(struct xfrm_state *x, struct netlink_ext_ack *extack);
 u32 xfrm_state_mtu(struct xfrm_state *x, int mtu);
-int __xfrm_init_state(struct xfrm_state *x, bool init_replay, bool offload,
+int __xfrm_init_state(struct xfrm_state *x, bool init_replay,
                      struct netlink_ext_ack *extack);
 int xfrm_init_state(struct xfrm_state *x);
 int xfrm_input(struct sk_buff *skb, int nexthdr, __be32 spi, int encap_type);
 
        xfrm_address_t *daddr;
        bool is_packet_offload;
 
-       if (!x->type_offload) {
-               NL_SET_ERR_MSG(extack, "Type doesn't support offload");
-               return -EINVAL;
-       }
-
        if (xuo->flags &
            ~(XFRM_OFFLOAD_IPV6 | XFRM_OFFLOAD_INBOUND | XFRM_OFFLOAD_PACKET)) {
                NL_SET_ERR_MSG(extack, "Unrecognized flags in offload request");
                return -EINVAL;
        }
 
+       xfrm_set_type_offload(x);
+       if (!x->type_offload) {
+               NL_SET_ERR_MSG(extack, "Type doesn't support offload");
+               dev_put(dev);
+               return -EINVAL;
+       }
+
        xso->dev = dev;
        netdev_tracker_alloc(dev, &xso->dev_tracker, GFP_ATOMIC);
        xso->real_dev = dev;
                netdev_put(dev, &xso->dev_tracker);
                xso->type = XFRM_DEV_OFFLOAD_UNSPECIFIED;
 
+               xfrm_unset_type_offload(x);
                /* User explicitly requested packet offload mode and configured
                 * policy in addition to the XFRM state. So be civil to users,
                 * and return an error instead of taking fallback path.
 
 }
 EXPORT_SYMBOL(xfrm_unregister_type_offload);
 
-static const struct xfrm_type_offload *
-xfrm_get_type_offload(u8 proto, unsigned short family, bool try_load)
+void xfrm_set_type_offload(struct xfrm_state *x)
 {
        const struct xfrm_type_offload *type = NULL;
        struct xfrm_state_afinfo *afinfo;
+       bool try_load = true;
 
 retry:
-       afinfo = xfrm_state_get_afinfo(family);
+       afinfo = xfrm_state_get_afinfo(x->props.family);
        if (unlikely(afinfo == NULL))
-               return NULL;
+               goto out;
 
-       switch (proto) {
+       switch (x->id.proto) {
        case IPPROTO_ESP:
                type = afinfo->type_offload_esp;
                break;
        rcu_read_unlock();
 
        if (!type && try_load) {
-               request_module("xfrm-offload-%d-%d", family, proto);
+               request_module("xfrm-offload-%d-%d", x->props.family,
+                              x->id.proto);
                try_load = false;
                goto retry;
        }
 
-       return type;
-}
-
-static void xfrm_put_type_offload(const struct xfrm_type_offload *type)
-{
-       module_put(type->owner);
+out:
+       x->type_offload = type;
 }
+EXPORT_SYMBOL(xfrm_set_type_offload);
 
 static const struct xfrm_mode xfrm4_mode_map[XFRM_MODE_MAX] = {
        [XFRM_MODE_BEET] = {
        kfree(x->coaddr);
        kfree(x->replay_esn);
        kfree(x->preplay_esn);
-       if (x->type_offload)
-               xfrm_put_type_offload(x->type_offload);
        if (x->type) {
                x->type->destructor(x);
                xfrm_put_type(x->type);
        struct xfrm_dev_offload *xso = &x->xso;
        struct net_device *dev = READ_ONCE(xso->dev);
 
+       xfrm_unset_type_offload(x);
+
        if (dev && dev->xfrmdev_ops) {
                spin_lock_bh(&xfrm_state_dev_gc_lock);
                if (!hlist_unhashed(&x->dev_gclist))
 }
 EXPORT_SYMBOL_GPL(xfrm_state_mtu);
 
-int __xfrm_init_state(struct xfrm_state *x, bool init_replay, bool offload,
+int __xfrm_init_state(struct xfrm_state *x, bool init_replay,
                      struct netlink_ext_ack *extack)
 {
        const struct xfrm_mode *inner_mode;
                goto error;
        }
 
-       x->type_offload = xfrm_get_type_offload(x->id.proto, family, offload);
-
        err = x->type->init_state(x, extack);
        if (err)
                goto error;
 {
        int err;
 
-       err = __xfrm_init_state(x, true, false, NULL);
+       err = __xfrm_init_state(x, true, NULL);
        if (!err)
                x->km.state = XFRM_STATE_VALID;