#define npu_to_phb(x) container_of(x, struct pnv_phb, npu)
 
+/*
+ * spinlock to protect initialisation of an npu_context for a particular
+ * mm_struct.
+ */
+static DEFINE_SPINLOCK(npu_context_lock);
+
 /*
  * Other types of TCE cache invalidation are not functional in the
  * hardware.
  * Returns an error if there no contexts are currently available or a
  * npu_context which should be passed to pnv_npu2_handle_fault().
  *
- * mmap_sem must be held in write mode.
+ * mmap_sem must be held in write mode and must not be called from interrupt
+ * context.
  */
 struct npu_context *pnv_npu2_init_context(struct pci_dev *gpdev,
                        unsigned long flags,
        /*
         * Setup the NPU context table for a particular GPU. These need to be
         * per-GPU as we need the tables to filter ATSDs when there are no
-        * active contexts on a particular GPU.
+        * active contexts on a particular GPU. It is safe for these to be
+        * called concurrently with destroy as the OPAL call takes appropriate
+        * locks and refcounts on init/destroy.
         */
        rc = opal_npu_init_context(nphb->opal_id, mm->context.id, flags,
                                PCI_DEVID(gpdev->bus->number, gpdev->devfn));
         * We store the npu pci device so we can more easily get at the
         * associated npus.
         */
+       spin_lock(&npu_context_lock);
        npu_context = mm->context.npu_context;
+       if (npu_context)
+               WARN_ON(!kref_get_unless_zero(&npu_context->kref));
+       spin_unlock(&npu_context_lock);
+
        if (!npu_context) {
+               /*
+                * We can set up these fields without holding the
+                * npu_context_lock as the npu_context hasn't been returned to
+                * the caller meaning it can't be destroyed. Parallel allocation
+                * is protected against by mmap_sem.
+                */
                rc = -ENOMEM;
                npu_context = kzalloc(sizeof(struct npu_context), GFP_KERNEL);
                if (npu_context) {
                }
 
                mm->context.npu_context = npu_context;
-       } else {
-               WARN_ON(!kref_get_unless_zero(&npu_context->kref));
        }
 
        npu_context->release_cb = cb;
                mm_context_remove_copro(npu_context->mm);
 
        npu_context->mm->context.npu_context = NULL;
-       mmu_notifier_unregister(&npu_context->mn,
-                               npu_context->mm);
-
-       kfree(npu_context);
 }
 
+/*
+ * Destroy a context on the given GPU. May free the npu_context if it is no
+ * longer active on any GPUs. Must not be called from interrupt context.
+ */
 void pnv_npu2_destroy_context(struct npu_context *npu_context,
                        struct pci_dev *gpdev)
 {
+       int removed;
        struct pnv_phb *nphb;
        struct npu *npu;
        struct pci_dev *npdev = pnv_pci_get_npu_dev(gpdev, 0);
        WRITE_ONCE(npu_context->npdev[npu->index][nvlink_index], NULL);
        opal_npu_destroy_context(nphb->opal_id, npu_context->mm->context.id,
                                PCI_DEVID(gpdev->bus->number, gpdev->devfn));
-       kref_put(&npu_context->kref, pnv_npu2_release_context);
+       spin_lock(&npu_context_lock);
+       removed = kref_put(&npu_context->kref, pnv_npu2_release_context);
+       spin_unlock(&npu_context_lock);
+
+       /*
+        * We need to do this outside of pnv_npu2_release_context so that it is
+        * outside the spinlock as mmu_notifier_destroy uses SRCU.
+        */
+       if (removed) {
+               mmu_notifier_unregister(&npu_context->mn,
+                                       npu_context->mm);
+
+               kfree(npu_context);
+       }
+
 }
 EXPORT_SYMBOL(pnv_npu2_destroy_context);