}
 
 /* MMIO ATSD register offsets */
-#define XTS_ATSD_AVA  1
-#define XTS_ATSD_STAT 2
+#define XTS_ATSD_LAUNCH 0
+#define XTS_ATSD_AVA    1
+#define XTS_ATSD_STAT   2
 
-static void mmio_launch_invalidate(struct mmio_atsd_reg *mmio_atsd_reg,
-                               unsigned long launch, unsigned long va)
+static unsigned long get_atsd_launch_val(unsigned long pid, unsigned long psize,
+                                       bool flush)
 {
-       struct npu *npu = mmio_atsd_reg->npu;
-       int reg = mmio_atsd_reg->reg;
+       unsigned long launch = 0;
 
-       __raw_writeq_be(va, npu->mmio_atsd_regs[reg] + XTS_ATSD_AVA);
-       eieio();
-       __raw_writeq_be(launch, npu->mmio_atsd_regs[reg]);
+       if (psize == MMU_PAGE_COUNT) {
+               /* IS set to invalidate entire matching PID */
+               launch |= PPC_BIT(12);
+       } else {
+               /* AP set to invalidate region of psize */
+               launch |= (u64)mmu_get_ap(psize) << PPC_BITLSHIFT(17);
+       }
+
+       /* PRS set to process-scoped */
+       launch |= PPC_BIT(13);
+
+       /* PID */
+       launch |= pid << PPC_BITLSHIFT(38);
+
+       /* No flush */
+       launch |= !flush << PPC_BITLSHIFT(39);
+
+       return launch;
 }
 
-static void mmio_invalidate_pid(struct mmio_atsd_reg mmio_atsd_reg[NV_MAX_NPUS],
-                               unsigned long pid, bool flush)
+static void mmio_atsd_regs_write(struct mmio_atsd_reg
+                       mmio_atsd_reg[NV_MAX_NPUS], unsigned long offset,
+                       unsigned long val)
 {
-       int i;
-       unsigned long launch;
+       struct npu *npu;
+       int i, reg;
 
        for (i = 0; i <= max_npu2_index; i++) {
-               if (mmio_atsd_reg[i].reg < 0)
+               reg = mmio_atsd_reg[i].reg;
+               if (reg < 0)
                        continue;
 
-               /* IS set to invalidate matching PID */
-               launch = PPC_BIT(12);
-
-               /* PRS set to process-scoped */
-               launch |= PPC_BIT(13);
-
-               /* AP */
-               launch |= (u64)
-                       mmu_get_ap(mmu_virtual_psize) << PPC_BITLSHIFT(17);
-
-               /* PID */
-               launch |= pid << PPC_BITLSHIFT(38);
+               npu = mmio_atsd_reg[i].npu;
+               __raw_writeq_be(val, npu->mmio_atsd_regs[reg] + offset);
+       }
+}
 
-               /* No flush */
-               launch |= !flush << PPC_BITLSHIFT(39);
+static void mmio_invalidate_pid(struct mmio_atsd_reg mmio_atsd_reg[NV_MAX_NPUS],
+                               unsigned long pid, bool flush)
+{
+       unsigned long launch = get_atsd_launch_val(pid, MMU_PAGE_COUNT, flush);
 
-               /* Invalidating the entire process doesn't use a va */
-               mmio_launch_invalidate(&mmio_atsd_reg[i], launch, 0);
-       }
+       /* Invalidating the entire process doesn't use a va */
+       mmio_atsd_regs_write(mmio_atsd_reg, XTS_ATSD_LAUNCH, launch);
 }
 
 static void mmio_invalidate_va(struct mmio_atsd_reg mmio_atsd_reg[NV_MAX_NPUS],
                        unsigned long va, unsigned long pid, bool flush)
 {
-       int i;
        unsigned long launch;
 
-       for (i = 0; i <= max_npu2_index; i++) {
-               if (mmio_atsd_reg[i].reg < 0)
-                       continue;
-
-               /* IS set to invalidate target VA */
-               launch = 0;
+       launch = get_atsd_launch_val(pid, mmu_virtual_psize, flush);
 
-               /* PRS set to process scoped */
-               launch |= PPC_BIT(13);
+       /* Write all VAs first */
+       mmio_atsd_regs_write(mmio_atsd_reg, XTS_ATSD_AVA, va);
 
-               /* AP */
-               launch |= (u64)
-                       mmu_get_ap(mmu_virtual_psize) << PPC_BITLSHIFT(17);
-
-               /* PID */
-               launch |= pid << PPC_BITLSHIFT(38);
-
-               /* No flush */
-               launch |= !flush << PPC_BITLSHIFT(39);
+       /* Issue one barrier for all address writes */
+       eieio();
 
-               mmio_launch_invalidate(&mmio_atsd_reg[i], launch, va);
-       }
+       /* Launch */
+       mmio_atsd_regs_write(mmio_atsd_reg, XTS_ATSD_LAUNCH, launch);
 }
 
 #define mn_to_npu_context(x) container_of(x, struct npu_context, mn)