#include "netdevsim.h"
 
+static u32 nsim_dev_id;
+
 struct nsim_vf_config {
        int link_state;
        u16 min_tx_rate;
        bool rss_query_enabled;
 };
 
-static u32 nsim_dev_id;
-
 static struct dentry *nsim_ddir;
-static struct dentry *nsim_sdev_ddir;
 
 static int nsim_num_vf(struct device *dev)
 {
 
 static int nsim_init(struct net_device *dev)
 {
-       char sdev_ddir_name[10], sdev_link_name[32];
        struct netdevsim *ns = netdev_priv(dev);
+       char sdev_link_name[32];
        int err;
 
        ns->netdev = dev;
        if (IS_ERR_OR_NULL(ns->ddir))
                return -ENOMEM;
 
-       if (!ns->sdev) {
-               ns->sdev = kzalloc(sizeof(*ns->sdev), GFP_KERNEL);
-               if (!ns->sdev) {
-                       err = -ENOMEM;
-                       goto err_debugfs_destroy;
-               }
-               ns->sdev->refcnt = 1;
-               ns->sdev->switch_id = nsim_dev_id;
-               sprintf(sdev_ddir_name, "%u", ns->sdev->switch_id);
-               ns->sdev->ddir = debugfs_create_dir(sdev_ddir_name,
-                                                   nsim_sdev_ddir);
-               if (IS_ERR_OR_NULL(ns->sdev->ddir)) {
-                       err = PTR_ERR_OR_ZERO(ns->sdev->ddir) ?: -EINVAL;
-                       goto err_sdev_free;
-               }
-       } else {
-               sprintf(sdev_ddir_name, "%u", ns->sdev->switch_id);
-               ns->sdev->refcnt++;
-       }
-
-       sprintf(sdev_link_name, "../../" DRV_NAME "_sdev/%s", sdev_ddir_name);
+       sprintf(sdev_link_name, "../../" DRV_NAME "_sdev/%u",
+               ns->sdev->switch_id);
        debugfs_create_symlink("sdev", ns->ddir, sdev_link_name);
 
        err = nsim_bpf_init(ns);
        if (err)
-               goto err_sdev_destroy;
+               goto err_debugfs_destroy;
 
        ns->dev.id = nsim_dev_id++;
        ns->dev.bus = &nsim_bus;
        device_unregister(&ns->dev);
 err_bpf_uninit:
        nsim_bpf_uninit(ns);
-err_sdev_destroy:
-       if (!--ns->sdev->refcnt) {
-               debugfs_remove_recursive(ns->sdev->ddir);
-err_sdev_free:
-               kfree(ns->sdev);
-       }
 err_debugfs_destroy:
        debugfs_remove_recursive(ns->ddir);
        return err;
        nsim_devlink_teardown(ns);
        debugfs_remove_recursive(ns->ddir);
        nsim_bpf_uninit(ns);
-       if (!--ns->sdev->refcnt) {
-               debugfs_remove_recursive(ns->sdev->ddir);
-               kfree(ns->sdev);
-       }
 }
 
 static void nsim_free(struct net_device *dev)
 
        device_unregister(&ns->dev);
        /* netdev and vf state will be freed out of device_release() */
+       nsim_sdev_put(ns->sdev);
 }
 
 static netdev_tx_t nsim_start_xmit(struct sk_buff *skb, struct net_device *dev)
                        struct netlink_ext_ack *extack)
 {
        struct netdevsim *ns = netdev_priv(dev);
+       struct netdevsim *joinns = NULL;
+       int err;
 
        if (tb[IFLA_LINK]) {
                struct net_device *joindev;
-               struct netdevsim *joinns;
 
                joindev = __dev_get_by_index(src_net,
                                             nla_get_u32(tb[IFLA_LINK]));
                        return -EINVAL;
 
                joinns = netdev_priv(joindev);
-               if (!joinns->sdev || !joinns->sdev->refcnt)
-                       return -EINVAL;
-               ns->sdev = joinns->sdev;
        }
 
-       return register_netdevice(dev);
+       ns->sdev = nsim_sdev_get(joinns);
+       if (IS_ERR(ns->sdev))
+               return PTR_ERR(ns->sdev);
+
+       err = register_netdevice(dev);
+       if (err)
+               goto err_sdev_put;
+       return 0;
+
+err_sdev_put:
+       nsim_sdev_put(ns->sdev);
+       return err;
 }
 
 static struct rtnl_link_ops nsim_link_ops __read_mostly = {
        if (IS_ERR_OR_NULL(nsim_ddir))
                return -ENOMEM;
 
-       nsim_sdev_ddir = debugfs_create_dir(DRV_NAME "_sdev", NULL);
-       if (IS_ERR_OR_NULL(nsim_sdev_ddir)) {
-               err = -ENOMEM;
+       err = nsim_sdev_init();
+       if (err)
                goto err_debugfs_destroy;
-       }
 
        err = bus_register(&nsim_bus);
        if (err)
-               goto err_sdir_destroy;
+               goto err_sdev_exit;
 
        err = nsim_devlink_init();
        if (err)
        nsim_devlink_exit();
 err_unreg_bus:
        bus_unregister(&nsim_bus);
-err_sdir_destroy:
-       debugfs_remove_recursive(nsim_sdev_ddir);
+err_sdev_exit:
+       nsim_sdev_exit();
 err_debugfs_destroy:
        debugfs_remove_recursive(nsim_ddir);
        return err;
        rtnl_link_unregister(&nsim_link_ops);
        nsim_devlink_exit();
        bus_unregister(&nsim_bus);
-       debugfs_remove_recursive(nsim_sdev_ddir);
+       nsim_sdev_exit();
        debugfs_remove_recursive(nsim_ddir);
 }
 
 
--- /dev/null
+// SPDX-License-Identifier: GPL-2.0
+/* Copyright (c) 2019 Mellanox Technologies. All rights reserved */
+
+#include <linux/debugfs.h>
+#include <linux/err.h>
+#include <linux/kernel.h>
+#include <linux/slab.h>
+
+#include "netdevsim.h"
+
+static struct dentry *nsim_sdev_ddir;
+
+static u32 nsim_sdev_id;
+
+struct netdevsim_shared_dev *nsim_sdev_get(struct netdevsim *joinns)
+{
+       struct netdevsim_shared_dev *sdev;
+       char sdev_ddir_name[10];
+       int err;
+
+       if (joinns) {
+               if (WARN_ON(!joinns->sdev))
+                       return ERR_PTR(-EINVAL);
+               sdev = joinns->sdev;
+               sdev->refcnt++;
+               return sdev;
+       }
+
+       sdev = kzalloc(sizeof(*sdev), GFP_KERNEL);
+       if (!sdev)
+               return ERR_PTR(-ENOMEM);
+       sdev->refcnt = 1;
+       sdev->switch_id = nsim_sdev_id++;
+
+       sprintf(sdev_ddir_name, "%u", sdev->switch_id);
+       sdev->ddir = debugfs_create_dir(sdev_ddir_name, nsim_sdev_ddir);
+       if (IS_ERR_OR_NULL(sdev->ddir)) {
+               err = PTR_ERR_OR_ZERO(sdev->ddir) ?: -EINVAL;
+               goto err_sdev_free;
+       }
+
+       return sdev;
+
+err_sdev_free:
+       nsim_sdev_id--;
+       kfree(sdev);
+       return ERR_PTR(err);
+}
+
+void nsim_sdev_put(struct netdevsim_shared_dev *sdev)
+{
+       if (--sdev->refcnt)
+               return;
+       debugfs_remove_recursive(sdev->ddir);
+       kfree(sdev);
+}
+
+int nsim_sdev_init(void)
+{
+       nsim_sdev_ddir = debugfs_create_dir(DRV_NAME "_sdev", NULL);
+       if (IS_ERR_OR_NULL(nsim_sdev_ddir))
+               return -ENOMEM;
+       return 0;
+}
+
+void nsim_sdev_exit(void)
+{
+       debugfs_remove_recursive(nsim_sdev_ddir);
+}