#include <asm/i387.h>
 #include <asm/insn.h>
 #include <asm/mman.h>
+#include <asm/mmu_context.h>
 #include <asm/mpx.h>
 #include <asm/processor.h>
 #include <asm/fpu-internal.h>
        .name = mpx_mapping_name,
 };
 
+static int is_mpx_vma(struct vm_area_struct *vma)
+{
+       return (vma->vm_ops == &mpx_vma_ops);
+}
+
 /*
  * This is really a simplified "vm_mmap". it only handles MPX
  * bounds tables (the bounds directory is user-allocated).
        }
        return 0;
 }
+
+/*
+ * A thin wrapper around get_user_pages().  Returns 0 if the
+ * fault was resolved or -errno if not.
+ */
+static int mpx_resolve_fault(long __user *addr, int write)
+{
+       long gup_ret;
+       int nr_pages = 1;
+       int force = 0;
+
+       gup_ret = get_user_pages(current, current->mm, (unsigned long)addr,
+                                nr_pages, write, force, NULL, NULL);
+       /*
+        * get_user_pages() returns number of pages gotten.
+        * 0 means we failed to fault in and get anything,
+        * probably because 'addr' is bad.
+        */
+       if (!gup_ret)
+               return -EFAULT;
+       /* Other error, return it */
+       if (gup_ret < 0)
+               return gup_ret;
+       /* must have gup'd a page and gup_ret>0, success */
+       return 0;
+}
+
+/*
+ * Get the base of bounds tables pointed by specific bounds
+ * directory entry.
+ */
+static int get_bt_addr(struct mm_struct *mm,
+                       long __user *bd_entry, unsigned long *bt_addr)
+{
+       int ret;
+       int valid_bit;
+
+       if (!access_ok(VERIFY_READ, (bd_entry), sizeof(*bd_entry)))
+               return -EFAULT;
+
+       while (1) {
+               int need_write = 0;
+
+               pagefault_disable();
+               ret = get_user(*bt_addr, bd_entry);
+               pagefault_enable();
+               if (!ret)
+                       break;
+               if (ret == -EFAULT)
+                       ret = mpx_resolve_fault(bd_entry, need_write);
+               /*
+                * If we could not resolve the fault, consider it
+                * userspace's fault and error out.
+                */
+               if (ret)
+                       return ret;
+       }
+
+       valid_bit = *bt_addr & MPX_BD_ENTRY_VALID_FLAG;
+       *bt_addr &= MPX_BT_ADDR_MASK;
+
+       /*
+        * When the kernel is managing bounds tables, a bounds directory
+        * entry will either have a valid address (plus the valid bit)
+        * *OR* be completely empty. If we see a !valid entry *and* some
+        * data in the address field, we know something is wrong. This
+        * -EINVAL return will cause a SIGSEGV.
+        */
+       if (!valid_bit && *bt_addr)
+               return -EINVAL;
+       /*
+        * Do we have an completely zeroed bt entry?  That is OK.  It
+        * just means there was no bounds table for this memory.  Make
+        * sure to distinguish this from -EINVAL, which will cause
+        * a SEGV.
+        */
+       if (!valid_bit)
+               return -ENOENT;
+
+       return 0;
+}
+
+/*
+ * Free the backing physical pages of bounds table 'bt_addr'.
+ * Assume start...end is within that bounds table.
+ */
+static int zap_bt_entries(struct mm_struct *mm,
+               unsigned long bt_addr,
+               unsigned long start, unsigned long end)
+{
+       struct vm_area_struct *vma;
+       unsigned long addr, len;
+
+       /*
+        * Find the first overlapping vma. If vma->vm_start > start, there
+        * will be a hole in the bounds table. This -EINVAL return will
+        * cause a SIGSEGV.
+        */
+       vma = find_vma(mm, start);
+       if (!vma || vma->vm_start > start)
+               return -EINVAL;
+
+       /*
+        * A NUMA policy on a VM_MPX VMA could cause this bouds table to
+        * be split. So we need to look across the entire 'start -> end'
+        * range of this bounds table, find all of the VM_MPX VMAs, and
+        * zap only those.
+        */
+       addr = start;
+       while (vma && vma->vm_start < end) {
+               /*
+                * We followed a bounds directory entry down
+                * here.  If we find a non-MPX VMA, that's bad,
+                * so stop immediately and return an error.  This
+                * probably results in a SIGSEGV.
+                */
+               if (!is_mpx_vma(vma))
+                       return -EINVAL;
+
+               len = min(vma->vm_end, end) - addr;
+               zap_page_range(vma, addr, len, NULL);
+
+               vma = vma->vm_next;
+               addr = vma->vm_start;
+       }
+
+       return 0;
+}
+
+static int unmap_single_bt(struct mm_struct *mm,
+               long __user *bd_entry, unsigned long bt_addr)
+{
+       unsigned long expected_old_val = bt_addr | MPX_BD_ENTRY_VALID_FLAG;
+       unsigned long actual_old_val = 0;
+       int ret;
+
+       while (1) {
+               int need_write = 1;
+
+               pagefault_disable();
+               ret = user_atomic_cmpxchg_inatomic(&actual_old_val, bd_entry,
+                                                  expected_old_val, 0);
+               pagefault_enable();
+               if (!ret)
+                       break;
+               if (ret == -EFAULT)
+                       ret = mpx_resolve_fault(bd_entry, need_write);
+               /*
+                * If we could not resolve the fault, consider it
+                * userspace's fault and error out.
+                */
+               if (ret)
+                       return ret;
+       }
+       /*
+        * The cmpxchg was performed, check the results.
+        */
+       if (actual_old_val != expected_old_val) {
+               /*
+                * Someone else raced with us to unmap the table.
+                * There was no bounds table pointed to by the
+                * directory, so declare success.  Somebody freed
+                * it.
+                */
+               if (!actual_old_val)
+                       return 0;
+               /*
+                * Something messed with the bounds directory
+                * entry.  We hold mmap_sem for read or write
+                * here, so it could not be a _new_ bounds table
+                * that someone just allocated.  Something is
+                * wrong, so pass up the error and SIGSEGV.
+                */
+               return -EINVAL;
+       }
+
+       /*
+        * Note, we are likely being called under do_munmap() already. To
+        * avoid recursion, do_munmap() will check whether it comes
+        * from one bounds table through VM_MPX flag.
+        */
+       return do_munmap(mm, bt_addr, MPX_BT_SIZE_BYTES);
+}
+
+/*
+ * If the bounds table pointed by bounds directory 'bd_entry' is
+ * not shared, unmap this whole bounds table. Otherwise, only free
+ * those backing physical pages of bounds table entries covered
+ * in this virtual address region start...end.
+ */
+static int unmap_shared_bt(struct mm_struct *mm,
+               long __user *bd_entry, unsigned long start,
+               unsigned long end, bool prev_shared, bool next_shared)
+{
+       unsigned long bt_addr;
+       int ret;
+
+       ret = get_bt_addr(mm, bd_entry, &bt_addr);
+       /*
+        * We could see an "error" ret for not-present bounds
+        * tables (not really an error), or actual errors, but
+        * stop unmapping either way.
+        */
+       if (ret)
+               return ret;
+
+       if (prev_shared && next_shared)
+               ret = zap_bt_entries(mm, bt_addr,
+                               bt_addr+MPX_GET_BT_ENTRY_OFFSET(start),
+                               bt_addr+MPX_GET_BT_ENTRY_OFFSET(end));
+       else if (prev_shared)
+               ret = zap_bt_entries(mm, bt_addr,
+                               bt_addr+MPX_GET_BT_ENTRY_OFFSET(start),
+                               bt_addr+MPX_BT_SIZE_BYTES);
+       else if (next_shared)
+               ret = zap_bt_entries(mm, bt_addr, bt_addr,
+                               bt_addr+MPX_GET_BT_ENTRY_OFFSET(end));
+       else
+               ret = unmap_single_bt(mm, bd_entry, bt_addr);
+
+       return ret;
+}
+
+/*
+ * A virtual address region being munmap()ed might share bounds table
+ * with adjacent VMAs. We only need to free the backing physical
+ * memory of these shared bounds tables entries covered in this virtual
+ * address region.
+ */
+static int unmap_edge_bts(struct mm_struct *mm,
+               unsigned long start, unsigned long end)
+{
+       int ret;
+       long __user *bde_start, *bde_end;
+       struct vm_area_struct *prev, *next;
+       bool prev_shared = false, next_shared = false;
+
+       bde_start = mm->bd_addr + MPX_GET_BD_ENTRY_OFFSET(start);
+       bde_end = mm->bd_addr + MPX_GET_BD_ENTRY_OFFSET(end-1);
+
+       /*
+        * Check whether bde_start and bde_end are shared with adjacent
+        * VMAs.
+        *
+        * We already unliked the VMAs from the mm's rbtree so 'start'
+        * is guaranteed to be in a hole. This gets us the first VMA
+        * before the hole in to 'prev' and the next VMA after the hole
+        * in to 'next'.
+        */
+       next = find_vma_prev(mm, start, &prev);
+       if (prev && (mm->bd_addr + MPX_GET_BD_ENTRY_OFFSET(prev->vm_end-1))
+                       == bde_start)
+               prev_shared = true;
+       if (next && (mm->bd_addr + MPX_GET_BD_ENTRY_OFFSET(next->vm_start))
+                       == bde_end)
+               next_shared = true;
+
+       /*
+        * This virtual address region being munmap()ed is only
+        * covered by one bounds table.
+        *
+        * In this case, if this table is also shared with adjacent
+        * VMAs, only part of the backing physical memory of the bounds
+        * table need be freeed. Otherwise the whole bounds table need
+        * be unmapped.
+        */
+       if (bde_start == bde_end) {
+               return unmap_shared_bt(mm, bde_start, start, end,
+                               prev_shared, next_shared);
+       }
+
+       /*
+        * If more than one bounds tables are covered in this virtual
+        * address region being munmap()ed, we need to separately check
+        * whether bde_start and bde_end are shared with adjacent VMAs.
+        */
+       ret = unmap_shared_bt(mm, bde_start, start, end, prev_shared, false);
+       if (ret)
+               return ret;
+       ret = unmap_shared_bt(mm, bde_end, start, end, false, next_shared);
+       if (ret)
+               return ret;
+
+       return 0;
+}
+
+static int mpx_unmap_tables(struct mm_struct *mm,
+               unsigned long start, unsigned long end)
+{
+       int ret;
+       long __user *bd_entry, *bde_start, *bde_end;
+       unsigned long bt_addr;
+
+       /*
+        * "Edge" bounds tables are those which are being used by the region
+        * (start -> end), but that may be shared with adjacent areas.  If they
+        * turn out to be completely unshared, they will be freed.  If they are
+        * shared, we will free the backing store (like an MADV_DONTNEED) for
+        * areas used by this region.
+        */
+       ret = unmap_edge_bts(mm, start, end);
+       switch (ret) {
+               /* non-present tables are OK */
+               case 0:
+               case -ENOENT:
+                       /* Success, or no tables to unmap */
+                       break;
+               case -EINVAL:
+               case -EFAULT:
+               default:
+                       return ret;
+       }
+
+       /*
+        * Only unmap the bounds table that are
+        *   1. fully covered
+        *   2. not at the edges of the mapping, even if full aligned
+        */
+       bde_start = mm->bd_addr + MPX_GET_BD_ENTRY_OFFSET(start);
+       bde_end = mm->bd_addr + MPX_GET_BD_ENTRY_OFFSET(end-1);
+       for (bd_entry = bde_start + 1; bd_entry < bde_end; bd_entry++) {
+               ret = get_bt_addr(mm, bd_entry, &bt_addr);
+               switch (ret) {
+                       case 0:
+                               break;
+                       case -ENOENT:
+                               /* No table here, try the next one */
+                               continue;
+                       case -EINVAL:
+                       case -EFAULT:
+                       default:
+                               /*
+                                * Note: we are being strict here.
+                                * Any time we run in to an issue
+                                * unmapping tables, we stop and
+                                * SIGSEGV.
+                                */
+                               return ret;
+               }
+
+               ret = unmap_single_bt(mm, bd_entry, bt_addr);
+               if (ret)
+                       return ret;
+       }
+
+       return 0;
+}
+
+/*
+ * Free unused bounds tables covered in a virtual address region being
+ * munmap()ed. Assume end > start.
+ *
+ * This function will be called by do_munmap(), and the VMAs covering
+ * the virtual address region start...end have already been split if
+ * necessary, and the 'vma' is the first vma in this range (start -> end).
+ */
+void mpx_notify_unmap(struct mm_struct *mm, struct vm_area_struct *vma,
+               unsigned long start, unsigned long end)
+{
+       int ret;
+
+       /*
+        * Refuse to do anything unless userspace has asked
+        * the kernel to help manage the bounds tables,
+        */
+       if (!kernel_managing_mpx_tables(current->mm))
+               return;
+       /*
+        * This will look across the entire 'start -> end' range,
+        * and find all of the non-VM_MPX VMAs.
+        *
+        * To avoid recursion, if a VM_MPX vma is found in the range
+        * (start->end), we will not continue follow-up work. This
+        * recursion represents having bounds tables for bounds tables,
+        * which should not occur normally. Being strict about it here
+        * helps ensure that we do not have an exploitable stack overflow.
+        */
+       do {
+               if (vma->vm_flags & VM_MPX)
+                       return;
+               vma = vma->vm_next;
+       } while (vma && vma->vm_start < end);
+
+       ret = mpx_unmap_tables(mm, start, end);
+       if (ret)
+               force_sig(SIGSEGV, current);
+}