if (to->remap_slice)
                return false;
 
-       if (to->ppgtt) {
-               if (from == to && !test_bit(ring->id,
-                               &to->ppgtt->pd_dirty_rings))
-                       return true;
-       }
+       if (to->ppgtt && from == to &&
+           !(intel_ring_flag(ring) & to->ppgtt->pd_dirty_rings))
+               return true;
 
        return false;
 }
                        goto unpin_out;
 
                /* Doing a PD load always reloads the page dirs */
-               clear_bit(ring->id, &to->ppgtt->pd_dirty_rings);
+               to->ppgtt->pd_dirty_rings &= ~intel_ring_flag(ring);
        }
 
        if (ring != &dev_priv->ring[RCS]) {
                 * space. This means we must enforce that a page table load
                 * occur when this occurs. */
        } else if (to->ppgtt &&
-                       test_and_clear_bit(ring->id, &to->ppgtt->pd_dirty_rings))
+                  (intel_ring_flag(ring) & to->ppgtt->pd_dirty_rings)) {
                hw_flags |= MI_FORCE_RESTORE;
+               to->ppgtt->pd_dirty_rings &= ~intel_ring_flag(ring);
+       }
 
        /* We should never emit switch_mm more than once */
        WARN_ON(needs_pd_load_pre(ring, to) &&
-                       needs_pd_load_post(ring, to, hw_flags));
+               needs_pd_load_post(ring, to, hw_flags));
 
        ret = mi_set_context(ring, to, hw_flags);
        if (ret)
 
        if (ret)
                goto error;
 
-       if (ctx->ppgtt)
-               WARN(ctx->ppgtt->pd_dirty_rings & (1<<ring->id),
-                       "%s didn't clear reload\n", ring->name);
+       WARN(ctx->ppgtt && ctx->ppgtt->pd_dirty_rings & (1<<ring->id),
+            "%s didn't clear reload\n", ring->name);
 
        instp_mode = args->flags & I915_EXEC_CONSTANTS_MASK;
        instp_mask = I915_EXEC_CONSTANTS_MASK;