intel_atomic_get_new_crtc_state(state, crtc);
        enum pipe pipe = crtc->pipe;
 
+       intel_psr_pre_plane_update(state, crtc);
+
        if (hsw_pre_update_disable_ips(old_crtc_state, new_crtc_state))
                hsw_disable_ips(old_crtc_state);
 
                intel_encoders_update_prepare(state);
 
        intel_dbuf_pre_plane_update(state);
-       intel_psr_pre_plane_update(state);
 
        for_each_new_intel_crtc_in_state(state, crtc, new_crtc_state, i) {
                if (new_crtc_state->uapi.async_flip)
 
        return 0;
 }
 
-static void _intel_psr_pre_plane_update(const struct intel_atomic_state *state,
-                                       const struct intel_crtc_state *crtc_state)
+void intel_psr_pre_plane_update(struct intel_atomic_state *state,
+                               struct intel_crtc *crtc)
 {
+       struct drm_i915_private *i915 = to_i915(state->base.dev);
+       const struct intel_crtc_state *crtc_state =
+               intel_atomic_get_new_crtc_state(state, crtc);
        struct intel_encoder *encoder;
 
+       if (!HAS_PSR(i915))
+               return;
+
        for_each_intel_encoder_mask_with_psr(state->base.dev, encoder,
                                             crtc_state->uapi.encoder_mask) {
                struct intel_dp *intel_dp = enc_to_intel_dp(encoder);
                 * - All planes will go inactive
                 * - Changing between PSR versions
                 */
+               needs_to_disable |= intel_crtc_needs_modeset(crtc_state);
                needs_to_disable |= !crtc_state->has_psr;
                needs_to_disable |= !crtc_state->active_planes;
                needs_to_disable |= crtc_state->has_psr2 != psr->psr2_enabled;
        }
 }
 
-void intel_psr_pre_plane_update(const struct intel_atomic_state *state)
-{
-       struct drm_i915_private *dev_priv = to_i915(state->base.dev);
-       struct intel_crtc_state *crtc_state;
-       struct intel_crtc *crtc;
-       int i;
-
-       if (!HAS_PSR(dev_priv))
-               return;
-
-       for_each_new_intel_crtc_in_state(state, crtc, crtc_state, i)
-               _intel_psr_pre_plane_update(state, crtc_state);
-}
-
 static void _intel_psr_post_plane_update(const struct intel_atomic_state *state,
                                         const struct intel_crtc_state *crtc_state)
 {
 
 struct intel_encoder;
 
 void intel_psr_init_dpcd(struct intel_dp *intel_dp);
-void intel_psr_pre_plane_update(const struct intel_atomic_state *state);
+void intel_psr_pre_plane_update(struct intel_atomic_state *state,
+                               struct intel_crtc *crtc);
 void intel_psr_post_plane_update(const struct intel_atomic_state *state);
 void intel_psr_disable(struct intel_dp *intel_dp,
                       const struct intel_crtc_state *old_crtc_state);