#define REG(reg_name) \
        (CLK_BASE.instance[0].segment[mm ## reg_name ## _BASE_IDX] + mm ## reg_name)
 
+
+/* TODO: evaluate how to lower or disable all dcn clocks in screen off case */
+int rn_get_active_display_cnt_wa(
+               struct dc *dc,
+               struct dc_state *context)
+{
+       int i, display_count;
+       bool hdmi_present = false;
+
+       display_count = 0;
+       for (i = 0; i < context->stream_count; i++) {
+               const struct dc_stream_state *stream = context->streams[i];
+
+               if (stream->signal == SIGNAL_TYPE_HDMI_TYPE_A)
+                       hdmi_present = true;
+       }
+
+       for (i = 0; i < dc->link_count; i++) {
+               const struct dc_link *link = dc->links[i];
+
+               /*
+                * Only notify active stream or virtual stream.
+                * Need to notify virtual stream to work around
+                * headless case. HPD does not fire when system is in
+                * S0i2.
+                */
+               /* abusing the fact that the dig and phy are coupled to see if the phy is enabled */
+               if (link->link_enc->funcs->is_dig_enabled(link->link_enc))
+                       display_count++;
+       }
+
+       /* WA for hang on HDMI after display off back back on*/
+       if (display_count == 0 && hdmi_present)
+               display_count = 1;
+
+       return display_count;
+}
+
 void rn_update_clocks(struct clk_mgr *clk_mgr_base,
                        struct dc_state *context,
                        bool safe_to_lower)
        int display_count;
        bool update_dppclk = false;
        bool update_dispclk = false;
-       bool enter_display_off = false;
        bool dpp_clock_lowered = false;
-       struct dmcu *dmcu = clk_mgr_base->ctx->dc->res_pool->dmcu;
 
-       display_count = clk_mgr_helper_get_active_display_cnt(dc, context);
+       struct dmcu *dmcu = clk_mgr_base->ctx->dc->res_pool->dmcu;
 
-       if (display_count == 0)
-               enter_display_off = true;
+       if (dc->work_arounds.skip_clock_update)
+               return;
 
-       if (enter_display_off == safe_to_lower) {
-               rn_vbios_smu_set_display_count(clk_mgr, display_count);
+       /*
+        * if it is safe to lower, but we are already in the lower state, we don't have to do anything
+        * also if safe to lower is false, we just go in the higher state
+        */
+       if (safe_to_lower) {
+               /* check that we're not already in lower */
+               if (clk_mgr_base->clks.pwr_state != DCN_PWR_STATE_OPTIMIZED) {
+
+                       display_count = rn_get_active_display_cnt_wa(dc, context);
+                       /* if we can go lower, go lower */
+                       if (display_count == 0) {
+                               rn_vbios_smu_set_dcn_low_power_state(clk_mgr, DCN_PWR_STATE_OPTIMIZED);
+                               /* update power state */
+                               clk_mgr_base->clks.pwr_state = DCN_PWR_STATE_OPTIMIZED;
+                       }
+               }
+       } else {
+               /* check that we're not already in the normal state */
+               if (clk_mgr_base->clks.pwr_state != DCN_PWR_STATE_NORMAL) {
+                       rn_vbios_smu_set_dcn_low_power_state(clk_mgr, DCN_PWR_STATE_NORMAL);
+                       /* update power state */
+                       clk_mgr_base->clks.pwr_state = DCN_PWR_STATE_NORMAL;
+               }
        }
 
        if (should_set_clock(safe_to_lower, new_clocks->phyclk_khz, clk_mgr_base->clks.phyclk_khz)) {
        rn_vbios_smu_enable_pme_wa(clk_mgr);
 }
 
+void rn_init_clocks(struct clk_mgr *clk_mgr)
+{
+       memset(&(clk_mgr->clks), 0, sizeof(struct dc_clocks));
+       // Assumption is that boot state always supports pstate
+       clk_mgr->clks.p_state_change_support = true;
+       clk_mgr->clks.prev_p_state_change_support = true;
+       clk_mgr->clks.pwr_state = DCN_PWR_STATE_NORMAL;
+}
+
 static struct clk_mgr_funcs dcn21_funcs = {
        .get_dp_ref_clk_frequency = dce12_get_dp_ref_freq_khz,
        .update_clocks = rn_update_clocks,
-       .init_clocks = dcn2_init_clocks,
+       .init_clocks = rn_init_clocks,
        .enable_pme_wa = rn_enable_pme_wa,
        /* .dump_clk_registers = rn_dump_clk_registers */
 };
 
  * This does not create remote sinks but will trigger DM
  * to start MST detection if a branch is detected.
  */
-bool dc_link_detect(struct dc_link *link, enum dc_detect_reason reason)
+bool dc_link_detect_helper(struct dc_link *link, enum dc_detect_reason reason)
 {
        struct dc_sink_init_data sink_init_data = { 0 };
        struct display_sink_capability sink_caps = { 0 };
        bool same_dpcd = true;
        enum dc_connection_type new_connection_type = dc_connection_none;
        bool perform_dp_seamless_boot = false;
+
        DC_LOGGER_INIT(link->ctx->logger);
 
        if (dc_is_virtual_signal(link->connector_signal))
                dc_sink_release(prev_sink);
 
        return true;
+
+}
+
+bool dc_link_detect(struct dc_link *link, enum dc_detect_reason reason)
+{
+       const struct dc *dc = link->dc;
+       bool ret;
+       /* get out of low power state */
+
+       if (dc->hwss.exit_optimized_pwr_state)
+               dc->hwss.exit_optimized_pwr_state(dc, dc->current_state);
+
+       ret = dc_link_detect_helper(link, reason);
+
+       if (dc->hwss.optimize_pwr_state)
+               dc->hwss.optimize_pwr_state(dc, dc->current_state);
+
+       return ret;
 }
 
 bool dc_link_get_hpd_state(struct dc_link *dc_link)