*   3 * 4096 * 8192  * 4 < 2^32
  */
 static unsigned int
-skl_get_total_relative_data_rate(struct intel_crtc_state *cstate)
+skl_get_total_relative_data_rate(struct intel_crtc_state *intel_cstate)
 {
-       struct intel_crtc *intel_crtc = to_intel_crtc(cstate->base.crtc);
-       struct drm_device *dev = intel_crtc->base.dev;
+       struct drm_crtc_state *cstate = &intel_cstate->base;
+       struct drm_atomic_state *state = cstate->state;
+       struct drm_crtc *crtc = cstate->crtc;
+       struct drm_device *dev = crtc->dev;
+       struct intel_crtc *intel_crtc = to_intel_crtc(crtc);
        const struct intel_plane *intel_plane;
        unsigned int rate, total_data_rate = 0;
+       int id;
 
        /* Calculate and cache data rate for each plane */
-       for_each_intel_plane_on_crtc(dev, intel_crtc, intel_plane) {
-               const struct drm_plane_state *pstate = intel_plane->base.state;
-               int id = skl_wm_plane_id(intel_plane);
+       /*
+        * FIXME: At the moment this function can be called on either an
+        * in-flight or a committed state object.  If it's in-flight then we
+        * only want to re-calculate the plane data rate for planes that are
+        * part of the transaction (i.e., we don't want to grab any additional
+        * plane states if we don't have to).  If we're operating on committed
+        * state, we'll just go ahead and recalculate the plane data rate for
+        * all planes.
+        *
+        * Once we finish moving our DDB allocation to the atomic check phase,
+        * we'll only be calling this function on in-flight state objects, so
+        * the 'else' branch here will go away.
+        */
+       if (state) {
+               struct drm_plane *plane;
+               struct drm_plane_state *pstate;
+               int i;
+
+               for_each_plane_in_state(state, plane, pstate, i) {
+                       intel_plane = to_intel_plane(plane);
+                       id = skl_wm_plane_id(intel_plane);
+
+                       if (intel_plane->pipe != intel_crtc->pipe)
+                               continue;
+
+                       /* packed/uv */
+                       rate = skl_plane_relative_data_rate(intel_cstate,
+                                                           pstate, 0);
+                       intel_cstate->wm.skl.plane_data_rate[id] = rate;
+
+                       /* y-plane */
+                       rate = skl_plane_relative_data_rate(intel_cstate,
+                                                           pstate, 1);
+                       intel_cstate->wm.skl.plane_y_data_rate[id] = rate;
+               }
+       } else {
+               for_each_intel_plane_on_crtc(dev, intel_crtc, intel_plane) {
+                       const struct drm_plane_state *pstate =
+                               intel_plane->base.state;
+                       int id = skl_wm_plane_id(intel_plane);
 
-               /* packed/uv */
-               rate = skl_plane_relative_data_rate(cstate, pstate, 0);
-               cstate->wm.skl.plane_data_rate[id] = rate;
+                       /* packed/uv */
+                       rate = skl_plane_relative_data_rate(intel_cstate,
+                                                           pstate, 0);
+                       intel_cstate->wm.skl.plane_data_rate[id] = rate;
 
-               /* y-plane */
-               rate = skl_plane_relative_data_rate(cstate, pstate, 1);
-               cstate->wm.skl.plane_y_data_rate[id] = rate;
+                       /* y-plane */
+                       rate = skl_plane_relative_data_rate(intel_cstate,
+                                                           pstate, 1);
+                       intel_cstate->wm.skl.plane_y_data_rate[id] = rate;
+               }
        }
 
        /* Calculate CRTC's total data rate from cached values */
                int id = skl_wm_plane_id(intel_plane);
 
                /* packed/uv */
-               total_data_rate += cstate->wm.skl.plane_data_rate[id];
-               total_data_rate += cstate->wm.skl.plane_y_data_rate[id];
+               total_data_rate += intel_cstate->wm.skl.plane_data_rate[id];
+               total_data_rate += intel_cstate->wm.skl.plane_y_data_rate[id];
        }
 
+       WARN_ON(cstate->plane_mask && total_data_rate == 0);
+
        return total_data_rate;
 }