wm[7] = (val >> GEN9_MEM_LATENCY_LEVEL_3_7_SHIFT) &
                                GEN9_MEM_LATENCY_LEVEL_MASK;
 
+               /*
+                * If a level n (n > 1) has a 0us latency, all levels m (m >= n)
+                * need to be disabled. We make sure to sanitize the values out
+                * of the punit to satisfy this requirement.
+                */
+               for (level = 1; level <= max_level; level++) {
+                       if (wm[level] == 0) {
+                               for (i = level + 1; i <= max_level; i++)
+                                       wm[i] = 0;
+                               break;
+                       }
+               }
+
                /*
                 * WaWmMemoryReadLatency:skl
                 *
                 * punit doesn't take into account the read latency so we need
-                * to add 2us to the various latency levels we retrieve from
-                * the punit.
-                *   - W0 is a bit special in that it's the only level that
-                *   can't be disabled if we want to have display working, so
-                *   we always add 2us there.
-                *   - For levels >=1, punit returns 0us latency when they are
-                *   disabled, so we respect that and don't add 2us then
-                *
-                * Additionally, if a level n (n > 1) has a 0us latency, all
-                * levels m (m >= n) need to be disabled. We make sure to
-                * sanitize the values out of the punit to satisfy this
-                * requirement.
+                * to add 2us to the various latency levels we retrieve from the
+                * punit when level 0 response data us 0us.
                 */
-               wm[0] += 2;
-               for (level = 1; level <= max_level; level++)
-                       if (wm[level] != 0)
+               if (wm[0] == 0) {
+                       wm[0] += 2;
+                       for (level = 1; level <= max_level; level++) {
+                               if (wm[level] == 0)
+                                       break;
                                wm[level] += 2;
-                       else {
-                               for (i = level + 1; i <= max_level; i++)
-                                       wm[i] = 0;
-
-                               break;
                        }
+               }
+
        } else if (IS_HASWELL(dev) || IS_BROADWELL(dev)) {
                uint64_t sskpd = I915_READ64(MCH_SSKPD);