int rc;
 
        /*
-        * We need to release a reference on the mm whenever exiting this
+        * We must release a reference on mm_users whenever exiting this
         * function (taken in the memory fault interrupt handler)
         */
        rc = copro_handle_mm_fault(fault->pe_data.mm, fault->dar, fault->dsisr,
        }
        r = RESTART;
 ack:
-       mmdrop(fault->pe_data.mm);
+       mmput(fault->pe_data.mm);
        ack_irq(spa, r);
 }
 
        struct pe_data *pe_data;
        struct ocxl_process_element *pe;
        int lpid, pid, tid;
+       bool schedule = false;
 
        read_irq(spa, &dsisr, &dar, &pe_handle);
        trace_ocxl_fault(spa->spa_mem, pe_handle, dsisr, dar, -1);
        }
        WARN_ON(pe_data->mm->context.id != pid);
 
-       spa->xsl_fault.pe = pe_handle;
-       spa->xsl_fault.dar = dar;
-       spa->xsl_fault.dsisr = dsisr;
-       spa->xsl_fault.pe_data = *pe_data;
-       mmgrab(pe_data->mm); /* mm count is released by bottom half */
-
+       if (mmget_not_zero(pe_data->mm)) {
+                       spa->xsl_fault.pe = pe_handle;
+                       spa->xsl_fault.dar = dar;
+                       spa->xsl_fault.dsisr = dsisr;
+                       spa->xsl_fault.pe_data = *pe_data;
+                       schedule = true;
+                       /* mm_users count released by bottom half */
+       }
        rcu_read_unlock();
-       schedule_work(&spa->xsl_fault.fault_work);
+       if (schedule)
+               schedule_work(&spa->xsl_fault.fault_work);
+       else
+               ack_irq(spa, ADDRESS_ERROR);
        return IRQ_HANDLED;
 }