struct mm_struct *,
                                            unsigned long, unsigned long);
 static void mmu_notifier_mem_invalidate(struct mmu_notifier *,
+                                       struct mm_struct *,
                                        unsigned long, unsigned long);
 static struct mmu_rb_node *__mmu_rb_search(struct mmu_rb_handler *,
                                           unsigned long, unsigned long);
                        rbnode = rb_entry(node, struct mmu_rb_node, node);
                        rb_erase(node, root);
                        if (handler->ops->remove)
-                               handler->ops->remove(root, rbnode, false);
+                               handler->ops->remove(root, rbnode, NULL);
                }
        }
 
 }
 
 static void __mmu_rb_remove(struct mmu_rb_handler *handler,
-                           struct mmu_rb_node *node, bool arg)
+                           struct mmu_rb_node *node, struct mm_struct *mm)
 {
        /* Validity of handler and node pointers has been checked by caller. */
        hfi1_cdbg(MMU, "Removing node addr 0x%llx, len %u", node->addr,
                  node->len);
        __mmu_int_rb_remove(node, handler->root);
        if (handler->ops->remove)
-               handler->ops->remove(handler->root, node, arg);
+               handler->ops->remove(handler->root, node, mm);
 }
 
 struct mmu_rb_node *hfi1_mmu_rb_search(struct rb_root *root, unsigned long addr,
                return;
 
        spin_lock_irqsave(&handler->lock, flags);
-       __mmu_rb_remove(handler, node, false);
+       __mmu_rb_remove(handler, node, NULL);
        spin_unlock_irqrestore(&handler->lock, flags);
 }
 
 static inline void mmu_notifier_page(struct mmu_notifier *mn,
                                     struct mm_struct *mm, unsigned long addr)
 {
-       mmu_notifier_mem_invalidate(mn, addr, addr + PAGE_SIZE);
+       mmu_notifier_mem_invalidate(mn, mm, addr, addr + PAGE_SIZE);
 }
 
 static inline void mmu_notifier_range_start(struct mmu_notifier *mn,
                                            unsigned long start,
                                            unsigned long end)
 {
-       mmu_notifier_mem_invalidate(mn, start, end);
+       mmu_notifier_mem_invalidate(mn, mm, start, end);
 }
 
 static void mmu_notifier_mem_invalidate(struct mmu_notifier *mn,
+                                       struct mm_struct *mm,
                                        unsigned long start, unsigned long end)
 {
        struct mmu_rb_handler *handler =
                container_of(mn, struct mmu_rb_handler, mn);
        struct rb_root *root = handler->root;
-       struct mmu_rb_node *node;
+       struct mmu_rb_node *node, *ptr = NULL;
        unsigned long flags;
 
        spin_lock_irqsave(&handler->lock, flags);
-       for (node = __mmu_int_rb_iter_first(root, start, end - 1); node;
-            node = __mmu_int_rb_iter_next(node, start, end - 1)) {
+       for (node = __mmu_int_rb_iter_first(root, start, end - 1);
+            node; node = ptr) {
+               /* Guard against node removal. */
+               ptr = __mmu_int_rb_iter_next(node, start, end - 1);
                hfi1_cdbg(MMU, "Invalidating node addr 0x%llx, len %u",
                          node->addr, node->len);
                if (handler->ops->invalidate(root, node))
-                       __mmu_rb_remove(handler, node, true);
+                       __mmu_rb_remove(handler, node, mm);
        }
        spin_unlock_irqrestore(&handler->lock, flags);
 }
 
 struct mmu_rb_ops {
        bool (*filter)(struct mmu_rb_node *, unsigned long, unsigned long);
        int (*insert)(struct rb_root *, struct mmu_rb_node *);
-       void (*remove)(struct rb_root *, struct mmu_rb_node *, bool);
+       void (*remove)(struct rb_root *, struct mmu_rb_node *,
+                      struct mm_struct *);
        int (*invalidate)(struct rb_root *, struct mmu_rb_node *);
 };
 
 
 static int set_rcvarray_entry(struct file *, unsigned long, u32,
                              struct tid_group *, struct page **, unsigned);
 static int mmu_rb_insert(struct rb_root *, struct mmu_rb_node *);
-static void mmu_rb_remove(struct rb_root *, struct mmu_rb_node *, bool);
+static void mmu_rb_remove(struct rb_root *, struct mmu_rb_node *,
+                         struct mm_struct *);
 static int mmu_rb_invalidate(struct rb_root *, struct mmu_rb_node *);
 static int program_rcvarray(struct file *, unsigned long, struct tid_group *,
                            struct tid_pageset *, unsigned, u16, struct page **,
        if (!node || node->rcventry != (uctxt->expected_base + rcventry))
                return -EBADF;
        if (HFI1_CAP_IS_USET(TID_UNMAP))
-               mmu_rb_remove(&fd->tid_rb_root, &node->mmu, false);
+               mmu_rb_remove(&fd->tid_rb_root, &node->mmu, NULL);
        else
                hfi1_mmu_rb_remove(&fd->tid_rb_root, &node->mmu);
 
                                        continue;
                                if (HFI1_CAP_IS_USET(TID_UNMAP))
                                        mmu_rb_remove(&fd->tid_rb_root,
-                                                     &node->mmu, false);
+                                                     &node->mmu, NULL);
                                else
                                        hfi1_mmu_rb_remove(&fd->tid_rb_root,
                                                           &node->mmu);
 }
 
 static void mmu_rb_remove(struct rb_root *root, struct mmu_rb_node *node,
-                         bool notifier)
+                         struct mm_struct *mm)
 {
        struct hfi1_filedata *fdata =
                container_of(root, struct hfi1_filedata, tid_rb_root);
 
 static void activate_packet_queue(struct iowait *, int);
 static bool sdma_rb_filter(struct mmu_rb_node *, unsigned long, unsigned long);
 static int sdma_rb_insert(struct rb_root *, struct mmu_rb_node *);
-static void sdma_rb_remove(struct rb_root *, struct mmu_rb_node *, bool);
+static void sdma_rb_remove(struct rb_root *, struct mmu_rb_node *,
+                          struct mm_struct *);
 static int sdma_rb_invalidate(struct rb_root *, struct mmu_rb_node *);
 
 static struct mmu_rb_ops sdma_rb_ops = {
        rb_node = hfi1_mmu_rb_search(&pq->sdma_rb_root,
                                     (unsigned long)iovec->iov.iov_base,
                                     iovec->iov.iov_len);
-       if (rb_node)
+       if (rb_node && !IS_ERR(rb_node))
                node = container_of(rb_node, struct sdma_mmu_node, rb);
+       else
+               rb_node = NULL;
 
        if (!node) {
                node = kzalloc(sizeof(*node), GFP_KERNEL);
                                &req->pq->sdma_rb_root,
                                (unsigned long)req->iovs[i].iov.iov_base,
                                req->iovs[i].iov.iov_len);
-                       if (!mnode)
+                       if (!mnode || IS_ERR(mnode))
                                continue;
 
                        node = container_of(mnode, struct sdma_mmu_node, rb);
 }
 
 static void sdma_rb_remove(struct rb_root *root, struct mmu_rb_node *mnode,
-                          bool notifier)
+                          struct mm_struct *mm)
 {
        struct sdma_mmu_node *node =
                container_of(mnode, struct sdma_mmu_node, rb);
        node->pq->n_locked -= node->npages;
        spin_unlock(&node->pq->evict_lock);
 
-       unpin_vector_pages(notifier ? NULL : current->mm, node->pages,
-                          node->npages);
+       /*
+        * If mm is set, we are being called by the MMU notifier and we
+        * should not pass a mm_struct to unpin_vector_page(). This is to
+        * prevent a deadlock when hfi1_release_user_pages() attempts to
+        * take the mmap_sem, which the MMU notifier has already taken.
+        */
+       unpin_vector_pages(mm ? NULL : current->mm, node->pages, node->npages);
        /*
         * If called by the MMU notifier, we have to adjust the pinned
         * page count ourselves.
         */
-       if (notifier)
-               current->mm->pinned_vm -= node->npages;
+       if (mm)
+               mm->pinned_vm -= node->npages;
        kfree(node);
 }