#include <linux/uprobes.h>
 #include <linux/livepatch.h>
 #include <linux/syscalls.h>
+#include <linux/uaccess.h>
 
 #include <asm/desc.h>
 #include <asm/traps.h>
 #include <asm/vdso.h>
-#include <linux/uaccess.h>
 #include <asm/cpufeature.h>
+#include <asm/fpu/api.h>
 
 #define CREATE_TRACE_POINTS
 #include <trace/events/syscalls.h>
        if (unlikely(cached_flags & EXIT_TO_USERMODE_LOOP_FLAGS))
                exit_to_usermode_loop(regs, cached_flags);
 
+       /* Reload ti->flags; we may have rescheduled above. */
+       cached_flags = READ_ONCE(ti->flags);
+
+       fpregs_assert_state_consistent();
+       if (unlikely(cached_flags & _TIF_NEED_FPU_LOAD))
+               switch_fpu_return();
+
 #ifdef CONFIG_COMPAT
        /*
         * Compat syscalls set TS_COMPAT.  Make sure we clear it before
 
 
 #ifndef _ASM_X86_FPU_API_H
 #define _ASM_X86_FPU_API_H
-#include <linux/preempt.h>
+#include <linux/bottom_half.h>
 
 /*
  * Use kernel_fpu_begin/end() if you intend to use FPU in kernel context. It
 extern void kernel_fpu_begin(void);
 extern void kernel_fpu_end(void);
 extern bool irq_fpu_usable(void);
+extern void fpregs_mark_activate(void);
 
+/*
+ * Use fpregs_lock() while editing CPU's FPU registers or fpu->state.
+ * A context switch will (and softirq might) save CPU's FPU registers to
+ * fpu->state and set TIF_NEED_FPU_LOAD leaving CPU's FPU registers in
+ * a random state.
+ */
 static inline void fpregs_lock(void)
 {
        preempt_disable();
+       local_bh_disable();
 }
 
 static inline void fpregs_unlock(void)
 {
+       local_bh_enable();
        preempt_enable();
 }
 
+#ifdef CONFIG_X86_DEBUG_FPU
+extern void fpregs_assert_state_consistent(void);
+#else
+static inline void fpregs_assert_state_consistent(void) { }
+#endif
+
+/*
+ * Load the task FPU state before returning to userspace.
+ */
+extern void switch_fpu_return(void);
+
 /*
  * Query the presence of one or more xfeatures. Works on any legacy CPU as well.
  *
 
 extern void fpu__save(struct fpu *fpu);
 extern int  fpu__restore_sig(void __user *buf, int ia32_frame);
 extern void fpu__drop(struct fpu *fpu);
-extern int  fpu__copy(struct fpu *dst_fpu, struct fpu *src_fpu);
+extern int  fpu__copy(struct task_struct *dst, struct task_struct *src);
 extern void fpu__clear(struct fpu *fpu);
 extern int  fpu__exception_code(struct fpu *fpu, int trap_nr);
 extern int  dump_fpu(struct pt_regs *ptregs, struct user_i387_struct *fpstate);
 /*
  * Internal helper, do not use directly. Use switch_fpu_return() instead.
  */
-static inline void __fpregs_load_activate(struct fpu *fpu, int cpu)
+static inline void __fpregs_load_activate(void)
 {
+       struct fpu *fpu = ¤t->thread.fpu;
+       int cpu = smp_processor_id();
+
+       if (WARN_ON_ONCE(current->mm == NULL))
+               return;
+
        if (!fpregs_state_valid(fpu, cpu)) {
-               if (current->mm)
-                       copy_kernel_to_fpregs(&fpu->state);
+               copy_kernel_to_fpregs(&fpu->state);
                fpregs_activate(fpu);
+               fpu->last_cpu = cpu;
        }
+       clear_thread_flag(TIF_NEED_FPU_LOAD);
 }
 
 /*
  *  - switch_fpu_prepare() saves the old state.
  *    This is done within the context of the old process.
  *
- *  - switch_fpu_finish() restores the new state as
- *    necessary.
+ *  - switch_fpu_finish() sets TIF_NEED_FPU_LOAD; the floating point state
+ *    will get loaded on return to userspace, or when the kernel needs it.
  *
  * If TIF_NEED_FPU_LOAD is cleared then the CPU's FPU registers
  * are saved in the current thread's FPU register state.
  */
 
 /*
- * Set up the userspace FPU context for the new task, if the task
- * has used the FPU.
+ * Load PKRU from the FPU context if available. Delay loading of the
+ * complete FPU state until the return to userland.
  */
-static inline void switch_fpu_finish(struct fpu *new_fpu, int cpu)
+static inline void switch_fpu_finish(struct fpu *new_fpu)
 {
        u32 pkru_val = init_pkru_value;
        struct pkru_state *pk;
        if (!static_cpu_has(X86_FEATURE_FPU))
                return;
 
-       __fpregs_load_activate(new_fpu, cpu);
+       set_thread_flag(TIF_NEED_FPU_LOAD);
 
        if (!cpu_feature_enabled(X86_FEATURE_OSPKE))
                return;
 
 
        TP_STRUCT__entry(
                __field(struct fpu *, fpu)
+               __field(bool, load_fpu)
                __field(u64, xfeatures)
                __field(u64, xcomp_bv)
                ),
 
        TP_fast_assign(
                __entry->fpu            = fpu;
+               __entry->load_fpu       = test_thread_flag(TIF_NEED_FPU_LOAD);
                if (boot_cpu_has(X86_FEATURE_OSXSAVE)) {
                        __entry->xfeatures = fpu->state.xsave.header.xfeatures;
                        __entry->xcomp_bv  = fpu->state.xsave.header.xcomp_bv;
                }
        ),
-       TP_printk("x86/fpu: %p xfeatures: %llx xcomp_bv: %llx",
+       TP_printk("x86/fpu: %p load: %d xfeatures: %llx xcomp_bv: %llx",
                        __entry->fpu,
+                       __entry->load_fpu,
                        __entry->xfeatures,
                        __entry->xcomp_bv
        )
        TP_ARGS(fpu)
 );
 
-DEFINE_EVENT(x86_fpu, x86_fpu_activate_state,
-       TP_PROTO(struct fpu *fpu),
-       TP_ARGS(fpu)
-);
-
 DEFINE_EVENT(x86_fpu, x86_fpu_init_state,
        TP_PROTO(struct fpu *fpu),
        TP_ARGS(fpu)
 
        kernel_fpu_disable();
 
        if (current->mm) {
-               /*
-                * Ignore return value -- we don't care if reg state
-                * is clobbered.
-                */
-               copy_fpregs_to_fpstate(fpu);
-       } else {
-               __cpu_invalidate_fpregs_state();
+               if (!test_thread_flag(TIF_NEED_FPU_LOAD)) {
+                       set_thread_flag(TIF_NEED_FPU_LOAD);
+                       /*
+                        * Ignore return value -- we don't care if reg state
+                        * is clobbered.
+                        */
+                       copy_fpregs_to_fpstate(fpu);
+               }
        }
+       __cpu_invalidate_fpregs_state();
 }
 
 static void __kernel_fpu_end(void)
 {
-       struct fpu *fpu = ¤t->thread.fpu;
-
-       if (current->mm)
-               copy_kernel_to_fpregs(&fpu->state);
-
        kernel_fpu_enable();
 }
 
 {
        WARN_ON_FPU(fpu != ¤t->thread.fpu);
 
-       preempt_disable();
+       fpregs_lock();
        trace_x86_fpu_before_save(fpu);
 
-       if (!copy_fpregs_to_fpstate(fpu))
-               copy_kernel_to_fpregs(&fpu->state);
+       if (!test_thread_flag(TIF_NEED_FPU_LOAD)) {
+               if (!copy_fpregs_to_fpstate(fpu)) {
+                       copy_kernel_to_fpregs(&fpu->state);
+               }
+       }
 
        trace_x86_fpu_after_save(fpu);
-       preempt_enable();
+       fpregs_unlock();
 }
 EXPORT_SYMBOL_GPL(fpu__save);
 
 }
 EXPORT_SYMBOL_GPL(fpstate_init);
 
-int fpu__copy(struct fpu *dst_fpu, struct fpu *src_fpu)
+int fpu__copy(struct task_struct *dst, struct task_struct *src)
 {
+       struct fpu *dst_fpu = &dst->thread.fpu;
+       struct fpu *src_fpu = &src->thread.fpu;
+
        dst_fpu->last_cpu = -1;
 
        if (!static_cpu_has(X86_FEATURE_FPU))
        memset(&dst_fpu->state.xsave, 0, fpu_kernel_xstate_size);
 
        /*
-        * Save current FPU registers directly into the child
-        * FPU context, without any memory-to-memory copying.
+        * If the FPU registers are not current just memcpy() the state.
+        * Otherwise save current FPU registers directly into the child's FPU
+        * context, without any memory-to-memory copying.
         *
         * ( The function 'fails' in the FNSAVE case, which destroys
-        *   register contents so we have to copy them back. )
+        *   register contents so we have to load them back. )
         */
-       if (!copy_fpregs_to_fpstate(dst_fpu)) {
-               memcpy(&src_fpu->state, &dst_fpu->state, fpu_kernel_xstate_size);
-               copy_kernel_to_fpregs(&src_fpu->state);
-       }
+       fpregs_lock();
+       if (test_thread_flag(TIF_NEED_FPU_LOAD))
+               memcpy(&dst_fpu->state, &src_fpu->state, fpu_kernel_xstate_size);
+
+       else if (!copy_fpregs_to_fpstate(dst_fpu))
+               copy_kernel_to_fpregs(&dst_fpu->state);
+
+       fpregs_unlock();
+
+       set_tsk_thread_flag(dst, TIF_NEED_FPU_LOAD);
 
        trace_x86_fpu_copy_src(src_fpu);
        trace_x86_fpu_copy_dst(dst_fpu);
 {
        WARN_ON_FPU(fpu != ¤t->thread.fpu);
 
+       set_thread_flag(TIF_NEED_FPU_LOAD);
        fpstate_init(&fpu->state);
        trace_x86_fpu_init_state(fpu);
-
-       trace_x86_fpu_activate_state(fpu);
 }
 
 /*
  */
 static inline void copy_init_fpstate_to_fpregs(void)
 {
+       fpregs_lock();
+
        if (use_xsave())
                copy_kernel_to_xregs(&init_fpstate.xsave, -1);
        else if (static_cpu_has(X86_FEATURE_FXSR))
 
        if (boot_cpu_has(X86_FEATURE_OSPKE))
                copy_init_pkru_to_fpregs();
+
+       fpregs_mark_activate();
+       fpregs_unlock();
 }
 
 /*
                copy_init_fpstate_to_fpregs();
 }
 
+/*
+ * Load FPU context before returning to userspace.
+ */
+void switch_fpu_return(void)
+{
+       if (!static_cpu_has(X86_FEATURE_FPU))
+               return;
+
+       __fpregs_load_activate();
+}
+EXPORT_SYMBOL_GPL(switch_fpu_return);
+
+#ifdef CONFIG_X86_DEBUG_FPU
+/*
+ * If current FPU state according to its tracking (loaded FPU context on this
+ * CPU) is not valid then we must have TIF_NEED_FPU_LOAD set so the context is
+ * loaded on return to userland.
+ */
+void fpregs_assert_state_consistent(void)
+{
+       struct fpu *fpu = ¤t->thread.fpu;
+
+       if (test_thread_flag(TIF_NEED_FPU_LOAD))
+               return;
+
+       WARN_ON_FPU(!fpregs_state_valid(fpu, smp_processor_id()));
+}
+EXPORT_SYMBOL_GPL(fpregs_assert_state_consistent);
+#endif
+
+void fpregs_mark_activate(void)
+{
+       struct fpu *fpu = ¤t->thread.fpu;
+
+       fpregs_activate(fpu);
+       fpu->last_cpu = smp_processor_id();
+       clear_thread_flag(TIF_NEED_FPU_LOAD);
+}
+EXPORT_SYMBOL_GPL(fpregs_mark_activate);
+
 /*
  * x87 math exception handling:
  */
 
        struct task_struct *tsk = current;
        struct fpu *fpu = &tsk->thread.fpu;
        struct user_i387_ia32_struct env;
-       union fpregs_state *state;
        u64 xfeatures = 0;
        int fx_only = 0;
        int ret = 0;
-       void *tmp;
 
        ia32_fxstate &= (IS_ENABLED(CONFIG_X86_32) ||
                         IS_ENABLED(CONFIG_IA32_EMULATION));
                }
        }
 
-       tmp = kzalloc(sizeof(*state) + fpu_kernel_xstate_size + 64, GFP_KERNEL);
-       if (!tmp)
-               return -ENOMEM;
-       state = PTR_ALIGN(tmp, 64);
+       /*
+        * The current state of the FPU registers does not matter. By setting
+        * TIF_NEED_FPU_LOAD unconditionally it is ensured that the our xstate
+        * is not modified on context switch and that the xstate is considered
+        * to be loaded again on return to userland (overriding last_cpu avoids
+        * the optimisation).
+        */
+       set_thread_flag(TIF_NEED_FPU_LOAD);
+       __fpu_invalidate_fpregs_state(fpu);
 
        if ((unsigned long)buf_fx % 64)
                fx_only = 1;
-
        /*
         * For 32-bit frames with fxstate, copy the fxstate so it can be
         * reconstructed later.
                u64 init_bv = xfeatures_mask & ~xfeatures;
 
                if (using_compacted_format()) {
-                       ret = copy_user_to_xstate(&state->xsave, buf_fx);
+                       ret = copy_user_to_xstate(&fpu->state.xsave, buf_fx);
                } else {
-                       ret = __copy_from_user(&state->xsave, buf_fx, state_size);
+                       ret = __copy_from_user(&fpu->state.xsave, buf_fx, state_size);
 
                        if (!ret && state_size > offsetof(struct xregs_state, header))
-                               ret = validate_xstate_header(&state->xsave.header);
+                               ret = validate_xstate_header(&fpu->state.xsave.header);
                }
                if (ret)
                        goto err_out;
 
-               sanitize_restored_xstate(state, envp, xfeatures, fx_only);
+               sanitize_restored_xstate(&fpu->state, envp, xfeatures, fx_only);
 
+               fpregs_lock();
                if (unlikely(init_bv))
                        copy_kernel_to_xregs(&init_fpstate.xsave, init_bv);
-               ret = copy_kernel_to_xregs_err(&state->xsave, xfeatures);
+               ret = copy_kernel_to_xregs_err(&fpu->state.xsave, xfeatures);
 
        } else if (use_fxsr()) {
-               ret = __copy_from_user(&state->fxsave, buf_fx, state_size);
-               if (ret)
+               ret = __copy_from_user(&fpu->state.fxsave, buf_fx, state_size);
+               if (ret) {
+                       ret = -EFAULT;
                        goto err_out;
+               }
+
+               sanitize_restored_xstate(&fpu->state, envp, xfeatures, fx_only);
 
-               sanitize_restored_xstate(state, envp, xfeatures, fx_only);
+               fpregs_lock();
                if (use_xsave()) {
                        u64 init_bv = xfeatures_mask & ~XFEATURE_MASK_FPSSE;
                        copy_kernel_to_xregs(&init_fpstate.xsave, init_bv);
                }
 
-               ret = copy_kernel_to_fxregs_err(&state->fxsave);
+               ret = copy_kernel_to_fxregs_err(&fpu->state.fxsave);
        } else {
-               ret = __copy_from_user(&state->fsave, buf_fx, state_size);
+               ret = __copy_from_user(&fpu->state.fsave, buf_fx, state_size);
                if (ret)
                        goto err_out;
-               ret = copy_kernel_to_fregs_err(&state->fsave);
+
+               fpregs_lock();
+               ret = copy_kernel_to_fregs_err(&fpu->state.fsave);
        }
+       if (!ret)
+               fpregs_mark_activate();
+       fpregs_unlock();
 
 err_out:
-       kfree(tmp);
        if (ret)
                fpu__clear(fpu);
        return ret;
 
        dst->thread.vm86 = NULL;
 #endif
 
-       return fpu__copy(&dst->thread.fpu, &src->thread.fpu);
+       return fpu__copy(dst, src);
 }
 
 /*
 
 
        /* never put a printk in __switch_to... printk() calls wake_up*() indirectly */
 
-       switch_fpu_prepare(prev_fpu, cpu);
+       if (!test_thread_flag(TIF_NEED_FPU_LOAD))
+               switch_fpu_prepare(prev_fpu, cpu);
 
        /*
         * Save away %gs. No need to save %fs, as it was saved on the
 
        this_cpu_write(current_task, next_p);
 
-       switch_fpu_finish(next_fpu, cpu);
+       switch_fpu_finish(next_fpu);
 
        /* Load the Intel cache allocation PQR MSR. */
        resctrl_sched_in();
 
        WARN_ON_ONCE(IS_ENABLED(CONFIG_DEBUG_ENTRY) &&
                     this_cpu_read(irq_count) != -1);
 
-       switch_fpu_prepare(prev_fpu, cpu);
+       if (!test_thread_flag(TIF_NEED_FPU_LOAD))
+               switch_fpu_prepare(prev_fpu, cpu);
 
        /* We must save %fs and %gs before load_TLS() because
         * %fs and %gs may be cleared by load_TLS().
        this_cpu_write(current_task, next_p);
        this_cpu_write(cpu_current_top_of_stack, task_top_of_stack(next_p));
 
-       switch_fpu_finish(next_fpu, cpu);
+       switch_fpu_finish(next_fpu);
 
        /* Reload sp0. */
        update_task_stack(next_p);
 
                wait_lapic_expire(vcpu);
        guest_enter_irqoff();
 
+       fpregs_assert_state_consistent();
+       if (test_thread_flag(TIF_NEED_FPU_LOAD))
+               switch_fpu_return();
+
        if (unlikely(vcpu->arch.switch_db_regs)) {
                set_debugreg(0, 7);
                set_debugreg(vcpu->arch.eff_db[0], 0);
 /* Swap (qemu) user FPU context for the guest FPU context. */
 static void kvm_load_guest_fpu(struct kvm_vcpu *vcpu)
 {
-       preempt_disable();
+       fpregs_lock();
+
        copy_fpregs_to_fpstate(¤t->thread.fpu);
        /* PKRU is separately restored in kvm_x86_ops->run.  */
        __copy_kernel_to_fpregs(&vcpu->arch.guest_fpu->state,
                                ~XFEATURE_MASK_PKRU);
-       preempt_enable();
+
+       fpregs_mark_activate();
+       fpregs_unlock();
+
        trace_kvm_fpu(1);
 }
 
 /* When vcpu_run ends, restore user space FPU context. */
 static void kvm_put_guest_fpu(struct kvm_vcpu *vcpu)
 {
-       preempt_disable();
+       fpregs_lock();
+
        copy_fpregs_to_fpstate(vcpu->arch.guest_fpu);
        copy_kernel_to_fpregs(¤t->thread.fpu.state);
-       preempt_enable();
+
+       fpregs_mark_activate();
+       fpregs_unlock();
+
        ++vcpu->stat.fpu_reload;
        trace_kvm_fpu(0);
 }