ret = snd_sof_dsp_read_poll_timeout(sdev, HDA_DSP_BAR, MTL_DSP2CXCTL_PRIMARY_CORE, dspcxctl,
                                            (dspcxctl & cpa) == cpa, HDA_DSP_REG_POLL_INTERVAL_US,
                                            HDA_DSP_RESET_TIMEOUT_US);
-       if (ret < 0)
+       if (ret < 0) {
                dev_err(sdev->dev, "%s: timeout on MTL_DSP2CXCTL_PRIMARY_CORE read\n",
                        __func__);
+               return ret;
+       }
 
-       return ret;
+       /* set primary core mask and refcount to 1 */
+       sdev->enabled_cores_mask = BIT(SOF_DSP_PRIMARY_CORE);
+       sdev->dsp_core_ref_count[SOF_DSP_PRIMARY_CORE] = 1;
+
+       return 0;
 }
 
 static int mtl_dsp_core_power_down(struct snd_sof_dev *sdev, int core)
                                            !(dspcxctl & MTL_DSP2CXCTL_PRIMARY_CORE_CPA_MASK),
                                            HDA_DSP_REG_POLL_INTERVAL_US,
                                            HDA_DSP_PD_TIMEOUT * USEC_PER_MSEC);
-       if (ret < 0)
+       if (ret < 0) {
                dev_err(sdev->dev, "failed to power down primary core\n");
+               return ret;
+       }
 
-       return ret;
+       sdev->enabled_cores_mask = 0;
+       sdev->dsp_core_ref_count[SOF_DSP_PRIMARY_CORE] = 0;
+
+       return 0;
 }
 
 int mtl_power_down_dsp(struct snd_sof_dev *sdev)