Skip to main content

test_r_core/
tokio.rs

1use crate::args::{Arguments, TimeThreshold};
2use crate::bench::AsyncBencher;
3use crate::execution::{TestExecution, TestSuiteExecution};
4use crate::internal;
5use crate::internal::{
6    generate_tests, get_ensure_time, CapturedOutput, FailureCause, FlakinessControl,
7    RegisteredTest, SuiteResult, TestFunction, TestResult,
8};
9use crate::ipc::{ipc_name, IpcCommand, IpcResponse};
10use crate::output::{test_runner_output, TestRunnerOutput};
11use desert_rust::{deserialize, serialize_to_byte_vec};
12use futures::FutureExt;
13use interprocess::local_socket::tokio::prelude::*;
14use interprocess::local_socket::tokio::{Listener, Stream};
15use interprocess::local_socket::{GenericNamespaced, ListenerOptions};
16use std::any::Any;
17use std::collections::VecDeque;
18use std::future::Future;
19use std::panic::AssertUnwindSafe;
20use std::pin::Pin;
21use std::process::{ExitCode, Stdio};
22use std::sync::Arc;
23use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader};
24use tokio::process::{Child, Command};
25use tokio::spawn;
26use tokio::sync::Mutex;
27use tokio::task::{spawn_blocking, JoinHandle, JoinSet};
28use tokio::time::Instant;
29use uuid::Uuid;
30
31pub fn test_runner() -> ExitCode {
32    tokio::runtime::Builder::new_multi_thread()
33        .enable_all()
34        .build()
35        .unwrap()
36        .block_on(async_test_runner())
37}
38
39#[allow(clippy::await_holding_lock)]
40async fn async_test_runner() -> ExitCode {
41    crate::panic_hook::install_panic_hook();
42    let mut args = Arguments::from_args();
43    let output = test_runner_output(&args);
44
45    let registered_tests = internal::REGISTERED_TESTS.lock().unwrap();
46    let registered_dependency_constructors =
47        internal::REGISTERED_DEPENDENCY_CONSTRUCTORS.lock().unwrap();
48    let registered_testsuite_props = internal::REGISTERED_TESTSUITE_PROPS.lock().unwrap();
49    let registered_test_generators = internal::REGISTERED_TEST_GENERATORS.lock().unwrap();
50
51    let generated_tests = generate_tests(&registered_test_generators).await;
52
53    let all_tests: Vec<RegisteredTest> = registered_tests
54        .iter()
55        .cloned()
56        .chain(generated_tests)
57        .collect();
58
59    if args.list {
60        output.test_list(&all_tests);
61        ExitCode::SUCCESS
62    } else {
63        let mut remaining_retries = args.flaky_run.unwrap_or(1);
64
65        let mut exit_code = ExitCode::from(101);
66        while remaining_retries > 0 {
67            let (mut execution, filtered_tests) = TestSuiteExecution::construct(
68                &args,
69                registered_dependency_constructors.as_slice(),
70                &all_tests,
71                registered_testsuite_props.as_slice(),
72            );
73            args.finalize_for_execution(&execution, output.clone());
74            if args.spawn_workers {
75                execution.skip_creating_dependencies();
76            }
77
78            // println!("Execution plan: {execution:?}");
79            // println!("Final args: {args:?}");
80            // println!("Has dependencies: {:?}", execution.has_dependencies());
81
82            let count = execution.remaining();
83            let results = Arc::new(Mutex::new(Vec::with_capacity(count)));
84
85            let start = Instant::now();
86            output.start_suite(&filtered_tests);
87
88            let execution = Arc::new(Mutex::new(execution));
89            let mut join_set = JoinSet::new();
90            let threads = args.test_threads().get();
91
92            for _ in 0..threads {
93                let execution_clone = execution.clone();
94                let output_clone = output.clone();
95                let args_clone = args.clone();
96                let results_clone = results.clone();
97                let handle = tokio::runtime::Handle::current();
98                join_set.spawn_blocking(move || {
99                    handle.block_on(test_thread(
100                        args_clone,
101                        execution_clone,
102                        output_clone,
103                        count,
104                        results_clone,
105                    ))
106                });
107            }
108
109            while let Some(res) = join_set.join_next().await {
110                res.expect("Failed to join task");
111            }
112
113            drop(execution);
114
115            let results = results.lock().await;
116            output.finished_suite(&all_tests, &results, start.elapsed());
117            exit_code = SuiteResult::exit_code(&results);
118
119            if exit_code == ExitCode::SUCCESS {
120                break;
121            } else {
122                remaining_retries -= 1;
123            }
124        }
125        exit_code
126    }
127}
128
129async fn test_thread(
130    args: Arguments,
131    execution: Arc<Mutex<TestSuiteExecution>>,
132    output: Arc<dyn TestRunnerOutput>,
133    count: usize,
134    results: Arc<Mutex<Vec<(RegisteredTest, TestResult)>>>,
135) {
136    let mut worker = spawn_worker_if_needed(&args).await;
137    let mut connection = if let Some(ref name) = args.ipc {
138        let name = ipc_name(name.clone());
139        let stream = Stream::connect(name)
140            .await
141            .expect("Failed to connect to IPC socket");
142        Some(stream)
143    } else {
144        None
145    };
146
147    let mut expected_test = None;
148
149    while !is_done(&execution).await {
150        if let Some(connection) = &mut connection {
151            if expected_test.is_none() {
152                let mut command_size: [u8; 2] = [0, 0];
153                connection
154                    .read_exact(&mut command_size)
155                    .await
156                    .expect("Failed to read IPC command size");
157                let mut command = vec![0; u16::from_le_bytes(command_size) as usize];
158                connection
159                    .read_exact(&mut command)
160                    .await
161                    .expect("Failed to read IPC command");
162                let command: IpcCommand =
163                    deserialize(&command).expect("Failed to decode IPC command");
164
165                let IpcCommand::RunTest {
166                    name,
167                    crate_name,
168                    module_path,
169                } = command;
170                expected_test = Some((name, crate_name, module_path));
171            }
172        }
173
174        if let Some(next) = pick_next(&execution).await {
175            let skip = if let Some((name, crate_name, module_path)) = &expected_test {
176                next.test.name != *name
177                    || next.test.crate_name != *crate_name
178                    || next.test.module_path != *module_path
179            } else {
180                false
181            };
182
183            if !skip {
184                expected_test = None;
185
186                let ensure_time = get_ensure_time(&args, &next.test);
187
188                output.start_running_test(&next.test, next.index, count);
189                let result = run_test(
190                    output.clone(),
191                    next.index,
192                    count,
193                    args.nocapture,
194                    args.include_ignored,
195                    ensure_time,
196                    next.deps.clone(),
197                    &next.test,
198                    &mut worker,
199                )
200                .await;
201                output.finished_running_test(&next.test, next.index, count, &result);
202
203                if let Some(connection) = &mut connection {
204                    let finish_marker = Uuid::new_v4().to_string();
205                    let finish_marker_line = format!("{finish_marker}\n");
206                    tokio::io::stdout()
207                        .write_all(finish_marker_line.as_bytes())
208                        .await
209                        .unwrap();
210                    tokio::io::stderr()
211                        .write_all(finish_marker_line.as_bytes())
212                        .await
213                        .unwrap();
214                    tokio::io::stdout().flush().await.unwrap();
215                    tokio::io::stderr().flush().await.unwrap();
216
217                    let response = IpcResponse::TestFinished {
218                        result: (&result).into(),
219                        finish_marker,
220                    };
221                    let msg =
222                        serialize_to_byte_vec(&response).expect("Failed to encode IPC response");
223                    let message_size = (msg.len() as u16).to_le_bytes();
224                    connection
225                        .write_all(&message_size)
226                        .await
227                        .expect("Failed to write IPC response message size");
228                    connection
229                        .write_all(&msg)
230                        .await
231                        .expect("Failed to write response to IPC connection");
232                }
233
234                results.lock().await.push((next.test.clone(), result));
235            }
236        }
237    }
238}
239
240async fn is_done(execution: &Arc<Mutex<TestSuiteExecution>>) -> bool {
241    let execution = execution.lock().await;
242    execution.is_done()
243}
244
245async fn pick_next(execution: &Arc<Mutex<TestSuiteExecution>>) -> Option<TestExecution> {
246    let mut execution = execution.lock().await;
247    execution.pick_next().await
248}
249
250async fn run_with_flakiness_control<F>(
251    output: Arc<dyn TestRunnerOutput>,
252    test_description: &RegisteredTest,
253    idx: usize,
254    count: usize,
255    test: F,
256) -> Result<Result<(), FailureCause>, Box<dyn Any + Send>>
257where
258    F: Fn(
259            Instant,
260        )
261            -> Pin<Box<dyn Future<Output = Result<Result<(), FailureCause>, Box<dyn Any + Send>>>>>
262        + Send
263        + Sync,
264{
265    match &test_description.props.flakiness_control {
266        FlakinessControl::None => {
267            let start = Instant::now();
268            test(start).await
269        }
270        FlakinessControl::ProveNonFlaky(tries) => {
271            for n in 0..*tries {
272                if n > 0 {
273                    output.repeat_running_test(
274                        test_description,
275                        idx,
276                        count,
277                        n + 1,
278                        *tries,
279                        "to ensure test is not flaky",
280                    );
281                }
282                let start = Instant::now();
283                match test(start).await {
284                    Ok(Ok(())) => {}
285                    Ok(Err(e)) => return Ok(Err(e)),
286                    Err(e) => return Err(e),
287                };
288            }
289            Ok(Ok(()))
290        }
291        FlakinessControl::RetryKnownFlaky(max_retries) => {
292            let mut tries = 1;
293            loop {
294                let start = Instant::now();
295                let result = test(start).await;
296
297                if result.is_err() && tries < *max_retries {
298                    tries += 1;
299                    output.repeat_running_test(
300                        test_description,
301                        idx,
302                        count,
303                        tries,
304                        *max_retries,
305                        "because test is known to be flaky",
306                    );
307                } else {
308                    break result;
309                }
310            }
311        }
312    }
313}
314
315#[allow(clippy::too_many_arguments)]
316async fn run_test(
317    output: Arc<dyn TestRunnerOutput>,
318    idx: usize,
319    count: usize,
320    nocapture: bool,
321    include_ignored: bool,
322    ensure_time: Option<TimeThreshold>,
323    dependency_view: Arc<dyn internal::DependencyView + Send + Sync>,
324    test: &RegisteredTest,
325    worker: &mut Option<Worker>,
326) -> TestResult {
327    if test.props.is_ignored && !include_ignored {
328        TestResult::ignored()
329    } else if let Some(worker) = worker.as_mut() {
330        worker.run_test(nocapture, test).await
331    } else {
332        let start = Instant::now();
333        let test = test.clone();
334        match &test.run {
335            TestFunction::Sync(_) => {
336                let handle = spawn_blocking(move || {
337                    let test = test.clone();
338                    crate::sync::run_sync_test_function(
339                        output,
340                        &test,
341                        idx,
342                        count,
343                        ensure_time,
344                        dependency_view,
345                    )
346                });
347                handle.await.unwrap_or_else(|join_error| {
348                    TestResult::failed(
349                        start.elapsed(),
350                        FailureCause::HarnessError(format!(
351                            "Failed joining test task: {join_error}"
352                        )),
353                    )
354                })
355            }
356            TestFunction::Async(test_fn) => {
357                let timeout = test.props.timeout;
358                let test_fn = test_fn.clone();
359                let detached_panic_policy = test.props.detached_panic_policy.clone();
360                let result = run_with_flakiness_control(output, &test, idx, count, |start| {
361                    let dependency_view = dependency_view.clone();
362                    let test_fn = test_fn.clone();
363                    Box::pin(async move {
364                        let test_id = crate::panic_hook::next_test_id();
365                        crate::panic_hook::set_current_test_id(test_id);
366                        crate::panic_hook::create_detached_collector(test_id);
367                        let result = AssertUnwindSafe(Box::pin(async move {
368                            match timeout {
369                                None => test_fn(dependency_view).await,
370                                Some(duration) => {
371                                    let result =
372                                        tokio::time::timeout(duration, test_fn(dependency_view))
373                                            .await;
374                                    match result {
375                                        Ok(result) => result,
376                                        Err(_) => {
377                                            return Err(FailureCause::HarnessError(
378                                                "Test timed out".to_string(),
379                                            ))
380                                        }
381                                    }
382                                }
383                            }
384                            .into_result()?;
385                            if let Some(ensure_time) = ensure_time {
386                                let elapsed = start.elapsed();
387                                if ensure_time.is_critical(&elapsed) {
388                                    return Err(FailureCause::HarnessError(format!(
389                                        "Test run time exceeds critical threshold: {elapsed:?}"
390                                    )));
391                                }
392                            }
393                            Ok(())
394                        }))
395                        .catch_unwind()
396                        .await;
397                        result
398                    })
399                })
400                .await;
401                let mut test_result =
402                    TestResult::from_result(&test.props.should_panic, start.elapsed(), result);
403                if let Some(test_id) = crate::panic_hook::current_test_id() {
404                    if let Some(collector) = crate::panic_hook::take_detached_collector(test_id) {
405                        let panics = match collector.lock() {
406                            Ok(p) => p,
407                            Err(poisoned) => poisoned.into_inner(),
408                        };
409                        if !panics.is_empty()
410                            && detached_panic_policy == internal::DetachedPanicPolicy::FailTest
411                            && test_result.is_passed()
412                        {
413                            let messages: Vec<String> = panics.iter().map(|p| p.render()).collect();
414                            test_result = TestResult::failed(
415                                start.elapsed(),
416                                FailureCause::Panic(internal::PanicCause {
417                                    message: Some(format!(
418                                        "Detached task(s) panicked:\n{}",
419                                        messages.join("\n---\n")
420                                    )),
421                                    location: panics.first().and_then(|p| p.location.clone()),
422                                    backtrace: panics.first().and_then(|p| p.backtrace.clone()),
423                                }),
424                            );
425                        }
426                    }
427                }
428                crate::panic_hook::clear_current_test_id();
429                test_result
430            }
431            TestFunction::SyncBench(_) => {
432                let handle = spawn_blocking(move || {
433                    let test = test.clone();
434                    crate::sync::run_sync_test_function(
435                        output,
436                        &test,
437                        idx,
438                        count,
439                        ensure_time,
440                        dependency_view,
441                    )
442                });
443                handle.await.unwrap_or_else(|join_error| {
444                    TestResult::failed(
445                        start.elapsed(),
446                        FailureCause::HarnessError(format!(
447                            "Failed joining test task: {join_error}"
448                        )),
449                    )
450                })
451            }
452            TestFunction::AsyncBench(bench_fn) => {
453                let mut bencher = AsyncBencher::new();
454                let test_id = crate::panic_hook::next_test_id();
455                crate::panic_hook::set_current_test_id(test_id);
456                let result = AssertUnwindSafe(async move {
457                    bench_fn(&mut bencher, dependency_view).await;
458                    (
459                        bencher
460                            .summary()
461                            .expect("iter() was not called in bench function"),
462                        bencher.bytes,
463                    )
464                })
465                .catch_unwind()
466                .await;
467                let bytes = result.as_ref().map(|(_, bytes)| *bytes).unwrap_or_default();
468                let test_result = TestResult::from_summary(
469                    &test.props.should_panic,
470                    start.elapsed(),
471                    result.map(|(summary, _)| summary),
472                    bytes,
473                );
474                crate::panic_hook::clear_current_test_id();
475                test_result
476            }
477        }
478    }
479}
480
481struct Worker {
482    _listener: Listener,
483    _process: Child,
484    _out_handle: JoinHandle<()>,
485    _err_handle: JoinHandle<()>,
486    out_lines: Arc<Mutex<VecDeque<CapturedOutput>>>,
487    err_lines: Arc<Mutex<VecDeque<CapturedOutput>>>,
488    capture_enabled: Arc<Mutex<bool>>,
489    connection: Stream,
490}
491
492impl Worker {
493    pub async fn run_test(&mut self, nocapture: bool, test: &RegisteredTest) -> TestResult {
494        let mut capture_enabled = self.capture_enabled.lock().await;
495        *capture_enabled = test.props.capture_control.requires_capturing(!nocapture);
496        drop(capture_enabled);
497
498        // Send IPC command and wait for IPC response, and in the meantime read from the stdout/stderr channels
499        let cmd = IpcCommand::RunTest {
500            name: test.name.clone(),
501            crate_name: test.crate_name.clone(),
502            module_path: test.module_path.clone(),
503        };
504
505        let dump_on_ipc_failure = self.dump_on_failure();
506
507        let msg = serialize_to_byte_vec(&cmd).expect("Failed to encode IPC command");
508        let message_size = (msg.len() as u16).to_le_bytes();
509        dump_on_ipc_failure
510            .run(self.connection.write_all(&message_size).await)
511            .await;
512        dump_on_ipc_failure
513            .run(self.connection.write_all(&msg).await)
514            .await;
515
516        let mut response_size: [u8; 2] = [0, 0];
517        dump_on_ipc_failure
518            .run(self.connection.read_exact(&mut response_size).await)
519            .await;
520        let mut response = vec![0; u16::from_le_bytes(response_size) as usize];
521        dump_on_ipc_failure
522            .run(self.connection.read_exact(&mut response).await)
523            .await;
524        let response: IpcResponse = dump_on_ipc_failure.run(deserialize(&response)).await;
525
526        let IpcResponse::TestFinished {
527            result,
528            finish_marker,
529        } = response;
530
531        if test.props.capture_control.requires_capturing(!nocapture) {
532            let out_lines: Vec<_> =
533                Self::drain_until(self.out_lines.clone(), finish_marker.clone()).await;
534            let err_lines: Vec<_> =
535                Self::drain_until(self.err_lines.clone(), finish_marker.clone()).await;
536            result.into_test_result(out_lines, err_lines)
537        } else {
538            result.into_test_result(Vec::new(), Vec::new())
539        }
540    }
541
542    fn dump_on_failure(&self) -> DumpOnFailure {
543        DumpOnFailure {
544            out_lines: self.out_lines.clone(),
545            err_lines: self.err_lines.clone(),
546        }
547    }
548
549    async fn drain_until(
550        source: Arc<Mutex<VecDeque<CapturedOutput>>>,
551        finish_marker: String,
552    ) -> Vec<CapturedOutput> {
553        let mut result = Vec::new();
554        loop {
555            let mut source = source.lock().await;
556            while let Some(line) = source.pop_front() {
557                if line.line() == finish_marker {
558                    return result;
559                } else {
560                    result.push(line.clone());
561                }
562            }
563            drop(source);
564
565            tokio::time::sleep(std::time::Duration::from_millis(10)).await;
566        }
567    }
568}
569
570struct DumpOnFailure {
571    out_lines: Arc<Mutex<VecDeque<CapturedOutput>>>,
572    err_lines: Arc<Mutex<VecDeque<CapturedOutput>>>,
573}
574
575impl DumpOnFailure {
576    pub async fn run<T, E>(&self, result: Result<T, E>) -> T {
577        match result {
578            Ok(value) => value,
579            Err(_error) => {
580                let out_lines: Vec<_> = self.out_lines.lock().await.drain(..).collect();
581                let err_lines: Vec<_> = self.err_lines.lock().await.drain(..).collect();
582                let mut all_lines = [out_lines, err_lines].concat();
583                all_lines.sort();
584
585                for line in all_lines {
586                    eprintln!("{}", line.line());
587                }
588
589                std::process::exit(1);
590            }
591        }
592    }
593}
594
595async fn spawn_worker_if_needed(args: &Arguments) -> Option<Worker> {
596    if args.spawn_workers {
597        let id = Uuid::new_v4();
598        let name_str = format!("{id}.sock");
599        let name = name_str
600            .clone()
601            .to_ns_name::<GenericNamespaced>()
602            .expect("Invalid local socket name");
603        let opts = ListenerOptions::new().name(name.clone());
604        let listener = opts
605            .create_tokio()
606            .expect("Failed to create local socket listener");
607
608        let exe = std::env::current_exe().expect("Failed to get current executable path");
609
610        let mut args = args.clone();
611        args.ipc = Some(name_str);
612        args.spawn_workers = false;
613        args.logfile = None;
614        let args = args.to_args();
615
616        let mut process = Command::new(exe)
617            .args(args)
618            .stdin(Stdio::inherit())
619            .stderr(Stdio::piped())
620            .stdout(Stdio::piped())
621            .spawn()
622            .expect("Failed to spawn worker process");
623
624        let stdout = process.stdout.take().unwrap();
625        let stderr = process.stderr.take().unwrap();
626
627        let out_lines = Arc::new(Mutex::new(VecDeque::new()));
628        let err_lines = Arc::new(Mutex::new(VecDeque::new()));
629        let capture_enabled = Arc::new(Mutex::new(true));
630
631        let out_lines_clone = out_lines.clone();
632        let capture_enabled_clone = capture_enabled.clone();
633        let out_handle = spawn(async move {
634            let reader = BufReader::new(stdout);
635            let mut lines = reader.lines();
636            while let Some(line) = lines
637                .next_line()
638                .await
639                .expect("Failed to read from worker stdout")
640            {
641                if *capture_enabled_clone.lock().await {
642                    out_lines_clone
643                        .lock()
644                        .await
645                        .push_back(CapturedOutput::stdout(line));
646                } else {
647                    println!("{line}");
648                }
649            }
650        });
651
652        let err_lines_clone = err_lines.clone();
653        let capture_enabled_clone = capture_enabled.clone();
654        let err_handle = spawn(async move {
655            let reader = BufReader::new(stderr);
656            let mut lines = reader.lines();
657            while let Some(line) = lines
658                .next_line()
659                .await
660                .expect("Failed to read from worker stderr")
661            {
662                if *capture_enabled_clone.lock().await {
663                    err_lines_clone
664                        .lock()
665                        .await
666                        .push_back(CapturedOutput::stderr(line));
667                } else {
668                    eprintln!("{line}");
669                }
670            }
671        });
672
673        let connection = listener
674            .accept()
675            .await
676            .expect("Failed to accept connection");
677
678        Some(Worker {
679            _listener: listener,
680            _process: process,
681            _out_handle: out_handle,
682            _err_handle: err_handle,
683            out_lines,
684            err_lines,
685            connection,
686            capture_enabled,
687        })
688    } else {
689        None
690    }
691}