* @dev: Device to add the callbacks to.
  * @ops: Set of callbacks to add.
  * @td: Timing data to add to the device along with the callbacks (optional).
+ *
+ * Every call to this routine should be balanced with a call to
+ * __pm_genpd_remove_callbacks() and they must not be nested.
  */
 int pm_genpd_add_callbacks(struct device *dev, struct gpd_dev_ops *ops,
                           struct gpd_timing_data *td)
 {
-       struct pm_domain_data *pdd;
+       struct generic_pm_domain_data *gpd_data_new, *gpd_data = NULL;
        int ret = 0;
 
-       if (!(dev && dev->power.subsys_data && ops))
+       if (!(dev && ops))
                return -EINVAL;
 
+       gpd_data_new = __pm_genpd_alloc_dev_data(dev);
+       if (!gpd_data_new)
+               return -ENOMEM;
+
        pm_runtime_disable(dev);
        device_pm_lock();
 
-       pdd = dev->power.subsys_data->domain_data;
-       if (pdd) {
-               struct generic_pm_domain_data *gpd_data = to_gpd_data(pdd);
+       ret = dev_pm_get_subsys_data(dev);
+       if (ret)
+               goto out;
+
+       spin_lock_irq(&dev->power.lock);
 
-               gpd_data->ops = *ops;
-               if (td)
-                       gpd_data->td = *td;
+       if (dev->power.subsys_data->domain_data) {
+               gpd_data = to_gpd_data(dev->power.subsys_data->domain_data);
        } else {
-               ret = -EINVAL;
+               gpd_data = gpd_data_new;
+               dev->power.subsys_data->domain_data = &gpd_data->base;
        }
+       gpd_data->refcount++;
+       gpd_data->ops = *ops;
+       if (td)
+               gpd_data->td = *td;
 
+       spin_unlock_irq(&dev->power.lock);
+
+ out:
        device_pm_unlock();
        pm_runtime_enable(dev);
 
+       if (gpd_data != gpd_data_new)
+               __pm_genpd_free_dev_data(dev, gpd_data_new);
+
        return ret;
 }
 EXPORT_SYMBOL_GPL(pm_genpd_add_callbacks);
  * __pm_genpd_remove_callbacks - Remove PM domain callbacks from a given device.
  * @dev: Device to remove the callbacks from.
  * @clear_td: If set, clear the device's timing data too.
+ *
+ * This routine can only be called after pm_genpd_add_callbacks().
  */
 int __pm_genpd_remove_callbacks(struct device *dev, bool clear_td)
 {
-       struct pm_domain_data *pdd;
+       struct generic_pm_domain_data *gpd_data = NULL;
+       bool remove = false;
        int ret = 0;
 
        if (!(dev && dev->power.subsys_data))
        pm_runtime_disable(dev);
        device_pm_lock();
 
-       pdd = dev->power.subsys_data->domain_data;
-       if (pdd) {
-               struct generic_pm_domain_data *gpd_data = to_gpd_data(pdd);
+       spin_lock_irq(&dev->power.lock);
 
+       if (dev->power.subsys_data->domain_data) {
+               gpd_data = to_gpd_data(dev->power.subsys_data->domain_data);
                gpd_data->ops = (struct gpd_dev_ops){ 0 };
                if (clear_td)
                        gpd_data->td = (struct gpd_timing_data){ 0 };
+
+               if (--gpd_data->refcount == 0) {
+                       dev->power.subsys_data->domain_data = NULL;
+                       remove = true;
+               }
        } else {
                ret = -EINVAL;
        }
 
+       spin_unlock_irq(&dev->power.lock);
+
        device_pm_unlock();
        pm_runtime_enable(dev);
 
-       return ret;
+       if (ret)
+               return ret;
+
+       dev_pm_put_subsys_data(dev);
+       if (remove)
+               __pm_genpd_free_dev_data(dev, gpd_data);
+
+       return 0;
 }
 EXPORT_SYMBOL_GPL(__pm_genpd_remove_callbacks);