static int psp_load_smu_fw(struct psp_context *psp);
 static int psp_ta_unload(struct psp_context *psp, uint32_t session_id);
+static int psp_ta_load(struct psp_context *psp, struct ta_context *context);
 static int psp_rap_terminate(struct psp_context *psp);
 static int psp_securedisplay_terminate(struct psp_context *psp);
 
        return ret;
 }
 
-static void psp_prep_asd_load_cmd_buf(struct psp_gfx_cmd_resp *cmd,
-                               uint64_t asd_mc, uint32_t size)
+static int psp_asd_load(struct psp_context *psp)
 {
-       cmd->cmd_id = GFX_CMD_ID_LOAD_ASD;
-       cmd->cmd.cmd_load_ta.app_phy_addr_lo = lower_32_bits(asd_mc);
-       cmd->cmd.cmd_load_ta.app_phy_addr_hi = upper_32_bits(asd_mc);
-       cmd->cmd.cmd_load_ta.app_len = size;
-
-       cmd->cmd.cmd_load_ta.cmd_buf_phy_addr_lo = 0;
-       cmd->cmd.cmd_load_ta.cmd_buf_phy_addr_hi = 0;
-       cmd->cmd.cmd_load_ta.cmd_buf_len = 0;
+       return psp_ta_load(psp, &psp->asd_context);
 }
 
-static int psp_asd_load(struct psp_context *psp)
+static int psp_asd_initialize(struct psp_context *psp)
 {
        int ret;
-       struct psp_gfx_cmd_resp *cmd;
 
        /* If PSP version doesn't match ASD version, asd loading will be failed.
         * add workaround to bypass it for sriov now.
        if (amdgpu_sriov_vf(psp->adev) || !psp->asd_context.bin_desc.size_bytes)
                return 0;
 
-       cmd = acquire_psp_cmd_buf(psp);
+       psp->asd_context.mem_context.shared_mc_addr  = 0;
+       psp->asd_context.mem_context.shared_mem_size = PSP_ASD_SHARED_MEM_SIZE;
+       psp->asd_context.ta_load_type                = GFX_CMD_ID_LOAD_ASD;
 
-       psp_copy_fw(psp, psp->asd_context.bin_desc.start_addr,
-                   psp->asd_context.bin_desc.size_bytes);
-
-       psp_prep_asd_load_cmd_buf(cmd, psp->fw_pri_mc_addr,
-                                 psp->asd_context.bin_desc.size_bytes);
-
-       ret = psp_cmd_submit_buf(psp, NULL, cmd,
-                                psp->fence_buf_mc_addr);
-       if (!ret) {
-               psp->asd_context.asd_initialized = true;
-               psp->asd_context.session_id = cmd->resp.session_id;
-       }
-
-       release_psp_cmd_buf(psp);
+       ret = psp_asd_load(psp);
+       if (!ret)
+               psp->asd_context.initialized = true;
 
        return ret;
 }
        if (amdgpu_sriov_vf(psp->adev))
                return 0;
 
-       if (!psp->asd_context.asd_initialized)
+       if (!psp->asd_context.initialized)
                return 0;
 
        ret = psp_asd_unload(psp);
 
        if (!ret)
-               psp->asd_context.asd_initialized = false;
+               psp->asd_context.initialized = false;
 
        return ret;
 }
                                     uint64_t ta_bin_mc,
                                     struct ta_context *context)
 {
-       cmd->cmd_id                             = GFX_CMD_ID_LOAD_TA;
+       cmd->cmd_id                             = context->ta_load_type;
        cmd->cmd.cmd_load_ta.app_phy_addr_lo    = lower_32_bits(ta_bin_mc);
        cmd->cmd.cmd_load_ta.app_phy_addr_hi    = upper_32_bits(ta_bin_mc);
        cmd->cmd.cmd_load_ta.app_len            = context->bin_desc.size_bytes;
        return ret;
 }
 
-static int psp_ta_load(struct psp_context *psp,
-                          struct ta_context *context)
+static int psp_ta_load(struct psp_context *psp, struct ta_context *context)
 {
        int ret;
        struct psp_gfx_cmd_resp *cmd;
        psp_copy_fw(psp, context->bin_desc.start_addr,
                    context->bin_desc.size_bytes);
 
-       psp_prep_ta_load_cmd_buf(cmd,
-                                psp->fw_pri_mc_addr,
-                                context);
+       psp_prep_ta_load_cmd_buf(cmd, psp->fw_pri_mc_addr, context);
 
        ret = psp_cmd_submit_buf(psp, NULL, cmd,
                                 psp->fence_buf_mc_addr);
                goto invoke;
 
        psp->xgmi_context.context.mem_context.shared_mem_size = PSP_XGMI_SHARED_MEM_SIZE;
+       psp->xgmi_context.context.ta_load_type = GFX_CMD_ID_LOAD_TA;
 
        if (!psp->xgmi_context.context.initialized) {
                ret = psp_xgmi_init_shared_buf(psp);
        }
 
        psp->ras_context.context.mem_context.shared_mem_size = PSP_RAS_SHARED_MEM_SIZE;
+       psp->ras_context.context.ta_load_type = GFX_CMD_ID_LOAD_TA;
 
        if (!psp->ras_context.context.initialized) {
                ret = psp_ras_init_shared_buf(psp);
        }
 
        psp->hdcp_context.context.mem_context.shared_mem_size = PSP_HDCP_SHARED_MEM_SIZE;
+       psp->hdcp_context.context.ta_load_type = GFX_CMD_ID_LOAD_TA;
 
        if (!psp->hdcp_context.context.initialized) {
                ret = psp_hdcp_init_shared_buf(psp);
        }
 
        psp->dtm_context.context.mem_context.shared_mem_size = PSP_DTM_SHARED_MEM_SIZE;
+       psp->dtm_context.context.ta_load_type = GFX_CMD_ID_LOAD_TA;
 
        if (!psp->dtm_context.context.initialized) {
                ret = psp_dtm_init_shared_buf(psp);
        }
 
        psp->rap_context.context.mem_context.shared_mem_size = PSP_RAP_SHARED_MEM_SIZE;
+       psp->rap_context.context.ta_load_type = GFX_CMD_ID_LOAD_TA;
 
        if (!psp->rap_context.context.initialized) {
                ret = psp_rap_init_shared_buf(psp);
 
        psp->securedisplay_context.context.mem_context.shared_mem_size =
                PSP_SECUREDISPLAY_SHARED_MEM_SIZE;
+       psp->securedisplay_context.context.ta_load_type = GFX_CMD_ID_LOAD_TA;
 
        if (!psp->securedisplay_context.context.initialized) {
                ret = psp_securedisplay_init_shared_buf(psp);
        if (ret)
                goto failed;
 
-       ret = psp_asd_load(psp);
+       ret = psp_asd_initialize(psp);
        if (ret) {
                DRM_ERROR("PSP load asd failed!\n");
                return ret;
        if (ret)
                goto failed;
 
-       ret = psp_asd_load(psp);
+       ret = psp_asd_initialize(psp);
        if (ret) {
                DRM_ERROR("PSP load asd failed!\n");
                goto failed;