/*
  * Register whitelists, sorted by increasing register offset.
+ */
+
+/*
+ * An individual whitelist entry granting access to register addr.  If
+ * mask is non-zero the argument of immediate register writes will be
+ * AND-ed with mask, and the command will be rejected if the result
+ * doesn't match value.
+ *
+ * Registers with non-zero mask are only allowed to be written using
+ * LRI.
+ */
+struct drm_i915_reg_descriptor {
+       u32 addr;
+       u32 mask;
+       u32 value;
+};
+
+/* Convenience macro for adding 32-bit registers. */
+#define REG32(address, ...)                             \
+       { .addr = address, __VA_ARGS__ }
+
+/*
+ * Convenience macro for adding 64-bit registers.
  *
  * Some registers that userspace accesses are 64 bits. The register
  * access commands only allow 32-bit accesses. Hence, we have to include
  * entries for both halves of the 64-bit registers.
  */
+#define REG64(addr)                                     \
+       REG32(addr), REG32(addr + sizeof(u32))
 
-/* Convenience macro for adding 64-bit registers */
-#define REG64(addr) (addr), (addr + sizeof(u32))
-
-static const u32 gen7_render_regs[] = {
+static const struct drm_i915_reg_descriptor gen7_render_regs[] = {
        REG64(GPGPU_THREADS_DISPATCHED),
        REG64(HS_INVOCATION_COUNT),
        REG64(DS_INVOCATION_COUNT),
        REG64(CL_PRIMITIVES_COUNT),
        REG64(PS_INVOCATION_COUNT),
        REG64(PS_DEPTH_COUNT),
-       OACONTROL, /* Only allowed for LRI and SRM. See below. */
+       REG32(OACONTROL), /* Only allowed for LRI and SRM. See below. */
        REG64(MI_PREDICATE_SRC0),
        REG64(MI_PREDICATE_SRC1),
-       GEN7_3DPRIM_END_OFFSET,
-       GEN7_3DPRIM_START_VERTEX,
-       GEN7_3DPRIM_VERTEX_COUNT,
-       GEN7_3DPRIM_INSTANCE_COUNT,
-       GEN7_3DPRIM_START_INSTANCE,
-       GEN7_3DPRIM_BASE_VERTEX,
+       REG32(GEN7_3DPRIM_END_OFFSET),
+       REG32(GEN7_3DPRIM_START_VERTEX),
+       REG32(GEN7_3DPRIM_VERTEX_COUNT),
+       REG32(GEN7_3DPRIM_INSTANCE_COUNT),
+       REG32(GEN7_3DPRIM_START_INSTANCE),
+       REG32(GEN7_3DPRIM_BASE_VERTEX),
        REG64(GEN7_SO_NUM_PRIMS_WRITTEN(0)),
        REG64(GEN7_SO_NUM_PRIMS_WRITTEN(1)),
        REG64(GEN7_SO_NUM_PRIMS_WRITTEN(2)),
        REG64(GEN7_SO_PRIM_STORAGE_NEEDED(1)),
        REG64(GEN7_SO_PRIM_STORAGE_NEEDED(2)),
        REG64(GEN7_SO_PRIM_STORAGE_NEEDED(3)),
-       GEN7_SO_WRITE_OFFSET(0),
-       GEN7_SO_WRITE_OFFSET(1),
-       GEN7_SO_WRITE_OFFSET(2),
-       GEN7_SO_WRITE_OFFSET(3),
-       GEN7_L3SQCREG1,
-       GEN7_L3CNTLREG2,
-       GEN7_L3CNTLREG3,
+       REG32(GEN7_SO_WRITE_OFFSET(0)),
+       REG32(GEN7_SO_WRITE_OFFSET(1)),
+       REG32(GEN7_SO_WRITE_OFFSET(2)),
+       REG32(GEN7_SO_WRITE_OFFSET(3)),
+       REG32(GEN7_L3SQCREG1),
+       REG32(GEN7_L3CNTLREG2),
+       REG32(GEN7_L3CNTLREG3),
 };
 
-static const u32 gen7_blt_regs[] = {
-       BCS_SWCTRL,
+static const struct drm_i915_reg_descriptor gen7_blt_regs[] = {
+       REG32(BCS_SWCTRL),
 };
 
-static const u32 ivb_master_regs[] = {
-       FORCEWAKE_MT,
-       DERRMR,
-       GEN7_PIPE_DE_LOAD_SL(PIPE_A),
-       GEN7_PIPE_DE_LOAD_SL(PIPE_B),
-       GEN7_PIPE_DE_LOAD_SL(PIPE_C),
+static const struct drm_i915_reg_descriptor ivb_master_regs[] = {
+       REG32(FORCEWAKE_MT),
+       REG32(DERRMR),
+       REG32(GEN7_PIPE_DE_LOAD_SL(PIPE_A)),
+       REG32(GEN7_PIPE_DE_LOAD_SL(PIPE_B)),
+       REG32(GEN7_PIPE_DE_LOAD_SL(PIPE_C)),
 };
 
-static const u32 hsw_master_regs[] = {
-       FORCEWAKE_MT,
-       DERRMR,
+static const struct drm_i915_reg_descriptor hsw_master_regs[] = {
+       REG32(FORCEWAKE_MT),
+       REG32(DERRMR),
 };
 
 #undef REG64
+#undef REG32
 
 static u32 gen7_render_get_cmd_length_mask(u32 cmd_header)
 {
        return ret;
 }
 
-static bool check_sorted(int ring_id, const u32 *reg_table, int reg_count)
+static bool check_sorted(int ring_id,
+                        const struct drm_i915_reg_descriptor *reg_table,
+                        int reg_count)
 {
        int i;
        u32 previous = 0;
        bool ret = true;
 
        for (i = 0; i < reg_count; i++) {
-               u32 curr = reg_table[i];
+               u32 curr = reg_table[i].addr;
 
                if (curr < previous) {
                        DRM_ERROR("CMD: table not sorted ring=%d entry=%d reg=0x%08X prev=0x%08X\n",
        return default_desc;
 }
 
-static bool valid_reg(const u32 *table, int count, u32 addr)
+static const struct drm_i915_reg_descriptor *
+find_reg(const struct drm_i915_reg_descriptor *table,
+        int count, u32 addr)
 {
-       if (table && count != 0) {
+       if (table) {
                int i;
 
                for (i = 0; i < count; i++) {
-                       if (table[i] == addr)
-                               return true;
+                       if (table[i].addr == addr)
+                               return &table[i];
                }
        }
 
-       return false;
+       return NULL;
 }
 
 static u32 *vmap_batch(struct drm_i915_gem_object *obj,
                for (offset = desc->reg.offset; offset < length;
                     offset += step) {
                        const u32 reg_addr = cmd[offset] & desc->reg.mask;
+                       const struct drm_i915_reg_descriptor *reg =
+                               find_reg(ring->reg_table, ring->reg_count,
+                                        reg_addr);
+
+                       if (!reg && is_master)
+                               reg = find_reg(ring->master_reg_table,
+                                              ring->master_reg_count,
+                                              reg_addr);
+
+                       if (!reg) {
+                               DRM_DEBUG_DRIVER("CMD: Rejected register 0x%08X in command: 0x%08X (ring=%d)\n",
+                                                reg_addr, *cmd, ring->id);
+                               return false;
+                       }
 
                        /*
                         * OACONTROL requires some special handling for
                                        *oacontrol_set = (cmd[offset + 1] != 0);
                        }
 
-                       if (!valid_reg(ring->reg_table,
-                                      ring->reg_count, reg_addr)) {
-                               if (!is_master ||
-                                   !valid_reg(ring->master_reg_table,
-                                              ring->master_reg_count,
-                                              reg_addr)) {
-                                       DRM_DEBUG_DRIVER("CMD: Rejected register 0x%08X in command: 0x%08X (ring=%d)\n",
-                                                        reg_addr, *cmd,
-                                                        ring->id);
+                       /*
+                        * Check the value written to the register against the
+                        * allowed mask/value pair given in the whitelist entry.
+                        */
+                       if (reg->mask) {
+                               if (desc->cmd.value == MI_LOAD_REGISTER_MEM) {
+                                       DRM_DEBUG_DRIVER("CMD: Rejected LRM to masked register 0x%08X\n",
+                                                        reg_addr);
+                                       return false;
+                               }
+
+                               if (desc->cmd.value == MI_LOAD_REGISTER_IMM(1) &&
+                                   (offset + 2 > length ||
+                                    (cmd[offset + 1] & reg->mask) != reg->value)) {
+                                       DRM_DEBUG_DRIVER("CMD: Rejected LRI to masked register 0x%08X\n",
+                                                        reg_addr);
                                        return false;
                                }
                        }