WARN_ON(!system_supports_fpsimd());
        WARN_ON(!have_cpu_fpsimd_context());
 
-       /* Check if we should restore SVE first */
-       if (IS_ENABLED(CONFIG_ARM64_SVE) && test_thread_flag(TIF_SVE)) {
-               sve_set_vq(sve_vq_from_vl(task_get_sve_vl(current)) - 1);
-               restore_sve_regs = true;
-               restore_ffr = true;
+       if (system_supports_sve()) {
+               switch (current->thread.fp_type) {
+               case FP_STATE_FPSIMD:
+                       /* Stop tracking SVE for this task until next use. */
+                       if (test_and_clear_thread_flag(TIF_SVE))
+                               sve_user_disable();
+                       break;
+               case FP_STATE_SVE:
+                       if (!thread_sm_enabled(¤t->thread) &&
+                           !WARN_ON_ONCE(!test_and_set_thread_flag(TIF_SVE)))
+                               sve_user_enable();
+
+                       if (test_thread_flag(TIF_SVE))
+                               sve_set_vq(sve_vq_from_vl(task_get_sve_vl(current)) - 1);
+
+                       restore_sve_regs = true;
+                       restore_ffr = true;
+                       break;
+               default:
+                       /*
+                        * This indicates either a bug in
+                        * fpsimd_save() or memory corruption, we
+                        * should always record an explicit format
+                        * when we save. We always at least have the
+                        * memory allocated for FPSMID registers so
+                        * try that and hope for the best.
+                        */
+                       WARN_ON_ONCE(1);
+                       clear_thread_flag(TIF_SVE);
+                       break;
+               }
        }
 
        /* Restore SME, override SVE register configuration if needed */
                if (thread_za_enabled(¤t->thread))
                        za_load_state(current->thread.za_state);
 
-               if (thread_sm_enabled(¤t->thread)) {
-                       restore_sve_regs = true;
+               if (thread_sm_enabled(¤t->thread))
                        restore_ffr = system_supports_fa64();
-               }
        }
 
        if (restore_sve_regs) {