# Author: Brendan Higgins <brendanhiggins@google.com>
 
 import argparse
-import sys
 import os
+import re
+import sys
 import time
 
 assert sys.version_info >= (3, 7), "Python version is too old"
 
 from collections import namedtuple
 from enum import Enum, auto
-from typing import Iterable, Sequence
+from typing import Iterable, Sequence, List
 
 import kunit_json
 import kunit_kernel
                               ['jobs', 'build_dir', 'alltests',
                                'make_options'])
 KunitExecRequest = namedtuple('KunitExecRequest',
-                              ['timeout', 'build_dir', 'alltests',
-                               'filter_glob', 'kernel_args'])
+                             ['timeout', 'build_dir', 'alltests',
+                              'filter_glob', 'kernel_args', 'run_isolated'])
 KunitParseRequest = namedtuple('KunitParseRequest',
                               ['raw_output', 'build_dir', 'json'])
 KunitRequest = namedtuple('KunitRequest', ['raw_output','timeout', 'jobs',
                                           'build_dir', 'alltests', 'filter_glob',
-                                          'kernel_args', 'json', 'make_options'])
+                                          'kernel_args', 'run_isolated', 'json', 'make_options'])
 
 KernelDirectoryPath = sys.argv[0].split('tools/testing/kunit/')[0]
 
                           'built kernel successfully',
                           build_end - build_start)
 
+def _list_tests(linux: kunit_kernel.LinuxSourceTree, request: KunitExecRequest) -> List[str]:
+       args = ['kunit.action=list']
+       if request.kernel_args:
+               args.extend(request.kernel_args)
+
+       output = linux.run_kernel(args=args,
+                          timeout=None if request.alltests else request.timeout,
+                          filter_glob=request.filter_glob,
+                          build_dir=request.build_dir)
+       lines = kunit_parser.extract_tap_lines(output)
+       # Hack! Drop the dummy TAP version header that the executor prints out.
+       lines.pop()
+
+       # Filter out any extraneous non-test output that might have gotten mixed in.
+       return [l for l in lines if re.match('^[^\s.]+\.[^\s.]+$', l)]
+
+def _suites_from_test_list(tests: List[str]) -> List[str]:
+       """Extracts all the suites from an ordered list of tests."""
+       suites = []  # type: List[str]
+       for t in tests:
+               parts = t.split('.', maxsplit=2)
+               if len(parts) != 2:
+                       raise ValueError(f'internal KUnit error, test name should be of the form "<suite>.<test>", got "{t}"')
+               suite, case = parts
+               if not suites or suites[-1] != suite:
+                       suites.append(suite)
+       return suites
+
+
+
 def exec_tests(linux: kunit_kernel.LinuxSourceTree, request: KunitExecRequest,
               parse_request: KunitParseRequest) -> KunitResult:
-       kunit_parser.print_with_timestamp('Starting KUnit Kernel ...')
-       test_start = time.time()
-       run_result = linux.run_kernel(
-               args=request.kernel_args,
-               timeout=None if request.alltests else request.timeout,
-               filter_glob=request.filter_glob,
-               build_dir=request.build_dir)
-
-       result = parse_tests(parse_request, run_result)
-
-       # run_kernel() doesn't block on the kernel exiting.
-       # That only happens after we get the last line of output from `run_result`.
-       # So exec_time here actually contains parsing + execution time, which is fine.
-       test_end = time.time()
-       exec_time = test_end - test_start
+       filter_globs = [request.filter_glob]
+       if request.run_isolated:
+               tests = _list_tests(linux, request)
+               if request.run_isolated == 'test':
+                       filter_globs = tests
+               if request.run_isolated == 'suite':
+                       filter_globs = _suites_from_test_list(tests)
+                       # Apply the test-part of the user's glob, if present.
+                       if '.' in request.filter_glob:
+                               test_glob = request.filter_glob.split('.', maxsplit=2)[1]
+                               filter_globs = [g + '.'+ test_glob for g in filter_globs]
+
+       overall_status = kunit_parser.TestStatus.SUCCESS
+       exec_time = 0.0
+       for i, filter_glob in enumerate(filter_globs):
+               kunit_parser.print_with_timestamp('Starting KUnit Kernel ({}/{})...'.format(i+1, len(filter_globs)))
+
+               test_start = time.time()
+               run_result = linux.run_kernel(
+                       args=request.kernel_args,
+                       timeout=None if request.alltests else request.timeout,
+                       filter_glob=filter_glob,
+                       build_dir=request.build_dir)
+
+               result = parse_tests(parse_request, run_result)
+               # run_kernel() doesn't block on the kernel exiting.
+               # That only happens after we get the last line of output from `run_result`.
+               # So exec_time here actually contains parsing + execution time, which is fine.
+               test_end = time.time()
+               exec_time += test_end - test_start
+
+               overall_status = kunit_parser.max_status(overall_status, result.status)
 
        return KunitResult(status=result.status, result=result.result, elapsed_time=exec_time)
 
 
        exec_request = KunitExecRequest(request.timeout, request.build_dir,
                                 request.alltests, request.filter_glob,
-                                request.kernel_args)
+                                request.kernel_args, request.run_isolated)
        parse_request = KunitParseRequest(request.raw_output,
                                          request.build_dir,
                                          request.json)
        parser.add_argument('--kernel_args',
                            help='Kernel command-line parameters. Maybe be repeated',
                             action='append')
+       parser.add_argument('--run_isolated', help='If set, boot the kernel for each '
+                           'individual suite/test. This is can be useful for debugging '
+                           'a non-hermetic test, one that might pass/fail based on '
+                           'what ran before it.',
+                           type=str,
+                           choices=['suite', 'test']),
 
 def add_parse_opts(parser) -> None:
        parser.add_argument('--raw_output', help='If set don\'t format output from kernel. '
                                       cli_args.alltests,
                                       cli_args.filter_glob,
                                       cli_args.kernel_args,
+                                      cli_args.run_isolated,
                                       cli_args.json,
                                       cli_args.make_options)
                result = run_tests(linux, request)
                                                cli_args.build_dir,
                                                cli_args.alltests,
                                                cli_args.filter_glob,
-                                               cli_args.kernel_args)
+                                               cli_args.kernel_args,
+                                               cli_args.run_isolated)
                parse_request = KunitParseRequest(cli_args.raw_output,
                                                  cli_args.build_dir,
                                                  cli_args.json)
 
                      args=['a=1','b=2'], build_dir='.kunit', filter_glob='', timeout=300)
                self.print_mock.assert_any_call(StrContains('Testing complete.'))
 
+       def test_list_tests(self):
+               want = ['suite.test1', 'suite.test2', 'suite2.test1']
+               self.linux_source_mock.run_kernel.return_value = ['TAP version 14', 'init: random output'] + want
+
+               got = kunit._list_tests(self.linux_source_mock,
+                                    kunit.KunitExecRequest(300, '.kunit', False, 'suite*', None, 'suite'))
+
+               self.assertEqual(got, want)
+               # Should respect the user's filter glob when listing tests.
+               self.linux_source_mock.run_kernel.assert_called_once_with(
+                       args=['kunit.action=list'], build_dir='.kunit', filter_glob='suite*', timeout=300)
+
+
+       @mock.patch.object(kunit, '_list_tests')
+       def test_run_isolated_by_suite(self, mock_tests):
+               mock_tests.return_value = ['suite.test1', 'suite.test2', 'suite2.test1']
+               kunit.main(['exec', '--run_isolated=suite', 'suite*.test*'], self.linux_source_mock)
+
+               # Should respect the user's filter glob when listing tests.
+               mock_tests.assert_called_once_with(mock.ANY,
+                                    kunit.KunitExecRequest(300, '.kunit', False, 'suite*.test*', None, 'suite'))
+               self.linux_source_mock.run_kernel.assert_has_calls([
+                       mock.call(args=None, build_dir='.kunit', filter_glob='suite.test*', timeout=300),
+                       mock.call(args=None, build_dir='.kunit', filter_glob='suite2.test*', timeout=300),
+               ])
+
+       @mock.patch.object(kunit, '_list_tests')
+       def test_run_isolated_by_test(self, mock_tests):
+               mock_tests.return_value = ['suite.test1', 'suite.test2', 'suite2.test1']
+               kunit.main(['exec', '--run_isolated=test', 'suite*'], self.linux_source_mock)
+
+               # Should respect the user's filter glob when listing tests.
+               mock_tests.assert_called_once_with(mock.ANY,
+                                    kunit.KunitExecRequest(300, '.kunit', False, 'suite*', None, 'test'))
+               self.linux_source_mock.run_kernel.assert_has_calls([
+                       mock.call(args=None, build_dir='.kunit', filter_glob='suite.test1', timeout=300),
+                       mock.call(args=None, build_dir='.kunit', filter_glob='suite.test2', timeout=300),
+                       mock.call(args=None, build_dir='.kunit', filter_glob='suite2.test1', timeout=300),
+               ])
+
 
 if __name__ == '__main__':
        unittest.main()