* updates.  However, note that the handling of PERF_COUNT_SW_PAGE_FAULTS should
  * still be in per-arch page fault handlers at the entry of page fault.
  */
-static inline void mm_account_fault(struct pt_regs *regs,
+static inline void mm_account_fault(struct mm_struct *mm, struct pt_regs *regs,
                                    unsigned long address, unsigned int flags,
                                    vm_fault_t ret)
 {
        bool major;
 
+       /* Incomplete faults will be accounted upon completion. */
+       if (ret & VM_FAULT_RETRY)
+               return;
+
        /*
-        * We don't do accounting for some specific faults:
-        *
-        * - Unsuccessful faults (e.g. when the address wasn't valid).  That
-        *   includes arch_vma_access_permitted() failing before reaching here.
-        *   So this is not a "this many hardware page faults" counter.  We
-        *   should use the hw profiling for that.
-        *
-        * - Incomplete faults (VM_FAULT_RETRY).  They will only be counted
-        *   once they're completed.
+        * To preserve the behavior of older kernels, PGFAULT counters record
+        * both successful and failed faults, as opposed to perf counters,
+        * which ignore failed cases.
         */
-       if (ret & (VM_FAULT_ERROR | VM_FAULT_RETRY))
+       count_vm_event(PGFAULT);
+       count_memcg_event_mm(mm, PGFAULT);
+
+       /*
+        * Do not account for unsuccessful faults (e.g. when the address wasn't
+        * valid).  That includes arch_vma_access_permitted() failing before
+        * reaching here. So this is not a "this many hardware page faults"
+        * counter.  We should use the hw profiling for that.
+        */
+       if (ret & VM_FAULT_ERROR)
                return;
 
        /*
 vm_fault_t handle_mm_fault(struct vm_area_struct *vma, unsigned long address,
                           unsigned int flags, struct pt_regs *regs)
 {
+       /* If the fault handler drops the mmap_lock, vma may be freed */
+       struct mm_struct *mm = vma->vm_mm;
        vm_fault_t ret;
 
        __set_current_state(TASK_RUNNING);
 
-       count_vm_event(PGFAULT);
-       count_memcg_event_mm(vma->vm_mm, PGFAULT);
-
        ret = sanitize_fault_flags(vma, &flags);
        if (ret)
-               return ret;
+               goto out;
 
        if (!arch_vma_access_permitted(vma, flags & FAULT_FLAG_WRITE,
                                            flags & FAULT_FLAG_INSTRUCTION,
-                                           flags & FAULT_FLAG_REMOTE))
-               return VM_FAULT_SIGSEGV;
+                                           flags & FAULT_FLAG_REMOTE)) {
+               ret = VM_FAULT_SIGSEGV;
+               goto out;
+       }
 
        /*
         * Enable the memcg OOM handling for faults triggered in user
                if (task_in_memcg_oom(current) && !(ret & VM_FAULT_OOM))
                        mem_cgroup_oom_synchronize(false);
        }
-
-       mm_account_fault(regs, address, flags, ret);
+out:
+       mm_account_fault(mm, regs, address, flags, ret);
 
        return ret;
 }