#include <assert.h>
 #include <errno.h>
 #include <fcntl.h>
+#include <stdbool.h>
 #include <stddef.h>
 #include <stdio.h>
 #include <stdlib.h>
        int max_vl;
 };
 
+#define VEC_SVE 0
+#define VEC_SME 1
 
 static struct vec_data vec_data[] = {
-       {
+       [VEC_SVE] = {
                .name = "SVE",
                .hwcap_type = AT_HWCAP,
                .hwcap = HWCAP_SVE,
                .prctl_set = PR_SVE_SET_VL,
                .default_vl_file = "/proc/sys/abi/sve_default_vector_length",
        },
-       {
+       [VEC_SME] = {
                .name = "SME",
                .hwcap_type = AT_HWCAP2,
                .hwcap = HWCAP2_SME,
        prctl_set_all_vqs,
 };
 
+static inline void smstart(void)
+{
+       asm volatile("msr S0_3_C4_C7_3, xzr");
+}
+
+static inline void smstart_sm(void)
+{
+       asm volatile("msr S0_3_C4_C3_3, xzr");
+}
+
+static inline void smstop(void)
+{
+       asm volatile("msr S0_3_C4_C6_3, xzr");
+}
+
+
+/*
+ * Verify we can change the SVE vector length while SME is active and
+ * continue to use SME afterwards.
+ */
+static void change_sve_with_za(void)
+{
+       struct vec_data *sve_data = &vec_data[VEC_SVE];
+       bool pass = true;
+       int ret, i;
+
+       if (sve_data->min_vl == sve_data->max_vl) {
+               ksft_print_msg("Only one SVE VL supported, can't change\n");
+               ksft_test_result_skip("change_sve_while_sme\n");
+               return;
+       }
+
+       /* Ensure we will trigger a change when we set the maximum */
+       ret = prctl(sve_data->prctl_set, sve_data->min_vl);
+       if (ret != sve_data->min_vl) {
+               ksft_print_msg("Failed to set SVE VL %d: %d\n",
+                              sve_data->min_vl, ret);
+               pass = false;
+       }
+
+       /* Enable SM and ZA */
+       smstart();
+
+       /* Trigger another VL change */
+       ret = prctl(sve_data->prctl_set, sve_data->max_vl);
+       if (ret != sve_data->max_vl) {
+               ksft_print_msg("Failed to set SVE VL %d: %d\n",
+                              sve_data->max_vl, ret);
+               pass = false;
+       }
+
+       /*
+        * Spin for a bit with SM enabled to try to trigger another
+        * save/restore.  We can't use syscalls without exiting
+        * streaming mode.
+        */
+       for (i = 0; i < 100000000; i++)
+               smstart_sm();
+
+       /*
+        * TODO: Verify that ZA was preserved over the VL change and
+        * spin.
+        */
+
+       /* Clean up after ourselves */
+       smstop();
+       ret = prctl(sve_data->prctl_set, sve_data->default_vl);
+       if (ret != sve_data->default_vl) {
+               ksft_print_msg("Failed to restore SVE VL %d: %d\n",
+                              sve_data->default_vl, ret);
+               pass = false;
+       }
+
+       ksft_test_result(pass, "change_sve_with_za\n");
+}
+
+typedef void (*test_all_type)(void);
+
+static const struct {
+       const char *name;
+       test_all_type test;
+}  all_types_tests[] = {
+       { "change_sve_with_za", change_sve_with_za },
+};
+
 int main(void)
 {
+       bool all_supported = true;
        int i, j;
 
        ksft_print_header();
-       ksft_set_plan(ARRAY_SIZE(tests) * ARRAY_SIZE(vec_data));
+       ksft_set_plan(ARRAY_SIZE(tests) * ARRAY_SIZE(vec_data) +
+                     ARRAY_SIZE(all_types_tests));
 
        for (i = 0; i < ARRAY_SIZE(vec_data); i++) {
                struct vec_data *data = &vec_data[i];
                unsigned long supported;
 
                supported = getauxval(data->hwcap_type) & data->hwcap;
+               if (!supported)
+                       all_supported = false;
 
                for (j = 0; j < ARRAY_SIZE(tests); j++) {
                        if (supported)
                }
        }
 
+       for (i = 0; i < ARRAY_SIZE(all_types_tests); i++) {
+               if (all_supported)
+                       all_types_tests[i].test();
+               else
+                       ksft_test_result_skip("%s\n", all_types_tests[i].name);
+       }
+
        ksft_exit_pass();
 }