#ifndef __ASSEMBLY__
 #include <asm/x86_init.h>
+#include <asm/fpu/xstate.h>
+#include <asm/fpu/api.h>
 
 extern pgd_t early_top_pgt[PTRS_PER_PGD];
 int __init __early_make_pgtable(unsigned long address, pmdval_t pmd);
 
 static inline void write_pkru(u32 pkru)
 {
-       if (boot_cpu_has(X86_FEATURE_OSPKE))
-               __write_pkru(pkru);
+       struct pkru_state *pk;
+
+       if (!boot_cpu_has(X86_FEATURE_OSPKE))
+               return;
+
+       pk = get_xsave_addr(¤t->thread.fpu.state.xsave, XFEATURE_PKRU);
+
+       /*
+        * The PKRU value in xstate needs to be in sync with the value that is
+        * written to the CPU. The FPU restore on return to userland would
+        * otherwise load the previous value again.
+        */
+       fpregs_lock();
+       if (pk)
+               pk->pkru = pkru;
+       __write_pkru(pkru);
+       fpregs_unlock();
 }
 
 static inline int pte_young(pte_t pte)