#include <linux/sched/debug.h>
 #include <linux/kernel.h>
 #include <linux/mm.h>
+#include <linux/pkeys.h>
 #include <linux/stddef.h>
 #include <linux/unistd.h>
 #include <linux/ptrace.h>
        info->si_addr = (void __user *)regs->nip;
 }
 
-void _exception(int signr, struct pt_regs *regs, int code, unsigned long addr)
+
+void _exception_pkey(int signr, struct pt_regs *regs, int code,
+               unsigned long addr, int key)
 {
        siginfo_t info;
        const char fmt32[] = KERN_INFO "%s[%d]: unhandled signal %d " \
        info.si_signo = signr;
        info.si_code = code;
        info.si_addr = (void __user *) addr;
+       info.si_pkey = key;
+
        force_sig_info(signr, &info, current);
 }
 
+void _exception(int signr, struct pt_regs *regs, int code, unsigned long addr)
+{
+       _exception_pkey(signr, regs, code, addr, 0);
+}
+
 void system_reset_exception(struct pt_regs *regs)
 {
        /*
 
  */
 
 static int
-__bad_area_nosemaphore(struct pt_regs *regs, unsigned long address, int si_code)
+__bad_area_nosemaphore(struct pt_regs *regs, unsigned long address, int si_code,
+               int pkey)
 {
        /*
         * If we are in kernel mode, bail out with a SEGV, this will
        if (!user_mode(regs))
                return SIGSEGV;
 
-       _exception(SIGSEGV, regs, si_code, address);
+       _exception_pkey(SIGSEGV, regs, si_code, address, pkey);
 
        return 0;
 }
 
 static noinline int bad_area_nosemaphore(struct pt_regs *regs, unsigned long address)
 {
-       return __bad_area_nosemaphore(regs, address, SEGV_MAPERR);
+       return __bad_area_nosemaphore(regs, address, SEGV_MAPERR, 0);
 }
 
-static int __bad_area(struct pt_regs *regs, unsigned long address, int si_code)
+static int __bad_area(struct pt_regs *regs, unsigned long address, int si_code,
+                       int pkey)
 {
        struct mm_struct *mm = current->mm;
 
         */
        up_read(&mm->mmap_sem);
 
-       return __bad_area_nosemaphore(regs, address, si_code);
+       return __bad_area_nosemaphore(regs, address, si_code, pkey);
 }
 
 static noinline int bad_area(struct pt_regs *regs, unsigned long address)
 {
-       return __bad_area(regs, address, SEGV_MAPERR);
+       return __bad_area(regs, address, SEGV_MAPERR, 0);
+}
+
+static int bad_key_fault_exception(struct pt_regs *regs, unsigned long address,
+                                   int pkey)
+{
+       return __bad_area_nosemaphore(regs, address, SEGV_PKUERR, pkey);
 }
 
 static int do_sigbus(struct pt_regs *regs, unsigned long address,
 
        perf_sw_event(PERF_COUNT_SW_PAGE_FAULTS, 1, regs, address);
 
-       if (error_code & DSISR_KEYFAULT) {
-               _exception(SIGSEGV, regs, SEGV_PKUERR, address);
-               return 0;
-       }
+       if (error_code & DSISR_KEYFAULT)
+               return bad_key_fault_exception(regs, address,
+                                              get_mm_addr_key(mm, address));
 
        /*
         * We want to do this outside mmap_sem, because reading code around nip
        if (unlikely(fault & VM_FAULT_SIGSEGV) &&
                !arch_vma_access_permitted(vma, flags & FAULT_FLAG_WRITE,
                        is_exec, 0)) {
+               /*
+                * The PGD-PDT...PMD-PTE tree may not have been fully setup.
+                * Hence we cannot walk the tree to locate the PTE, to locate
+                * the key. Hence let's use vma_pkey() to get the key; instead
+                * of get_mm_addr_key().
+                */
                int pkey = vma_pkey(vma);
 
-               if (likely(pkey))
-                       return __bad_area(regs, address, SEGV_PKUERR);
+               if (likely(pkey)) {
+                       up_read(&mm->mmap_sem);
+                       return bad_key_fault_exception(regs, address, pkey);
+               }
        }
 #endif /* CONFIG_PPC_MEM_KEYS */