struct qcom_spmi_pmic pmic;
 };
 
+static DEFINE_MUTEX(pmic_spmi_revid_lock);
+
 #define N_USIDS(n)             ((void *)n)
 
 static const struct of_device_id pmic_spmi_id_table[] = {
  *
  * This only supports PMICs with 1 or 2 USIDs.
  */
-static struct spmi_device *qcom_pmic_get_base_usid(struct device *dev)
+static struct spmi_device *qcom_pmic_get_base_usid(struct spmi_device *sdev, struct qcom_spmi_dev *ctx)
 {
-       struct spmi_device *sdev;
-       struct qcom_spmi_dev *ctx;
        struct device_node *spmi_bus;
        struct device_node *child;
        int function_parent_usid, ret;
        u32 pmic_addr;
 
-       sdev = to_spmi_device(dev);
-       ctx = dev_get_drvdata(&sdev->dev);
-
        /*
         * Quick return if the function device is already in the base
         * USID. This will always be hit for PMICs with only 1 USID.
         */
-       if (sdev->usid % ctx->num_usids == 0)
+       if (sdev->usid % ctx->num_usids == 0) {
+               get_device(&sdev->dev);
                return sdev;
+       }
 
        function_parent_usid = sdev->usid;
 
                        sdev = spmi_device_from_of(child);
                        if (!sdev) {
                                /*
-                                * If the base USID for this PMIC hasn't probed yet
-                                * but the secondary USID has, then we need to defer
-                                * the function driver so that it will attempt to
-                                * probe again when the base USID is ready.
+                                * If the base USID for this PMIC hasn't been
+                                * registered yet then we need to defer.
                                 */
                                sdev = ERR_PTR(-EPROBE_DEFER);
                        }
        return sdev;
 }
 
+static int pmic_spmi_get_base_revid(struct spmi_device *sdev, struct qcom_spmi_dev *ctx)
+{
+       struct qcom_spmi_dev *base_ctx;
+       struct spmi_device *base;
+       int ret = 0;
+
+       base = qcom_pmic_get_base_usid(sdev, ctx);
+       if (IS_ERR(base))
+               return PTR_ERR(base);
+
+       /*
+        * Copy revid info from base device if it has probed and is still
+        * bound to its driver.
+        */
+       mutex_lock(&pmic_spmi_revid_lock);
+       base_ctx = spmi_device_get_drvdata(base);
+       if (!base_ctx) {
+               ret = -EPROBE_DEFER;
+               goto out_unlock;
+       }
+       memcpy(&ctx->pmic, &base_ctx->pmic, sizeof(ctx->pmic));
+out_unlock:
+       mutex_unlock(&pmic_spmi_revid_lock);
+
+       put_device(&base->dev);
+
+       return ret;
+}
+
 static int pmic_spmi_load_revid(struct regmap *map, struct device *dev,
                                 struct qcom_spmi_pmic *pmic)
 {
        if (!of_match_device(pmic_spmi_id_table, dev->parent))
                return ERR_PTR(-EINVAL);
 
-       sdev = qcom_pmic_get_base_usid(dev->parent);
-
-       if (IS_ERR(sdev))
-               return ERR_CAST(sdev);
-
+       sdev = to_spmi_device(dev->parent);
        spmi = dev_get_drvdata(&sdev->dev);
 
        return &spmi->pmic;
                ret = pmic_spmi_load_revid(regmap, &sdev->dev, &ctx->pmic);
                if (ret < 0)
                        return ret;
+       } else {
+               ret = pmic_spmi_get_base_revid(sdev, ctx);
+               if (ret)
+                       return ret;
        }
+
+       mutex_lock(&pmic_spmi_revid_lock);
        spmi_device_set_drvdata(sdev, ctx);
+       mutex_unlock(&pmic_spmi_revid_lock);
 
        return devm_of_platform_populate(&sdev->dev);
 }
 
+static void pmic_spmi_remove(struct spmi_device *sdev)
+{
+       mutex_lock(&pmic_spmi_revid_lock);
+       spmi_device_set_drvdata(sdev, NULL);
+       mutex_unlock(&pmic_spmi_revid_lock);
+}
+
 MODULE_DEVICE_TABLE(of, pmic_spmi_id_table);
 
 static struct spmi_driver pmic_spmi_driver = {
        .probe = pmic_spmi_probe,
+       .remove = pmic_spmi_remove,
        .driver = {
                .name = "pmic-spmi",
                .of_match_table = pmic_spmi_id_table,