#include <linux/mm.h>
 #include <linux/dax.h>
+#include <linux/fs.h>
 
 struct xfs_failure_info {
        xfs_agblock_t           startblock;
        struct xfs_mount                *mp = cur->bc_mp;
        struct xfs_inode                *ip;
        struct xfs_failure_info         *notify = data;
+       struct address_space            *mapping;
+       pgoff_t                         pgoff;
+       unsigned long                   pgcnt;
        int                             error = 0;
 
        if (XFS_RMAP_NON_INODE_OWNER(rec->rm_owner) ||
            (rec->rm_flags & (XFS_RMAP_ATTR_FORK | XFS_RMAP_BMBT_BLOCK))) {
+               /* Continue the query because this isn't a failure. */
+               if (notify->mf_flags & MF_MEM_PRE_REMOVE)
+                       return 0;
                notify->want_shutdown = true;
                return 0;
        }
                return 0;
        }
 
-       error = mf_dax_kill_procs(VFS_I(ip)->i_mapping,
-                                 xfs_failure_pgoff(mp, rec, notify),
-                                 xfs_failure_pgcnt(mp, rec, notify),
-                                 notify->mf_flags);
+       mapping = VFS_I(ip)->i_mapping;
+       pgoff = xfs_failure_pgoff(mp, rec, notify);
+       pgcnt = xfs_failure_pgcnt(mp, rec, notify);
+
+       /* Continue the rmap query if the inode isn't a dax file. */
+       if (dax_mapping(mapping))
+               error = mf_dax_kill_procs(mapping, pgoff, pgcnt,
+                                         notify->mf_flags);
+
+       /* Invalidate the cache in dax pages. */
+       if (notify->mf_flags & MF_MEM_PRE_REMOVE)
+               invalidate_inode_pages2_range(mapping, pgoff,
+                                             pgoff + pgcnt - 1);
+
        xfs_irele(ip);
        return error;
 }
 
+static int
+xfs_dax_notify_failure_freeze(
+       struct xfs_mount        *mp)
+{
+       struct super_block      *sb = mp->m_super;
+       int                     error;
+
+       error = freeze_super(sb, FREEZE_HOLDER_KERNEL);
+       if (error)
+               xfs_emerg(mp, "already frozen by kernel, err=%d", error);
+
+       return error;
+}
+
+static void
+xfs_dax_notify_failure_thaw(
+       struct xfs_mount        *mp,
+       bool                    kernel_frozen)
+{
+       struct super_block      *sb = mp->m_super;
+       int                     error;
+
+       if (kernel_frozen) {
+               error = thaw_super(sb, FREEZE_HOLDER_KERNEL);
+               if (error)
+                       xfs_emerg(mp, "still frozen after notify failure, err=%d",
+                               error);
+       }
+
+       /*
+        * Also thaw userspace call anyway because the device is about to be
+        * removed immediately.
+        */
+       thaw_super(sb, FREEZE_HOLDER_USERSPACE);
+}
+
 static int
 xfs_dax_notify_ddev_failure(
        struct xfs_mount        *mp,
        struct xfs_btree_cur    *cur = NULL;
        struct xfs_buf          *agf_bp = NULL;
        int                     error = 0;
+       bool                    kernel_frozen = false;
        xfs_fsblock_t           fsbno = XFS_DADDR_TO_FSB(mp, daddr);
        xfs_agnumber_t          agno = XFS_FSB_TO_AGNO(mp, fsbno);
        xfs_fsblock_t           end_fsbno = XFS_DADDR_TO_FSB(mp,
                                                             daddr + bblen - 1);
        xfs_agnumber_t          end_agno = XFS_FSB_TO_AGNO(mp, end_fsbno);
 
+       if (mf_flags & MF_MEM_PRE_REMOVE) {
+               xfs_info(mp, "Device is about to be removed!");
+               /*
+                * Freeze fs to prevent new mappings from being created.
+                * - Keep going on if others already hold the kernel forzen.
+                * - Keep going on if other errors too because this device is
+                *   starting to fail.
+                * - If kernel frozen state is hold successfully here, thaw it
+                *   here as well at the end.
+                */
+               kernel_frozen = xfs_dax_notify_failure_freeze(mp) == 0;
+       }
+
        error = xfs_trans_alloc_empty(mp, &tp);
        if (error)
-               return error;
+               goto out;
 
        for (; agno <= end_agno; agno++) {
                struct xfs_rmap_irec    ri_low = { };
        }
 
        xfs_trans_cancel(tp);
-       if (error || notify.want_shutdown) {
+
+       /*
+        * Shutdown fs from a force umount in pre-remove case which won't fail,
+        * so errors can be ignored.  Otherwise, shutdown the filesystem with
+        * CORRUPT flag if error occured or notify.want_shutdown was set during
+        * RMAP querying.
+        */
+       if (mf_flags & MF_MEM_PRE_REMOVE)
+               xfs_force_shutdown(mp, SHUTDOWN_FORCE_UMOUNT);
+       else if (error || notify.want_shutdown) {
                xfs_force_shutdown(mp, SHUTDOWN_CORRUPT_ONDISK);
                if (!error)
                        error = -EFSCORRUPTED;
        }
+
+out:
+       /* Thaw the fs if it has been frozen before. */
+       if (mf_flags & MF_MEM_PRE_REMOVE)
+               xfs_dax_notify_failure_thaw(mp, kernel_frozen);
+
        return error;
 }
 
 
        if (mp->m_logdev_targp && mp->m_logdev_targp->bt_daxdev == dax_dev &&
            mp->m_logdev_targp != mp->m_ddev_targp) {
+               /*
+                * In the pre-remove case the failure notification is attempting
+                * to trigger a force unmount.  The expectation is that the
+                * device is still present, but its removal is in progress and
+                * can not be cancelled, proceed with accessing the log device.
+                */
+               if (mf_flags & MF_MEM_PRE_REMOVE)
+                       return 0;
                xfs_err(mp, "ondisk log corrupt, shutting down fs!");
                xfs_force_shutdown(mp, SHUTDOWN_CORRUPT_ONDISK);
                return -EFSCORRUPTED;
        ddev_start = mp->m_ddev_targp->bt_dax_part_off;
        ddev_end = ddev_start + bdev_nr_bytes(mp->m_ddev_targp->bt_bdev) - 1;
 
+       /* Notify failure on the whole device. */
+       if (offset == 0 && len == U64_MAX) {
+               offset = ddev_start;
+               len = bdev_nr_bytes(mp->m_ddev_targp->bt_bdev);
+       }
+
        /* Ignore the range out of filesystem area */
        if (offset + len - 1 < ddev_start)
                return -ENXIO;
 
  */
 static void collect_procs_fsdax(struct page *page,
                struct address_space *mapping, pgoff_t pgoff,
-               struct list_head *to_kill)
+               struct list_head *to_kill, bool pre_remove)
 {
        struct vm_area_struct *vma;
        struct task_struct *tsk;
        i_mmap_lock_read(mapping);
        rcu_read_lock();
        for_each_process(tsk) {
-               struct task_struct *t = task_early_kill(tsk, true);
+               struct task_struct *t = tsk;
 
+               /*
+                * Search for all tasks while MF_MEM_PRE_REMOVE is set, because
+                * the current may not be the one accessing the fsdax page.
+                * Otherwise, search for the current task.
+                */
+               if (!pre_remove)
+                       t = task_early_kill(tsk, true);
                if (!t)
                        continue;
                vma_interval_tree_foreach(vma, &mapping->i_mmap, pgoff, pgoff) {
        dax_entry_t cookie;
        struct page *page;
        size_t end = index + count;
+       bool pre_remove = mf_flags & MF_MEM_PRE_REMOVE;
 
        mf_flags |= MF_ACTION_REQUIRED | MF_MUST_KILL;
 
                if (!page)
                        goto unlock;
 
-               SetPageHWPoison(page);
+               if (!pre_remove)
+                       SetPageHWPoison(page);
 
-               collect_procs_fsdax(page, mapping, index, &to_kill);
+               /*
+                * The pre_remove case is revoking access, the memory is still
+                * good and could theoretically be put back into service.
+                */
+               collect_procs_fsdax(page, mapping, index, &to_kill, pre_remove);
                unmap_and_kill(&to_kill, page_to_pfn(page), mapping,
                                index, mf_flags);
 unlock: