}
 }
 
-static void skl_disable_async_flip_wa(struct intel_atomic_state *state,
-                                     struct intel_crtc *crtc,
-                                     const struct intel_crtc_state *new_crtc_state)
+static void intel_crtc_async_flip_disable_wa(struct intel_atomic_state *state,
+                                            struct intel_crtc *crtc)
 {
-       struct drm_i915_private *dev_priv = to_i915(state->base.dev);
+       struct drm_i915_private *i915 = to_i915(state->base.dev);
+       const struct intel_crtc_state *old_crtc_state =
+               intel_atomic_get_old_crtc_state(state, crtc);
+       const struct intel_crtc_state *new_crtc_state =
+               intel_atomic_get_new_crtc_state(state, crtc);
+       u8 update_planes = new_crtc_state->update_planes;
+       const struct intel_plane_state *old_plane_state;
        struct intel_plane *plane;
-       struct intel_plane_state *new_plane_state;
+       bool need_vbl_wait = false;
        int i;
 
-       for_each_new_intel_plane_in_state(state, plane, new_plane_state, i) {
-               u32 update_mask = new_crtc_state->update_planes;
-               u32 plane_ctl, surf_addr;
-               enum plane_id plane_id;
-               unsigned long irqflags;
-               enum pipe pipe;
-
-               if (crtc->pipe != plane->pipe ||
-                   !(update_mask & BIT(plane->id)))
-                       continue;
-
-               plane_id = plane->id;
-               pipe = plane->pipe;
-
-               spin_lock_irqsave(&dev_priv->uncore.lock, irqflags);
-               plane_ctl = intel_de_read_fw(dev_priv, PLANE_CTL(pipe, plane_id));
-               surf_addr = intel_de_read_fw(dev_priv, PLANE_SURF(pipe, plane_id));
-
-               plane_ctl &= ~PLANE_CTL_ASYNC_FLIP;
-
-               intel_de_write_fw(dev_priv, PLANE_CTL(pipe, plane_id), plane_ctl);
-               intel_de_write_fw(dev_priv, PLANE_SURF(pipe, plane_id), surf_addr);
-               spin_unlock_irqrestore(&dev_priv->uncore.lock, irqflags);
+       for_each_old_intel_plane_in_state(state, plane, old_plane_state, i) {
+               if (plane->need_async_flip_disable_wa &&
+                   plane->pipe == crtc->pipe &&
+                   update_planes & BIT(plane->id)) {
+                       /*
+                        * Apart from the async flip bit we want to
+                        * preserve the old state for the plane.
+                        */
+                       plane->async_flip(plane, old_crtc_state,
+                                         old_plane_state, false);
+                       need_vbl_wait = true;
+               }
        }
 
-       intel_wait_for_vblank(dev_priv, crtc->pipe);
+       if (need_vbl_wait)
+               intel_wait_for_vblank(i915, crtc->pipe);
 }
 
 static void intel_pre_plane_update(struct intel_atomic_state *state,
         * WA for platforms where async address update enable bit
         * is double buffered and only latched at start of vblank.
         */
-       if (old_crtc_state->uapi.async_flip &&
-           !new_crtc_state->uapi.async_flip &&
-           IS_GEN_RANGE(dev_priv, 9, 10))
-               skl_disable_async_flip_wa(state, crtc, new_crtc_state);
+       if (old_crtc_state->uapi.async_flip && !new_crtc_state->uapi.async_flip)
+               intel_crtc_async_flip_disable_wa(state, crtc);
 }
 
 static void intel_crtc_disable_planes(struct intel_atomic_state *state,
 
 static void
 skl_plane_async_flip(struct intel_plane *plane,
                     const struct intel_crtc_state *crtc_state,
-                    const struct intel_plane_state *plane_state)
+                    const struct intel_plane_state *plane_state,
+                    bool async_flip)
 {
        struct drm_i915_private *dev_priv = to_i915(plane->base.dev);
        unsigned long irqflags;
 
        plane_ctl |= skl_plane_ctl_crtc(crtc_state);
 
-       plane_ctl |= PLANE_CTL_ASYNC_FLIP;
+       if (async_flip)
+               plane_ctl |= PLANE_CTL_ASYNC_FLIP;
 
        spin_lock_irqsave(&dev_priv->uncore.lock, irqflags);
 
        plane->min_cdclk = skl_plane_min_cdclk;
 
        if (plane_id == PLANE_PRIMARY) {
+               plane->need_async_flip_disable_wa = IS_GEN_RANGE(dev_priv, 9, 10);
                plane->async_flip = skl_plane_async_flip;
                plane->enable_flip_done = skl_plane_enable_flip_done;
                plane->disable_flip_done = skl_plane_disable_flip_done;