}
 }
 
+static void mlx4_restart_one_down(struct pci_dev *pdev);
+static int mlx4_restart_one_up(struct pci_dev *pdev, bool reload,
+                              struct devlink *devlink);
+
 static int mlx4_devlink_reload(struct devlink *devlink,
                               struct netlink_ext_ack *extack)
 {
 
        if (persist->num_vfs)
                mlx4_warn(persist->dev, "Reload performed on PF, will cause reset on operating Virtual Functions\n");
-       err = mlx4_restart_one(persist->pdev, true, devlink);
+       mlx4_restart_one_down(persist->pdev);
+       err = mlx4_restart_one_up(persist->pdev, true, devlink);
        if (err)
-               mlx4_err(persist->dev, "mlx4_restart_one failed, ret=%d\n", err);
+               mlx4_err(persist->dev, "mlx4_restart_one_up failed, ret=%d\n",
+                        err);
 
        return err;
 }
        return err;
 }
 
-int mlx4_restart_one(struct pci_dev *pdev, bool reload, struct devlink *devlink)
+static void mlx4_restart_one_down(struct pci_dev *pdev)
+{
+       mlx4_unload_one(pdev);
+}
+
+static int mlx4_restart_one_up(struct pci_dev *pdev, bool reload,
+                              struct devlink *devlink)
 {
        struct mlx4_dev_persistent *persist = pci_get_drvdata(pdev);
        struct mlx4_dev  *dev  = persist->dev;
        total_vfs = dev->persist->num_vfs;
        memcpy(nvfs, dev->persist->nvfs, sizeof(dev->persist->nvfs));
 
-       mlx4_unload_one(pdev);
        if (reload)
                mlx4_devlink_param_load_driverinit_values(devlink);
        err = mlx4_load_one(pdev, pci_dev_data, total_vfs, nvfs, priv, 1);
        return err;
 }
 
+int mlx4_restart_one(struct pci_dev *pdev)
+{
+       mlx4_restart_one_down(pdev);
+       return mlx4_restart_one_up(pdev, false, NULL);
+}
+
 #define MLX_SP(id) { PCI_VDEVICE(MELLANOX, id), MLX4_PCI_DEV_FORCE_SENSE_PORT }
 #define MLX_VF(id) { PCI_VDEVICE(MELLANOX, id), MLX4_PCI_DEV_IS_VF }
 #define MLX_GN(id) { PCI_VDEVICE(MELLANOX, id), 0 }
 
 void mlx4_catas_end(struct mlx4_dev *dev);
 int mlx4_crdump_init(struct mlx4_dev *dev);
 void mlx4_crdump_end(struct mlx4_dev *dev);
-int mlx4_restart_one(struct pci_dev *pdev, bool reload,
-                    struct devlink *devlink);
+int mlx4_restart_one(struct pci_dev *pdev);
 int mlx4_register_device(struct mlx4_dev *dev);
 void mlx4_unregister_device(struct mlx4_dev *dev);
 void mlx4_dispatch_event(struct mlx4_dev *dev, enum mlx4_dev_event type,