hba = shost_priv(sdev->host);
        scsi_deactivate_tcq(sdev, hba->nutrs);
        /* Drop the reference as it won't be needed anymore */
-       if (ufshcd_scsi_to_upiu_lun(sdev->lun) == UFS_UPIU_UFS_DEVICE_WLUN)
+       if (ufshcd_scsi_to_upiu_lun(sdev->lun) == UFS_UPIU_UFS_DEVICE_WLUN) {
+               unsigned long flags;
+
+               spin_lock_irqsave(hba->host->host_lock, flags);
                hba->sdev_ufs_device = NULL;
+               spin_unlock_irqrestore(hba->host->host_lock, flags);
+       }
 }
 
 /**
 static int ufshcd_scsi_add_wlus(struct ufs_hba *hba)
 {
        int ret = 0;
+       struct scsi_device *sdev_rpmb;
+       struct scsi_device *sdev_boot;
 
        hba->sdev_ufs_device = __scsi_add_device(hba->host, 0, 0,
                ufshcd_upiu_wlun_to_scsi_wlun(UFS_UPIU_UFS_DEVICE_WLUN), NULL);
                hba->sdev_ufs_device = NULL;
                goto out;
        }
+       scsi_device_put(hba->sdev_ufs_device);
 
-       hba->sdev_boot = __scsi_add_device(hba->host, 0, 0,
+       sdev_boot = __scsi_add_device(hba->host, 0, 0,
                ufshcd_upiu_wlun_to_scsi_wlun(UFS_UPIU_BOOT_WLUN), NULL);
-       if (IS_ERR(hba->sdev_boot)) {
-               ret = PTR_ERR(hba->sdev_boot);
-               hba->sdev_boot = NULL;
+       if (IS_ERR(sdev_boot)) {
+               ret = PTR_ERR(sdev_boot);
                goto remove_sdev_ufs_device;
        }
+       scsi_device_put(sdev_boot);
 
-       hba->sdev_rpmb = __scsi_add_device(hba->host, 0, 0,
+       sdev_rpmb = __scsi_add_device(hba->host, 0, 0,
                ufshcd_upiu_wlun_to_scsi_wlun(UFS_UPIU_RPMB_WLUN), NULL);
-       if (IS_ERR(hba->sdev_rpmb)) {
-               ret = PTR_ERR(hba->sdev_rpmb);
-               hba->sdev_rpmb = NULL;
+       if (IS_ERR(sdev_rpmb)) {
+               ret = PTR_ERR(sdev_rpmb);
                goto remove_sdev_boot;
        }
+       scsi_device_put(sdev_rpmb);
        goto out;
 
 remove_sdev_boot:
-       scsi_remove_device(hba->sdev_boot);
+       scsi_remove_device(sdev_boot);
 remove_sdev_ufs_device:
        scsi_remove_device(hba->sdev_ufs_device);
 out:
        return ret;
 }
 
-/**
- * ufshcd_scsi_remove_wlus - Removes the W-LUs which were added by
- *                          ufshcd_scsi_add_wlus()
- * @hba: per-adapter instance
- *
- */
-static void ufshcd_scsi_remove_wlus(struct ufs_hba *hba)
-{
-       if (hba->sdev_ufs_device) {
-               scsi_remove_device(hba->sdev_ufs_device);
-               hba->sdev_ufs_device = NULL;
-       }
-
-       if (hba->sdev_boot) {
-               scsi_remove_device(hba->sdev_boot);
-               hba->sdev_boot = NULL;
-       }
-
-       if (hba->sdev_rpmb) {
-               scsi_remove_device(hba->sdev_rpmb);
-               hba->sdev_rpmb = NULL;
-       }
-}
-
 /**
  * ufshcd_probe_hba - probe hba to detect device and initialize
  * @hba: per-adapter instance
 {
        unsigned char cmd[6] = { START_STOP };
        struct scsi_sense_hdr sshdr;
-       struct scsi_device *sdp = hba->sdev_ufs_device;
+       struct scsi_device *sdp;
+       unsigned long flags;
        int ret;
 
-       if (!sdp || !scsi_device_online(sdp))
-               return -ENODEV;
+       spin_lock_irqsave(hba->host->host_lock, flags);
+       sdp = hba->sdev_ufs_device;
+       if (sdp) {
+               ret = scsi_device_get(sdp);
+               if (!ret && !scsi_device_online(sdp)) {
+                       ret = -ENODEV;
+                       scsi_device_put(sdp);
+               }
+       } else {
+               ret = -ENODEV;
+       }
+       spin_unlock_irqrestore(hba->host->host_lock, flags);
+
+       if (ret)
+               return ret;
 
        /*
         * If scsi commands fail, the scsi mid-layer schedules scsi error-
        if (!ret)
                hba->curr_dev_pwr_mode = pwr_mode;
 out:
+       scsi_device_put(sdp);
        hba->host->eh_noresume = 0;
        return ret;
 }
 void ufshcd_remove(struct ufs_hba *hba)
 {
        scsi_remove_host(hba->host);
-       ufshcd_scsi_remove_wlus(hba);
        /* disable interrupts */
        ufshcd_disable_intr(hba, hba->intr_mask);
        ufshcd_hba_stop(hba);