#include <linux/hw_breakpoint.h>
 #include <linux/uaccess.h>
 #include <linux/elf-randomize.h>
+#include <linux/pkeys.h>
 
 #include <asm/pgtable.h>
 #include <asm/io.h>
                t->tar = mfspr(SPRN_TAR);
        }
 #endif
+
+       thread_pkey_regs_save(t);
 }
 
 static inline void restore_sprs(struct thread_struct *old_thread,
            old_thread->tidr != new_thread->tidr)
                mtspr(SPRN_TIDR, new_thread->tidr);
 #endif
+
+       thread_pkey_regs_restore(new_thread, old_thread);
 }
 
 #ifdef CONFIG_PPC_BOOK3S_64
        current->thread.tm_tfiar = 0;
        current->thread.load_tm = 0;
 #endif /* CONFIG_PPC_TRANSACTIONAL_MEM */
+
+       thread_pkey_regs_init(¤t->thread);
 }
 EXPORT_SYMBOL(start_thread);
 
 
 bool pkey_execute_disable_supported;
 int  pkeys_total;              /* Total pkeys as per device tree */
 u32  initial_allocation_mask;  /* Bits set for reserved keys */
+u64  pkey_amr_uamor_mask;      /* Bits in AMR/UMOR not to be touched */
+u64  pkey_iamr_mask;           /* Bits in AMR not to be touched */
 
 #define AMR_BITS_PER_PKEY 2
 #define AMR_RD_BIT 0x1UL
         * programming note.
         */
        initial_allocation_mask = ~0x0;
-       for (i = 2; i < (pkeys_total - os_reserved); i++)
+
+       /* register mask is in BE format */
+       pkey_amr_uamor_mask = ~0x0ul;
+       pkey_iamr_mask = ~0x0ul;
+
+       for (i = 2; i < (pkeys_total - os_reserved); i++) {
                initial_allocation_mask &= ~(0x1 << i);
+               pkey_amr_uamor_mask &= ~(0x3ul << pkeyshift(i));
+               pkey_iamr_mask &= ~(0x1ul << pkeyshift(i));
+       }
        return 0;
 }
 
        init_amr(pkey, new_amr_bits);
        return 0;
 }
+
+void thread_pkey_regs_save(struct thread_struct *thread)
+{
+       if (static_branch_likely(&pkey_disabled))
+               return;
+
+       /*
+        * TODO: Skip saving registers if @thread hasn't used any keys yet.
+        */
+       thread->amr = read_amr();
+       thread->iamr = read_iamr();
+       thread->uamor = read_uamor();
+}
+
+void thread_pkey_regs_restore(struct thread_struct *new_thread,
+                             struct thread_struct *old_thread)
+{
+       if (static_branch_likely(&pkey_disabled))
+               return;
+
+       /*
+        * TODO: Just set UAMOR to zero if @new_thread hasn't used any keys yet.
+        */
+       if (old_thread->amr != new_thread->amr)
+               write_amr(new_thread->amr);
+       if (old_thread->iamr != new_thread->iamr)
+               write_iamr(new_thread->iamr);
+       if (old_thread->uamor != new_thread->uamor)
+               write_uamor(new_thread->uamor);
+}
+
+void thread_pkey_regs_init(struct thread_struct *thread)
+{
+       if (static_branch_likely(&pkey_disabled))
+               return;
+
+       write_amr(read_amr() & pkey_amr_uamor_mask);
+       write_iamr(read_iamr() & pkey_iamr_mask);
+       write_uamor(read_uamor() & pkey_amr_uamor_mask);
+}