1use crate::args::{Arguments, TimeThreshold};
2use crate::bench::AsyncBencher;
3use crate::execution::{TestExecution, TestSuiteExecution};
4use crate::internal;
5use crate::internal::{
6    generate_tests, get_ensure_time, CapturedOutput, FlakinessControl, RegisteredTest, SuiteResult,
7    TestFunction, TestResult,
8};
9use crate::ipc::{ipc_name, IpcCommand, IpcResponse};
10use crate::output::{test_runner_output, TestRunnerOutput};
11use bincode::{decode_from_slice, encode_to_vec};
12use futures::FutureExt;
13use interprocess::local_socket::tokio::prelude::*;
14use interprocess::local_socket::tokio::{Listener, Stream};
15use interprocess::local_socket::{GenericNamespaced, ListenerOptions};
16use std::collections::VecDeque;
17use std::future::Future;
18use std::panic::AssertUnwindSafe;
19use std::pin::Pin;
20use std::process::{ExitCode, Stdio};
21use std::sync::Arc;
22use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader};
23use tokio::process::{Child, Command};
24use tokio::spawn;
25use tokio::sync::Mutex;
26use tokio::task::{spawn_blocking, JoinHandle, JoinSet};
27use tokio::time::Instant;
28use uuid::Uuid;
29
30pub fn test_runner() -> ExitCode {
31    tokio::runtime::Builder::new_multi_thread()
32        .enable_all()
33        .build()
34        .unwrap()
35        .block_on(async_test_runner())
36}
37
38#[allow(clippy::await_holding_lock)]
39async fn async_test_runner() -> ExitCode {
40    let mut args = Arguments::from_args();
41    let output = test_runner_output(&args);
42
43    let registered_tests = internal::REGISTERED_TESTS.lock().unwrap();
44    let registered_dependency_constructors =
45        internal::REGISTERED_DEPENDENCY_CONSTRUCTORS.lock().unwrap();
46    let registered_testsuite_props = internal::REGISTERED_TESTSUITE_PROPS.lock().unwrap();
47    let registered_test_generators = internal::REGISTERED_TEST_GENERATORS.lock().unwrap();
48
49    let generated_tests = generate_tests(®istered_test_generators).await;
50
51    let all_tests: Vec<RegisteredTest> = registered_tests
52        .iter()
53        .cloned()
54        .chain(generated_tests)
55        .collect();
56
57    if args.list {
58        output.test_list(&all_tests);
59        ExitCode::SUCCESS
60    } else {
61        let mut remaining_retries = args.flaky_run.unwrap_or(1);
62
63        let mut exit_code = ExitCode::from(101);
64        while remaining_retries > 0 {
65            let (mut execution, filtered_tests) = TestSuiteExecution::construct(
66                &args,
67                registered_dependency_constructors.as_slice(),
68                &all_tests,
69                registered_testsuite_props.as_slice(),
70            );
71            args.finalize_for_execution(&execution, output.clone());
72            if args.spawn_workers {
73                execution.skip_creating_dependencies();
74            }
75
76            let count = execution.remaining();
81            let results = Arc::new(Mutex::new(Vec::with_capacity(count)));
82
83            let start = Instant::now();
84            output.start_suite(&filtered_tests);
85
86            let execution = Arc::new(Mutex::new(execution));
87            let mut join_set = JoinSet::new();
88            let threads = args.test_threads().get();
89
90            for _ in 0..threads {
91                let execution_clone = execution.clone();
92                let output_clone = output.clone();
93                let args_clone = args.clone();
94                let results_clone = results.clone();
95                let handle = tokio::runtime::Handle::current();
96                join_set.spawn_blocking(move || {
97                    handle.block_on(test_thread(
98                        args_clone,
99                        execution_clone,
100                        output_clone,
101                        count,
102                        results_clone,
103                    ))
104                });
105            }
106
107            while let Some(res) = join_set.join_next().await {
108                res.expect("Failed to join task");
109            }
110
111            drop(execution);
112
113            let results = results.lock().await;
114            output.finished_suite(&all_tests, &results, start.elapsed());
115            exit_code = SuiteResult::exit_code(&results);
116
117            if exit_code == ExitCode::SUCCESS {
118                break;
119            } else {
120                remaining_retries -= 1;
121            }
122        }
123        exit_code
124    }
125}
126
127async fn test_thread(
128    args: Arguments,
129    execution: Arc<Mutex<TestSuiteExecution>>,
130    output: Arc<dyn TestRunnerOutput>,
131    count: usize,
132    results: Arc<Mutex<Vec<(RegisteredTest, TestResult)>>>,
133) {
134    let mut worker = spawn_worker_if_needed(&args).await;
135    let mut connection = if let Some(ref name) = args.ipc {
136        let name = ipc_name(name.clone());
137        let stream = Stream::connect(name)
138            .await
139            .expect("Failed to connect to IPC socket");
140        Some(stream)
141    } else {
142        None
143    };
144
145    let mut expected_test = None;
146
147    while !is_done(&execution).await {
148        if let Some(connection) = &mut connection {
149            if expected_test.is_none() {
150                let mut command_size: [u8; 2] = [0, 0];
151                connection
152                    .read_exact(&mut command_size)
153                    .await
154                    .expect("Failed to read IPC command size");
155                let mut command = vec![0; u16::from_le_bytes(command_size) as usize];
156                connection
157                    .read_exact(&mut command)
158                    .await
159                    .expect("Failed to read IPC command");
160                let (command, _): (IpcCommand, usize) =
161                    decode_from_slice(&command, bincode::config::standard())
162                        .expect("Failed to decode IPC command");
163
164                let IpcCommand::RunTest {
165                    name,
166                    crate_name,
167                    module_path,
168                } = command;
169                expected_test = Some((name, crate_name, module_path));
170            }
171        }
172
173        if let Some(next) = pick_next(&execution).await {
174            let skip = if let Some((name, crate_name, module_path)) = &expected_test {
175                next.test.name != *name
176                    || next.test.crate_name != *crate_name
177                    || next.test.module_path != *module_path
178            } else {
179                false
180            };
181
182            if !skip {
183                expected_test = None;
184
185                let ensure_time = get_ensure_time(&args, &next.test);
186
187                output.start_running_test(&next.test, next.index, count);
188                let result = run_test(
189                    output.clone(),
190                    next.index,
191                    count,
192                    args.nocapture,
193                    args.include_ignored,
194                    ensure_time,
195                    next.deps.clone(),
196                    &next.test,
197                    &mut worker,
198                )
199                .await;
200                output.finished_running_test(&next.test, next.index, count, &result);
201
202                if let Some(connection) = &mut connection {
203                    let finish_marker = Uuid::new_v4().to_string();
204                    let finish_marker_line = format!("{finish_marker}\n");
205                    tokio::io::stdout()
206                        .write_all(finish_marker_line.as_bytes())
207                        .await
208                        .unwrap();
209                    tokio::io::stderr()
210                        .write_all(finish_marker_line.as_bytes())
211                        .await
212                        .unwrap();
213                    tokio::io::stdout().flush().await.unwrap();
214                    tokio::io::stderr().flush().await.unwrap();
215
216                    let response = IpcResponse::TestFinished {
217                        result: (&result).into(),
218                        finish_marker,
219                    };
220                    let msg = encode_to_vec(&response, bincode::config::standard())
221                        .expect("Failed to encode IPC response");
222                    let message_size = (msg.len() as u16).to_le_bytes();
223                    connection
224                        .write_all(&message_size)
225                        .await
226                        .expect("Failed to write IPC response message size");
227                    connection
228                        .write_all(&msg)
229                        .await
230                        .expect("Failed to write response to IPC connection");
231                }
232
233                results.lock().await.push((next.test.clone(), result));
234            }
235        }
236    }
237}
238
239async fn is_done(execution: &Arc<Mutex<TestSuiteExecution>>) -> bool {
240    let execution = execution.lock().await;
241    execution.is_done()
242}
243
244async fn pick_next(execution: &Arc<Mutex<TestSuiteExecution>>) -> Option<TestExecution> {
245    let mut execution = execution.lock().await;
246    execution.pick_next().await
247}
248
249async fn run_with_flakiness_control<F, R>(
250    output: Arc<dyn TestRunnerOutput>,
251    test_description: &RegisteredTest,
252    idx: usize,
253    count: usize,
254    test: F,
255) -> Result<(), R>
256where
257    F: Fn(Instant) -> Pin<Box<dyn Future<Output = Result<(), R>>>> + Send + Sync,
258{
259    match &test_description.props.flakiness_control {
260        FlakinessControl::None => {
261            let start = Instant::now();
262            test(start).await
263        }
264        FlakinessControl::ProveNonFlaky(tries) => {
265            for n in 0..*tries {
266                if n > 0 {
267                    output.repeat_running_test(
268                        test_description,
269                        idx,
270                        count,
271                        n + 1,
272                        *tries,
273                        "to ensure test is not flaky",
274                    );
275                }
276                let start = Instant::now();
277                test(start).await?;
278            }
279            Ok(())
280        }
281        FlakinessControl::RetryKnownFlaky(max_retries) => {
282            let mut tries = 1;
283            loop {
284                let start = Instant::now();
285                let result = test(start).await;
286
287                if result.is_err() && tries < *max_retries {
288                    tries += 1;
289                    output.repeat_running_test(
290                        test_description,
291                        idx,
292                        count,
293                        tries,
294                        *max_retries,
295                        "because test is known to be flaky",
296                    );
297                } else {
298                    break result;
299                }
300            }
301        }
302    }
303}
304
305#[allow(clippy::too_many_arguments)]
306async fn run_test(
307    output: Arc<dyn TestRunnerOutput>,
308    idx: usize,
309    count: usize,
310    nocapture: bool,
311    include_ignored: bool,
312    ensure_time: Option<TimeThreshold>,
313    dependency_view: Arc<dyn internal::DependencyView + Send + Sync>,
314    test: &RegisteredTest,
315    worker: &mut Option<Worker>,
316) -> TestResult {
317    if test.props.is_ignored && !include_ignored {
318        TestResult::ignored()
319    } else if let Some(worker) = worker.as_mut() {
320        worker.run_test(nocapture, test).await
321    } else {
322        let start = Instant::now();
323        let test = test.clone();
324        match &test.run {
325            TestFunction::Sync(_) => {
326                let handle = spawn_blocking(move || {
327                    let test = test.clone();
328                    crate::sync::run_sync_test_function(
329                        output,
330                        &test,
331                        idx,
332                        count,
333                        ensure_time,
334                        dependency_view,
335                    )
336                });
337                handle.await.unwrap_or_else(|join_error| {
338                    TestResult::failed(start.elapsed(), Box::new(join_error))
339                })
340            }
341            TestFunction::Async(test_fn) => {
342                let timeout = test.props.timeout;
343                let test_fn = test_fn.clone();
344                let result = run_with_flakiness_control(output, &test, idx, count, |start| {
345                    let dependency_view = dependency_view.clone();
346                    let test_fn = test_fn.clone();
347                    Box::pin(async move {
348                        AssertUnwindSafe(Box::pin(async move {
349                            let result = match timeout {
350                                None => test_fn(dependency_view).await,
351                                Some(duration) => {
352                                    let result =
353                                        tokio::time::timeout(duration, test_fn(dependency_view))
354                                            .await;
355                                    match result {
356                                        Ok(result) => result,
357                                        Err(_) => panic!("Test timed out"),
358                                    }
359                                }
360                            };
361                            match result.as_result() {
362                                Ok(_) => (),
363                                Err(message) => panic!("{message}"),
364                            };
365                            if let Some(ensure_time) = ensure_time {
366                                let elapsed = start.elapsed();
367                                if ensure_time.is_critical(&elapsed) {
368                                    panic!("Test run time exceeds critical threshold: {elapsed:?}");
369                                }
370                            }
371                        }))
372                        .catch_unwind()
373                        .await
374                    })
375                })
376                .await;
377                TestResult::from_result(&test.props.should_panic, start.elapsed(), result)
378            }
379            TestFunction::SyncBench(_) => {
380                let handle = spawn_blocking(move || {
381                    let test = test.clone();
382                    crate::sync::run_sync_test_function(
383                        output,
384                        &test,
385                        idx,
386                        count,
387                        ensure_time,
388                        dependency_view,
389                    )
390                });
391                handle.await.unwrap_or_else(|join_error| {
392                    TestResult::failed(start.elapsed(), Box::new(join_error))
393                })
394            }
395            TestFunction::AsyncBench(bench_fn) => {
396                let mut bencher = AsyncBencher::new();
397                let result = AssertUnwindSafe(async move {
398                    bench_fn(&mut bencher, dependency_view).await;
399                    (
400                        bencher
401                            .summary()
402                            .expect("iter() was not called in bench function"),
403                        bencher.bytes,
404                    )
405                })
406                .catch_unwind()
407                .await;
408                let bytes = result.as_ref().map(|(_, bytes)| *bytes).unwrap_or_default();
409                TestResult::from_summary(
410                    &test.props.should_panic,
411                    start.elapsed(),
412                    result.map(|(summary, _)| summary),
413                    bytes,
414                )
415            }
416        }
417    }
418}
419
420struct Worker {
421    _listener: Listener,
422    _process: Child,
423    _out_handle: JoinHandle<()>,
424    _err_handle: JoinHandle<()>,
425    out_lines: Arc<Mutex<VecDeque<CapturedOutput>>>,
426    err_lines: Arc<Mutex<VecDeque<CapturedOutput>>>,
427    capture_enabled: Arc<Mutex<bool>>,
428    connection: Stream,
429}
430
431impl Worker {
432    pub async fn run_test(&mut self, nocapture: bool, test: &RegisteredTest) -> TestResult {
433        let mut capture_enabled = self.capture_enabled.lock().await;
434        *capture_enabled = test.props.capture_control.requires_capturing(!nocapture);
435        drop(capture_enabled);
436
437        let cmd = IpcCommand::RunTest {
439            name: test.name.clone(),
440            crate_name: test.crate_name.clone(),
441            module_path: test.module_path.clone(),
442        };
443
444        let dump_on_ipc_failure = self.dump_on_failure();
445
446        let msg =
447            encode_to_vec(&cmd, bincode::config::standard()).expect("Failed to encode IPC command");
448        let message_size = (msg.len() as u16).to_le_bytes();
449        dump_on_ipc_failure
450            .run(self.connection.write_all(&message_size).await)
451            .await;
452        dump_on_ipc_failure
453            .run(self.connection.write_all(&msg).await)
454            .await;
455
456        let mut response_size: [u8; 2] = [0, 0];
457        dump_on_ipc_failure
458            .run(self.connection.read_exact(&mut response_size).await)
459            .await;
460        let mut response = vec![0; u16::from_le_bytes(response_size) as usize];
461        dump_on_ipc_failure
462            .run(self.connection.read_exact(&mut response).await)
463            .await;
464        let (response, _): (IpcResponse, usize) = dump_on_ipc_failure
465            .run(decode_from_slice(&response, bincode::config::standard()))
466            .await;
467
468        let IpcResponse::TestFinished {
469            result,
470            finish_marker,
471        } = response;
472
473        if test.props.capture_control.requires_capturing(!nocapture) {
474            let out_lines: Vec<_> =
475                Self::drain_until(self.out_lines.clone(), finish_marker.clone()).await;
476            let err_lines: Vec<_> =
477                Self::drain_until(self.err_lines.clone(), finish_marker.clone()).await;
478            result.into_test_result(out_lines, err_lines)
479        } else {
480            result.into_test_result(Vec::new(), Vec::new())
481        }
482    }
483
484    fn dump_on_failure(&self) -> DumpOnFailure {
485        DumpOnFailure {
486            out_lines: self.out_lines.clone(),
487            err_lines: self.err_lines.clone(),
488        }
489    }
490
491    async fn drain_until(
492        source: Arc<Mutex<VecDeque<CapturedOutput>>>,
493        finish_marker: String,
494    ) -> Vec<CapturedOutput> {
495        let mut result = Vec::new();
496        loop {
497            let mut source = source.lock().await;
498            while let Some(line) = source.pop_front() {
499                if line.line() == finish_marker {
500                    return result;
501                } else {
502                    result.push(line.clone());
503                }
504            }
505            drop(source);
506
507            tokio::time::sleep(std::time::Duration::from_millis(10)).await;
508        }
509    }
510}
511
512struct DumpOnFailure {
513    out_lines: Arc<Mutex<VecDeque<CapturedOutput>>>,
514    err_lines: Arc<Mutex<VecDeque<CapturedOutput>>>,
515}
516
517impl DumpOnFailure {
518    pub async fn run<T, E>(&self, result: Result<T, E>) -> T {
519        match result {
520            Ok(value) => value,
521            Err(_error) => {
522                let out_lines: Vec<_> = self.out_lines.lock().await.drain(..).collect();
523                let err_lines: Vec<_> = self.err_lines.lock().await.drain(..).collect();
524                let mut all_lines = [out_lines, err_lines].concat();
525                all_lines.sort();
526
527                for line in all_lines {
528                    eprintln!("{}", line.line());
529                }
530
531                std::process::exit(1);
532            }
533        }
534    }
535}
536
537async fn spawn_worker_if_needed(args: &Arguments) -> Option<Worker> {
538    if args.spawn_workers {
539        let id = Uuid::new_v4();
540        let name_str = format!("{id}.sock");
541        let name = name_str
542            .clone()
543            .to_ns_name::<GenericNamespaced>()
544            .expect("Invalid local socket name");
545        let opts = ListenerOptions::new().name(name.clone());
546        let listener = opts
547            .create_tokio()
548            .expect("Failed to create local socket listener");
549
550        let exe = std::env::current_exe().expect("Failed to get current executable path");
551
552        let mut args = args.clone();
553        args.ipc = Some(name_str);
554        args.spawn_workers = false;
555        args.logfile = None;
556        let args = args.to_args();
557
558        let mut process = Command::new(exe)
559            .args(args)
560            .stdin(Stdio::inherit())
561            .stderr(Stdio::piped())
562            .stdout(Stdio::piped())
563            .spawn()
564            .expect("Failed to spawn worker process");
565
566        let stdout = process.stdout.take().unwrap();
567        let stderr = process.stderr.take().unwrap();
568
569        let out_lines = Arc::new(Mutex::new(VecDeque::new()));
570        let err_lines = Arc::new(Mutex::new(VecDeque::new()));
571        let capture_enabled = Arc::new(Mutex::new(true));
572
573        let out_lines_clone = out_lines.clone();
574        let capture_enabled_clone = capture_enabled.clone();
575        let out_handle = spawn(async move {
576            let reader = BufReader::new(stdout);
577            let mut lines = reader.lines();
578            while let Some(line) = lines
579                .next_line()
580                .await
581                .expect("Failed to read from worker stdout")
582            {
583                if *capture_enabled_clone.lock().await {
584                    out_lines_clone
585                        .lock()
586                        .await
587                        .push_back(CapturedOutput::stdout(line));
588                } else {
589                    println!("{line}");
590                }
591            }
592        });
593
594        let err_lines_clone = err_lines.clone();
595        let capture_enabled_clone = capture_enabled.clone();
596        let err_handle = spawn(async move {
597            let reader = BufReader::new(stderr);
598            let mut lines = reader.lines();
599            while let Some(line) = lines
600                .next_line()
601                .await
602                .expect("Failed to read from worker stderr")
603            {
604                if *capture_enabled_clone.lock().await {
605                    err_lines_clone
606                        .lock()
607                        .await
608                        .push_back(CapturedOutput::stderr(line));
609                } else {
610                    eprintln!("{line}");
611                }
612            }
613        });
614
615        let connection = listener
616            .accept()
617            .await
618            .expect("Failed to accept connection");
619
620        Some(Worker {
621            _listener: listener,
622            _process: process,
623            _out_handle: out_handle,
624            _err_handle: err_handle,
625            out_lines,
626            err_lines,
627            connection,
628            capture_enabled,
629        })
630    } else {
631        None
632    }
633}