return 0;
 }
 
+static int vlv_num_wm_levels(struct drm_i915_private *dev_priv)
+{
+       return dev_priv->wm.max_level + 1;
+}
+
+/* mark all levels starting from 'level' as invalid */
+static void vlv_invalidate_wms(struct intel_crtc *crtc,
+                              struct vlv_wm_state *wm_state, int level)
+{
+       struct drm_i915_private *dev_priv = to_i915(crtc->base.dev);
+
+       for (; level < vlv_num_wm_levels(dev_priv); level++) {
+               enum plane_id plane_id;
+
+               for_each_plane_id_on_crtc(crtc, plane_id)
+                       wm_state->wm[level].plane[plane_id] = USHRT_MAX;
+
+               wm_state->sr[level].cursor = USHRT_MAX;
+               wm_state->sr[level].plane = USHRT_MAX;
+       }
+}
+
 static u16 vlv_invert_wm_value(u16 wm, u16 fifo_size)
 {
        if (wm > fifo_size)
                return fifo_size - wm;
 }
 
-static void vlv_invert_wms(struct intel_crtc_state *crtc_state)
+/*
+ * Starting from 'level' set all higher
+ * levels to 'value' in the "raw" watermarks.
+ */
+static void vlv_raw_plane_wm_set(struct intel_crtc_state *crtc_state,
+                                int level, enum plane_id plane_id, u16 value)
 {
-       struct intel_crtc *crtc = to_intel_crtc(crtc_state->base.crtc);
-       struct vlv_wm_state *wm_state = &crtc_state->wm.vlv.optimal;
-       const struct vlv_fifo_state *fifo_state =
-               &crtc_state->wm.vlv.fifo_state;
-       int level;
-
-       for (level = 0; level < wm_state->num_levels; level++) {
-               struct drm_i915_private *dev_priv = to_i915(crtc->base.dev);
-               const int sr_fifo_size =
-                       INTEL_INFO(dev_priv)->num_pipes * 512 - 1;
-               enum plane_id plane_id;
+       struct drm_i915_private *dev_priv = to_i915(crtc_state->base.crtc->dev);
+       int num_levels = vlv_num_wm_levels(dev_priv);
 
-               wm_state->sr[level].plane =
-                       vlv_invert_wm_value(wm_state->sr[level].plane,
-                                           sr_fifo_size);
-               wm_state->sr[level].cursor =
-                       vlv_invert_wm_value(wm_state->sr[level].cursor,
-                                           63);
+       for (; level < num_levels; level++) {
+               struct vlv_pipe_wm *raw = &crtc_state->wm.vlv.raw[level];
 
-               for_each_plane_id_on_crtc(crtc, plane_id) {
-                       wm_state->wm[level].plane[plane_id] =
-                               vlv_invert_wm_value(wm_state->wm[level].plane[plane_id],
-                                                   fifo_state->plane[plane_id]);
-               }
+               raw->plane[plane_id] = value;
        }
 }
 
-static void vlv_compute_wm(struct intel_crtc_state *crtc_state)
+static void vlv_plane_wm_compute(struct intel_crtc_state *crtc_state,
+                                const struct intel_plane_state *plane_state)
 {
-       struct intel_crtc *crtc = to_intel_crtc(crtc_state->base.crtc);
-       struct drm_i915_private *dev_priv = to_i915(crtc->base.dev);
-       struct vlv_wm_state *wm_state = &crtc_state->wm.vlv.optimal;
-       struct intel_plane *plane;
+       struct intel_plane *plane = to_intel_plane(plane_state->base.plane);
+       enum plane_id plane_id = plane->id;
+       int num_levels = vlv_num_wm_levels(to_i915(plane->base.dev));
        int level;
 
-       memset(wm_state, 0, sizeof(*wm_state));
-       memset(&crtc_state->wm.vlv.raw, 0, sizeof(crtc_state->wm.vlv.raw));
+       if (!plane_state->base.visible) {
+               vlv_raw_plane_wm_set(crtc_state, 0, plane_id, 0);
+               return;
+       }
 
-       wm_state->cxsr = crtc->pipe != PIPE_C && crtc->wm.cxsr_allowed;
-       wm_state->num_levels = dev_priv->wm.max_level + 1;
+       for (level = 0; level < num_levels; level++) {
+               struct vlv_pipe_wm *raw = &crtc_state->wm.vlv.raw[level];
+               int wm = vlv_compute_wm_level(crtc_state, plane_state, level);
+               int max_wm = plane_id == PLANE_CURSOR ? 63 : 511;
 
-       wm_state->num_active_planes = 0;
+               /* FIXME just bail */
+               if (WARN_ON(level == 0 && wm > max_wm))
+                       wm = max_wm;
 
-       if (wm_state->num_active_planes != 1)
-               wm_state->cxsr = false;
+               if (wm > max_wm)
+                       break;
 
-       for_each_intel_plane_on_crtc(&dev_priv->drm, crtc, plane) {
-               struct intel_plane_state *state =
-                       to_intel_plane_state(plane->base.state);
+               raw->plane[plane_id] = wm;
+       }
 
-               if (!state->base.visible)
-                       continue;
+       /* mark all higher levels as invalid */
+       vlv_raw_plane_wm_set(crtc_state, level, plane_id, USHRT_MAX);
 
-               for (level = 0; level < wm_state->num_levels; level++) {
-                       struct vlv_pipe_wm *raw = &crtc_state->wm.vlv.raw[level];
-                       int wm = vlv_compute_wm_level(crtc_state, state, level);
-                       int max_wm = plane->id == PLANE_CURSOR ? 63 : 511;
+       DRM_DEBUG_KMS("%s wms: [0]=%d,[1]=%d,[2]=%d\n",
+                     plane->base.name,
+                     crtc_state->wm.vlv.raw[VLV_WM_LEVEL_PM2].plane[plane_id],
+                     crtc_state->wm.vlv.raw[VLV_WM_LEVEL_PM5].plane[plane_id],
+                     crtc_state->wm.vlv.raw[VLV_WM_LEVEL_DDR_DVFS].plane[plane_id]);
+}
 
-                       /* hack */
-                       if (WARN_ON(level == 0 && wm > max_wm))
-                               wm = max_wm;
+static bool vlv_plane_wm_is_valid(const struct intel_crtc_state *crtc_state,
+                                 enum plane_id plane_id, int level)
+{
+       const struct vlv_pipe_wm *raw =
+               &crtc_state->wm.vlv.raw[level];
+       const struct vlv_fifo_state *fifo_state =
+               &crtc_state->wm.vlv.fifo_state;
 
-                       if (wm > max_wm)
-                               break;
+       return raw->plane[plane_id] <= fifo_state->plane[plane_id];
+}
 
-                       raw->plane[plane->id] = wm;
-               }
+static bool vlv_crtc_wm_is_valid(const struct intel_crtc_state *crtc_state, int level)
+{
+       return vlv_plane_wm_is_valid(crtc_state, PLANE_PRIMARY, level) &&
+               vlv_plane_wm_is_valid(crtc_state, PLANE_SPRITE0, level) &&
+               vlv_plane_wm_is_valid(crtc_state, PLANE_SPRITE1, level) &&
+               vlv_plane_wm_is_valid(crtc_state, PLANE_CURSOR, level);
+}
+
+static int vlv_compute_pipe_wm(struct intel_crtc_state *crtc_state)
+{
+       struct intel_crtc *crtc = to_intel_crtc(crtc_state->base.crtc);
+       struct drm_i915_private *dev_priv = to_i915(crtc->base.dev);
+       struct intel_atomic_state *state =
+               to_intel_atomic_state(crtc_state->base.state);
+       struct vlv_wm_state *wm_state = &crtc_state->wm.vlv.optimal;
+       const struct vlv_fifo_state *fifo_state =
+               &crtc_state->wm.vlv.fifo_state;
+       int num_active_planes = hweight32(crtc_state->active_planes &
+                                         ~BIT(PLANE_CURSOR));
+       struct intel_plane_state *plane_state;
+       struct intel_plane *plane;
+       enum plane_id plane_id;
+       int level, ret, i;
+
+       for_each_intel_plane_in_state(state, plane, plane_state, i) {
+               const struct intel_plane_state *old_plane_state =
+                       to_intel_plane_state(plane->base.state);
+
+               if (plane_state->base.crtc != &crtc->base &&
+                   old_plane_state->base.crtc != &crtc->base)
+                       continue;
 
-               wm_state->num_levels = level;
+               vlv_plane_wm_compute(crtc_state, plane_state);
        }
 
-       vlv_compute_fifo(crtc_state);
+       /* initially allow all levels */
+       wm_state->num_levels = vlv_num_wm_levels(dev_priv);
+       /*
+        * Note that enabling cxsr with no primary/sprite planes
+        * enabled can wedge the pipe. Hence we only allow cxsr
+        * with exactly one enabled primary/sprite plane.
+        */
+       wm_state->cxsr = crtc->pipe != PIPE_C &&
+               crtc->wm.cxsr_allowed && num_active_planes == 1;
+
+       ret = vlv_compute_fifo(crtc_state);
+       if (ret)
+               return ret;
 
        for (level = 0; level < wm_state->num_levels; level++) {
-               struct vlv_pipe_wm *raw = &crtc_state->wm.vlv.raw[level];
+               const struct vlv_pipe_wm *raw = &crtc_state->wm.vlv.raw[level];
+               const int sr_fifo_size = INTEL_INFO(dev_priv)->num_pipes * 512 - 1;
 
-               wm_state->wm[level] = *raw;
+               if (!vlv_crtc_wm_is_valid(crtc_state, level))
+                       break;
 
-               wm_state->sr[level].plane = max3(raw->plane[PLANE_PRIMARY],
+               for_each_plane_id_on_crtc(crtc, plane_id) {
+                       wm_state->wm[level].plane[plane_id] =
+                               vlv_invert_wm_value(raw->plane[plane_id],
+                                                   fifo_state->plane[plane_id]);
+               }
+
+               wm_state->sr[level].plane =
+                       vlv_invert_wm_value(max3(raw->plane[PLANE_PRIMARY],
                                                 raw->plane[PLANE_SPRITE0],
-                                                raw->plane[PLANE_SPRITE1]);
-               wm_state->sr[level].cursor = raw->plane[PLANE_CURSOR];
-       }
+                                                raw->plane[PLANE_SPRITE1]),
+                                           sr_fifo_size);
 
-       /* clear any (partially) filled invalid levels */
-       for (level = wm_state->num_levels; level < dev_priv->wm.max_level + 1; level++) {
-               memset(&wm_state->wm[level], 0, sizeof(wm_state->wm[level]));
-               memset(&wm_state->sr[level], 0, sizeof(wm_state->sr[level]));
+               wm_state->sr[level].cursor =
+                       vlv_invert_wm_value(raw->plane[PLANE_CURSOR],
+                                           63);
        }
 
-       vlv_invert_wms(crtc_state);
+       if (level == 0)
+               return -EINVAL;
+
+       /* limit to only levels we can actually handle */
+       wm_state->num_levels = level;
+
+       /* invalidate the higher levels */
+       vlv_invalidate_wms(crtc, wm_state, level);
+
+       return 0;
 }
 
 #define VLV_FIFO(plane, value) \
        (((value) << DSPARB_ ## plane ## _SHIFT_VLV) & DSPARB_ ## plane ## _MASK_VLV)
 
-static void vlv_pipe_set_fifo_size(const struct intel_crtc_state *crtc_state)
+static void vlv_atomic_update_fifo(struct intel_atomic_state *state,
+                                  struct intel_crtc_state *crtc_state)
 {
        struct intel_crtc *crtc = to_intel_crtc(crtc_state->base.crtc);
        struct drm_i915_private *dev_priv = to_i915(crtc->base.dev);
        WARN_ON(fifo_state->plane[PLANE_CURSOR] != 63);
        WARN_ON(fifo_size != 511);
 
-       DRM_DEBUG_KMS("Pipe %c FIFO split %d / %d / %d\n",
-                     pipe_name(crtc->pipe), sprite0_start,
-                     sprite1_start, fifo_size);
-
        spin_lock(&dev_priv->wm.dsparb_lock);
 
        switch (crtc->pipe) {
                const struct vlv_wm_state *wm_state = &crtc->wm.active.vlv;
                enum pipe pipe = crtc->pipe;
 
-               if (!crtc->active)
-                       continue;
-
                wm->pipe[pipe] = wm_state->wm[wm->level];
-               if (wm->cxsr)
+               if (crtc->active && wm->cxsr)
                        wm->sr = wm_state->sr[wm->level];
 
                wm->ddl[pipe].plane[PLANE_PRIMARY] = DDL_PRECISION_HIGH | 2;
        return old < threshold && new >= threshold;
 }
 
-static void vlv_update_wm(struct intel_crtc *crtc)
+static void vlv_program_watermarks(struct drm_i915_private *dev_priv)
 {
-       struct drm_i915_private *dev_priv = to_i915(crtc->base.dev);
-       struct intel_crtc_state *crtc_state =
-               to_intel_crtc_state(crtc->base.state);
-       enum pipe pipe = crtc->pipe;
        struct vlv_wm_values *old_wm = &dev_priv->wm.vlv;
        struct vlv_wm_values new_wm = {};
 
-       vlv_compute_wm(crtc_state);
-       crtc->wm.active.vlv = crtc_state->wm.vlv.optimal;
        vlv_merge_wm(dev_priv, &new_wm);
 
-       if (memcmp(old_wm, &new_wm, sizeof(new_wm)) == 0) {
-               /* FIXME should be part of crtc atomic commit */
-               vlv_pipe_set_fifo_size(crtc_state);
+       if (memcmp(old_wm, &new_wm, sizeof(new_wm)) == 0)
                return;
-       }
 
        if (is_disabling(old_wm->level, new_wm.level, VLV_WM_LEVEL_DDR_DVFS))
                chv_set_memory_dvfs(dev_priv, false);
        if (is_disabling(old_wm->cxsr, new_wm.cxsr, true))
                _intel_set_memory_cxsr(dev_priv, false);
 
-       /* FIXME should be part of crtc atomic commit */
-       vlv_pipe_set_fifo_size(crtc_state);
-
        vlv_write_wm_values(dev_priv, &new_wm);
 
-       DRM_DEBUG_KMS("Setting FIFO watermarks - %c: plane=%d, cursor=%d, "
-                     "sprite0=%d, sprite1=%d, SR: plane=%d, cursor=%d level=%d cxsr=%d\n",
-                     pipe_name(pipe), new_wm.pipe[pipe].plane[PLANE_PRIMARY], new_wm.pipe[pipe].plane[PLANE_CURSOR],
-                     new_wm.pipe[pipe].plane[PLANE_SPRITE0], new_wm.pipe[pipe].plane[PLANE_SPRITE1],
-                     new_wm.sr.plane, new_wm.sr.cursor, new_wm.level, new_wm.cxsr);
-
        if (is_enabling(old_wm->cxsr, new_wm.cxsr, true))
                _intel_set_memory_cxsr(dev_priv, true);
 
        *old_wm = new_wm;
 }
 
+static void vlv_initial_watermarks(struct intel_atomic_state *state,
+                                  struct intel_crtc_state *crtc_state)
+{
+       struct drm_i915_private *dev_priv = to_i915(crtc_state->base.crtc->dev);
+       struct intel_crtc *crtc = to_intel_crtc(crtc_state->base.crtc);
+
+       mutex_lock(&dev_priv->wm.wm_mutex);
+       crtc->wm.active.vlv = crtc_state->wm.vlv.optimal;
+       vlv_program_watermarks(dev_priv);
+       mutex_unlock(&dev_priv->wm.wm_mutex);
+}
+
 #define single_plane_enabled(mask) is_power_of_2(mask)
 
 static void g4x_update_wm(struct intel_crtc *crtc)
        struct drm_i915_private *dev_priv = to_i915(dev);
        struct vlv_wm_values *wm = &dev_priv->wm.vlv;
        struct intel_crtc *crtc;
-       enum pipe pipe;
        u32 val;
 
        vlv_read_wm_values(dev_priv, wm);
 
-       for_each_intel_crtc(dev, crtc)
-               vlv_get_fifo_size(to_intel_crtc_state(crtc->base.state));
-
        wm->cxsr = I915_READ(FW_BLC_SELF_VLV) & FW_CSPWRDWNEN;
        wm->level = VLV_WM_LEVEL_PM2;
 
                mutex_unlock(&dev_priv->rps.hw_lock);
        }
 
-       for_each_pipe(dev_priv, pipe)
+       for_each_intel_crtc(dev, crtc) {
+               struct intel_crtc_state *crtc_state =
+                       to_intel_crtc_state(crtc->base.state);
+               struct vlv_wm_state *active = &crtc->wm.active.vlv;
+               const struct vlv_fifo_state *fifo_state =
+                       &crtc_state->wm.vlv.fifo_state;
+               enum pipe pipe = crtc->pipe;
+               enum plane_id plane_id;
+               int level;
+
+               vlv_get_fifo_size(crtc_state);
+
+               active->num_levels = wm->level + 1;
+               active->cxsr = wm->cxsr;
+
+               /* FIXME sanitize things more */
+               for (level = 0; level < active->num_levels; level++) {
+                       struct vlv_pipe_wm *raw =
+                               &crtc_state->wm.vlv.raw[level];
+
+                       active->sr[level].plane = wm->sr.plane;
+                       active->sr[level].cursor = wm->sr.cursor;
+
+                       for_each_plane_id_on_crtc(crtc, plane_id) {
+                               active->wm[level].plane[plane_id] =
+                                       wm->pipe[pipe].plane[plane_id];
+
+                               raw->plane[plane_id] =
+                                       vlv_invert_wm_value(active->wm[level].plane[plane_id],
+                                                           fifo_state->plane[plane_id]);
+                       }
+               }
+
+               for_each_plane_id_on_crtc(crtc, plane_id)
+                       vlv_raw_plane_wm_set(crtc_state, level,
+                                            plane_id, USHRT_MAX);
+               vlv_invalidate_wms(crtc, active, level);
+
+               crtc_state->wm.vlv.optimal = *active;
+
                DRM_DEBUG_KMS("Initial watermarks: pipe %c, plane=%d, cursor=%d, sprite0=%d, sprite1=%d\n",
                              pipe_name(pipe),
                              wm->pipe[pipe].plane[PLANE_PRIMARY],
                              wm->pipe[pipe].plane[PLANE_CURSOR],
                              wm->pipe[pipe].plane[PLANE_SPRITE0],
                              wm->pipe[pipe].plane[PLANE_SPRITE1]);
+       }
 
        DRM_DEBUG_KMS("Initial watermarks: SR plane=%d, SR cursor=%d level=%d cxsr=%d\n",
                      wm->sr.plane, wm->sr.cursor, wm->level, wm->cxsr);
                }
        } else if (IS_VALLEYVIEW(dev_priv) || IS_CHERRYVIEW(dev_priv)) {
                vlv_setup_wm_latency(dev_priv);
-               dev_priv->display.update_wm = vlv_update_wm;
+               dev_priv->display.compute_pipe_wm = vlv_compute_pipe_wm;
+               dev_priv->display.initial_watermarks = vlv_initial_watermarks;
+               dev_priv->display.atomic_update_watermarks = vlv_atomic_update_fifo;
        } else if (IS_PINEVIEW(dev_priv)) {
                if (!intel_get_cxsr_latency(IS_PINEVIEW_G(dev_priv),
                                            dev_priv->is_ddr3,