return NULL;
 }
 
+static void hmm_free_rcu(struct rcu_head *rcu)
+{
+       kfree(container_of(rcu, struct hmm, rcu));
+}
+
 static void hmm_free(struct kref *kref)
 {
        struct hmm *hmm = container_of(kref, struct hmm, kref);
                mm->hmm = NULL;
        spin_unlock(&mm->page_table_lock);
 
-       kfree(hmm);
+       mmu_notifier_call_srcu(&hmm->rcu, hmm_free_rcu);
 }
 
 static inline void hmm_put(struct hmm *hmm)
 
 static void hmm_release(struct mmu_notifier *mn, struct mm_struct *mm)
 {
-       struct hmm *hmm = mm_get_hmm(mm);
+       struct hmm *hmm = container_of(mn, struct hmm, mmu_notifier);
        struct hmm_mirror *mirror;
        struct hmm_range *range;
 
+       /* Bail out if hmm is in the process of being freed */
+       if (!kref_get_unless_zero(&hmm->kref))
+               return;
+
        /* Report this HMM as dying. */
        hmm->dead = true;
 
 static int hmm_invalidate_range_start(struct mmu_notifier *mn,
                        const struct mmu_notifier_range *nrange)
 {
-       struct hmm *hmm = mm_get_hmm(nrange->mm);
+       struct hmm *hmm = container_of(mn, struct hmm, mmu_notifier);
        struct hmm_mirror *mirror;
        struct hmm_update update;
        struct hmm_range *range;
        int ret = 0;
 
-       VM_BUG_ON(!hmm);
+       if (!kref_get_unless_zero(&hmm->kref))
+               return 0;
 
        update.start = nrange->start;
        update.end = nrange->end;
 static void hmm_invalidate_range_end(struct mmu_notifier *mn,
                        const struct mmu_notifier_range *nrange)
 {
-       struct hmm *hmm = mm_get_hmm(nrange->mm);
+       struct hmm *hmm = container_of(mn, struct hmm, mmu_notifier);
 
-       VM_BUG_ON(!hmm);
+       if (!kref_get_unless_zero(&hmm->kref))
+               return;
 
        mutex_lock(&hmm->lock);
        hmm->notifiers--;