#include <linux/etherdevice.h>
 #include <linux/filter.h>
 #include <linux/interrupt.h>
+#include <linux/irq.h>
 #include <linux/module.h>
 #include <linux/pci.h>
 #include <linux/sched.h>
        return IRQ_HANDLED;
 }
 
+static int gve_is_napi_on_home_cpu(struct gve_priv *priv, u32 irq)
+{
+       int cpu_curr = smp_processor_id();
+       const struct cpumask *aff_mask;
+
+       aff_mask = irq_get_effective_affinity_mask(irq);
+       if (unlikely(!aff_mask))
+               return 1;
+
+       return cpumask_test_cpu(cpu_curr, aff_mask);
+}
+
 int gve_napi_poll(struct napi_struct *napi, int budget)
 {
        struct gve_notify_block *block;
                reschedule |= work_done == budget;
        }
 
-       if (reschedule)
-               return budget;
+       if (reschedule) {
+               /* Reschedule by returning budget only if already on the correct
+                * cpu.
+                */
+               if (likely(gve_is_napi_on_home_cpu(priv, block->irq)))
+                       return budget;
+
+               /* If not on the cpu with which this queue's irq has affinity
+                * with, we avoid rescheduling napi and arm the irq instead so
+                * that napi gets rescheduled back eventually onto the right
+                * cpu.
+                */
+               if (work_done == budget)
+                       work_done--;
+       }
 
        if (likely(napi_complete_done(napi, work_done))) {
                /* Enable interrupts again.
                                "Failed to receive msix vector %d\n", i);
                        goto abort_with_some_ntfy_blocks;
                }
+               block->irq = priv->msix_vectors[msix_idx].vector;
                irq_set_affinity_hint(priv->msix_vectors[msix_idx].vector,
                                      get_cpu_mask(i % active_cpus));
                block->irq_db_index = &priv->irq_db_indices[i].index;
                irq_set_affinity_hint(priv->msix_vectors[msix_idx].vector,
                                      NULL);
                free_irq(priv->msix_vectors[msix_idx].vector, block);
+               block->irq = 0;
        }
        kvfree(priv->ntfy_blocks);
        priv->ntfy_blocks = NULL;
                irq_set_affinity_hint(priv->msix_vectors[msix_idx].vector,
                                      NULL);
                free_irq(priv->msix_vectors[msix_idx].vector, block);
+               block->irq = 0;
        }
        free_irq(priv->msix_vectors[priv->mgmt_msix_idx].vector, priv);
        kvfree(priv->ntfy_blocks);