return __bad_area(regs, address, SEGV_MAPERR);
 }
 
-static int bad_key_fault_exception(struct pt_regs *regs, unsigned long address,
-                                   int pkey)
+#ifdef CONFIG_PPC_MEM_KEYS
+static noinline int bad_access_pkey(struct pt_regs *regs, unsigned long address,
+                                   struct vm_area_struct *vma)
 {
+       struct mm_struct *mm = current->mm;
+       int pkey;
+
+       /*
+        * We don't try to fetch the pkey from page table because reading
+        * page table without locking doesn't guarantee stable pte value.
+        * Hence the pkey value that we return to userspace can be different
+        * from the pkey that actually caused access error.
+        *
+        * It does *not* guarantee that the VMA we find here
+        * was the one that we faulted on.
+        *
+        * 1. T1   : mprotect_key(foo, PAGE_SIZE, pkey=4);
+        * 2. T1   : set AMR to deny access to pkey=4, touches, page
+        * 3. T1   : faults...
+        * 4.    T2: mprotect_key(foo, PAGE_SIZE, pkey=5);
+        * 5. T1   : enters fault handler, takes mmap_sem, etc...
+        * 6. T1   : reaches here, sees vma_pkey(vma)=5, when we really
+        *           faulted on a pte with its pkey=4.
+        */
+       pkey = vma_pkey(vma);
+
+       up_read(&mm->mmap_sem);
+
        /*
         * If we are in kernel mode, bail out with a SEGV, this will
         * be caught by the assembly which will restore the non-volatile
 
        return 0;
 }
+#endif
 
 static noinline int bad_access(struct pt_regs *regs, unsigned long address)
 {
        return false;
 }
 
-static bool access_error(bool is_write, bool is_exec,
-                        struct vm_area_struct *vma)
+#ifdef CONFIG_PPC_MEM_KEYS
+static bool access_pkey_error(bool is_write, bool is_exec, bool is_pkey,
+                             struct vm_area_struct *vma)
+{
+       /*
+        * Read or write was blocked by protection keys.  This is
+        * always an unconditional error and can never result in
+        * a follow-up action to resolve the fault, like a COW.
+        */
+       if (is_pkey)
+               return true;
+
+       /*
+        * Make sure to check the VMA so that we do not perform
+        * faults just to hit a pkey fault as soon as we fill in a
+        * page. Only called for current mm, hence foreign == 0
+        */
+       if (!arch_vma_access_permitted(vma, is_write, is_exec, 0))
+               return true;
+
+       return false;
+}
+#endif
+
+static bool access_error(bool is_write, bool is_exec, struct vm_area_struct *vma)
 {
        /*
         * Allow execution from readable areas if the MMU does not
 
        perf_sw_event(PERF_COUNT_SW_PAGE_FAULTS, 1, regs, address);
 
-       if (error_code & DSISR_KEYFAULT)
-               return bad_key_fault_exception(regs, address,
-                                              get_mm_addr_key(mm, address));
-
        /*
         * We want to do this outside mmap_sem, because reading code around nip
         * can result in fault, which will cause a deadlock when called with
                return bad_area(regs, address);
 
 good_area:
+
+#ifdef CONFIG_PPC_MEM_KEYS
+       if (unlikely(access_pkey_error(is_write, is_exec,
+                                      (error_code & DSISR_KEYFAULT), vma)))
+               return bad_access_pkey(regs, address, vma);
+#endif /* CONFIG_PPC_MEM_KEYS */
+
        if (unlikely(access_error(is_write, is_exec, vma)))
                return bad_access(regs, address);
 
         */
        fault = handle_mm_fault(vma, address, flags);
 
-#ifdef CONFIG_PPC_MEM_KEYS
-       /*
-        * we skipped checking for access error due to key earlier.
-        * Check that using handle_mm_fault error return.
-        */
-       if (unlikely(fault & VM_FAULT_SIGSEGV) &&
-               !arch_vma_access_permitted(vma, is_write, is_exec, 0)) {
-
-               int pkey = vma_pkey(vma);
-
-               up_read(&mm->mmap_sem);
-               return bad_key_fault_exception(regs, address, pkey);
-       }
-#endif /* CONFIG_PPC_MEM_KEYS */
-
        major |= fault & VM_FAULT_MAJOR;
 
        if (fault_signal_pending(fault, regs))