static int psp_ta_invoke(struct psp_context *psp,
                  uint32_t ta_cmd_id,
-                 uint32_t session_id)
+                 struct ta_context *context)
 {
        int ret;
        struct psp_gfx_cmd_resp *cmd = acquire_psp_cmd_buf(psp);
 
-       psp_prep_ta_invoke_cmd_buf(cmd, ta_cmd_id, session_id);
+       psp_prep_ta_invoke_cmd_buf(cmd, ta_cmd_id, context->session_id);
 
        ret = psp_cmd_submit_buf(psp, NULL, cmd,
                                 psp->fence_buf_mc_addr);
 
 int psp_xgmi_invoke(struct psp_context *psp, uint32_t ta_cmd_id)
 {
-       return psp_ta_invoke(psp, ta_cmd_id, psp->xgmi_context.context.session_id);
+       return psp_ta_invoke(psp, ta_cmd_id, &psp->xgmi_context.context);
 }
 
 int psp_xgmi_terminate(struct psp_context *psp)
        if (amdgpu_sriov_vf(psp->adev))
                return 0;
 
-       ret = psp_ta_invoke(psp, ta_cmd_id, psp->ras_context.context.session_id);
+       ret = psp_ta_invoke(psp, ta_cmd_id, &psp->ras_context.context);
 
        if (amdgpu_ras_intr_triggered())
                return ret;
        if (amdgpu_sriov_vf(psp->adev))
                return 0;
 
-       return psp_ta_invoke(psp, ta_cmd_id, psp->hdcp_context.context.session_id);
+       return psp_ta_invoke(psp, ta_cmd_id, &psp->hdcp_context.context);
 }
 
 static int psp_hdcp_terminate(struct psp_context *psp)
        if (amdgpu_sriov_vf(psp->adev))
                return 0;
 
-       return psp_ta_invoke(psp, ta_cmd_id, psp->dtm_context.context.session_id);
+       return psp_ta_invoke(psp, ta_cmd_id, &psp->dtm_context.context);
 }
 
 static int psp_dtm_terminate(struct psp_context *psp)
        rap_cmd->cmd_id = ta_cmd_id;
        rap_cmd->validation_method_id = METHOD_A;
 
-       ret = psp_ta_invoke(psp, rap_cmd->cmd_id, psp->rap_context.context.session_id);
+       ret = psp_ta_invoke(psp, rap_cmd->cmd_id, &psp->rap_context.context);
        if (ret)
                goto out_unlock;
 
 
        mutex_lock(&psp->securedisplay_context.mutex);
 
-       ret = psp_ta_invoke(psp, ta_cmd_id, psp->securedisplay_context.context.session_id);
+       ret = psp_ta_invoke(psp, ta_cmd_id, &psp->securedisplay_context.context);
 
        mutex_unlock(&psp->securedisplay_context.mutex);