// SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause
 /* Copyright (c) 2010-2012 Broadcom. All rights reserved. */
 
+#include <linux/kref.h>
+#include <linux/rcupdate.h>
+
 #include "vchiq_core.h"
 
 #define VCHIQ_SLOT_HANDLER_STACK 8192
 int vchiq_core_msg_log_level = VCHIQ_LOG_DEFAULT;
 int vchiq_sync_log_level = VCHIQ_LOG_DEFAULT;
 
-static DEFINE_SPINLOCK(service_spinlock);
 DEFINE_SPINLOCK(bulk_waiter_spinlock);
 static DEFINE_SPINLOCK(quota_spinlock);
 
 {
        struct vchiq_service *service;
 
-       spin_lock(&service_spinlock);
+       rcu_read_lock();
        service = handle_to_service(handle);
        if (service && service->srvstate != VCHIQ_SRVSTATE_FREE &&
-           service->handle == handle) {
-               WARN_ON(service->ref_count == 0);
-               service->ref_count++;
-       } else
-               service = NULL;
-       spin_unlock(&service_spinlock);
-
-       if (!service)
-               vchiq_log_info(vchiq_core_log_level,
-                       "Invalid service handle 0x%x", handle);
-
-       return service;
+           service->handle == handle &&
+           kref_get_unless_zero(&service->ref_count)) {
+               service = rcu_pointer_handoff(service);
+               rcu_read_unlock();
+               return service;
+       }
+       rcu_read_unlock();
+       vchiq_log_info(vchiq_core_log_level,
+                      "Invalid service handle 0x%x", handle);
+       return NULL;
 }
 
 struct vchiq_service *
 find_service_by_port(struct vchiq_state *state, int localport)
 {
-       struct vchiq_service *service = NULL;
 
        if ((unsigned int)localport <= VCHIQ_PORT_MAX) {
-               spin_lock(&service_spinlock);
-               service = state->services[localport];
-               if (service && service->srvstate != VCHIQ_SRVSTATE_FREE) {
-                       WARN_ON(service->ref_count == 0);
-                       service->ref_count++;
-               } else
-                       service = NULL;
-               spin_unlock(&service_spinlock);
-       }
-
-       if (!service)
-               vchiq_log_info(vchiq_core_log_level,
-                       "Invalid port %d", localport);
+               struct vchiq_service *service;
 
-       return service;
+               rcu_read_lock();
+               service = rcu_dereference(state->services[localport]);
+               if (service && service->srvstate != VCHIQ_SRVSTATE_FREE &&
+                   kref_get_unless_zero(&service->ref_count)) {
+                       service = rcu_pointer_handoff(service);
+                       rcu_read_unlock();
+                       return service;
+               }
+               rcu_read_unlock();
+       }
+       vchiq_log_info(vchiq_core_log_level,
+                      "Invalid port %d", localport);
+       return NULL;
 }
 
 struct vchiq_service *
 {
        struct vchiq_service *service;
 
-       spin_lock(&service_spinlock);
+       rcu_read_lock();
        service = handle_to_service(handle);
        if (service && service->srvstate != VCHIQ_SRVSTATE_FREE &&
            service->handle == handle &&
-           service->instance == instance) {
-               WARN_ON(service->ref_count == 0);
-               service->ref_count++;
-       } else
-               service = NULL;
-       spin_unlock(&service_spinlock);
-
-       if (!service)
-               vchiq_log_info(vchiq_core_log_level,
-                       "Invalid service handle 0x%x", handle);
-
-       return service;
+           service->instance == instance &&
+           kref_get_unless_zero(&service->ref_count)) {
+               service = rcu_pointer_handoff(service);
+               rcu_read_unlock();
+               return service;
+       }
+       rcu_read_unlock();
+       vchiq_log_info(vchiq_core_log_level,
+                      "Invalid service handle 0x%x", handle);
+       return NULL;
 }
 
 struct vchiq_service *
 {
        struct vchiq_service *service;
 
-       spin_lock(&service_spinlock);
+       rcu_read_lock();
        service = handle_to_service(handle);
        if (service &&
            (service->srvstate == VCHIQ_SRVSTATE_FREE ||
             service->srvstate == VCHIQ_SRVSTATE_CLOSED) &&
            service->handle == handle &&
-           service->instance == instance) {
-               WARN_ON(service->ref_count == 0);
-               service->ref_count++;
-       } else
-               service = NULL;
-       spin_unlock(&service_spinlock);
-
-       if (!service)
-               vchiq_log_info(vchiq_core_log_level,
-                       "Invalid service handle 0x%x", handle);
-
+           service->instance == instance &&
+           kref_get_unless_zero(&service->ref_count)) {
+               service = rcu_pointer_handoff(service);
+               rcu_read_unlock();
+               return service;
+       }
+       rcu_read_unlock();
+       vchiq_log_info(vchiq_core_log_level,
+                      "Invalid service handle 0x%x", handle);
        return service;
 }
 
        struct vchiq_service *service = NULL;
        int idx = *pidx;
 
-       spin_lock(&service_spinlock);
+       rcu_read_lock();
        while (idx < state->unused_service) {
-               struct vchiq_service *srv = state->services[idx++];
+               struct vchiq_service *srv;
 
+               srv = rcu_dereference(state->services[idx++]);
                if (srv && srv->srvstate != VCHIQ_SRVSTATE_FREE &&
-                   srv->instance == instance) {
-                       service = srv;
-                       WARN_ON(service->ref_count == 0);
-                       service->ref_count++;
+                   srv->instance == instance &&
+                   kref_get_unless_zero(&srv->ref_count)) {
+                       service = rcu_pointer_handoff(srv);
                        break;
                }
        }
-       spin_unlock(&service_spinlock);
+       rcu_read_unlock();
 
        *pidx = idx;
 
 void
 lock_service(struct vchiq_service *service)
 {
-       spin_lock(&service_spinlock);
-       WARN_ON(!service);
-       if (service) {
-               WARN_ON(service->ref_count == 0);
-               service->ref_count++;
+       if (!service) {
+               WARN(1, "%s service is NULL\n", __func__);
+               return;
        }
-       spin_unlock(&service_spinlock);
+       kref_get(&service->ref_count);
+}
+
+static void service_release(struct kref *kref)
+{
+       struct vchiq_service *service =
+               container_of(kref, struct vchiq_service, ref_count);
+       struct vchiq_state *state = service->state;
+
+       WARN_ON(service->srvstate != VCHIQ_SRVSTATE_FREE);
+       rcu_assign_pointer(state->services[service->localport], NULL);
+       if (service->userdata_term)
+               service->userdata_term(service->base.userdata);
+       kfree_rcu(service, rcu);
 }
 
 void
 unlock_service(struct vchiq_service *service)
 {
-       spin_lock(&service_spinlock);
        if (!service) {
                WARN(1, "%s: service is NULL\n", __func__);
-               goto unlock;
-       }
-       if (!service->ref_count) {
-               WARN(1, "%s: ref_count is zero\n", __func__);
-               goto unlock;
-       }
-       service->ref_count--;
-       if (!service->ref_count) {
-               struct vchiq_state *state = service->state;
-
-               WARN_ON(service->srvstate != VCHIQ_SRVSTATE_FREE);
-               state->services[service->localport] = NULL;
-       } else {
-               service = NULL;
+               return;
        }
-unlock:
-       spin_unlock(&service_spinlock);
-
-       if (service && service->userdata_term)
-               service->userdata_term(service->base.userdata);
-
-       kfree(service);
+       kref_put(&service->ref_count, service_release);
 }
 
 int
 void *
 vchiq_get_service_userdata(unsigned int handle)
 {
-       struct vchiq_service *service = handle_to_service(handle);
+       void *userdata;
+       struct vchiq_service *service;
 
-       return service ? service->base.userdata : NULL;
+       rcu_read_lock();
+       service = handle_to_service(handle);
+       userdata = service ? service->base.userdata : NULL;
+       rcu_read_unlock();
+       return userdata;
 }
 
 static void
 
        WARN_ON(fourcc == VCHIQ_FOURCC_INVALID);
 
+       rcu_read_lock();
        for (i = 0; i < state->unused_service; i++) {
-               struct vchiq_service *service = state->services[i];
+               struct vchiq_service *service;
 
+               service = rcu_dereference(state->services[i]);
                if (service &&
                    service->public_fourcc == fourcc &&
                    (service->srvstate == VCHIQ_SRVSTATE_LISTENING ||
                     (service->srvstate == VCHIQ_SRVSTATE_OPEN &&
-                     service->remoteport == VCHIQ_PORT_FREE))) {
-                       lock_service(service);
+                     service->remoteport == VCHIQ_PORT_FREE)) &&
+                   kref_get_unless_zero(&service->ref_count)) {
+                       service = rcu_pointer_handoff(service);
+                       rcu_read_unlock();
                        return service;
                }
        }
-
+       rcu_read_unlock();
        return NULL;
 }
 
 {
        int i;
 
+       rcu_read_lock();
        for (i = 0; i < state->unused_service; i++) {
-               struct vchiq_service *service = state->services[i];
+               struct vchiq_service *service =
+                       rcu_dereference(state->services[i]);
 
                if (service && service->srvstate == VCHIQ_SRVSTATE_OPEN &&
-                   service->remoteport == port) {
-                       lock_service(service);
+                   service->remoteport == port &&
+                   kref_get_unless_zero(&service->ref_count)) {
+                       service = rcu_pointer_handoff(service);
+                       rcu_read_unlock();
                        return service;
                }
        }
+       rcu_read_unlock();
        return NULL;
 }
 
                           vchiq_userdata_term userdata_term)
 {
        struct vchiq_service *service;
-       struct vchiq_service **pservice = NULL;
+       struct vchiq_service __rcu **pservice = NULL;
        struct vchiq_service_quota *service_quota;
        int i;
 
        service->base.callback = params->callback;
        service->base.userdata = params->userdata;
        service->handle        = VCHIQ_SERVICE_HANDLE_INVALID;
-       service->ref_count     = 1;
+       kref_init(&service->ref_count);
        service->srvstate      = VCHIQ_SRVSTATE_FREE;
        service->userdata_term = userdata_term;
        service->localport     = VCHIQ_PORT_FREE;
        mutex_init(&service->bulk_mutex);
        memset(&service->stats, 0, sizeof(service->stats));
 
-       /* Although it is perfectly possible to use service_spinlock
+       /* Although it is perfectly possible to use a spinlock
        ** to protect the creation of services, it is overkill as it
        ** disables interrupts while the array is searched.
        ** The only danger is of another thread trying to create a
 
        if (srvstate == VCHIQ_SRVSTATE_OPENING) {
                for (i = 0; i < state->unused_service; i++) {
-                       struct vchiq_service *srv = state->services[i];
-
-                       if (!srv) {
+                       if (!rcu_access_pointer(state->services[i])) {
                                pservice = &state->services[i];
                                break;
                        }
                }
        } else {
+               rcu_read_lock();
                for (i = (state->unused_service - 1); i >= 0; i--) {
-                       struct vchiq_service *srv = state->services[i];
+                       struct vchiq_service *srv;
 
+                       srv = rcu_dereference(state->services[i]);
                        if (!srv)
                                pservice = &state->services[i];
                        else if ((srv->public_fourcc == params->fourcc)
                                break;
                        }
                }
+               rcu_read_unlock();
        }
 
        if (pservice) {
                        (state->id * VCHIQ_MAX_SERVICES) |
                        service->localport;
                handle_seq += VCHIQ_MAX_STATES * VCHIQ_MAX_SERVICES;
-               *pservice = service;
+               rcu_assign_pointer(*pservice, service);
                if (pservice == &state->services[state->unused_service])
                        state->unused_service++;
        }
                           (service->srvstate != VCHIQ_SRVSTATE_OPENSYNC)) {
                        if (service->srvstate != VCHIQ_SRVSTATE_CLOSEWAIT)
                                vchiq_log_error(vchiq_core_log_level,
-                                               "%d: osi - srvstate = %s (ref %d)",
+                                               "%d: osi - srvstate = %s (ref %u)",
                                                service->state->id,
                                                srvstate_names[service->srvstate],
-                                               service->ref_count);
+                                               kref_read(&service->ref_count));
                        status = VCHIQ_ERROR;
                        VCHIQ_SERVICE_STATS_INC(service, error_count);
                        vchiq_release_service_internal(service);
        char buf[80];
        int len;
        int err;
+       unsigned int ref_count;
 
+       /*Don't include the lock just taken*/
+       ref_count = kref_read(&service->ref_count) - 1;
        len = scnprintf(buf, sizeof(buf), "Service %u: %s (ref %u)",
                        service->localport, srvstate_names[service->srvstate],
-                       service->ref_count - 1); /*Don't include the lock just taken*/
+                       ref_count);
 
        if (service->srvstate != VCHIQ_SRVSTATE_FREE) {
                char remoteport[30];