1use crate::args::{Arguments, TimeThreshold};
2use crate::bench::AsyncBencher;
3use crate::execution::{TestExecution, TestSuiteExecution};
4use crate::internal;
5use crate::internal::{
6 generate_tests, get_ensure_time, CapturedOutput, FailureCause, FlakinessControl,
7 RegisteredTest, SuiteResult, TestFunction, TestResult,
8};
9use crate::ipc::{ipc_name, IpcCommand, IpcResponse};
10use crate::output::{test_runner_output, TestRunnerOutput};
11use desert_rust::{deserialize, serialize_to_byte_vec};
12use futures::FutureExt;
13use interprocess::local_socket::tokio::prelude::*;
14use interprocess::local_socket::tokio::{Listener, Stream};
15use interprocess::local_socket::{GenericNamespaced, ListenerOptions};
16use std::any::Any;
17use std::collections::VecDeque;
18use std::future::Future;
19use std::panic::AssertUnwindSafe;
20use std::pin::Pin;
21use std::process::{ExitCode, Stdio};
22use std::sync::Arc;
23use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader};
24use tokio::process::{Child, Command};
25use tokio::spawn;
26use tokio::sync::Mutex;
27use tokio::task::{spawn_blocking, JoinHandle, JoinSet};
28use tokio::time::Instant;
29use uuid::Uuid;
30
31pub fn test_runner() -> ExitCode {
32 tokio::runtime::Builder::new_multi_thread()
33 .enable_all()
34 .build()
35 .unwrap()
36 .block_on(async_test_runner())
37}
38
39#[allow(clippy::await_holding_lock)]
40async fn async_test_runner() -> ExitCode {
41 crate::panic_hook::install_panic_hook();
42 let mut args = Arguments::from_args();
43 let output = test_runner_output(&args);
44
45 let registered_tests = internal::REGISTERED_TESTS.lock().unwrap();
46 let registered_dependency_constructors =
47 internal::REGISTERED_DEPENDENCY_CONSTRUCTORS.lock().unwrap();
48 let registered_testsuite_props = internal::REGISTERED_TESTSUITE_PROPS.lock().unwrap();
49 let registered_test_generators = internal::REGISTERED_TEST_GENERATORS.lock().unwrap();
50
51 let generated_tests = generate_tests(®istered_test_generators).await;
52
53 let all_tests: Vec<RegisteredTest> = registered_tests
54 .iter()
55 .cloned()
56 .chain(generated_tests)
57 .collect();
58
59 if args.list {
60 output.test_list(&all_tests);
61 ExitCode::SUCCESS
62 } else {
63 let mut remaining_retries = args.flaky_run.unwrap_or(1);
64
65 let mut exit_code = ExitCode::from(101);
66 while remaining_retries > 0 {
67 let (mut execution, filtered_tests) = TestSuiteExecution::construct(
68 &args,
69 registered_dependency_constructors.as_slice(),
70 &all_tests,
71 registered_testsuite_props.as_slice(),
72 );
73 args.finalize_for_execution(&execution, output.clone());
74 if args.spawn_workers {
75 execution.skip_creating_dependencies();
76 }
77
78 let count = execution.remaining();
83 let results = Arc::new(Mutex::new(Vec::with_capacity(count)));
84
85 let start = Instant::now();
86 output.start_suite(&filtered_tests);
87
88 let execution = Arc::new(Mutex::new(execution));
89 let mut join_set = JoinSet::new();
90 let threads = args.test_threads().get();
91
92 for _ in 0..threads {
93 let execution_clone = execution.clone();
94 let output_clone = output.clone();
95 let args_clone = args.clone();
96 let results_clone = results.clone();
97 let handle = tokio::runtime::Handle::current();
98 join_set.spawn_blocking(move || {
99 handle.block_on(test_thread(
100 args_clone,
101 execution_clone,
102 output_clone,
103 count,
104 results_clone,
105 ))
106 });
107 }
108
109 while let Some(res) = join_set.join_next().await {
110 res.expect("Failed to join task");
111 }
112
113 drop(execution);
114
115 let results = results.lock().await;
116 output.finished_suite(&all_tests, &results, start.elapsed());
117 exit_code = SuiteResult::exit_code(&results);
118
119 if exit_code == ExitCode::SUCCESS {
120 break;
121 } else {
122 remaining_retries -= 1;
123 }
124 }
125 exit_code
126 }
127}
128
129async fn test_thread(
130 args: Arguments,
131 execution: Arc<Mutex<TestSuiteExecution>>,
132 output: Arc<dyn TestRunnerOutput>,
133 count: usize,
134 results: Arc<Mutex<Vec<(RegisteredTest, TestResult)>>>,
135) {
136 let mut worker = spawn_worker_if_needed(&args).await;
137 let mut connection = if let Some(ref name) = args.ipc {
138 let name = ipc_name(name.clone());
139 let stream = Stream::connect(name)
140 .await
141 .expect("Failed to connect to IPC socket");
142 Some(stream)
143 } else {
144 None
145 };
146
147 let mut expected_test = None;
148
149 while !is_done(&execution).await {
150 if let Some(connection) = &mut connection {
151 if expected_test.is_none() {
152 let mut command_size: [u8; 2] = [0, 0];
153 connection
154 .read_exact(&mut command_size)
155 .await
156 .expect("Failed to read IPC command size");
157 let mut command = vec![0; u16::from_le_bytes(command_size) as usize];
158 connection
159 .read_exact(&mut command)
160 .await
161 .expect("Failed to read IPC command");
162 let command: IpcCommand =
163 deserialize(&command).expect("Failed to decode IPC command");
164
165 let IpcCommand::RunTest {
166 name,
167 crate_name,
168 module_path,
169 } = command;
170 expected_test = Some((name, crate_name, module_path));
171 }
172 }
173
174 if let Some(next) = pick_next(&execution).await {
175 let skip = if let Some((name, crate_name, module_path)) = &expected_test {
176 next.test.name != *name
177 || next.test.crate_name != *crate_name
178 || next.test.module_path != *module_path
179 } else {
180 false
181 };
182
183 if !skip {
184 expected_test = None;
185
186 let ensure_time = get_ensure_time(&args, &next.test);
187
188 output.start_running_test(&next.test, next.index, count);
189 let result = run_test(
190 output.clone(),
191 next.index,
192 count,
193 args.nocapture,
194 args.include_ignored,
195 ensure_time,
196 next.deps.clone(),
197 &next.test,
198 &mut worker,
199 )
200 .await;
201 output.finished_running_test(&next.test, next.index, count, &result);
202
203 if let Some(connection) = &mut connection {
204 let finish_marker = Uuid::new_v4().to_string();
205 let finish_marker_line = format!("{finish_marker}\n");
206 tokio::io::stdout()
207 .write_all(finish_marker_line.as_bytes())
208 .await
209 .unwrap();
210 tokio::io::stderr()
211 .write_all(finish_marker_line.as_bytes())
212 .await
213 .unwrap();
214 tokio::io::stdout().flush().await.unwrap();
215 tokio::io::stderr().flush().await.unwrap();
216
217 let response = IpcResponse::TestFinished {
218 result: (&result).into(),
219 finish_marker,
220 };
221 let msg =
222 serialize_to_byte_vec(&response).expect("Failed to encode IPC response");
223 let message_size = (msg.len() as u16).to_le_bytes();
224 connection
225 .write_all(&message_size)
226 .await
227 .expect("Failed to write IPC response message size");
228 connection
229 .write_all(&msg)
230 .await
231 .expect("Failed to write response to IPC connection");
232 }
233
234 results.lock().await.push((next.test.clone(), result));
235 }
236 }
237 }
238}
239
240async fn is_done(execution: &Arc<Mutex<TestSuiteExecution>>) -> bool {
241 let execution = execution.lock().await;
242 execution.is_done()
243}
244
245async fn pick_next(execution: &Arc<Mutex<TestSuiteExecution>>) -> Option<TestExecution> {
246 let mut execution = execution.lock().await;
247 execution.pick_next().await
248}
249
250async fn run_with_flakiness_control<F>(
251 output: Arc<dyn TestRunnerOutput>,
252 test_description: &RegisteredTest,
253 idx: usize,
254 count: usize,
255 test: F,
256) -> Result<Result<(), FailureCause>, Box<dyn Any + Send>>
257where
258 F: Fn(
259 Instant,
260 )
261 -> Pin<Box<dyn Future<Output = Result<Result<(), FailureCause>, Box<dyn Any + Send>>>>>
262 + Send
263 + Sync,
264{
265 match &test_description.props.flakiness_control {
266 FlakinessControl::None => {
267 let start = Instant::now();
268 test(start).await
269 }
270 FlakinessControl::ProveNonFlaky(tries) => {
271 for n in 0..*tries {
272 if n > 0 {
273 output.repeat_running_test(
274 test_description,
275 idx,
276 count,
277 n + 1,
278 *tries,
279 "to ensure test is not flaky",
280 );
281 }
282 let start = Instant::now();
283 match test(start).await {
284 Ok(Ok(())) => {}
285 Ok(Err(e)) => return Ok(Err(e)),
286 Err(e) => return Err(e),
287 };
288 }
289 Ok(Ok(()))
290 }
291 FlakinessControl::RetryKnownFlaky(max_retries) => {
292 let mut tries = 1;
293 loop {
294 let start = Instant::now();
295 let result = test(start).await;
296
297 if result.is_err() && tries < *max_retries {
298 tries += 1;
299 output.repeat_running_test(
300 test_description,
301 idx,
302 count,
303 tries,
304 *max_retries,
305 "because test is known to be flaky",
306 );
307 } else {
308 break result;
309 }
310 }
311 }
312 }
313}
314
315#[allow(clippy::too_many_arguments)]
316async fn run_test(
317 output: Arc<dyn TestRunnerOutput>,
318 idx: usize,
319 count: usize,
320 nocapture: bool,
321 include_ignored: bool,
322 ensure_time: Option<TimeThreshold>,
323 dependency_view: Arc<dyn internal::DependencyView + Send + Sync>,
324 test: &RegisteredTest,
325 worker: &mut Option<Worker>,
326) -> TestResult {
327 if test.props.is_ignored && !include_ignored {
328 TestResult::ignored()
329 } else if let Some(worker) = worker.as_mut() {
330 worker.run_test(nocapture, test).await
331 } else {
332 let start = Instant::now();
333 let test = test.clone();
334 match &test.run {
335 TestFunction::Sync(_) => {
336 let handle = spawn_blocking(move || {
337 let test = test.clone();
338 crate::sync::run_sync_test_function(
339 output,
340 &test,
341 idx,
342 count,
343 ensure_time,
344 dependency_view,
345 )
346 });
347 handle.await.unwrap_or_else(|join_error| {
348 TestResult::failed(
349 start.elapsed(),
350 FailureCause::HarnessError(format!(
351 "Failed joining test task: {join_error}"
352 )),
353 )
354 })
355 }
356 TestFunction::Async(test_fn) => {
357 let timeout = test.props.timeout;
358 let test_fn = test_fn.clone();
359 let detached_panic_policy = test.props.detached_panic_policy.clone();
360 let result = run_with_flakiness_control(output, &test, idx, count, |start| {
361 let dependency_view = dependency_view.clone();
362 let test_fn = test_fn.clone();
363 Box::pin(async move {
364 let test_id = crate::panic_hook::next_test_id();
365 crate::panic_hook::set_current_test_id(test_id);
366 crate::panic_hook::create_detached_collector(test_id);
367 let result = AssertUnwindSafe(Box::pin(async move {
368 match timeout {
369 None => test_fn(dependency_view).await,
370 Some(duration) => {
371 let result =
372 tokio::time::timeout(duration, test_fn(dependency_view))
373 .await;
374 match result {
375 Ok(result) => result,
376 Err(_) => {
377 return Err(FailureCause::HarnessError(
378 "Test timed out".to_string(),
379 ))
380 }
381 }
382 }
383 }
384 .into_result()?;
385 if let Some(ensure_time) = ensure_time {
386 let elapsed = start.elapsed();
387 if ensure_time.is_critical(&elapsed) {
388 return Err(FailureCause::HarnessError(format!(
389 "Test run time exceeds critical threshold: {elapsed:?}"
390 )));
391 }
392 }
393 Ok(())
394 }))
395 .catch_unwind()
396 .await;
397 result
398 })
399 })
400 .await;
401 let mut test_result =
402 TestResult::from_result(&test.props.should_panic, start.elapsed(), result);
403 if let Some(test_id) = crate::panic_hook::current_test_id() {
404 if let Some(collector) = crate::panic_hook::take_detached_collector(test_id) {
405 let panics = match collector.lock() {
406 Ok(p) => p,
407 Err(poisoned) => poisoned.into_inner(),
408 };
409 if !panics.is_empty()
410 && detached_panic_policy == internal::DetachedPanicPolicy::FailTest
411 && test_result.is_passed()
412 {
413 let messages: Vec<String> = panics.iter().map(|p| p.render()).collect();
414 test_result = TestResult::failed(
415 start.elapsed(),
416 FailureCause::Panic(internal::PanicCause {
417 message: Some(format!(
418 "Detached task(s) panicked:\n{}",
419 messages.join("\n---\n")
420 )),
421 location: panics.first().and_then(|p| p.location.clone()),
422 backtrace: panics.first().and_then(|p| p.backtrace.clone()),
423 }),
424 );
425 }
426 }
427 }
428 crate::panic_hook::clear_current_test_id();
429 test_result
430 }
431 TestFunction::SyncBench(_) => {
432 let handle = spawn_blocking(move || {
433 let test = test.clone();
434 crate::sync::run_sync_test_function(
435 output,
436 &test,
437 idx,
438 count,
439 ensure_time,
440 dependency_view,
441 )
442 });
443 handle.await.unwrap_or_else(|join_error| {
444 TestResult::failed(
445 start.elapsed(),
446 FailureCause::HarnessError(format!(
447 "Failed joining test task: {join_error}"
448 )),
449 )
450 })
451 }
452 TestFunction::AsyncBench(bench_fn) => {
453 let mut bencher = AsyncBencher::new();
454 let test_id = crate::panic_hook::next_test_id();
455 crate::panic_hook::set_current_test_id(test_id);
456 let result = AssertUnwindSafe(async move {
457 bench_fn(&mut bencher, dependency_view).await;
458 (
459 bencher
460 .summary()
461 .expect("iter() was not called in bench function"),
462 bencher.bytes,
463 )
464 })
465 .catch_unwind()
466 .await;
467 let bytes = result.as_ref().map(|(_, bytes)| *bytes).unwrap_or_default();
468 let test_result = TestResult::from_summary(
469 &test.props.should_panic,
470 start.elapsed(),
471 result.map(|(summary, _)| summary),
472 bytes,
473 );
474 crate::panic_hook::clear_current_test_id();
475 test_result
476 }
477 }
478 }
479}
480
481struct Worker {
482 _listener: Listener,
483 _process: Child,
484 _out_handle: JoinHandle<()>,
485 _err_handle: JoinHandle<()>,
486 out_lines: Arc<Mutex<VecDeque<CapturedOutput>>>,
487 err_lines: Arc<Mutex<VecDeque<CapturedOutput>>>,
488 capture_enabled: Arc<Mutex<bool>>,
489 connection: Stream,
490}
491
492impl Worker {
493 pub async fn run_test(&mut self, nocapture: bool, test: &RegisteredTest) -> TestResult {
494 let mut capture_enabled = self.capture_enabled.lock().await;
495 *capture_enabled = test.props.capture_control.requires_capturing(!nocapture);
496 drop(capture_enabled);
497
498 let cmd = IpcCommand::RunTest {
500 name: test.name.clone(),
501 crate_name: test.crate_name.clone(),
502 module_path: test.module_path.clone(),
503 };
504
505 let dump_on_ipc_failure = self.dump_on_failure();
506
507 let msg = serialize_to_byte_vec(&cmd).expect("Failed to encode IPC command");
508 let message_size = (msg.len() as u16).to_le_bytes();
509 dump_on_ipc_failure
510 .run(self.connection.write_all(&message_size).await)
511 .await;
512 dump_on_ipc_failure
513 .run(self.connection.write_all(&msg).await)
514 .await;
515
516 let mut response_size: [u8; 2] = [0, 0];
517 dump_on_ipc_failure
518 .run(self.connection.read_exact(&mut response_size).await)
519 .await;
520 let mut response = vec![0; u16::from_le_bytes(response_size) as usize];
521 dump_on_ipc_failure
522 .run(self.connection.read_exact(&mut response).await)
523 .await;
524 let response: IpcResponse = dump_on_ipc_failure.run(deserialize(&response)).await;
525
526 let IpcResponse::TestFinished {
527 result,
528 finish_marker,
529 } = response;
530
531 if test.props.capture_control.requires_capturing(!nocapture) {
532 let out_lines: Vec<_> =
533 Self::drain_until(self.out_lines.clone(), finish_marker.clone()).await;
534 let err_lines: Vec<_> =
535 Self::drain_until(self.err_lines.clone(), finish_marker.clone()).await;
536 result.into_test_result(out_lines, err_lines)
537 } else {
538 result.into_test_result(Vec::new(), Vec::new())
539 }
540 }
541
542 fn dump_on_failure(&self) -> DumpOnFailure {
543 DumpOnFailure {
544 out_lines: self.out_lines.clone(),
545 err_lines: self.err_lines.clone(),
546 }
547 }
548
549 async fn drain_until(
550 source: Arc<Mutex<VecDeque<CapturedOutput>>>,
551 finish_marker: String,
552 ) -> Vec<CapturedOutput> {
553 let mut result = Vec::new();
554 loop {
555 let mut source = source.lock().await;
556 while let Some(line) = source.pop_front() {
557 if line.line() == finish_marker {
558 return result;
559 } else {
560 result.push(line.clone());
561 }
562 }
563 drop(source);
564
565 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
566 }
567 }
568}
569
570struct DumpOnFailure {
571 out_lines: Arc<Mutex<VecDeque<CapturedOutput>>>,
572 err_lines: Arc<Mutex<VecDeque<CapturedOutput>>>,
573}
574
575impl DumpOnFailure {
576 pub async fn run<T, E>(&self, result: Result<T, E>) -> T {
577 match result {
578 Ok(value) => value,
579 Err(_error) => {
580 let out_lines: Vec<_> = self.out_lines.lock().await.drain(..).collect();
581 let err_lines: Vec<_> = self.err_lines.lock().await.drain(..).collect();
582 let mut all_lines = [out_lines, err_lines].concat();
583 all_lines.sort();
584
585 for line in all_lines {
586 eprintln!("{}", line.line());
587 }
588
589 std::process::exit(1);
590 }
591 }
592 }
593}
594
595async fn spawn_worker_if_needed(args: &Arguments) -> Option<Worker> {
596 if args.spawn_workers {
597 let id = Uuid::new_v4();
598 let name_str = format!("{id}.sock");
599 let name = name_str
600 .clone()
601 .to_ns_name::<GenericNamespaced>()
602 .expect("Invalid local socket name");
603 let opts = ListenerOptions::new().name(name.clone());
604 let listener = opts
605 .create_tokio()
606 .expect("Failed to create local socket listener");
607
608 let exe = std::env::current_exe().expect("Failed to get current executable path");
609
610 let mut args = args.clone();
611 args.ipc = Some(name_str);
612 args.spawn_workers = false;
613 args.logfile = None;
614 let args = args.to_args();
615
616 let mut process = Command::new(exe)
617 .args(args)
618 .stdin(Stdio::inherit())
619 .stderr(Stdio::piped())
620 .stdout(Stdio::piped())
621 .spawn()
622 .expect("Failed to spawn worker process");
623
624 let stdout = process.stdout.take().unwrap();
625 let stderr = process.stderr.take().unwrap();
626
627 let out_lines = Arc::new(Mutex::new(VecDeque::new()));
628 let err_lines = Arc::new(Mutex::new(VecDeque::new()));
629 let capture_enabled = Arc::new(Mutex::new(true));
630
631 let out_lines_clone = out_lines.clone();
632 let capture_enabled_clone = capture_enabled.clone();
633 let out_handle = spawn(async move {
634 let reader = BufReader::new(stdout);
635 let mut lines = reader.lines();
636 while let Some(line) = lines
637 .next_line()
638 .await
639 .expect("Failed to read from worker stdout")
640 {
641 if *capture_enabled_clone.lock().await {
642 out_lines_clone
643 .lock()
644 .await
645 .push_back(CapturedOutput::stdout(line));
646 } else {
647 println!("{line}");
648 }
649 }
650 });
651
652 let err_lines_clone = err_lines.clone();
653 let capture_enabled_clone = capture_enabled.clone();
654 let err_handle = spawn(async move {
655 let reader = BufReader::new(stderr);
656 let mut lines = reader.lines();
657 while let Some(line) = lines
658 .next_line()
659 .await
660 .expect("Failed to read from worker stderr")
661 {
662 if *capture_enabled_clone.lock().await {
663 err_lines_clone
664 .lock()
665 .await
666 .push_back(CapturedOutput::stderr(line));
667 } else {
668 eprintln!("{line}");
669 }
670 }
671 });
672
673 let connection = listener
674 .accept()
675 .await
676 .expect("Failed to accept connection");
677
678 Some(Worker {
679 _listener: listener,
680 _process: process,
681 _out_handle: out_handle,
682 _err_handle: err_handle,
683 out_lines,
684 err_lines,
685 connection,
686 capture_enabled,
687 })
688 } else {
689 None
690 }
691}