#include <linux/memblock.h>
 #include <linux/iommu.h>
 #include <linux/debugfs.h>
+#include <linux/sizes.h>
 
 #include <asm/debugfs.h>
 #include <asm/tlb.h>
 #define XTS_ATSD_AVA    1
 #define XTS_ATSD_STAT   2
 
-static unsigned long get_atsd_launch_val(unsigned long pid, unsigned long psize,
-                                       bool flush)
+static unsigned long get_atsd_launch_val(unsigned long pid, unsigned long psize)
 {
        unsigned long launch = 0;
 
        /* PID */
        launch |= pid << PPC_BITLSHIFT(38);
 
-       /* No flush */
-       launch |= !flush << PPC_BITLSHIFT(39);
+       /* Leave "No flush" (bit 39) 0 so every ATSD performs a flush */
 
        return launch;
 }
 }
 
 static void mmio_invalidate_pid(struct mmio_atsd_reg mmio_atsd_reg[NV_MAX_NPUS],
-                               unsigned long pid, bool flush)
+                               unsigned long pid)
 {
-       unsigned long launch = get_atsd_launch_val(pid, MMU_PAGE_COUNT, flush);
+       unsigned long launch = get_atsd_launch_val(pid, MMU_PAGE_COUNT);
 
        /* Invalidating the entire process doesn't use a va */
        mmio_atsd_regs_write(mmio_atsd_reg, XTS_ATSD_LAUNCH, launch);
 }
 
-static void mmio_invalidate_va(struct mmio_atsd_reg mmio_atsd_reg[NV_MAX_NPUS],
-                       unsigned long va, unsigned long pid, bool flush)
+static void mmio_invalidate_range(struct mmio_atsd_reg
+                       mmio_atsd_reg[NV_MAX_NPUS], unsigned long pid,
+                       unsigned long start, unsigned long psize)
 {
-       unsigned long launch;
-
-       launch = get_atsd_launch_val(pid, mmu_virtual_psize, flush);
+       unsigned long launch = get_atsd_launch_val(pid, psize);
 
        /* Write all VAs first */
-       mmio_atsd_regs_write(mmio_atsd_reg, XTS_ATSD_AVA, va);
+       mmio_atsd_regs_write(mmio_atsd_reg, XTS_ATSD_AVA, start);
 
        /* Issue one barrier for all address writes */
        eieio();
 }
 
 /*
- * Invalidate either a single address or an entire PID depending on
- * the value of va.
+ * Invalidate a virtual address range
  */
-static void mmio_invalidate(struct npu_context *npu_context, int va,
-                       unsigned long address, bool flush)
+static void mmio_invalidate(struct npu_context *npu_context,
+                       unsigned long start, unsigned long size)
 {
        struct mmio_atsd_reg mmio_atsd_reg[NV_MAX_NPUS];
        unsigned long pid = npu_context->mm->context.id;
+       unsigned long atsd_start = 0;
+       unsigned long end = start + size - 1;
+       int atsd_psize = MMU_PAGE_COUNT;
+
+       /*
+        * Convert the input range into one of the supported sizes. If the range
+        * doesn't fit, use the next larger supported size. Invalidation latency
+        * is high, so over-invalidation is preferred to issuing multiple
+        * invalidates.
+        *
+        * A 4K page size isn't supported by NPU/GPU ATS, so that case is
+        * ignored.
+        */
+       if (size == SZ_64K) {
+               atsd_start = start;
+               atsd_psize = MMU_PAGE_64K;
+       } else if (ALIGN_DOWN(start, SZ_2M) == ALIGN_DOWN(end, SZ_2M)) {
+               atsd_start = ALIGN_DOWN(start, SZ_2M);
+               atsd_psize = MMU_PAGE_2M;
+       } else if (ALIGN_DOWN(start, SZ_1G) == ALIGN_DOWN(end, SZ_1G)) {
+               atsd_start = ALIGN_DOWN(start, SZ_1G);
+               atsd_psize = MMU_PAGE_1G;
+       }
 
        if (npu_context->nmmu_flush)
                /*
         * an invalidate.
         */
        acquire_atsd_reg(npu_context, mmio_atsd_reg);
-       if (va)
-               mmio_invalidate_va(mmio_atsd_reg, address, pid, flush);
+
+       if (atsd_psize == MMU_PAGE_COUNT)
+               mmio_invalidate_pid(mmio_atsd_reg, pid);
        else
-               mmio_invalidate_pid(mmio_atsd_reg, pid, flush);
+               mmio_invalidate_range(mmio_atsd_reg, pid, atsd_start,
+                                       atsd_psize);
 
        mmio_invalidate_wait(mmio_atsd_reg);
-       if (flush) {
-               /*
-                * The GPU requires two flush ATSDs to ensure all entries have
-                * been flushed. We use PID 0 as it will never be used for a
-                * process on the GPU.
-                */
-               mmio_invalidate_pid(mmio_atsd_reg, 0, true);
-               mmio_invalidate_wait(mmio_atsd_reg);
-               mmio_invalidate_pid(mmio_atsd_reg, 0, true);
-               mmio_invalidate_wait(mmio_atsd_reg);
-       }
+
+       /*
+        * The GPU requires two flush ATSDs to ensure all entries have been
+        * flushed. We use PID 0 as it will never be used for a process on the
+        * GPU.
+        */
+       mmio_invalidate_pid(mmio_atsd_reg, 0);
+       mmio_invalidate_wait(mmio_atsd_reg);
+       mmio_invalidate_pid(mmio_atsd_reg, 0);
+       mmio_invalidate_wait(mmio_atsd_reg);
+
        release_atsd_reg(mmio_atsd_reg);
 }
 
         * There should be no more translation requests for this PID, but we
         * need to ensure any entries for it are removed from the TLB.
         */
-       mmio_invalidate(npu_context, 0, 0, true);
+       mmio_invalidate(npu_context, 0, ~0UL);
 }
 
 static void pnv_npu2_mn_change_pte(struct mmu_notifier *mn,
                                pte_t pte)
 {
        struct npu_context *npu_context = mn_to_npu_context(mn);
-
-       mmio_invalidate(npu_context, 1, address, true);
+       mmio_invalidate(npu_context, address, PAGE_SIZE);
 }
 
 static void pnv_npu2_mn_invalidate_range(struct mmu_notifier *mn,
                                        unsigned long start, unsigned long end)
 {
        struct npu_context *npu_context = mn_to_npu_context(mn);
-       unsigned long address;
-
-       if (end - start > atsd_threshold) {
-               /*
-                * Just invalidate the entire PID if the address range is too
-                * large.
-                */
-               mmio_invalidate(npu_context, 0, 0, true);
-       } else {
-               for (address = start; address < end; address += PAGE_SIZE)
-                       mmio_invalidate(npu_context, 1, address, false);
-
-               /* Do the flush only on the final addess == end */
-               mmio_invalidate(npu_context, 1, address, true);
-       }
+       mmio_invalidate(npu_context, start, end - start);
 }
 
 static const struct mmu_notifier_ops nv_nmmu_notifier_ops = {