if (!(vma->vm_flags & VM_WRITE))
                goto out_unlock_mmap;
 
-       ret = follow_pte(vma->vm_mm, mmio_addr, &ptep, &ptl);
+       ret = follow_pte(vma, mmio_addr, &ptep, &ptl);
        if (ret)
                goto out_unlock_mmap;
 
        if (!(vma->vm_flags & VM_WRITE))
                goto out_unlock_mmap;
 
-       ret = follow_pte(vma->vm_mm, mmio_addr, &ptep, &ptl);
+       ret = follow_pte(vma, mmio_addr, &ptep, &ptl);
        if (ret)
                goto out_unlock_mmap;
 
 
        pte_t *ptep, pte;
        spinlock_t *ptl;
 
-       if (!(vma->vm_flags & (VM_IO | VM_PFNMAP)))
-               return -EINVAL;
-
-       if (follow_pte(vma->vm_mm, vma->vm_start, &ptep, &ptl))
+       if (follow_pte(vma, vma->vm_start, &ptep, &ptl))
                return -EINVAL;
 
        pte = ptep_get(ptep);
 
        spinlock_t *ptl;
        int ret;
 
-       ret = follow_pte(vma->vm_mm, vaddr, &ptep, &ptl);
+       ret = follow_pte(vma, vaddr, &ptep, &ptl);
        if (ret) {
                bool unlocked = false;
 
                if (ret)
                        return ret;
 
-               ret = follow_pte(vma->vm_mm, vaddr, &ptep, &ptl);
+               ret = follow_pte(vma, vaddr, &ptep, &ptl);
                if (ret)
                        return ret;
        }
 
                }
 
                for (i = 0; i < nr_pages; i++) {
-                       ret = follow_pte(vma->vm_mm,
-                                        memmap->vma_base + i * PAGE_SIZE,
+                       ret = follow_pte(vma, memmap->vma_base + i * PAGE_SIZE,
                                         &ptep, &ptl);
                        if (ret)
                                break;
 
                unsigned long end, unsigned long floor, unsigned long ceiling);
 int
 copy_page_range(struct vm_area_struct *dst_vma, struct vm_area_struct *src_vma);
-int follow_pte(struct mm_struct *mm, unsigned long address,
+int follow_pte(struct vm_area_struct *vma, unsigned long address,
               pte_t **ptepp, spinlock_t **ptlp);
 int generic_access_phys(struct vm_area_struct *vma, unsigned long addr,
                        void *buf, int len, int write);
 
 
 /**
  * follow_pte - look up PTE at a user virtual address
- * @mm: the mm_struct of the target address space
+ * @vma: the memory mapping
  * @address: user virtual address
  * @ptepp: location to store found PTE
  * @ptlp: location to store the lock for the PTE
  *
  * Return: zero on success, -ve otherwise.
  */
-int follow_pte(struct mm_struct *mm, unsigned long address,
+int follow_pte(struct vm_area_struct *vma, unsigned long address,
               pte_t **ptepp, spinlock_t **ptlp)
 {
+       struct mm_struct *mm = vma->vm_mm;
        pgd_t *pgd;
        p4d_t *p4d;
        pud_t *pud;
        pmd_t *pmd;
        pte_t *ptep;
 
+       if (!(vma->vm_flags & (VM_IO | VM_PFNMAP)))
+               goto out;
+
        pgd = pgd_offset(mm, address);
        if (pgd_none(*pgd) || unlikely(pgd_bad(*pgd)))
                goto out;
        int offset = offset_in_page(addr);
        int ret = -EINVAL;
 
-       if (!(vma->vm_flags & (VM_IO | VM_PFNMAP)))
-               return -EINVAL;
-
 retry:
-       if (follow_pte(vma->vm_mm, addr, &ptep, &ptl))
+       if (follow_pte(vma, addr, &ptep, &ptl))
                return -EINVAL;
        pte = ptep_get(ptep);
        pte_unmap_unlock(ptep, ptl);
        if (!maddr)
                return -ENOMEM;
 
-       if (follow_pte(vma->vm_mm, addr, &ptep, &ptl))
+       if (follow_pte(vma, addr, &ptep, &ptl))
                goto out_unmap;
 
        if (!pte_same(pte, ptep_get(ptep))) {
 
        spinlock_t *ptl;
        int r;
 
-       r = follow_pte(vma->vm_mm, addr, &ptep, &ptl);
+       r = follow_pte(vma, addr, &ptep, &ptl);
        if (r) {
                /*
                 * get_user_pages fails for VM_IO and VM_PFNMAP vmas and does
                if (r)
                        return r;
 
-               r = follow_pte(vma->vm_mm, addr, &ptep, &ptl);
+               r = follow_pte(vma, addr, &ptep, &ptl);
                if (r)
                        return r;
        }