return NULL;
 }
 
-static int
-update_notifier_cb(struct ffa_device *dev, int notify_id, void *cb,
-                  void *cb_data, bool is_registration, bool is_framework)
+static int update_notifier_cb(struct ffa_device *dev, int notify_id,
+                             struct notifier_cb_info *cb, bool is_framework)
 {
        struct notifier_cb_info *cb_info = NULL;
        enum notify_type type = ffa_notify_type_get(dev->vm_id);
-       bool cb_found;
+       bool cb_found, is_registration = !!cb;
 
        if (is_framework)
                cb_info = notifier_hnode_get_by_vmid_uuid(notify_id, dev->vm_id,
                return -EINVAL;
 
        if (is_registration) {
-               cb_info = kzalloc(sizeof(*cb_info), GFP_KERNEL);
-               if (!cb_info)
-                       return -ENOMEM;
-
-               cb_info->dev = dev;
-               cb_info->cb_data = cb_data;
-               if (is_framework)
-                       cb_info->fwk_cb = cb;
-               else
-                       cb_info->cb = cb;
-
-               hash_add(drv_info->notifier_hash, &cb_info->hnode, notify_id);
+               hash_add(drv_info->notifier_hash, &cb->hnode, notify_id);
        } else {
                hash_del(&cb_info->hnode);
                kfree(cb_info);
 
        mutex_lock(&drv_info->notify_lock);
 
-       rc = update_notifier_cb(dev, notify_id, NULL, NULL, false,
-                               is_framework);
+       rc = update_notifier_cb(dev, notify_id, NULL, is_framework);
        if (rc) {
                pr_err("Could not unregister notification callback\n");
                mutex_unlock(&drv_info->notify_lock);
 {
        int rc;
        u32 flags = 0;
+       struct notifier_cb_info *cb_info = NULL;
 
        if (ffa_notifications_disabled())
                return -EOPNOTSUPP;
        if (notify_id >= FFA_MAX_NOTIFICATIONS)
                return -EINVAL;
 
+       cb_info = kzalloc(sizeof(*cb_info), GFP_KERNEL);
+       if (!cb_info)
+               return -ENOMEM;
+
+       cb_info->dev = dev;
+       cb_info->cb_data = cb_data;
+       if (is_framework)
+               cb_info->fwk_cb = cb;
+       else
+               cb_info->cb = cb;
+
        mutex_lock(&drv_info->notify_lock);
 
        if (!is_framework) {
                        flags = PER_VCPU_NOTIFICATION_FLAG;
 
                rc = ffa_notification_bind(dev->vm_id, BIT(notify_id), flags);
-               if (rc) {
-                       mutex_unlock(&drv_info->notify_lock);
-                       return rc;
-               }
+               if (rc)
+                       goto out_unlock_free;
        }
 
-       rc = update_notifier_cb(dev, notify_id, cb, cb_data, true,
-                               is_framework);
+       rc = update_notifier_cb(dev, notify_id, cb_info, is_framework);
        if (rc) {
                pr_err("Failed to register callback for %d - %d\n",
                       notify_id, rc);
                if (!is_framework)
                        ffa_notification_unbind(dev->vm_id, BIT(notify_id));
        }
+
+out_unlock_free:
        mutex_unlock(&drv_info->notify_lock);
+       if (rc)
+               kfree(cb_info);
 
        return rc;
 }