From: Matthew Wilcox (Oracle) <willy@infradead.org>
Date: Tue, 11 Aug 2020 20:14:43 +0000 (-0400)
Subject: mshare: Basic page table sharing support
X-Git-Url: https://www.infradead.org/git/?a=commitdiff_plain;h=12b7669e02d0717bd3d269c32f58f25a64d26646;p=users%2Fwilly%2Flinux.git

mshare: Basic page table sharing support

There are many bugs with this; in particular the kernel will hit
a VM_BUG_ON_PAGE if a page table is shared as its refcount will be
decremented to 0.  Also we don't currently reparent VMAs to the
newly created MM.  And the refcount on the MM isn't maintained.

Signed-off-by: Matthew Wilcox (Oracle) <willy@infradead.org>
---

diff --git a/include/linux/mm.h b/include/linux/mm.h
index dc7b87310c10..582a6040736c 100644
--- a/include/linux/mm.h
+++ b/include/linux/mm.h
@@ -289,12 +289,18 @@ extern unsigned int kobjsize(const void *objp);
 #define VM_NOHUGEPAGE	0x40000000	/* MADV_NOHUGEPAGE marked this vma */
 #define VM_MERGEABLE	0x80000000	/* KSM may merge identical pages */
 
+#ifdef CONFIG_64BIT
+#define VM_SHARED_PT	(1UL << 32)
+#else
+#define VM_SHARED_PT	0
+#endif
+
 #ifdef CONFIG_ARCH_USES_HIGH_VMA_FLAGS
-#define VM_HIGH_ARCH_BIT_0	32	/* bit only usable on 64-bit architectures */
-#define VM_HIGH_ARCH_BIT_1	33	/* bit only usable on 64-bit architectures */
-#define VM_HIGH_ARCH_BIT_2	34	/* bit only usable on 64-bit architectures */
-#define VM_HIGH_ARCH_BIT_3	35	/* bit only usable on 64-bit architectures */
-#define VM_HIGH_ARCH_BIT_4	36	/* bit only usable on 64-bit architectures */
+#define VM_HIGH_ARCH_BIT_0	33	/* bit only usable on 64-bit architectures */
+#define VM_HIGH_ARCH_BIT_1	34	/* bit only usable on 64-bit architectures */
+#define VM_HIGH_ARCH_BIT_2	35	/* bit only usable on 64-bit architectures */
+#define VM_HIGH_ARCH_BIT_3	36	/* bit only usable on 64-bit architectures */
+#define VM_HIGH_ARCH_BIT_4	37	/* bit only usable on 64-bit architectures */
 #define VM_HIGH_ARCH_0	BIT(VM_HIGH_ARCH_BIT_0)
 #define VM_HIGH_ARCH_1	BIT(VM_HIGH_ARCH_BIT_1)
 #define VM_HIGH_ARCH_2	BIT(VM_HIGH_ARCH_BIT_2)
diff --git a/include/trace/events/mmflags.h b/include/trace/events/mmflags.h
index 5fb752034386..9cc30b79f9bd 100644
--- a/include/trace/events/mmflags.h
+++ b/include/trace/events/mmflags.h
@@ -162,7 +162,8 @@ IF_HAVE_VM_SOFTDIRTY(VM_SOFTDIRTY,	"softdirty"	)		\
 	{VM_MIXEDMAP,			"mixedmap"	},		\
 	{VM_HUGEPAGE,			"hugepage"	},		\
 	{VM_NOHUGEPAGE,			"nohugepage"	},		\
-	{VM_MERGEABLE,			"mergeable"	}		\
+	{VM_MERGEABLE,			"mergeable"	},		\
+	{VM_SHARED_PT,			"sharedpt"	}
 
 #define show_vma_flags(flags)						\
 	(flags) ? __print_flags(flags, "|",				\
diff --git a/mm/internal.h b/mm/internal.h
index 9886db20d94f..85611d81183f 100644
--- a/mm/internal.h
+++ b/mm/internal.h
@@ -613,4 +613,11 @@ static inline bool is_migrate_highatomic_page(struct page *page)
 
 void setup_zone_pageset(struct zone *zone);
 extern struct page *alloc_new_node_page(struct page *page, unsigned long node);
+
+extern vm_fault_t find_shared_vma(struct vm_area_struct **,
+		unsigned long *addrp);
+static inline bool vma_is_shared(const struct vm_area_struct *vma)
+{
+	return vma->vm_flags & VM_SHARED_PT;
+}
 #endif	/* __MM_INTERNAL_H */
diff --git a/mm/memory.c b/mm/memory.c
index 3ecad55103ad..e6c5af86bb0d 100644
--- a/mm/memory.c
+++ b/mm/memory.c
@@ -4367,6 +4367,7 @@ vm_fault_t handle_mm_fault(struct vm_area_struct *vma, unsigned long address,
 		unsigned int flags)
 {
 	vm_fault_t ret;
+	bool shared = false;
 
 	__set_current_state(TASK_RUNNING);
 
@@ -4376,6 +4377,15 @@ vm_fault_t handle_mm_fault(struct vm_area_struct *vma, unsigned long address,
 	/* do counter updates before entering really critical section. */
 	check_sync_rss_stat(current);
 
+	if (unlikely(vma_is_shared(vma))) {
+		ret = find_shared_vma(&vma, &address);
+		if (ret)
+			return ret;
+		if (!vma)
+			return VM_FAULT_SIGSEGV;
+		shared = true;
+	}
+
 	if (!arch_vma_access_permitted(vma, flags & FAULT_FLAG_WRITE,
 					    flags & FAULT_FLAG_INSTRUCTION,
 					    flags & FAULT_FLAG_REMOTE))
@@ -4393,6 +4403,9 @@ vm_fault_t handle_mm_fault(struct vm_area_struct *vma, unsigned long address,
 	else
 		ret = __handle_mm_fault(vma, address, flags);
 
+	if (shared)
+		mmap_read_unlock(vma->vm_mm);
+
 	if (flags & FAULT_FLAG_USER) {
 		mem_cgroup_exit_user_fault();
 		/*
diff --git a/mm/mshare.c b/mm/mshare.c
index 75eb0796584f..f79e6519c94a 100644
--- a/mm/mshare.c
+++ b/mm/mshare.c
@@ -2,6 +2,33 @@
 #include <linux/fs.h>
 #include <linux/sched/mm.h>
 #include <linux/syscalls.h>
+#include "internal.h"
+
+/* Returns holding the guest mm's lock for read.  Caller must release. */
+vm_fault_t find_shared_vma(struct vm_area_struct **vmap, unsigned long *addrp)
+{
+	struct vm_area_struct *vma, *host = *vmap;
+	struct mm_struct *mm = host->vm_private_data;
+	unsigned long guest_addr = *addrp - host->vm_start;
+	pgd_t pgd = *pgd_offset(mm, guest_addr);
+	pgd_t *host_pgd = pgd_offset(current->mm, *addrp);
+
+	if (!pgd_same(*host_pgd, pgd)) {
+		set_pgd(host_pgd, pgd);
+		return VM_FAULT_NOPAGE;
+	}
+
+	mmap_read_lock(mm);
+	vma = find_vma(mm, guest_addr);
+
+	/* XXX: expand stack? */
+	if (vma && vma->vm_start > guest_addr)
+		vma = NULL;
+
+	*addrp = guest_addr;
+	*vmap = vma;
+	return 0;
+}
 
 static ssize_t mshare_read(struct kiocb *iocb, struct iov_iter *iov)
 {
@@ -17,6 +44,18 @@ static ssize_t mshare_read(struct kiocb *iocb, struct iov_iter *iov)
 	return ret;
 }
 
+static int mshare_mmap(struct file *file, struct vm_area_struct *vma)
+{
+	struct mm_struct *mm = file->private_data;
+
+	if ((vma->vm_start | vma->vm_end) & (PGDIR_SIZE - 1))
+		return -EINVAL;
+
+	vma->vm_flags |= VM_SHARED_PT;
+	vma->vm_private_data = mm;
+	return 0;
+}
+
 static int mshare_release(struct inode *inode, struct file *file)
 {
 	struct mm_struct *mm = file->private_data;
@@ -28,6 +67,7 @@ static int mshare_release(struct inode *inode, struct file *file)
 
 static const struct file_operations mshare_fops = {
 	.read_iter = mshare_read,
+	.mmap = mshare_mmap,
 	.release = mshare_release,
 };
 
@@ -35,7 +75,9 @@ SYSCALL_DEFINE3(mshare, unsigned long, addr, unsigned long, len,
 		unsigned long, flags)
 {
 	struct mm_struct *mm;
+	struct vm_area_struct *vma;
 	int fd;
+	int i = 0;
 
 	if ((addr | len) & (PGDIR_SIZE - 1))
 		return -EINVAL;
@@ -50,7 +92,30 @@ SYSCALL_DEFINE3(mshare, unsigned long, addr, unsigned long, len,
 	if (!mm->task_size)
 		mm->task_size--;
 
-	fd = anon_inode_getfd("mshare", &mshare_fops, mm, O_RDWR);
+	mmap_write_lock(current->mm);
+
+	vma = find_vma(current->mm, addr + len);
+	if (vma && vma->vm_start < addr + len)
+		goto unlock;
+	vma = find_vma(current->mm, addr);
+	if (vma && vma->vm_start < addr)
+		goto unlock;
+
+	while (addr < mm->task_size) {
+		mm->pgd[i++] = *pgd_offset(current->mm, addr);
+		addr += PGDIR_SIZE;
+	}
+	mmap_write_unlock(current->mm);
 
+	fd = anon_inode_getfd("mshare", &mshare_fops, mm, O_RDWR);
+	if (fd < 0)
+		goto nofd;
+out:
 	return fd;
+unlock:
+	mmap_write_unlock(current->mm);
+	fd = -EINVAL;
+nofd:
+	mmput(mm);
+	goto out;
 }