return ret;
 }
 
+static DEFINE_MUTEX(ta_refcount_mutex);
+static struct list_head ta_list = LIST_HEAD_INIT(ta_list);
+
+static u32 get_ta_refcount(u32 ta_handle)
+{
+       struct amdtee_ta_data *ta_data;
+       u32 count = 0;
+
+       /* Caller must hold a mutex */
+       list_for_each_entry(ta_data, &ta_list, list_node)
+               if (ta_data->ta_handle == ta_handle)
+                       return ++ta_data->refcount;
+
+       ta_data = kzalloc(sizeof(*ta_data), GFP_KERNEL);
+       if (ta_data) {
+               ta_data->ta_handle = ta_handle;
+               ta_data->refcount = 1;
+               count = ta_data->refcount;
+               list_add(&ta_data->list_node, &ta_list);
+       }
+
+       return count;
+}
+
+static u32 put_ta_refcount(u32 ta_handle)
+{
+       struct amdtee_ta_data *ta_data;
+       u32 count = 0;
+
+       /* Caller must hold a mutex */
+       list_for_each_entry(ta_data, &ta_list, list_node)
+               if (ta_data->ta_handle == ta_handle) {
+                       count = --ta_data->refcount;
+                       if (count == 0) {
+                               list_del(&ta_data->list_node);
+                               kfree(ta_data);
+                               break;
+                       }
+               }
+
+       return count;
+}
+
 int handle_unload_ta(u32 ta_handle)
 {
        struct tee_cmd_unload_ta cmd = {0};
-       u32 status;
+       u32 status, count;
        int ret;
 
        if (!ta_handle)
                return -EINVAL;
 
+       mutex_lock(&ta_refcount_mutex);
+
+       count = put_ta_refcount(ta_handle);
+
+       if (count) {
+               pr_debug("unload ta: not unloading %u count %u\n",
+                        ta_handle, count);
+               ret = -EBUSY;
+               goto unlock;
+       }
+
        cmd.ta_handle = ta_handle;
 
        ret = psp_tee_process_cmd(TEE_CMD_ID_UNLOAD_TA, (void *)&cmd,
        if (!ret && status != 0) {
                pr_err("unload ta: status = 0x%x\n", status);
                ret = -EBUSY;
+       } else {
+               pr_debug("unloaded ta handle %u\n", ta_handle);
        }
 
+unlock:
+       mutex_unlock(&ta_refcount_mutex);
        return ret;
 }
 
 
 int handle_load_ta(void *data, u32 size, struct tee_ioctl_open_session_arg *arg)
 {
-       struct tee_cmd_load_ta cmd = {0};
+       struct tee_cmd_unload_ta unload_cmd = {};
+       struct tee_cmd_load_ta load_cmd = {};
        phys_addr_t blob;
        int ret;
 
                return -EINVAL;
        }
 
-       cmd.hi_addr = upper_32_bits(blob);
-       cmd.low_addr = lower_32_bits(blob);
-       cmd.size = size;
+       load_cmd.hi_addr = upper_32_bits(blob);
+       load_cmd.low_addr = lower_32_bits(blob);
+       load_cmd.size = size;
 
-       ret = psp_tee_process_cmd(TEE_CMD_ID_LOAD_TA, (void *)&cmd,
-                                 sizeof(cmd), &arg->ret);
+       mutex_lock(&ta_refcount_mutex);
+
+       ret = psp_tee_process_cmd(TEE_CMD_ID_LOAD_TA, (void *)&load_cmd,
+                                 sizeof(load_cmd), &arg->ret);
        if (ret) {
                arg->ret_origin = TEEC_ORIGIN_COMMS;
                arg->ret = TEEC_ERROR_COMMUNICATION;
-       } else {
-               set_session_id(cmd.ta_handle, 0, &arg->session);
+       } else if (arg->ret == TEEC_SUCCESS) {
+               ret = get_ta_refcount(load_cmd.ta_handle);
+               if (!ret) {
+                       arg->ret_origin = TEEC_ORIGIN_COMMS;
+                       arg->ret = TEEC_ERROR_OUT_OF_MEMORY;
+
+                       /* Unload the TA on error */
+                       unload_cmd.ta_handle = load_cmd.ta_handle;
+                       psp_tee_process_cmd(TEE_CMD_ID_UNLOAD_TA,
+                                           (void *)&unload_cmd,
+                                           sizeof(unload_cmd), &ret);
+               } else {
+                       set_session_id(load_cmd.ta_handle, 0, &arg->session);
+               }
        }
+       mutex_unlock(&ta_refcount_mutex);
 
        pr_debug("load TA: TA handle = 0x%x, RO = 0x%x, ret = 0x%x\n",
-                cmd.ta_handle, arg->ret_origin, arg->ret);
+                load_cmd.ta_handle, arg->ret_origin, arg->ret);
 
        return 0;
 }
 
                        continue;
 
                handle_close_session(sess->ta_handle, sess->session_info[i]);
+               handle_unload_ta(sess->ta_handle);
        }
 
-       /* Unload Trusted Application once all sessions are closed */
-       handle_unload_ta(sess->ta_handle);
        kfree(sess);
 }
 
        struct amdtee_session *sess = container_of(ref, struct amdtee_session,
                                                   refcount);
 
-       /* Unload the TA from TEE */
-       handle_unload_ta(sess->ta_handle);
        mutex_lock(&session_list_mutex);
        list_del(&sess->list_node);
        mutex_unlock(&session_list_mutex);
 {
        struct amdtee_context_data *ctxdata = ctx->data;
        struct amdtee_session *sess = NULL;
-       u32 session_info;
+       u32 session_info, ta_handle;
        size_t ta_size;
        int rc, i;
        void *ta;
        if (arg->ret != TEEC_SUCCESS)
                goto out;
 
+       ta_handle = get_ta_handle(arg->session);
+
        mutex_lock(&session_list_mutex);
        sess = alloc_session(ctxdata, arg->session);
        mutex_unlock(&session_list_mutex);
 
        if (!sess) {
+               handle_unload_ta(ta_handle);
                rc = -ENOMEM;
                goto out;
        }
 
        if (i >= TEE_NUM_SESSIONS) {
                pr_err("reached maximum session count %d\n", TEE_NUM_SESSIONS);
+               handle_unload_ta(ta_handle);
                kref_put(&sess->refcount, destroy_session);
                rc = -ENOMEM;
                goto out;
                spin_lock(&sess->lock);
                clear_bit(i, sess->sess_mask);
                spin_unlock(&sess->lock);
+               handle_unload_ta(ta_handle);
                kref_put(&sess->refcount, destroy_session);
                goto out;
        }
 
        sess->session_info[i] = session_info;
-       set_session_id(sess->ta_handle, i, &arg->session);
+       set_session_id(ta_handle, i, &arg->session);
 out:
        free_pages((u64)ta, get_order(ta_size));
        return rc;
 
        /* Close the session */
        handle_close_session(ta_handle, session_info);
+       handle_unload_ta(ta_handle);
 
        kref_put(&sess->refcount, destroy_session);