#include <net/net_namespace.h>
 #include <net/flow_offload.h>
 #include <uapi/linux/devlink.h>
+#include <linux/xarray.h>
 
 struct devlink_ops;
 
        struct list_head resource_list;
        struct list_head param_list;
        struct list_head region_list;
-       u32 snapshot_id;
        struct list_head reporter_list;
        struct mutex reporters_lock; /* protects reporter_list */
        struct devlink_dpipe_headers *dpipe_headers;
        struct list_head trap_list;
        struct list_head trap_group_list;
        const struct devlink_ops *ops;
+       struct xarray snapshot_ids;
        struct device *dev;
        possible_net_t _net;
        struct mutex lock;
                      u32 region_max_snapshots, u64 region_size);
 void devlink_region_destroy(struct devlink_region *region);
 int devlink_region_snapshot_id_get(struct devlink *devlink, u32 *id);
+void devlink_region_snapshot_id_put(struct devlink *devlink, u32 id);
 int devlink_region_snapshot_create(struct devlink_region *region,
                                   u8 *data, u32 snapshot_id);
 int devlink_info_serial_number_put(struct devlink_info_req *req,
 
        nlmsg_free(msg);
 }
 
+/**
+ * __devlink_snapshot_id_increment - Increment number of snapshots using an id
+ *     @devlink: devlink instance
+ *     @id: the snapshot id
+ *
+ *     Track when a new snapshot begins using an id. Load the count for the
+ *     given id from the snapshot xarray, increment it, and store it back.
+ *
+ *     Called when a new snapshot is created with the given id.
+ *
+ *     The id *must* have been previously allocated by
+ *     devlink_region_snapshot_id_get().
+ *
+ *     Returns 0 on success, or an error on failure.
+ */
+static int __devlink_snapshot_id_increment(struct devlink *devlink, u32 id)
+{
+       unsigned long count;
+       void *p;
+
+       lockdep_assert_held(&devlink->lock);
+
+       p = xa_load(&devlink->snapshot_ids, id);
+       if (WARN_ON(!p))
+               return -EINVAL;
+
+       if (WARN_ON(!xa_is_value(p)))
+               return -EINVAL;
+
+       count = xa_to_value(p);
+       count++;
+
+       return xa_err(xa_store(&devlink->snapshot_ids, id, xa_mk_value(count),
+                              GFP_KERNEL));
+}
+
+/**
+ * __devlink_snapshot_id_decrement - Decrease number of snapshots using an id
+ *     @devlink: devlink instance
+ *     @id: the snapshot id
+ *
+ *     Track when a snapshot is deleted and stops using an id. Load the count
+ *     for the given id from the snapshot xarray, decrement it, and store it
+ *     back.
+ *
+ *     If the count reaches zero, erase this id from the xarray, freeing it
+ *     up for future re-use by devlink_region_snapshot_id_get().
+ *
+ *     Called when a snapshot using the given id is deleted, and when the
+ *     initial allocator of the id is finished using it.
+ */
+static void __devlink_snapshot_id_decrement(struct devlink *devlink, u32 id)
+{
+       unsigned long count;
+       void *p;
+
+       lockdep_assert_held(&devlink->lock);
+
+       p = xa_load(&devlink->snapshot_ids, id);
+       if (WARN_ON(!p))
+               return;
+
+       if (WARN_ON(!xa_is_value(p)))
+               return;
+
+       count = xa_to_value(p);
+
+       if (count > 1) {
+               count--;
+               xa_store(&devlink->snapshot_ids, id, xa_mk_value(count),
+                        GFP_KERNEL);
+       } else {
+               /* If this was the last user, we can erase this id */
+               xa_erase(&devlink->snapshot_ids, id);
+       }
+}
+
 /**
  *     __devlink_region_snapshot_id_get - get snapshot ID
  *     @devlink: devlink instance
  *     Allocates a new snapshot id. Returns zero on success, or a negative
  *     error on failure. Must be called while holding the devlink instance
  *     lock.
+ *
+ *     Snapshot IDs are tracked using an xarray which stores the number of
+ *     users of the snapshot id.
+ *
+ *     Note that the caller of this function counts as a 'user', in order to
+ *     avoid race conditions. The caller must release its hold on the
+ *     snapshot by using devlink_region_snapshot_id_put.
  */
 static int __devlink_region_snapshot_id_get(struct devlink *devlink, u32 *id)
 {
        lockdep_assert_held(&devlink->lock);
 
-       if (devlink->snapshot_id >= U32_MAX)
-               return -ENOSPC;
-
-       *id = ++devlink->snapshot_id;
-
-       return 0;
+       return xa_alloc(&devlink->snapshot_ids, id, xa_mk_value(1),
+                       xa_limit_32b, GFP_KERNEL);
 }
 
 /**
 {
        struct devlink *devlink = region->devlink;
        struct devlink_snapshot *snapshot;
+       int err;
 
        lockdep_assert_held(&devlink->lock);
 
        if (!snapshot)
                return -ENOMEM;
 
+       err = __devlink_snapshot_id_increment(devlink, snapshot_id);
+       if (err)
+               goto err_snapshot_id_increment;
+
        snapshot->id = snapshot_id;
        snapshot->region = region;
        snapshot->data = data;
 
        devlink_nl_region_notify(region, snapshot, DEVLINK_CMD_REGION_NEW);
        return 0;
+
+err_snapshot_id_increment:
+       kfree(snapshot);
+       return err;
 }
 
 static void devlink_region_snapshot_del(struct devlink_region *region,
                                        struct devlink_snapshot *snapshot)
 {
+       struct devlink *devlink = region->devlink;
+
+       lockdep_assert_held(&devlink->lock);
+
        devlink_nl_region_notify(region, snapshot, DEVLINK_CMD_REGION_DEL);
        region->cur_snapshots--;
        list_del(&snapshot->list);
        region->ops->destructor(snapshot->data);
+       __devlink_snapshot_id_decrement(devlink, snapshot->id);
        kfree(snapshot);
 }
 
        if (!devlink)
                return NULL;
        devlink->ops = ops;
+       xa_init_flags(&devlink->snapshot_ids, XA_FLAGS_ALLOC);
        __devlink_net_set(devlink, &init_net);
        INIT_LIST_HEAD(&devlink->port_list);
        INIT_LIST_HEAD(&devlink->sb_list);
        WARN_ON(!list_empty(&devlink->sb_list));
        WARN_ON(!list_empty(&devlink->port_list));
 
+       xa_destroy(&devlink->snapshot_ids);
+
        kfree(devlink);
 }
 EXPORT_SYMBOL_GPL(devlink_free);
  *     Driver should use the same id for multiple snapshots taken
  *     on multiple regions at the same time/by the same trigger.
  *
+ *     The caller of this function must use devlink_region_snapshot_id_put
+ *     when finished creating regions using this id.
+ *
  *     Returns zero on success, or a negative error code on failure.
  *
  *     @devlink: devlink
 }
 EXPORT_SYMBOL_GPL(devlink_region_snapshot_id_get);
 
+/**
+ *     devlink_region_snapshot_id_put - put snapshot ID reference
+ *
+ *     This should be called by a driver after finishing creating snapshots
+ *     with an id. Doing so ensures that the ID can later be released in the
+ *     event that all snapshots using it have been destroyed.
+ *
+ *     @devlink: devlink
+ *     @id: id to release reference on
+ */
+void devlink_region_snapshot_id_put(struct devlink *devlink, u32 id)
+{
+       mutex_lock(&devlink->lock);
+       __devlink_snapshot_id_decrement(devlink, id);
+       mutex_unlock(&devlink->lock);
+}
+EXPORT_SYMBOL_GPL(devlink_region_snapshot_id_put);
+
 /**
  *     devlink_region_snapshot_create - create a new snapshot
  *     This will add a new snapshot of a region. The snapshot