1use crate::args::{Arguments, TimeThreshold};
2use crate::bench::AsyncBencher;
3use crate::execution::{TestExecution, TestSuiteExecution};
4use crate::internal;
5use crate::internal::{
6 generate_tests, get_ensure_time, CapturedOutput, FlakinessControl, RegisteredTest, SuiteResult,
7 TestFunction, TestResult,
8};
9use crate::ipc::{ipc_name, IpcCommand, IpcResponse};
10use crate::output::{test_runner_output, TestRunnerOutput};
11use bincode::{decode_from_slice, encode_to_vec};
12use futures::FutureExt;
13use interprocess::local_socket::tokio::prelude::*;
14use interprocess::local_socket::tokio::{Listener, Stream};
15use interprocess::local_socket::{GenericNamespaced, ListenerOptions};
16use std::collections::VecDeque;
17use std::future::Future;
18use std::panic::AssertUnwindSafe;
19use std::pin::Pin;
20use std::process::{ExitCode, Stdio};
21use std::sync::Arc;
22use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader};
23use tokio::process::{Child, Command};
24use tokio::spawn;
25use tokio::sync::Mutex;
26use tokio::task::{spawn_blocking, JoinHandle, JoinSet};
27use tokio::time::Instant;
28use uuid::Uuid;
29
30pub fn test_runner() -> ExitCode {
31 tokio::runtime::Builder::new_multi_thread()
32 .enable_all()
33 .build()
34 .unwrap()
35 .block_on(async_test_runner())
36}
37
38#[allow(clippy::await_holding_lock)]
39async fn async_test_runner() -> ExitCode {
40 let mut args = Arguments::from_args();
41 let output = test_runner_output(&args);
42
43 let registered_tests = internal::REGISTERED_TESTS.lock().unwrap();
44 let registered_dependency_constructors =
45 internal::REGISTERED_DEPENDENCY_CONSTRUCTORS.lock().unwrap();
46 let registered_testsuite_props = internal::REGISTERED_TESTSUITE_PROPS.lock().unwrap();
47 let registered_test_generators = internal::REGISTERED_TEST_GENERATORS.lock().unwrap();
48
49 let generated_tests = generate_tests(®istered_test_generators).await;
50
51 let all_tests: Vec<RegisteredTest> = registered_tests
52 .iter()
53 .cloned()
54 .chain(generated_tests)
55 .collect();
56
57 if args.list {
58 output.test_list(&all_tests);
59 ExitCode::SUCCESS
60 } else {
61 let (mut execution, filtered_tests) = TestSuiteExecution::construct(
62 &args,
63 registered_dependency_constructors.as_slice(),
64 &all_tests,
65 registered_testsuite_props.as_slice(),
66 );
67 args.finalize_for_execution(&execution, output.clone());
68 if args.spawn_workers {
69 execution.skip_creating_dependencies();
70 }
71
72 let count = execution.remaining();
77 let results = Arc::new(Mutex::new(Vec::with_capacity(count)));
78
79 let start = Instant::now();
80 output.start_suite(&filtered_tests);
81
82 let execution = Arc::new(Mutex::new(execution));
83 let mut join_set = JoinSet::new();
84 let threads = args.test_threads().get();
85
86 for _ in 0..threads {
87 let execution_clone = execution.clone();
88 let output_clone = output.clone();
89 let args_clone = args.clone();
90 let results_clone = results.clone();
91 let handle = tokio::runtime::Handle::current();
92 join_set.spawn_blocking(move || {
93 handle.block_on(test_thread(
94 args_clone,
95 execution_clone,
96 output_clone,
97 count,
98 results_clone,
99 ))
100 });
101 }
102
103 while let Some(res) = join_set.join_next().await {
104 res.expect("Failed to join task");
105 }
106
107 let results = results.lock().await;
108 output.finished_suite(&all_tests, &results, start.elapsed());
109 SuiteResult::exit_code(&results)
110 }
111}
112
113async fn test_thread(
114 args: Arguments,
115 execution: Arc<Mutex<TestSuiteExecution>>,
116 output: Arc<dyn TestRunnerOutput>,
117 count: usize,
118 results: Arc<Mutex<Vec<(RegisteredTest, TestResult)>>>,
119) {
120 let mut worker = spawn_worker_if_needed(&args).await;
121 let mut connection = if let Some(ref name) = args.ipc {
122 let name = ipc_name(name.clone());
123 let stream = Stream::connect(name)
124 .await
125 .expect("Failed to connect to IPC socket");
126 Some(stream)
127 } else {
128 None
129 };
130
131 let mut expected_test = None;
132
133 while !is_done(&execution).await {
134 if let Some(connection) = &mut connection {
135 if expected_test.is_none() {
136 let mut command_size: [u8; 2] = [0, 0];
137 connection
138 .read_exact(&mut command_size)
139 .await
140 .expect("Failed to read IPC command size");
141 let mut command = vec![0; u16::from_le_bytes(command_size) as usize];
142 connection
143 .read_exact(&mut command)
144 .await
145 .expect("Failed to read IPC command");
146 let (command, _): (IpcCommand, usize) =
147 decode_from_slice(&command, bincode::config::standard())
148 .expect("Failed to decode IPC command");
149
150 let IpcCommand::RunTest {
151 name,
152 crate_name,
153 module_path,
154 } = command;
155 expected_test = Some((name, crate_name, module_path));
156 }
157 }
158
159 if let Some(next) = pick_next(&execution).await {
160 let skip = if let Some((name, crate_name, module_path)) = &expected_test {
161 next.test.name != *name
162 || next.test.crate_name != *crate_name
163 || next.test.module_path != *module_path
164 } else {
165 false
166 };
167
168 if !skip {
169 expected_test = None;
170
171 let ensure_time = get_ensure_time(&args, &next.test);
172
173 output.start_running_test(&next.test, next.index, count);
174 let result = run_test(
175 output.clone(),
176 next.index,
177 count,
178 args.nocapture,
179 args.include_ignored,
180 ensure_time,
181 next.deps.clone(),
182 &next.test,
183 &mut worker,
184 )
185 .await;
186 output.finished_running_test(&next.test, next.index, count, &result);
187
188 if let Some(connection) = &mut connection {
189 let finish_marker = Uuid::new_v4().to_string();
190 let finish_marker_line = format!("{finish_marker}\n");
191 tokio::io::stdout()
192 .write_all(finish_marker_line.as_bytes())
193 .await
194 .unwrap();
195 tokio::io::stderr()
196 .write_all(finish_marker_line.as_bytes())
197 .await
198 .unwrap();
199 tokio::io::stdout().flush().await.unwrap();
200 tokio::io::stderr().flush().await.unwrap();
201
202 let response = IpcResponse::TestFinished {
203 result: (&result).into(),
204 finish_marker,
205 };
206 let msg = encode_to_vec(&response, bincode::config::standard())
207 .expect("Failed to encode IPC response");
208 let message_size = (msg.len() as u16).to_le_bytes();
209 connection
210 .write_all(&message_size)
211 .await
212 .expect("Failed to write IPC response message size");
213 connection
214 .write_all(&msg)
215 .await
216 .expect("Failed to write response to IPC connection");
217 }
218
219 results.lock().await.push((next.test.clone(), result));
220 }
221 }
222 }
223}
224
225async fn is_done(execution: &Arc<Mutex<TestSuiteExecution>>) -> bool {
226 let execution = execution.lock().await;
227 execution.is_done()
228}
229
230async fn pick_next(execution: &Arc<Mutex<TestSuiteExecution>>) -> Option<TestExecution> {
231 let mut execution = execution.lock().await;
232 execution.pick_next().await
233}
234
235async fn run_with_flakiness_control<F, R>(
236 output: Arc<dyn TestRunnerOutput>,
237 test_description: &RegisteredTest,
238 idx: usize,
239 count: usize,
240 test: F,
241) -> Result<(), R>
242where
243 F: Fn(Instant) -> Pin<Box<dyn Future<Output = Result<(), R>>>> + Send + Sync,
244{
245 match &test_description.props.flakiness_control {
246 FlakinessControl::None => {
247 let start = Instant::now();
248 test(start).await
249 }
250 FlakinessControl::ProveNonFlaky(tries) => {
251 for n in 0..*tries {
252 if n > 0 {
253 output.repeat_running_test(
254 test_description,
255 idx,
256 count,
257 n + 1,
258 *tries,
259 "to ensure test is not flaky",
260 );
261 }
262 let start = Instant::now();
263 test(start).await?;
264 }
265 Ok(())
266 }
267 FlakinessControl::RetryKnownFlaky(max_retries) => {
268 let mut tries = 1;
269 loop {
270 let start = Instant::now();
271 let result = test(start).await;
272
273 if result.is_err() && tries < *max_retries {
274 tries += 1;
275 output.repeat_running_test(
276 test_description,
277 idx,
278 count,
279 tries,
280 *max_retries,
281 "because test is known to be flaky",
282 );
283 } else {
284 break result;
285 }
286 }
287 }
288 }
289}
290
291#[allow(clippy::too_many_arguments)]
292async fn run_test(
293 output: Arc<dyn TestRunnerOutput>,
294 idx: usize,
295 count: usize,
296 nocapture: bool,
297 include_ignored: bool,
298 ensure_time: Option<TimeThreshold>,
299 dependency_view: Arc<dyn internal::DependencyView + Send + Sync>,
300 test: &RegisteredTest,
301 worker: &mut Option<Worker>,
302) -> TestResult {
303 if test.is_ignored && !include_ignored {
304 TestResult::ignored()
305 } else if let Some(worker) = worker.as_mut() {
306 worker.run_test(nocapture, test).await
307 } else {
308 let start = Instant::now();
309 let test = test.clone();
310 match &test.run {
311 TestFunction::Sync(_) => {
312 let handle = spawn_blocking(move || {
313 let test = test.clone();
314 crate::sync::run_sync_test_function(
315 output,
316 &test,
317 idx,
318 count,
319 ensure_time,
320 dependency_view,
321 )
322 });
323 handle.await.unwrap_or_else(|join_error| {
324 TestResult::failed(start.elapsed(), Box::new(join_error))
325 })
326 }
327 TestFunction::Async(test_fn) => {
328 let timeout = test.props.timeout;
329 let test_fn = test_fn.clone();
330 let result = run_with_flakiness_control(output, &test, idx, count, |start| {
331 let dependency_view = dependency_view.clone();
332 let test_fn = test_fn.clone();
333 Box::pin(async move {
334 AssertUnwindSafe(Box::pin(async move {
335 let result = match timeout {
336 None => test_fn(dependency_view).await,
337 Some(duration) => {
338 let result =
339 tokio::time::timeout(duration, test_fn(dependency_view))
340 .await;
341 match result {
342 Ok(result) => result,
343 Err(_) => panic!("Test timed out"),
344 }
345 }
346 };
347 match result.as_result() {
348 Ok(_) => (),
349 Err(message) => panic!("{message}"),
350 };
351 if let Some(ensure_time) = ensure_time {
352 let elapsed = start.elapsed();
353 if ensure_time.is_critical(&elapsed) {
354 panic!(
355 "Test run time exceeds critical threshold: {:?}",
356 elapsed
357 );
358 }
359 }
360 }))
361 .catch_unwind()
362 .await
363 })
364 })
365 .await;
366 TestResult::from_result(&test.props.should_panic, start.elapsed(), result)
367 }
368 TestFunction::SyncBench(_) => {
369 let handle = spawn_blocking(move || {
370 let test = test.clone();
371 crate::sync::run_sync_test_function(
372 output,
373 &test,
374 idx,
375 count,
376 ensure_time,
377 dependency_view,
378 )
379 });
380 handle.await.unwrap_or_else(|join_error| {
381 TestResult::failed(start.elapsed(), Box::new(join_error))
382 })
383 }
384 TestFunction::AsyncBench(bench_fn) => {
385 let mut bencher = AsyncBencher::new();
386 let result = AssertUnwindSafe(async move {
387 bench_fn(&mut bencher, dependency_view).await;
388 (
389 bencher
390 .summary()
391 .expect("iter() was not called in bench function"),
392 bencher.bytes,
393 )
394 })
395 .catch_unwind()
396 .await;
397 let bytes = result.as_ref().map(|(_, bytes)| *bytes).unwrap_or_default();
398 TestResult::from_summary(
399 &test.props.should_panic,
400 start.elapsed(),
401 result.map(|(summary, _)| summary),
402 bytes,
403 )
404 }
405 }
406 }
407}
408
409struct Worker {
410 _listener: Listener,
411 _process: Child,
412 _out_handle: JoinHandle<()>,
413 _err_handle: JoinHandle<()>,
414 out_lines: Arc<Mutex<VecDeque<CapturedOutput>>>,
415 err_lines: Arc<Mutex<VecDeque<CapturedOutput>>>,
416 capture_enabled: Arc<Mutex<bool>>,
417 connection: Stream,
418}
419
420impl Worker {
421 pub async fn run_test(&mut self, nocapture: bool, test: &RegisteredTest) -> TestResult {
422 let mut capture_enabled = self.capture_enabled.lock().await;
423 *capture_enabled = test.props.capture_control.requires_capturing(!nocapture);
424 drop(capture_enabled);
425
426 let cmd = IpcCommand::RunTest {
428 name: test.name.clone(),
429 crate_name: test.crate_name.clone(),
430 module_path: test.module_path.clone(),
431 };
432
433 let dump_on_ipc_failure = self.dump_on_failure();
434
435 let msg =
436 encode_to_vec(&cmd, bincode::config::standard()).expect("Failed to encode IPC command");
437 let message_size = (msg.len() as u16).to_le_bytes();
438 dump_on_ipc_failure
439 .run(self.connection.write_all(&message_size).await)
440 .await;
441 dump_on_ipc_failure
442 .run(self.connection.write_all(&msg).await)
443 .await;
444
445 let mut response_size: [u8; 2] = [0, 0];
446 dump_on_ipc_failure
447 .run(self.connection.read_exact(&mut response_size).await)
448 .await;
449 let mut response = vec![0; u16::from_le_bytes(response_size) as usize];
450 dump_on_ipc_failure
451 .run(self.connection.read_exact(&mut response).await)
452 .await;
453 let (response, _): (IpcResponse, usize) = dump_on_ipc_failure
454 .run(decode_from_slice(&response, bincode::config::standard()))
455 .await;
456
457 let IpcResponse::TestFinished {
458 result,
459 finish_marker,
460 } = response;
461
462 if test.props.capture_control.requires_capturing(!nocapture) {
463 let out_lines: Vec<_> =
464 Self::drain_until(self.out_lines.clone(), finish_marker.clone()).await;
465 let err_lines: Vec<_> =
466 Self::drain_until(self.err_lines.clone(), finish_marker.clone()).await;
467 result.into_test_result(out_lines, err_lines)
468 } else {
469 result.into_test_result(Vec::new(), Vec::new())
470 }
471 }
472
473 fn dump_on_failure(&self) -> DumpOnFailure {
474 DumpOnFailure {
475 out_lines: self.out_lines.clone(),
476 err_lines: self.err_lines.clone(),
477 }
478 }
479
480 async fn drain_until(
481 source: Arc<Mutex<VecDeque<CapturedOutput>>>,
482 finish_marker: String,
483 ) -> Vec<CapturedOutput> {
484 let mut result = Vec::new();
485 loop {
486 let mut source = source.lock().await;
487 while let Some(line) = source.pop_front() {
488 if line.line() == finish_marker {
489 return result;
490 } else {
491 result.push(line.clone());
492 }
493 }
494 drop(source);
495
496 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
497 }
498 }
499}
500
501struct DumpOnFailure {
502 out_lines: Arc<Mutex<VecDeque<CapturedOutput>>>,
503 err_lines: Arc<Mutex<VecDeque<CapturedOutput>>>,
504}
505
506impl DumpOnFailure {
507 pub async fn run<T, E>(&self, result: Result<T, E>) -> T {
508 match result {
509 Ok(value) => value,
510 Err(_error) => {
511 let out_lines: Vec<_> = self.out_lines.lock().await.drain(..).collect();
512 let err_lines: Vec<_> = self.err_lines.lock().await.drain(..).collect();
513 let mut all_lines = [out_lines, err_lines].concat();
514 all_lines.sort();
515
516 for line in all_lines {
517 eprintln!("{}", line.line());
518 }
519
520 std::process::exit(1);
521 }
522 }
523 }
524}
525
526async fn spawn_worker_if_needed(args: &Arguments) -> Option<Worker> {
527 if args.spawn_workers {
528 let id = Uuid::new_v4();
529 let name_str = format!("{id}.sock");
530 let name = name_str
531 .clone()
532 .to_ns_name::<GenericNamespaced>()
533 .expect("Invalid local socket name");
534 let opts = ListenerOptions::new().name(name.clone());
535 let listener = opts
536 .create_tokio()
537 .expect("Failed to create local socket listener");
538
539 let exe = std::env::current_exe().expect("Failed to get current executable path");
540
541 let mut args = args.clone();
542 args.ipc = Some(name_str);
543 args.spawn_workers = false;
544 args.logfile = None;
545 let args = args.to_args();
546
547 let mut process = Command::new(exe)
548 .args(args)
549 .stdin(Stdio::inherit())
550 .stderr(Stdio::piped())
551 .stdout(Stdio::piped())
552 .spawn()
553 .expect("Failed to spawn worker process");
554
555 let stdout = process.stdout.take().unwrap();
556 let stderr = process.stderr.take().unwrap();
557
558 let out_lines = Arc::new(Mutex::new(VecDeque::new()));
559 let err_lines = Arc::new(Mutex::new(VecDeque::new()));
560 let capture_enabled = Arc::new(Mutex::new(true));
561
562 let out_lines_clone = out_lines.clone();
563 let capture_enabled_clone = capture_enabled.clone();
564 let out_handle = spawn(async move {
565 let reader = BufReader::new(stdout);
566 let mut lines = reader.lines();
567 while let Some(line) = lines
568 .next_line()
569 .await
570 .expect("Failed to read from worker stdout")
571 {
572 if *capture_enabled_clone.lock().await {
573 out_lines_clone
574 .lock()
575 .await
576 .push_back(CapturedOutput::stdout(line));
577 } else {
578 println!("{line}");
579 }
580 }
581 });
582
583 let err_lines_clone = err_lines.clone();
584 let capture_enabled_clone = capture_enabled.clone();
585 let err_handle = spawn(async move {
586 let reader = BufReader::new(stderr);
587 let mut lines = reader.lines();
588 while let Some(line) = lines
589 .next_line()
590 .await
591 .expect("Failed to read from worker stderr")
592 {
593 if *capture_enabled_clone.lock().await {
594 err_lines_clone
595 .lock()
596 .await
597 .push_back(CapturedOutput::stderr(line));
598 } else {
599 eprintln!("{line}");
600 }
601 }
602 });
603
604 let connection = listener
605 .accept()
606 .await
607 .expect("Failed to accept connection");
608
609 Some(Worker {
610 _listener: listener,
611 _process: process,
612 _out_handle: out_handle,
613 _err_handle: err_handle,
614 out_lines,
615 err_lines,
616 connection,
617 capture_enabled,
618 })
619 } else {
620 None
621 }
622}