1use crate::args::{Arguments, TimeThreshold};
2use crate::bench::AsyncBencher;
3use crate::execution::{TestExecution, TestSuiteExecution};
4use crate::internal;
5use crate::internal::{
6 generate_tests, get_ensure_time, CapturedOutput, FlakinessControl, RegisteredTest, SuiteResult,
7 TestFunction, TestResult,
8};
9use crate::ipc::{ipc_name, IpcCommand, IpcResponse};
10use crate::output::{test_runner_output, TestRunnerOutput};
11use bincode::{decode_from_slice, encode_to_vec};
12use futures::FutureExt;
13use interprocess::local_socket::tokio::prelude::*;
14use interprocess::local_socket::tokio::{Listener, Stream};
15use interprocess::local_socket::{GenericNamespaced, ListenerOptions};
16use std::collections::VecDeque;
17use std::future::Future;
18use std::panic::AssertUnwindSafe;
19use std::pin::Pin;
20use std::process::{ExitCode, Stdio};
21use std::sync::Arc;
22use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader};
23use tokio::process::{Child, Command};
24use tokio::spawn;
25use tokio::sync::Mutex;
26use tokio::task::{spawn_blocking, JoinHandle, JoinSet};
27use tokio::time::Instant;
28use uuid::Uuid;
29
30pub fn test_runner() -> ExitCode {
31 tokio::runtime::Builder::new_multi_thread()
32 .enable_all()
33 .build()
34 .unwrap()
35 .block_on(async_test_runner())
36}
37
38#[allow(clippy::await_holding_lock)]
39async fn async_test_runner() -> ExitCode {
40 let mut args = Arguments::from_args();
41 let output = test_runner_output(&args);
42
43 let registered_tests = internal::REGISTERED_TESTS.lock().unwrap();
44 let registered_dependency_constructors =
45 internal::REGISTERED_DEPENDENCY_CONSTRUCTORS.lock().unwrap();
46 let registered_testsuite_props = internal::REGISTERED_TESTSUITE_PROPS.lock().unwrap();
47 let registered_test_generators = internal::REGISTERED_TEST_GENERATORS.lock().unwrap();
48
49 let generated_tests = generate_tests(®istered_test_generators).await;
50
51 let all_tests: Vec<RegisteredTest> = registered_tests
52 .iter()
53 .cloned()
54 .chain(generated_tests)
55 .collect();
56
57 if args.list {
58 output.test_list(&all_tests);
59 ExitCode::SUCCESS
60 } else {
61 let mut remaining_retries = args.flaky_run.unwrap_or(1);
62
63 let mut exit_code = ExitCode::from(101);
64 while remaining_retries > 0 {
65 let (mut execution, filtered_tests) = TestSuiteExecution::construct(
66 &args,
67 registered_dependency_constructors.as_slice(),
68 &all_tests,
69 registered_testsuite_props.as_slice(),
70 );
71 args.finalize_for_execution(&execution, output.clone());
72 if args.spawn_workers {
73 execution.skip_creating_dependencies();
74 }
75
76 let count = execution.remaining();
81 let results = Arc::new(Mutex::new(Vec::with_capacity(count)));
82
83 let start = Instant::now();
84 output.start_suite(&filtered_tests);
85
86 let execution = Arc::new(Mutex::new(execution));
87 let mut join_set = JoinSet::new();
88 let threads = args.test_threads().get();
89
90 for _ in 0..threads {
91 let execution_clone = execution.clone();
92 let output_clone = output.clone();
93 let args_clone = args.clone();
94 let results_clone = results.clone();
95 let handle = tokio::runtime::Handle::current();
96 join_set.spawn_blocking(move || {
97 handle.block_on(test_thread(
98 args_clone,
99 execution_clone,
100 output_clone,
101 count,
102 results_clone,
103 ))
104 });
105 }
106
107 while let Some(res) = join_set.join_next().await {
108 res.expect("Failed to join task");
109 }
110
111 drop(execution);
112
113 let results = results.lock().await;
114 output.finished_suite(&all_tests, &results, start.elapsed());
115 exit_code = SuiteResult::exit_code(&results);
116
117 if exit_code == ExitCode::SUCCESS {
118 break;
119 } else {
120 remaining_retries -= 1;
121 }
122 }
123 exit_code
124 }
125}
126
127async fn test_thread(
128 args: Arguments,
129 execution: Arc<Mutex<TestSuiteExecution>>,
130 output: Arc<dyn TestRunnerOutput>,
131 count: usize,
132 results: Arc<Mutex<Vec<(RegisteredTest, TestResult)>>>,
133) {
134 let mut worker = spawn_worker_if_needed(&args).await;
135 let mut connection = if let Some(ref name) = args.ipc {
136 let name = ipc_name(name.clone());
137 let stream = Stream::connect(name)
138 .await
139 .expect("Failed to connect to IPC socket");
140 Some(stream)
141 } else {
142 None
143 };
144
145 let mut expected_test = None;
146
147 while !is_done(&execution).await {
148 if let Some(connection) = &mut connection {
149 if expected_test.is_none() {
150 let mut command_size: [u8; 2] = [0, 0];
151 connection
152 .read_exact(&mut command_size)
153 .await
154 .expect("Failed to read IPC command size");
155 let mut command = vec![0; u16::from_le_bytes(command_size) as usize];
156 connection
157 .read_exact(&mut command)
158 .await
159 .expect("Failed to read IPC command");
160 let (command, _): (IpcCommand, usize) =
161 decode_from_slice(&command, bincode::config::standard())
162 .expect("Failed to decode IPC command");
163
164 let IpcCommand::RunTest {
165 name,
166 crate_name,
167 module_path,
168 } = command;
169 expected_test = Some((name, crate_name, module_path));
170 }
171 }
172
173 if let Some(next) = pick_next(&execution).await {
174 let skip = if let Some((name, crate_name, module_path)) = &expected_test {
175 next.test.name != *name
176 || next.test.crate_name != *crate_name
177 || next.test.module_path != *module_path
178 } else {
179 false
180 };
181
182 if !skip {
183 expected_test = None;
184
185 let ensure_time = get_ensure_time(&args, &next.test);
186
187 output.start_running_test(&next.test, next.index, count);
188 let result = run_test(
189 output.clone(),
190 next.index,
191 count,
192 args.nocapture,
193 args.include_ignored,
194 ensure_time,
195 next.deps.clone(),
196 &next.test,
197 &mut worker,
198 )
199 .await;
200 output.finished_running_test(&next.test, next.index, count, &result);
201
202 if let Some(connection) = &mut connection {
203 let finish_marker = Uuid::new_v4().to_string();
204 let finish_marker_line = format!("{finish_marker}\n");
205 tokio::io::stdout()
206 .write_all(finish_marker_line.as_bytes())
207 .await
208 .unwrap();
209 tokio::io::stderr()
210 .write_all(finish_marker_line.as_bytes())
211 .await
212 .unwrap();
213 tokio::io::stdout().flush().await.unwrap();
214 tokio::io::stderr().flush().await.unwrap();
215
216 let response = IpcResponse::TestFinished {
217 result: (&result).into(),
218 finish_marker,
219 };
220 let msg = encode_to_vec(&response, bincode::config::standard())
221 .expect("Failed to encode IPC response");
222 let message_size = (msg.len() as u16).to_le_bytes();
223 connection
224 .write_all(&message_size)
225 .await
226 .expect("Failed to write IPC response message size");
227 connection
228 .write_all(&msg)
229 .await
230 .expect("Failed to write response to IPC connection");
231 }
232
233 results.lock().await.push((next.test.clone(), result));
234 }
235 }
236 }
237}
238
239async fn is_done(execution: &Arc<Mutex<TestSuiteExecution>>) -> bool {
240 let execution = execution.lock().await;
241 execution.is_done()
242}
243
244async fn pick_next(execution: &Arc<Mutex<TestSuiteExecution>>) -> Option<TestExecution> {
245 let mut execution = execution.lock().await;
246 execution.pick_next().await
247}
248
249async fn run_with_flakiness_control<F, R>(
250 output: Arc<dyn TestRunnerOutput>,
251 test_description: &RegisteredTest,
252 idx: usize,
253 count: usize,
254 test: F,
255) -> Result<(), R>
256where
257 F: Fn(Instant) -> Pin<Box<dyn Future<Output = Result<(), R>>>> + Send + Sync,
258{
259 match &test_description.props.flakiness_control {
260 FlakinessControl::None => {
261 let start = Instant::now();
262 test(start).await
263 }
264 FlakinessControl::ProveNonFlaky(tries) => {
265 for n in 0..*tries {
266 if n > 0 {
267 output.repeat_running_test(
268 test_description,
269 idx,
270 count,
271 n + 1,
272 *tries,
273 "to ensure test is not flaky",
274 );
275 }
276 let start = Instant::now();
277 test(start).await?;
278 }
279 Ok(())
280 }
281 FlakinessControl::RetryKnownFlaky(max_retries) => {
282 let mut tries = 1;
283 loop {
284 let start = Instant::now();
285 let result = test(start).await;
286
287 if result.is_err() && tries < *max_retries {
288 tries += 1;
289 output.repeat_running_test(
290 test_description,
291 idx,
292 count,
293 tries,
294 *max_retries,
295 "because test is known to be flaky",
296 );
297 } else {
298 break result;
299 }
300 }
301 }
302 }
303}
304
305#[allow(clippy::too_many_arguments)]
306async fn run_test(
307 output: Arc<dyn TestRunnerOutput>,
308 idx: usize,
309 count: usize,
310 nocapture: bool,
311 include_ignored: bool,
312 ensure_time: Option<TimeThreshold>,
313 dependency_view: Arc<dyn internal::DependencyView + Send + Sync>,
314 test: &RegisteredTest,
315 worker: &mut Option<Worker>,
316) -> TestResult {
317 if test.props.is_ignored && !include_ignored {
318 TestResult::ignored()
319 } else if let Some(worker) = worker.as_mut() {
320 worker.run_test(nocapture, test).await
321 } else {
322 let start = Instant::now();
323 let test = test.clone();
324 match &test.run {
325 TestFunction::Sync(_) => {
326 let handle = spawn_blocking(move || {
327 let test = test.clone();
328 crate::sync::run_sync_test_function(
329 output,
330 &test,
331 idx,
332 count,
333 ensure_time,
334 dependency_view,
335 )
336 });
337 handle.await.unwrap_or_else(|join_error| {
338 TestResult::failed(start.elapsed(), Box::new(join_error))
339 })
340 }
341 TestFunction::Async(test_fn) => {
342 let timeout = test.props.timeout;
343 let test_fn = test_fn.clone();
344 let result = run_with_flakiness_control(output, &test, idx, count, |start| {
345 let dependency_view = dependency_view.clone();
346 let test_fn = test_fn.clone();
347 Box::pin(async move {
348 AssertUnwindSafe(Box::pin(async move {
349 let result = match timeout {
350 None => test_fn(dependency_view).await,
351 Some(duration) => {
352 let result =
353 tokio::time::timeout(duration, test_fn(dependency_view))
354 .await;
355 match result {
356 Ok(result) => result,
357 Err(_) => panic!("Test timed out"),
358 }
359 }
360 };
361 match result.as_result() {
362 Ok(_) => (),
363 Err(message) => panic!("{message}"),
364 };
365 if let Some(ensure_time) = ensure_time {
366 let elapsed = start.elapsed();
367 if ensure_time.is_critical(&elapsed) {
368 panic!("Test run time exceeds critical threshold: {elapsed:?}");
369 }
370 }
371 }))
372 .catch_unwind()
373 .await
374 })
375 })
376 .await;
377 TestResult::from_result(&test.props.should_panic, start.elapsed(), result)
378 }
379 TestFunction::SyncBench(_) => {
380 let handle = spawn_blocking(move || {
381 let test = test.clone();
382 crate::sync::run_sync_test_function(
383 output,
384 &test,
385 idx,
386 count,
387 ensure_time,
388 dependency_view,
389 )
390 });
391 handle.await.unwrap_or_else(|join_error| {
392 TestResult::failed(start.elapsed(), Box::new(join_error))
393 })
394 }
395 TestFunction::AsyncBench(bench_fn) => {
396 let mut bencher = AsyncBencher::new();
397 let result = AssertUnwindSafe(async move {
398 bench_fn(&mut bencher, dependency_view).await;
399 (
400 bencher
401 .summary()
402 .expect("iter() was not called in bench function"),
403 bencher.bytes,
404 )
405 })
406 .catch_unwind()
407 .await;
408 let bytes = result.as_ref().map(|(_, bytes)| *bytes).unwrap_or_default();
409 TestResult::from_summary(
410 &test.props.should_panic,
411 start.elapsed(),
412 result.map(|(summary, _)| summary),
413 bytes,
414 )
415 }
416 }
417 }
418}
419
420struct Worker {
421 _listener: Listener,
422 _process: Child,
423 _out_handle: JoinHandle<()>,
424 _err_handle: JoinHandle<()>,
425 out_lines: Arc<Mutex<VecDeque<CapturedOutput>>>,
426 err_lines: Arc<Mutex<VecDeque<CapturedOutput>>>,
427 capture_enabled: Arc<Mutex<bool>>,
428 connection: Stream,
429}
430
431impl Worker {
432 pub async fn run_test(&mut self, nocapture: bool, test: &RegisteredTest) -> TestResult {
433 let mut capture_enabled = self.capture_enabled.lock().await;
434 *capture_enabled = test.props.capture_control.requires_capturing(!nocapture);
435 drop(capture_enabled);
436
437 let cmd = IpcCommand::RunTest {
439 name: test.name.clone(),
440 crate_name: test.crate_name.clone(),
441 module_path: test.module_path.clone(),
442 };
443
444 let dump_on_ipc_failure = self.dump_on_failure();
445
446 let msg =
447 encode_to_vec(&cmd, bincode::config::standard()).expect("Failed to encode IPC command");
448 let message_size = (msg.len() as u16).to_le_bytes();
449 dump_on_ipc_failure
450 .run(self.connection.write_all(&message_size).await)
451 .await;
452 dump_on_ipc_failure
453 .run(self.connection.write_all(&msg).await)
454 .await;
455
456 let mut response_size: [u8; 2] = [0, 0];
457 dump_on_ipc_failure
458 .run(self.connection.read_exact(&mut response_size).await)
459 .await;
460 let mut response = vec![0; u16::from_le_bytes(response_size) as usize];
461 dump_on_ipc_failure
462 .run(self.connection.read_exact(&mut response).await)
463 .await;
464 let (response, _): (IpcResponse, usize) = dump_on_ipc_failure
465 .run(decode_from_slice(&response, bincode::config::standard()))
466 .await;
467
468 let IpcResponse::TestFinished {
469 result,
470 finish_marker,
471 } = response;
472
473 if test.props.capture_control.requires_capturing(!nocapture) {
474 let out_lines: Vec<_> =
475 Self::drain_until(self.out_lines.clone(), finish_marker.clone()).await;
476 let err_lines: Vec<_> =
477 Self::drain_until(self.err_lines.clone(), finish_marker.clone()).await;
478 result.into_test_result(out_lines, err_lines)
479 } else {
480 result.into_test_result(Vec::new(), Vec::new())
481 }
482 }
483
484 fn dump_on_failure(&self) -> DumpOnFailure {
485 DumpOnFailure {
486 out_lines: self.out_lines.clone(),
487 err_lines: self.err_lines.clone(),
488 }
489 }
490
491 async fn drain_until(
492 source: Arc<Mutex<VecDeque<CapturedOutput>>>,
493 finish_marker: String,
494 ) -> Vec<CapturedOutput> {
495 let mut result = Vec::new();
496 loop {
497 let mut source = source.lock().await;
498 while let Some(line) = source.pop_front() {
499 if line.line() == finish_marker {
500 return result;
501 } else {
502 result.push(line.clone());
503 }
504 }
505 drop(source);
506
507 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
508 }
509 }
510}
511
512struct DumpOnFailure {
513 out_lines: Arc<Mutex<VecDeque<CapturedOutput>>>,
514 err_lines: Arc<Mutex<VecDeque<CapturedOutput>>>,
515}
516
517impl DumpOnFailure {
518 pub async fn run<T, E>(&self, result: Result<T, E>) -> T {
519 match result {
520 Ok(value) => value,
521 Err(_error) => {
522 let out_lines: Vec<_> = self.out_lines.lock().await.drain(..).collect();
523 let err_lines: Vec<_> = self.err_lines.lock().await.drain(..).collect();
524 let mut all_lines = [out_lines, err_lines].concat();
525 all_lines.sort();
526
527 for line in all_lines {
528 eprintln!("{}", line.line());
529 }
530
531 std::process::exit(1);
532 }
533 }
534 }
535}
536
537async fn spawn_worker_if_needed(args: &Arguments) -> Option<Worker> {
538 if args.spawn_workers {
539 let id = Uuid::new_v4();
540 let name_str = format!("{id}.sock");
541 let name = name_str
542 .clone()
543 .to_ns_name::<GenericNamespaced>()
544 .expect("Invalid local socket name");
545 let opts = ListenerOptions::new().name(name.clone());
546 let listener = opts
547 .create_tokio()
548 .expect("Failed to create local socket listener");
549
550 let exe = std::env::current_exe().expect("Failed to get current executable path");
551
552 let mut args = args.clone();
553 args.ipc = Some(name_str);
554 args.spawn_workers = false;
555 args.logfile = None;
556 let args = args.to_args();
557
558 let mut process = Command::new(exe)
559 .args(args)
560 .stdin(Stdio::inherit())
561 .stderr(Stdio::piped())
562 .stdout(Stdio::piped())
563 .spawn()
564 .expect("Failed to spawn worker process");
565
566 let stdout = process.stdout.take().unwrap();
567 let stderr = process.stderr.take().unwrap();
568
569 let out_lines = Arc::new(Mutex::new(VecDeque::new()));
570 let err_lines = Arc::new(Mutex::new(VecDeque::new()));
571 let capture_enabled = Arc::new(Mutex::new(true));
572
573 let out_lines_clone = out_lines.clone();
574 let capture_enabled_clone = capture_enabled.clone();
575 let out_handle = spawn(async move {
576 let reader = BufReader::new(stdout);
577 let mut lines = reader.lines();
578 while let Some(line) = lines
579 .next_line()
580 .await
581 .expect("Failed to read from worker stdout")
582 {
583 if *capture_enabled_clone.lock().await {
584 out_lines_clone
585 .lock()
586 .await
587 .push_back(CapturedOutput::stdout(line));
588 } else {
589 println!("{line}");
590 }
591 }
592 });
593
594 let err_lines_clone = err_lines.clone();
595 let capture_enabled_clone = capture_enabled.clone();
596 let err_handle = spawn(async move {
597 let reader = BufReader::new(stderr);
598 let mut lines = reader.lines();
599 while let Some(line) = lines
600 .next_line()
601 .await
602 .expect("Failed to read from worker stderr")
603 {
604 if *capture_enabled_clone.lock().await {
605 err_lines_clone
606 .lock()
607 .await
608 .push_back(CapturedOutput::stderr(line));
609 } else {
610 eprintln!("{line}");
611 }
612 }
613 });
614
615 let connection = listener
616 .accept()
617 .await
618 .expect("Failed to accept connection");
619
620 Some(Worker {
621 _listener: listener,
622 _process: process,
623 _out_handle: out_handle,
624 _err_handle: err_handle,
625 out_lines,
626 err_lines,
627 connection,
628 capture_enabled,
629 })
630 } else {
631 None
632 }
633}