test_r_core/
tokio.rs

1use crate::args::{Arguments, TimeThreshold};
2use crate::bench::AsyncBencher;
3use crate::execution::{TestExecution, TestSuiteExecution};
4use crate::internal;
5use crate::internal::{
6    generate_tests, get_ensure_time, CapturedOutput, FlakinessControl, RegisteredTest, SuiteResult,
7    TestFunction, TestResult,
8};
9use crate::ipc::{ipc_name, IpcCommand, IpcResponse};
10use crate::output::{test_runner_output, TestRunnerOutput};
11use bincode::{decode_from_slice, encode_to_vec};
12use futures::FutureExt;
13use interprocess::local_socket::tokio::prelude::*;
14use interprocess::local_socket::tokio::{Listener, Stream};
15use interprocess::local_socket::{GenericNamespaced, ListenerOptions};
16use std::collections::VecDeque;
17use std::future::Future;
18use std::panic::AssertUnwindSafe;
19use std::pin::Pin;
20use std::process::{ExitCode, Stdio};
21use std::sync::Arc;
22use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader};
23use tokio::process::{Child, Command};
24use tokio::spawn;
25use tokio::sync::Mutex;
26use tokio::task::{spawn_blocking, JoinHandle, JoinSet};
27use tokio::time::Instant;
28use uuid::Uuid;
29
30pub fn test_runner() -> ExitCode {
31    tokio::runtime::Builder::new_multi_thread()
32        .enable_all()
33        .build()
34        .unwrap()
35        .block_on(async_test_runner())
36}
37
38#[allow(clippy::await_holding_lock)]
39async fn async_test_runner() -> ExitCode {
40    let mut args = Arguments::from_args();
41    let output = test_runner_output(&args);
42
43    let registered_tests = internal::REGISTERED_TESTS.lock().unwrap();
44    let registered_dependency_constructors =
45        internal::REGISTERED_DEPENDENCY_CONSTRUCTORS.lock().unwrap();
46    let registered_testsuite_props = internal::REGISTERED_TESTSUITE_PROPS.lock().unwrap();
47    let registered_test_generators = internal::REGISTERED_TEST_GENERATORS.lock().unwrap();
48
49    let generated_tests = generate_tests(&registered_test_generators).await;
50
51    let all_tests: Vec<RegisteredTest> = registered_tests
52        .iter()
53        .cloned()
54        .chain(generated_tests)
55        .collect();
56
57    if args.list {
58        output.test_list(&all_tests);
59        ExitCode::SUCCESS
60    } else {
61        let (mut execution, filtered_tests) = TestSuiteExecution::construct(
62            &args,
63            registered_dependency_constructors.as_slice(),
64            &all_tests,
65            registered_testsuite_props.as_slice(),
66        );
67        args.finalize_for_execution(&execution, output.clone());
68        if args.spawn_workers {
69            execution.skip_creating_dependencies();
70        }
71
72        // println!("Execution plan: {execution:?}");
73        // println!("Final args: {args:?}");
74        // println!("Has dependencies: {:?}", execution.has_dependencies());
75
76        let count = execution.remaining();
77        let results = Arc::new(Mutex::new(Vec::with_capacity(count)));
78
79        let start = Instant::now();
80        output.start_suite(&filtered_tests);
81
82        let execution = Arc::new(Mutex::new(execution));
83        let mut join_set = JoinSet::new();
84        let threads = args.test_threads().get();
85
86        for _ in 0..threads {
87            let execution_clone = execution.clone();
88            let output_clone = output.clone();
89            let args_clone = args.clone();
90            let results_clone = results.clone();
91            let handle = tokio::runtime::Handle::current();
92            join_set.spawn_blocking(move || {
93                handle.block_on(test_thread(
94                    args_clone,
95                    execution_clone,
96                    output_clone,
97                    count,
98                    results_clone,
99                ))
100            });
101        }
102
103        while let Some(res) = join_set.join_next().await {
104            res.expect("Failed to join task");
105        }
106
107        let results = results.lock().await;
108        output.finished_suite(&all_tests, &results, start.elapsed());
109        SuiteResult::exit_code(&results)
110    }
111}
112
113async fn test_thread(
114    args: Arguments,
115    execution: Arc<Mutex<TestSuiteExecution>>,
116    output: Arc<dyn TestRunnerOutput>,
117    count: usize,
118    results: Arc<Mutex<Vec<(RegisteredTest, TestResult)>>>,
119) {
120    let mut worker = spawn_worker_if_needed(&args).await;
121    let mut connection = if let Some(ref name) = args.ipc {
122        let name = ipc_name(name.clone());
123        let stream = Stream::connect(name)
124            .await
125            .expect("Failed to connect to IPC socket");
126        Some(stream)
127    } else {
128        None
129    };
130
131    let mut expected_test = None;
132
133    while !is_done(&execution).await {
134        if let Some(connection) = &mut connection {
135            if expected_test.is_none() {
136                let mut command_size: [u8; 2] = [0, 0];
137                connection
138                    .read_exact(&mut command_size)
139                    .await
140                    .expect("Failed to read IPC command size");
141                let mut command = vec![0; u16::from_le_bytes(command_size) as usize];
142                connection
143                    .read_exact(&mut command)
144                    .await
145                    .expect("Failed to read IPC command");
146                let (command, _): (IpcCommand, usize) =
147                    decode_from_slice(&command, bincode::config::standard())
148                        .expect("Failed to decode IPC command");
149
150                let IpcCommand::RunTest {
151                    name,
152                    crate_name,
153                    module_path,
154                } = command;
155                expected_test = Some((name, crate_name, module_path));
156            }
157        }
158
159        if let Some(next) = pick_next(&execution).await {
160            let skip = if let Some((name, crate_name, module_path)) = &expected_test {
161                next.test.name != *name
162                    || next.test.crate_name != *crate_name
163                    || next.test.module_path != *module_path
164            } else {
165                false
166            };
167
168            if !skip {
169                expected_test = None;
170
171                let ensure_time = get_ensure_time(&args, &next.test);
172
173                output.start_running_test(&next.test, next.index, count);
174                let result = run_test(
175                    output.clone(),
176                    next.index,
177                    count,
178                    args.nocapture,
179                    args.include_ignored,
180                    ensure_time,
181                    next.deps.clone(),
182                    &next.test,
183                    &mut worker,
184                )
185                .await;
186                output.finished_running_test(&next.test, next.index, count, &result);
187
188                if let Some(connection) = &mut connection {
189                    let finish_marker = Uuid::new_v4().to_string();
190                    let finish_marker_line = format!("{finish_marker}\n");
191                    tokio::io::stdout()
192                        .write_all(finish_marker_line.as_bytes())
193                        .await
194                        .unwrap();
195                    tokio::io::stderr()
196                        .write_all(finish_marker_line.as_bytes())
197                        .await
198                        .unwrap();
199                    tokio::io::stdout().flush().await.unwrap();
200                    tokio::io::stderr().flush().await.unwrap();
201
202                    let response = IpcResponse::TestFinished {
203                        result: (&result).into(),
204                        finish_marker,
205                    };
206                    let msg = encode_to_vec(&response, bincode::config::standard())
207                        .expect("Failed to encode IPC response");
208                    let message_size = (msg.len() as u16).to_le_bytes();
209                    connection
210                        .write_all(&message_size)
211                        .await
212                        .expect("Failed to write IPC response message size");
213                    connection
214                        .write_all(&msg)
215                        .await
216                        .expect("Failed to write response to IPC connection");
217                }
218
219                results.lock().await.push((next.test.clone(), result));
220            }
221        }
222    }
223}
224
225async fn is_done(execution: &Arc<Mutex<TestSuiteExecution>>) -> bool {
226    let execution = execution.lock().await;
227    execution.is_done()
228}
229
230async fn pick_next(execution: &Arc<Mutex<TestSuiteExecution>>) -> Option<TestExecution> {
231    let mut execution = execution.lock().await;
232    execution.pick_next().await
233}
234
235async fn run_with_flakiness_control<F, R>(
236    output: Arc<dyn TestRunnerOutput>,
237    test_description: &RegisteredTest,
238    idx: usize,
239    count: usize,
240    test: F,
241) -> Result<(), R>
242where
243    F: Fn(Instant) -> Pin<Box<dyn Future<Output = Result<(), R>>>> + Send + Sync,
244{
245    match &test_description.props.flakiness_control {
246        FlakinessControl::None => {
247            let start = Instant::now();
248            test(start).await
249        }
250        FlakinessControl::ProveNonFlaky(tries) => {
251            for n in 0..*tries {
252                if n > 0 {
253                    output.repeat_running_test(
254                        test_description,
255                        idx,
256                        count,
257                        n + 1,
258                        *tries,
259                        "to ensure test is not flaky",
260                    );
261                }
262                let start = Instant::now();
263                test(start).await?;
264            }
265            Ok(())
266        }
267        FlakinessControl::RetryKnownFlaky(max_retries) => {
268            let mut tries = 1;
269            loop {
270                let start = Instant::now();
271                let result = test(start).await;
272
273                if result.is_err() && tries < *max_retries {
274                    tries += 1;
275                    output.repeat_running_test(
276                        test_description,
277                        idx,
278                        count,
279                        tries,
280                        *max_retries,
281                        "because test is known to be flaky",
282                    );
283                } else {
284                    break result;
285                }
286            }
287        }
288    }
289}
290
291#[allow(clippy::too_many_arguments)]
292async fn run_test(
293    output: Arc<dyn TestRunnerOutput>,
294    idx: usize,
295    count: usize,
296    nocapture: bool,
297    include_ignored: bool,
298    ensure_time: Option<TimeThreshold>,
299    dependency_view: Arc<dyn internal::DependencyView + Send + Sync>,
300    test: &RegisteredTest,
301    worker: &mut Option<Worker>,
302) -> TestResult {
303    if test.is_ignored && !include_ignored {
304        TestResult::ignored()
305    } else if let Some(worker) = worker.as_mut() {
306        worker.run_test(nocapture, test).await
307    } else {
308        let start = Instant::now();
309        let test = test.clone();
310        match &test.run {
311            TestFunction::Sync(_) => {
312                let handle = spawn_blocking(move || {
313                    let test = test.clone();
314                    crate::sync::run_sync_test_function(
315                        output,
316                        &test,
317                        idx,
318                        count,
319                        ensure_time,
320                        dependency_view,
321                    )
322                });
323                handle.await.unwrap_or_else(|join_error| {
324                    TestResult::failed(start.elapsed(), Box::new(join_error))
325                })
326            }
327            TestFunction::Async(test_fn) => {
328                let timeout = test.props.timeout;
329                let test_fn = test_fn.clone();
330                let result = run_with_flakiness_control(output, &test, idx, count, |start| {
331                    let dependency_view = dependency_view.clone();
332                    let test_fn = test_fn.clone();
333                    Box::pin(async move {
334                        AssertUnwindSafe(Box::pin(async move {
335                            let result = match timeout {
336                                None => test_fn(dependency_view).await,
337                                Some(duration) => {
338                                    let result =
339                                        tokio::time::timeout(duration, test_fn(dependency_view))
340                                            .await;
341                                    match result {
342                                        Ok(result) => result,
343                                        Err(_) => panic!("Test timed out"),
344                                    }
345                                }
346                            };
347                            match result.as_result() {
348                                Ok(_) => (),
349                                Err(message) => panic!("{message}"),
350                            };
351                            if let Some(ensure_time) = ensure_time {
352                                let elapsed = start.elapsed();
353                                if ensure_time.is_critical(&elapsed) {
354                                    panic!(
355                                        "Test run time exceeds critical threshold: {:?}",
356                                        elapsed
357                                    );
358                                }
359                            }
360                        }))
361                        .catch_unwind()
362                        .await
363                    })
364                })
365                .await;
366                TestResult::from_result(&test.props.should_panic, start.elapsed(), result)
367            }
368            TestFunction::SyncBench(_) => {
369                let handle = spawn_blocking(move || {
370                    let test = test.clone();
371                    crate::sync::run_sync_test_function(
372                        output,
373                        &test,
374                        idx,
375                        count,
376                        ensure_time,
377                        dependency_view,
378                    )
379                });
380                handle.await.unwrap_or_else(|join_error| {
381                    TestResult::failed(start.elapsed(), Box::new(join_error))
382                })
383            }
384            TestFunction::AsyncBench(bench_fn) => {
385                let mut bencher = AsyncBencher::new();
386                let result = AssertUnwindSafe(async move {
387                    bench_fn(&mut bencher, dependency_view).await;
388                    (
389                        bencher
390                            .summary()
391                            .expect("iter() was not called in bench function"),
392                        bencher.bytes,
393                    )
394                })
395                .catch_unwind()
396                .await;
397                let bytes = result.as_ref().map(|(_, bytes)| *bytes).unwrap_or_default();
398                TestResult::from_summary(
399                    &test.props.should_panic,
400                    start.elapsed(),
401                    result.map(|(summary, _)| summary),
402                    bytes,
403                )
404            }
405        }
406    }
407}
408
409struct Worker {
410    _listener: Listener,
411    _process: Child,
412    _out_handle: JoinHandle<()>,
413    _err_handle: JoinHandle<()>,
414    out_lines: Arc<Mutex<VecDeque<CapturedOutput>>>,
415    err_lines: Arc<Mutex<VecDeque<CapturedOutput>>>,
416    capture_enabled: Arc<Mutex<bool>>,
417    connection: Stream,
418}
419
420impl Worker {
421    pub async fn run_test(&mut self, nocapture: bool, test: &RegisteredTest) -> TestResult {
422        let mut capture_enabled = self.capture_enabled.lock().await;
423        *capture_enabled = test.props.capture_control.requires_capturing(!nocapture);
424        drop(capture_enabled);
425
426        // Send IPC command and wait for IPC response, and in the meantime read from the stdout/stderr channels
427        let cmd = IpcCommand::RunTest {
428            name: test.name.clone(),
429            crate_name: test.crate_name.clone(),
430            module_path: test.module_path.clone(),
431        };
432
433        let dump_on_ipc_failure = self.dump_on_failure();
434
435        let msg =
436            encode_to_vec(&cmd, bincode::config::standard()).expect("Failed to encode IPC command");
437        let message_size = (msg.len() as u16).to_le_bytes();
438        dump_on_ipc_failure
439            .run(self.connection.write_all(&message_size).await)
440            .await;
441        dump_on_ipc_failure
442            .run(self.connection.write_all(&msg).await)
443            .await;
444
445        let mut response_size: [u8; 2] = [0, 0];
446        dump_on_ipc_failure
447            .run(self.connection.read_exact(&mut response_size).await)
448            .await;
449        let mut response = vec![0; u16::from_le_bytes(response_size) as usize];
450        dump_on_ipc_failure
451            .run(self.connection.read_exact(&mut response).await)
452            .await;
453        let (response, _): (IpcResponse, usize) = dump_on_ipc_failure
454            .run(decode_from_slice(&response, bincode::config::standard()))
455            .await;
456
457        let IpcResponse::TestFinished {
458            result,
459            finish_marker,
460        } = response;
461
462        if test.props.capture_control.requires_capturing(!nocapture) {
463            let out_lines: Vec<_> =
464                Self::drain_until(self.out_lines.clone(), finish_marker.clone()).await;
465            let err_lines: Vec<_> =
466                Self::drain_until(self.err_lines.clone(), finish_marker.clone()).await;
467            result.into_test_result(out_lines, err_lines)
468        } else {
469            result.into_test_result(Vec::new(), Vec::new())
470        }
471    }
472
473    fn dump_on_failure(&self) -> DumpOnFailure {
474        DumpOnFailure {
475            out_lines: self.out_lines.clone(),
476            err_lines: self.err_lines.clone(),
477        }
478    }
479
480    async fn drain_until(
481        source: Arc<Mutex<VecDeque<CapturedOutput>>>,
482        finish_marker: String,
483    ) -> Vec<CapturedOutput> {
484        let mut result = Vec::new();
485        loop {
486            let mut source = source.lock().await;
487            while let Some(line) = source.pop_front() {
488                if line.line() == finish_marker {
489                    return result;
490                } else {
491                    result.push(line.clone());
492                }
493            }
494            drop(source);
495
496            tokio::time::sleep(std::time::Duration::from_millis(10)).await;
497        }
498    }
499}
500
501struct DumpOnFailure {
502    out_lines: Arc<Mutex<VecDeque<CapturedOutput>>>,
503    err_lines: Arc<Mutex<VecDeque<CapturedOutput>>>,
504}
505
506impl DumpOnFailure {
507    pub async fn run<T, E>(&self, result: Result<T, E>) -> T {
508        match result {
509            Ok(value) => value,
510            Err(_error) => {
511                let out_lines: Vec<_> = self.out_lines.lock().await.drain(..).collect();
512                let err_lines: Vec<_> = self.err_lines.lock().await.drain(..).collect();
513                let mut all_lines = [out_lines, err_lines].concat();
514                all_lines.sort();
515
516                for line in all_lines {
517                    eprintln!("{}", line.line());
518                }
519
520                std::process::exit(1);
521            }
522        }
523    }
524}
525
526async fn spawn_worker_if_needed(args: &Arguments) -> Option<Worker> {
527    if args.spawn_workers {
528        let id = Uuid::new_v4();
529        let name_str = format!("{id}.sock");
530        let name = name_str
531            .clone()
532            .to_ns_name::<GenericNamespaced>()
533            .expect("Invalid local socket name");
534        let opts = ListenerOptions::new().name(name.clone());
535        let listener = opts
536            .create_tokio()
537            .expect("Failed to create local socket listener");
538
539        let exe = std::env::current_exe().expect("Failed to get current executable path");
540
541        let mut args = args.clone();
542        args.ipc = Some(name_str);
543        args.spawn_workers = false;
544        args.logfile = None;
545        let args = args.to_args();
546
547        let mut process = Command::new(exe)
548            .args(args)
549            .stdin(Stdio::inherit())
550            .stderr(Stdio::piped())
551            .stdout(Stdio::piped())
552            .spawn()
553            .expect("Failed to spawn worker process");
554
555        let stdout = process.stdout.take().unwrap();
556        let stderr = process.stderr.take().unwrap();
557
558        let out_lines = Arc::new(Mutex::new(VecDeque::new()));
559        let err_lines = Arc::new(Mutex::new(VecDeque::new()));
560        let capture_enabled = Arc::new(Mutex::new(true));
561
562        let out_lines_clone = out_lines.clone();
563        let capture_enabled_clone = capture_enabled.clone();
564        let out_handle = spawn(async move {
565            let reader = BufReader::new(stdout);
566            let mut lines = reader.lines();
567            while let Some(line) = lines
568                .next_line()
569                .await
570                .expect("Failed to read from worker stdout")
571            {
572                if *capture_enabled_clone.lock().await {
573                    out_lines_clone
574                        .lock()
575                        .await
576                        .push_back(CapturedOutput::stdout(line));
577                } else {
578                    println!("{line}");
579                }
580            }
581        });
582
583        let err_lines_clone = err_lines.clone();
584        let capture_enabled_clone = capture_enabled.clone();
585        let err_handle = spawn(async move {
586            let reader = BufReader::new(stderr);
587            let mut lines = reader.lines();
588            while let Some(line) = lines
589                .next_line()
590                .await
591                .expect("Failed to read from worker stderr")
592            {
593                if *capture_enabled_clone.lock().await {
594                    err_lines_clone
595                        .lock()
596                        .await
597                        .push_back(CapturedOutput::stderr(line));
598                } else {
599                    eprintln!("{line}");
600                }
601            }
602        });
603
604        let connection = listener
605            .accept()
606            .await
607            .expect("Failed to accept connection");
608
609        Some(Worker {
610            _listener: listener,
611            _process: process,
612            _out_handle: out_handle,
613            _err_handle: err_handle,
614            out_lines,
615            err_lines,
616            connection,
617            capture_enabled,
618        })
619    } else {
620        None
621    }
622}