#include <linux/workqueue.h>
 #include <linux/srcu.h>
 #include <linux/oom.h>          /* check_stable_address_space */
+#include <linux/pagewalk.h>
 
 #include <linux/uprobes.h>
 
        return ((loff_t)vma->vm_pgoff << PAGE_SHIFT) + (vaddr - vma->vm_start);
 }
 
-/**
- * __replace_page - replace page in vma by new page.
- * based on replace_page in mm/ksm.c
- *
- * @vma:      vma that holds the pte pointing to page
- * @addr:     address the old @page is mapped at
- * @old_page: the page we are replacing by new_page
- * @new_page: the modified page we replace page by
- *
- * If @new_page is NULL, only unmap @old_page.
- *
- * Returns 0 on success, negative error code otherwise.
- */
-static int __replace_page(struct vm_area_struct *vma, unsigned long addr,
-                               struct page *old_page, struct page *new_page)
-{
-       struct folio *old_folio = page_folio(old_page);
-       struct folio *new_folio;
-       struct mm_struct *mm = vma->vm_mm;
-       DEFINE_FOLIO_VMA_WALK(pvmw, old_folio, vma, addr, 0);
-       int err;
-       struct mmu_notifier_range range;
-       pte_t pte;
-
-       mmu_notifier_range_init(&range, MMU_NOTIFY_CLEAR, 0, mm, addr,
-                               addr + PAGE_SIZE);
-
-       if (new_page) {
-               new_folio = page_folio(new_page);
-               err = mem_cgroup_charge(new_folio, vma->vm_mm, GFP_KERNEL);
-               if (err)
-                       return err;
-       }
-
-       /* For folio_free_swap() below */
-       folio_lock(old_folio);
-
-       mmu_notifier_invalidate_range_start(&range);
-       err = -EAGAIN;
-       if (!page_vma_mapped_walk(&pvmw))
-               goto unlock;
-       VM_BUG_ON_PAGE(addr != pvmw.address, old_page);
-       pte = ptep_get(pvmw.pte);
-
-       /*
-        * Handle PFN swap PTES, such as device-exclusive ones, that actually
-        * map pages: simply trigger GUP again to fix it up.
-        */
-       if (unlikely(!pte_present(pte))) {
-               page_vma_mapped_walk_done(&pvmw);
-               goto unlock;
-       }
-
-       if (new_page) {
-               folio_get(new_folio);
-               folio_add_new_anon_rmap(new_folio, vma, addr, RMAP_EXCLUSIVE);
-               folio_add_lru_vma(new_folio, vma);
-       } else
-               /* no new page, just dec_mm_counter for old_page */
-               dec_mm_counter(mm, MM_ANONPAGES);
-
-       if (!folio_test_anon(old_folio)) {
-               dec_mm_counter(mm, mm_counter_file(old_folio));
-               inc_mm_counter(mm, MM_ANONPAGES);
-       }
-
-       flush_cache_page(vma, addr, pte_pfn(pte));
-       ptep_clear_flush(vma, addr, pvmw.pte);
-       if (new_page)
-               set_pte_at(mm, addr, pvmw.pte,
-                          mk_pte(new_page, vma->vm_page_prot));
-
-       folio_remove_rmap_pte(old_folio, old_page, vma);
-       if (!folio_mapped(old_folio))
-               folio_free_swap(old_folio);
-       page_vma_mapped_walk_done(&pvmw);
-       folio_put(old_folio);
-
-       err = 0;
- unlock:
-       mmu_notifier_invalidate_range_end(&range);
-       folio_unlock(old_folio);
-       return err;
-}
-
 /**
  * is_swbp_insn - check if instruction is breakpoint instruction.
  * @insn: instruction to be checked.
        return ret;
 }
 
+static bool orig_page_is_identical(struct vm_area_struct *vma,
+               unsigned long vaddr, struct page *page, bool *pmd_mappable)
+{
+       const pgoff_t index = vaddr_to_offset(vma, vaddr) >> PAGE_SHIFT;
+       struct folio *orig_folio = filemap_get_folio(vma->vm_file->f_mapping,
+                                                   index);
+       struct page *orig_page;
+       bool identical;
+
+       if (IS_ERR(orig_folio))
+               return false;
+       orig_page = folio_file_page(orig_folio, index);
+
+       *pmd_mappable = folio_test_pmd_mappable(orig_folio);
+       identical = folio_test_uptodate(orig_folio) &&
+                   pages_identical(page, orig_page);
+       folio_put(orig_folio);
+       return identical;
+}
+
+static int __uprobe_write_opcode(struct vm_area_struct *vma,
+               struct folio_walk *fw, struct folio *folio,
+               unsigned long opcode_vaddr, uprobe_opcode_t opcode)
+{
+       const unsigned long vaddr = opcode_vaddr & PAGE_MASK;
+       const bool is_register = !!is_swbp_insn(&opcode);
+       bool pmd_mappable;
+
+       /* For now, we'll only handle PTE-mapped folios. */
+       if (fw->level != FW_LEVEL_PTE)
+               return -EFAULT;
+
+       /*
+        * See can_follow_write_pte(): we'd actually prefer a writable PTE here,
+        * but the VMA might not be writable.
+        */
+       if (!pte_write(fw->pte)) {
+               if (!PageAnonExclusive(fw->page))
+                       return -EFAULT;
+               if (unlikely(userfaultfd_pte_wp(vma, fw->pte)))
+                       return -EFAULT;
+               /* SOFTDIRTY is handled via pte_mkdirty() below. */
+       }
+
+       /*
+        * We'll temporarily unmap the page and flush the TLB, such that we can
+        * modify the page atomically.
+        */
+       flush_cache_page(vma, vaddr, pte_pfn(fw->pte));
+       fw->pte = ptep_clear_flush(vma, vaddr, fw->ptep);
+       copy_to_page(fw->page, opcode_vaddr, &opcode, UPROBE_SWBP_INSN_SIZE);
+
+       /*
+        * When unregistering, we may only zap a PTE if uffd is disabled and
+        * there are no unexpected folio references ...
+        */
+       if (is_register || userfaultfd_missing(vma) ||
+           (folio_ref_count(folio) != folio_mapcount(folio) + 1 +
+            folio_test_swapcache(folio) * folio_nr_pages(folio)))
+               goto remap;
+
+       /*
+        * ... and the mapped page is identical to the original page that
+        * would get faulted in on next access.
+        */
+       if (!orig_page_is_identical(vma, vaddr, fw->page, &pmd_mappable))
+               goto remap;
+
+       dec_mm_counter(vma->vm_mm, MM_ANONPAGES);
+       folio_remove_rmap_pte(folio, fw->page, vma);
+       if (!folio_mapped(folio) && folio_test_swapcache(folio) &&
+            folio_trylock(folio)) {
+               folio_free_swap(folio);
+               folio_unlock(folio);
+       }
+       folio_put(folio);
+
+       return pmd_mappable;
+remap:
+       /*
+        * Make sure that our copy_to_page() changes become visible before the
+        * set_pte_at() write.
+        */
+       smp_wmb();
+       /* We modified the page. Make sure to mark the PTE dirty. */
+       set_pte_at(vma->vm_mm, vaddr, fw->ptep, pte_mkdirty(fw->pte));
+       return 0;
+}
+
 /*
  * NOTE:
  * Expect the breakpoint instruction to be the smallest size instruction for
  * uprobe_write_opcode - write the opcode at a given virtual address.
  * @auprobe: arch specific probepoint information.
  * @vma: the probed virtual memory area.
- * @vaddr: the virtual address to store the opcode.
- * @opcode: opcode to be written at @vaddr.
+ * @opcode_vaddr: the virtual address to store the opcode.
+ * @opcode: opcode to be written at @opcode_vaddr.
  *
  * Called with mm->mmap_lock held for read or write.
  * Return 0 (success) or a negative errno.
  */
 int uprobe_write_opcode(struct arch_uprobe *auprobe, struct vm_area_struct *vma,
-               unsigned long vaddr, uprobe_opcode_t opcode)
+               const unsigned long opcode_vaddr, uprobe_opcode_t opcode)
 {
+       const unsigned long vaddr = opcode_vaddr & PAGE_MASK;
        struct mm_struct *mm = vma->vm_mm;
        struct uprobe *uprobe;
-       struct page *old_page, *new_page;
        int ret, is_register, ref_ctr_updated = 0;
-       bool orig_page_huge = false;
        unsigned int gup_flags = FOLL_FORCE;
+       struct mmu_notifier_range range;
+       struct folio_walk fw;
+       struct folio *folio;
+       struct page *page;
 
        is_register = is_swbp_insn(&opcode);
        uprobe = container_of(auprobe, struct uprobe, arch);
 
-retry:
+       if (WARN_ON_ONCE(!is_cow_mapping(vma->vm_flags)))
+               return -EINVAL;
+
+       /*
+        * When registering, we have to break COW to get an exclusive anonymous
+        * page that we can safely modify. Use FOLL_WRITE to trigger a write
+        * fault if required. When unregistering, we might be lucky and the
+        * anon page is already gone. So defer write faults until really
+        * required. Use FOLL_SPLIT_PMD, because __uprobe_write_opcode()
+        * cannot deal with PMDs yet.
+        */
        if (is_register)
-               gup_flags |= FOLL_SPLIT_PMD;
-       /* Read the page with vaddr into memory */
-       ret = get_user_pages_remote(mm, vaddr, 1, gup_flags, &old_page, NULL);
-       if (ret != 1)
-               return ret;
+               gup_flags |= FOLL_WRITE | FOLL_SPLIT_PMD;
 
-       ret = verify_opcode(old_page, vaddr, &opcode);
+retry:
+       ret = get_user_pages_remote(mm, vaddr, 1, gup_flags, &page, NULL);
        if (ret <= 0)
-               goto put_old;
-
-       if (is_zero_page(old_page)) {
-               ret = -EINVAL;
-               goto put_old;
-       }
+               goto out;
+       folio = page_folio(page);
 
-       if (WARN(!is_register && PageCompound(old_page),
-                "uprobe unregister should never work on compound page\n")) {
-               ret = -EINVAL;
-               goto put_old;
+       ret = verify_opcode(page, opcode_vaddr, &opcode);
+       if (ret <= 0) {
+               folio_put(folio);
+               goto out;
        }
 
        /* We are going to replace instruction, update ref_ctr. */
        if (!ref_ctr_updated && uprobe->ref_ctr_offset) {
                ret = update_ref_ctr(uprobe, mm, is_register ? 1 : -1);
-               if (ret)
-                       goto put_old;
+               if (ret) {
+                       folio_put(folio);
+                       goto out;
+               }
 
                ref_ctr_updated = 1;
        }
 
        ret = 0;
-       if (!is_register && !PageAnon(old_page))
-               goto put_old;
-
-       ret = anon_vma_prepare(vma);
-       if (ret)
-               goto put_old;
-
-       ret = -ENOMEM;
-       new_page = alloc_page_vma(GFP_HIGHUSER_MOVABLE, vma, vaddr);
-       if (!new_page)
-               goto put_old;
-
-       __SetPageUptodate(new_page);
-       copy_highpage(new_page, old_page);
-       copy_to_page(new_page, vaddr, &opcode, UPROBE_SWBP_INSN_SIZE);
+       if (unlikely(!folio_test_anon(folio))) {
+               VM_WARN_ON_ONCE(is_register);
+               folio_put(folio);
+               goto out;
+       }
 
        if (!is_register) {
-               struct page *orig_page;
-               pgoff_t index;
-
-               VM_BUG_ON_PAGE(!PageAnon(old_page), old_page);
-
-               index = vaddr_to_offset(vma, vaddr & PAGE_MASK) >> PAGE_SHIFT;
-               orig_page = find_get_page(vma->vm_file->f_inode->i_mapping,
-                                         index);
-
-               if (orig_page) {
-                       if (PageUptodate(orig_page) &&
-                           pages_identical(new_page, orig_page)) {
-                               /* let go new_page */
-                               put_page(new_page);
-                               new_page = NULL;
-
-                               if (PageCompound(orig_page))
-                                       orig_page_huge = true;
-                       }
-                       put_page(orig_page);
-               }
+               /*
+                * In the common case, we'll be able to zap the page when
+                * unregistering. So trigger MMU notifiers now, as we won't
+                * be able to do it under PTL.
+                */
+               mmu_notifier_range_init(&range, MMU_NOTIFY_CLEAR, 0, mm,
+                                       vaddr, vaddr + PAGE_SIZE);
+               mmu_notifier_invalidate_range_start(&range);
+       }
+
+       ret = -EAGAIN;
+       /* Walk the page tables again, to perform the actual update. */
+       if (folio_walk_start(&fw, vma, vaddr, 0)) {
+               if (fw.page == page)
+                       ret = __uprobe_write_opcode(vma, &fw, folio, opcode_vaddr, opcode);
+               folio_walk_end(&fw, vma);
        }
 
-       ret = __replace_page(vma, vaddr & PAGE_MASK, old_page, new_page);
-       if (new_page)
-               put_page(new_page);
-put_old:
-       put_page(old_page);
+       if (!is_register)
+               mmu_notifier_invalidate_range_end(&range);
 
-       if (unlikely(ret == -EAGAIN))
+       folio_put(folio);
+       switch (ret) {
+       case -EFAULT:
+               gup_flags |= FOLL_WRITE | FOLL_SPLIT_PMD;
+               fallthrough;
+       case -EAGAIN:
                goto retry;
+       default:
+               break;
+       }
 
+out:
        /* Revert back reference counter if instruction update failed. */
-       if (ret && is_register && ref_ctr_updated)
+       if (ret < 0 && is_register && ref_ctr_updated)
                update_ref_ctr(uprobe, mm, -1);
 
        /* try collapse pmd for compound page */
-       if (!ret && orig_page_huge)
+       if (ret > 0)
                collapse_pte_mapped_thp(mm, vaddr, false);
 
-       return ret;
+       return ret < 0 ? ret : 0;
 }
 
 /**