memset(psp->fw_pri_buf, 0, PSP_1_MEG);
        memcpy(psp->fw_pri_buf, psp->ta_ras_start_addr, psp->ta_ras_ucode_size);
 
+       ras_cmd = (struct ta_ras_shared_memory *)psp->ras.ras_shared_buf;
+
+       if (psp->adev->gmc.xgmi.connected_to_cpu)
+               ras_cmd->ras_in_message.init_flags.poison_mode_en = 1;
+       else
+               ras_cmd->ras_in_message.init_flags.dgpu_mode = 1;
+
        psp_prep_ta_load_cmd_buf(cmd,
                                 psp->fw_pri_mc_addr,
                                 psp->ta_ras_ucode_size,
        ret = psp_cmd_submit_buf(psp, NULL, cmd,
                        psp->fence_buf_mc_addr);
 
-       ras_cmd = (struct ta_ras_shared_memory *)psp->ras.ras_shared_buf;
-
        if (!ret) {
                psp->ras.session_id = cmd->resp.session_id;
 
 
        uint64_t                value;                  // method if error injection. i.e persistent, coherent etc.
 };
 
+struct ta_ras_init_flags
+{
+    uint8_t     poison_mode_en;
+    uint8_t     dgpu_mode;
+};
+
 struct ta_ras_output_flags
 {
        uint8_t    ras_init_success_flag;
 /* Common input structure for RAS callbacks */
 /**********************************************************/
 union ta_ras_cmd_input {
+       struct ta_ras_init_flags                init_flags;
        struct ta_ras_enable_features_input     enable_features;
        struct ta_ras_disable_features_input    disable_features;
        struct ta_ras_trigger_error_input       trigger_error;