#include <generated/xe_wa_oob.h>
 
+#include "instructions/xe_alu_commands.h"
 #include "instructions/xe_gfxpipe_commands.h"
 #include "instructions/xe_mi_commands.h"
+#include "regs/xe_engine_regs.h"
 #include "regs/xe_gt_regs.h"
 #include "xe_assert.h"
 #include "xe_bb.h"
        return 0;
 }
 
-/*
- * Convert back from encoded value to type-safe, only to be used when reg.mcr
- * is true
- */
-static struct xe_reg_mcr to_xe_reg_mcr(const struct xe_reg reg)
-{
-       return (const struct xe_reg_mcr){.__reg.raw = reg.raw };
-}
-
 static int emit_wa_job(struct xe_gt *gt, struct xe_exec_queue *q)
 {
        struct xe_reg_sr *sr = &q->hwe->reg_lrc;
        struct xe_bb *bb;
        struct dma_fence *fence;
        long timeout;
+       int count_rmw = 0;
        int count = 0;
 
        if (q->hwe->class == XE_ENGINE_CLASS_RENDER)
        if (IS_ERR(bb))
                return PTR_ERR(bb);
 
-       xa_for_each(&sr->xa, idx, entry)
-               ++count;
+       /* count RMW registers as those will be handled separately */
+       xa_for_each(&sr->xa, idx, entry) {
+               if (entry->reg.masked || entry->clr_bits == ~0)
+                       ++count;
+               else
+                       ++count_rmw;
+       }
 
-       if (count) {
+       if (count || count_rmw)
                xe_gt_dbg(gt, "LRC WA %s save-restore batch\n", sr->name);
 
+       if (count) {
+               /* emit single LRI with all non RMW regs */
+
                bb->cs[bb->len++] = MI_LOAD_REGISTER_IMM | MI_LRI_NUM_REGS(count);
 
                xa_for_each(&sr->xa, idx, entry) {
                        struct xe_reg reg = entry->reg;
-                       struct xe_reg_mcr reg_mcr = to_xe_reg_mcr(reg);
                        u32 val;
 
-                       /*
-                        * Skip reading the register if it's not really needed
-                        */
                        if (reg.masked)
                                val = entry->clr_bits << 16;
-                       else if (entry->clr_bits + 1)
-                               val = (reg.mcr ?
-                                      xe_gt_mcr_unicast_read_any(gt, reg_mcr) :
-                                      xe_mmio_read32(>->mmio, reg)) & (~entry->clr_bits);
-                       else
+                       else if (entry->clr_bits == ~0)
                                val = 0;
+                       else
+                               continue;
 
                        val |= entry->set_bits;
 
                }
        }
 
+       if (count_rmw) {
+               /* emit MI_MATH for each RMW reg */
+
+               xa_for_each(&sr->xa, idx, entry) {
+                       if (entry->reg.masked || entry->clr_bits == ~0)
+                               continue;
+
+                       bb->cs[bb->len++] = MI_LOAD_REGISTER_REG | MI_LRR_DST_CS_MMIO;
+                       bb->cs[bb->len++] = entry->reg.addr;
+                       bb->cs[bb->len++] = CS_GPR_REG(0, 0).addr;
+
+                       bb->cs[bb->len++] = MI_LOAD_REGISTER_IMM | MI_LRI_NUM_REGS(2) |
+                                           MI_LRI_LRM_CS_MMIO;
+                       bb->cs[bb->len++] = CS_GPR_REG(0, 1).addr;
+                       bb->cs[bb->len++] = entry->clr_bits;
+                       bb->cs[bb->len++] = CS_GPR_REG(0, 2).addr;
+                       bb->cs[bb->len++] = entry->set_bits;
+
+                       bb->cs[bb->len++] = MI_MATH(8);
+                       bb->cs[bb->len++] = CS_ALU_INSTR_LOAD(SRCA, REG0);
+                       bb->cs[bb->len++] = CS_ALU_INSTR_LOADINV(SRCB, REG1);
+                       bb->cs[bb->len++] = CS_ALU_INSTR_AND;
+                       bb->cs[bb->len++] = CS_ALU_INSTR_STORE(REG0, ACCU);
+                       bb->cs[bb->len++] = CS_ALU_INSTR_LOAD(SRCA, REG0);
+                       bb->cs[bb->len++] = CS_ALU_INSTR_LOAD(SRCB, REG2);
+                       bb->cs[bb->len++] = CS_ALU_INSTR_OR;
+                       bb->cs[bb->len++] = CS_ALU_INSTR_STORE(REG0, ACCU);
+
+                       bb->cs[bb->len++] = MI_LOAD_REGISTER_REG | MI_LRR_SRC_CS_MMIO;
+                       bb->cs[bb->len++] = CS_GPR_REG(0, 0).addr;
+                       bb->cs[bb->len++] = entry->reg.addr;
+
+                       xe_gt_dbg(gt, "REG[%#x] = ~%#x|%#x\n",
+                                 entry->reg.addr, entry->clr_bits, entry->set_bits);
+               }
+
+               /* reset used GPR */
+               bb->cs[bb->len++] = MI_LOAD_REGISTER_IMM | MI_LRI_NUM_REGS(3) | MI_LRI_LRM_CS_MMIO;
+               bb->cs[bb->len++] = CS_GPR_REG(0, 0).addr;
+               bb->cs[bb->len++] = 0;
+               bb->cs[bb->len++] = CS_GPR_REG(0, 1).addr;
+               bb->cs[bb->len++] = 0;
+               bb->cs[bb->len++] = CS_GPR_REG(0, 2).addr;
+               bb->cs[bb->len++] = 0;
+       }
+
        xe_lrc_emit_hwe_state_instructions(q, bb);
 
        job = xe_bb_create_job(q, bb);