}
 fs_initcall(init_dax_wait_table);
 
-static wait_queue_head_t *dax_entry_waitqueue(struct address_space *mapping,
-                                             pgoff_t index)
-{
-       unsigned long hash = hash_long((unsigned long)mapping ^ index,
-                                      DAX_WAIT_TABLE_BITS);
-       return wait_table + hash;
-}
-
 static long dax_map_atomic(struct block_device *bdev, struct blk_dax_ctl *dax)
 {
        struct request_queue *q = bdev->bd_queue;
  */
 struct exceptional_entry_key {
        struct address_space *mapping;
-       unsigned long index;
+       pgoff_t entry_start;
 };
 
 struct wait_exceptional_entry_queue {
        struct exceptional_entry_key key;
 };
 
+static wait_queue_head_t *dax_entry_waitqueue(struct address_space *mapping,
+               pgoff_t index, void *entry, struct exceptional_entry_key *key)
+{
+       unsigned long hash;
+
+       /*
+        * If 'entry' is a PMD, align the 'index' that we use for the wait
+        * queue to the start of that PMD.  This ensures that all offsets in
+        * the range covered by the PMD map to the same bit lock.
+        */
+       if (RADIX_DAX_TYPE(entry) == RADIX_DAX_PMD)
+               index &= ~((1UL << (PMD_SHIFT - PAGE_SHIFT)) - 1);
+
+       key->mapping = mapping;
+       key->entry_start = index;
+
+       hash = hash_long((unsigned long)mapping ^ index, DAX_WAIT_TABLE_BITS);
+       return wait_table + hash;
+}
+
 static int wake_exceptional_entry_func(wait_queue_t *wait, unsigned int mode,
                                       int sync, void *keyp)
 {
                container_of(wait, struct wait_exceptional_entry_queue, wait);
 
        if (key->mapping != ewait->key.mapping ||
-           key->index != ewait->key.index)
+           key->entry_start != ewait->key.entry_start)
                return 0;
        return autoremove_wake_function(wait, mode, sync, NULL);
 }
 {
        void *entry, **slot;
        struct wait_exceptional_entry_queue ewait;
-       wait_queue_head_t *wq = dax_entry_waitqueue(mapping, index);
+       wait_queue_head_t *wq;
 
        init_wait(&ewait.wait);
        ewait.wait.func = wake_exceptional_entry_func;
-       ewait.key.mapping = mapping;
-       ewait.key.index = index;
 
        for (;;) {
                entry = __radix_tree_lookup(&mapping->page_tree, index, NULL,
                                *slotp = slot;
                        return entry;
                }
+
+               wq = dax_entry_waitqueue(mapping, index, entry, &ewait.key);
                prepare_to_wait_exclusive(wq, &ewait.wait,
                                          TASK_UNINTERRUPTIBLE);
                spin_unlock_irq(&mapping->tree_lock);
        return entry;
 }
 
+/*
+ * We do not necessarily hold the mapping->tree_lock when we call this
+ * function so it is possible that 'entry' is no longer a valid item in the
+ * radix tree.  This is okay, though, because all we really need to do is to
+ * find the correct waitqueue where tasks might be sleeping waiting for that
+ * old 'entry' and wake them.
+ */
 void dax_wake_mapping_entry_waiter(struct address_space *mapping,
-                                  pgoff_t index, bool wake_all)
+               pgoff_t index, void *entry, bool wake_all)
 {
-       wait_queue_head_t *wq = dax_entry_waitqueue(mapping, index);
+       struct exceptional_entry_key key;
+       wait_queue_head_t *wq;
+
+       wq = dax_entry_waitqueue(mapping, index, entry, &key);
 
        /*
         * Checking for locked entry and prepare_to_wait_exclusive() happens
         * So at this point all tasks that could have seen our entry locked
         * must be in the waitqueue and the following check will see them.
         */
-       if (waitqueue_active(wq)) {
-               struct exceptional_entry_key key;
-
-               key.mapping = mapping;
-               key.index = index;
+       if (waitqueue_active(wq))
                __wake_up(wq, TASK_NORMAL, wake_all ? 0 : 1, &key);
-       }
 }
 
 void dax_unlock_mapping_entry(struct address_space *mapping, pgoff_t index)
        }
        unlock_slot(mapping, slot);
        spin_unlock_irq(&mapping->tree_lock);
-       dax_wake_mapping_entry_waiter(mapping, index, false);
+       dax_wake_mapping_entry_waiter(mapping, index, entry, false);
 }
 
 static void put_locked_mapping_entry(struct address_space *mapping,
                return;
 
        /* We have to wake up next waiter for the radix tree entry lock */
-       dax_wake_mapping_entry_waiter(mapping, index, false);
+       dax_wake_mapping_entry_waiter(mapping, index, entry, false);
 }
 
 /*
        radix_tree_delete(&mapping->page_tree, index);
        mapping->nrexceptional--;
        spin_unlock_irq(&mapping->tree_lock);
-       dax_wake_mapping_entry_waiter(mapping, index, true);
+       dax_wake_mapping_entry_waiter(mapping, index, entry, true);
 
        return 1;
 }