1use backon::Retryable;
3use eyre::WrapErr;
4use futures::{stream::FuturesUnordered, FutureExt, StreamExt};
5use once_cell::sync::Lazy;
6use std::{
7 collections::HashMap,
8 ops::Deref,
9 pin::Pin,
10 sync::{
11 atomic::{AtomicUsize, Ordering},
12 Arc, Mutex,
13 },
14 time::Duration,
15};
16use tokio::sync::{broadcast, Semaphore};
17use tracing::*;
18
19use crate::{
20 config::{self, get_tanu_config, ProjectConfig},
21 http,
22 reporter::Reporter,
23 Config, ModuleName, ProjectName,
24};
25
26tokio::task_local! {
27 pub(crate) static TEST_INFO: TestInfo;
28}
29
30pub(crate) fn get_test_info() -> TestInfo {
31 TEST_INFO.with(|info| info.clone())
32}
33
34#[allow(clippy::type_complexity)]
36pub(crate) static CHANNEL: Lazy<
37 Mutex<Option<(broadcast::Sender<Event>, broadcast::Receiver<Event>)>>,
38> = Lazy::new(|| Mutex::new(Some(broadcast::channel(1000))));
39
40pub fn publish(e: impl Into<Event>) -> eyre::Result<()> {
41 let Ok(guard) = CHANNEL.lock() else {
42 eyre::bail!("failed to acquire runner channel lock");
43 };
44 let Some((tx, _)) = guard.deref() else {
45 eyre::bail!("runner channel has been already closed");
46 };
47
48 tx.send(e.into())
49 .wrap_err("failed to publish message to the runner channel")?;
50
51 Ok(())
52}
53
54pub fn subscribe() -> eyre::Result<broadcast::Receiver<Event>> {
56 let Ok(guard) = CHANNEL.lock() else {
57 eyre::bail!("failed to acquire runner channel lock");
58 };
59 let Some((tx, _)) = guard.deref() else {
60 eyre::bail!("runner channel has been already closed");
61 };
62
63 Ok(tx.subscribe())
64}
65
66#[derive(Debug, Clone, thiserror::Error)]
67pub enum Error {
68 #[error("panic: {0}")]
69 Panicked(String),
70 #[error("error: {0}")]
71 ErrorReturned(String),
72}
73
74#[derive(Debug, Clone)]
75pub struct Check {
76 pub result: bool,
77 pub expr: String,
78}
79
80impl Check {
81 pub fn success(expr: impl Into<String>) -> Check {
82 Check {
83 result: true,
84 expr: expr.into(),
85 }
86 }
87
88 pub fn error(expr: impl Into<String>) -> Check {
89 Check {
90 result: false,
91 expr: expr.into(),
92 }
93 }
94}
95
96#[derive(Debug, Clone)]
98pub struct Event {
99 pub project: ProjectName,
100 pub module: ModuleName,
101 pub test: ModuleName,
102 pub body: EventBody,
103}
104
105#[derive(Debug, Clone)]
106pub enum EventBody {
107 Start,
108 Check(Box<Check>),
109 Http(Box<http::Log>),
110 Retry,
111 End(Test),
112}
113
114impl From<EventBody> for Event {
115 fn from(body: EventBody) -> Self {
116 let project = crate::config::get_config();
117 let test_info = crate::runner::get_test_info();
118 Event {
119 project: project.name,
120 module: test_info.module,
121 test: test_info.name,
122 body,
123 }
124 }
125}
126
127#[derive(Debug, Clone)]
128pub struct Test {
129 pub info: TestInfo,
130 pub request_time: Duration,
131 pub result: Result<(), Error>,
132}
133
134#[derive(Debug, Clone, Default)]
135pub struct TestInfo {
136 pub module: String,
137 pub name: String,
138}
139
140impl TestInfo {
141 pub fn full_name(&self) -> String {
143 format!("{}::{}", self.module, self.name)
144 }
145
146 pub fn unique_name(&self, project: &str) -> String {
148 format!("{project}::{}::{}", self.module, self.name)
149 }
150}
151
152type TestCaseFactory = Arc<
153 dyn Fn() -> Pin<Box<dyn futures::Future<Output = eyre::Result<()>> + Send + 'static>>
154 + Sync
155 + Send
156 + 'static,
157>;
158
159#[derive(Debug, Clone, Default)]
160pub struct Options {
161 pub debug: bool,
162 pub capture_http: bool,
163 pub capture_rust: bool,
164 pub terminate_channel: bool,
165 pub concurrency: Option<usize>,
166}
167
168pub trait Filter {
170 fn filter(&self, project: &ProjectConfig, info: &TestInfo) -> bool;
171}
172
173pub struct ProjectFilter<'a> {
175 project_names: &'a [String],
176}
177
178impl Filter for ProjectFilter<'_> {
179 fn filter(&self, project: &ProjectConfig, _info: &TestInfo) -> bool {
180 if self.project_names.is_empty() {
181 return true;
182 }
183
184 self.project_names
185 .iter()
186 .any(|project_name| &project.name == project_name)
187 }
188}
189
190pub struct ModuleFilter<'a> {
192 module_names: &'a [String],
193}
194
195impl Filter for ModuleFilter<'_> {
196 fn filter(&self, _project: &ProjectConfig, info: &TestInfo) -> bool {
197 if self.module_names.is_empty() {
198 return true;
199 }
200
201 self.module_names
202 .iter()
203 .any(|module_name| &info.module == module_name)
204 }
205}
206
207pub struct TestNameFilter<'a> {
209 test_names: &'a [String],
210}
211
212impl Filter for TestNameFilter<'_> {
213 fn filter(&self, _project: &ProjectConfig, info: &TestInfo) -> bool {
214 if self.test_names.is_empty() {
215 return true;
216 }
217
218 self.test_names
219 .iter()
220 .any(|test_name| &info.full_name() == test_name)
221 }
222}
223
224pub struct TestIgnoreFilter {
226 test_ignores: HashMap<String, Vec<String>>,
227}
228
229impl Default for TestIgnoreFilter {
230 fn default() -> TestIgnoreFilter {
231 TestIgnoreFilter {
232 test_ignores: get_tanu_config()
233 .projects
234 .iter()
235 .map(|proj| (proj.name.clone(), proj.test_ignore.clone()))
236 .collect(),
237 }
238 }
239}
240
241impl Filter for TestIgnoreFilter {
242 fn filter(&self, project: &ProjectConfig, info: &TestInfo) -> bool {
243 let Some(test_ignore) = self.test_ignores.get(&project.name) else {
244 return true;
245 };
246
247 test_ignore
248 .iter()
249 .all(|test_name| &info.full_name() != test_name)
250 }
251}
252
253#[derive(Default)]
254pub struct Runner {
255 cfg: Config,
256 options: Options,
257 test_cases: Vec<(TestInfo, TestCaseFactory)>,
258 reporters: Vec<Box<dyn Reporter + Send>>,
259}
260
261impl Runner {
262 pub fn new() -> Runner {
263 Runner::with_config(get_tanu_config().clone())
264 }
265
266 pub fn with_config(cfg: Config) -> Runner {
267 Runner {
268 cfg,
269 options: Options::default(),
270 test_cases: Vec::new(),
271 reporters: Vec::new(),
272 }
273 }
274
275 pub fn capture_http(&mut self) {
276 self.options.capture_http = true;
277 }
278
279 pub fn capture_rust(&mut self) {
280 self.options.capture_rust = true;
281 }
282
283 pub fn terminate_channel(&mut self) {
284 self.options.terminate_channel = true;
285 }
286
287 pub fn add_reporter(&mut self, reporter: impl Reporter + 'static + Send) {
288 self.reporters.push(Box::new(reporter));
289 }
290
291 pub fn add_boxed_reporter(&mut self, reporter: Box<dyn Reporter + 'static + Send>) {
292 self.reporters.push(reporter);
293 }
294
295 pub fn add_test(&mut self, name: &str, module: &str, factory: TestCaseFactory) {
297 self.test_cases.push((
298 TestInfo {
299 name: name.into(),
300 module: module.into(),
301 },
302 factory,
303 ));
304 }
305
306 pub fn set_concurrency(&mut self, concurrency: usize) {
307 self.options.concurrency = Some(concurrency);
308 }
309
310 #[allow(clippy::too_many_lines)]
312 pub async fn run(
313 &mut self,
314 project_names: &[String],
315 module_names: &[String],
316 test_names: &[String],
317 ) -> eyre::Result<()> {
318 if self.options.capture_rust {
319 tracing_subscriber::fmt::init();
320 }
321
322 let mut reporters = std::mem::take(&mut self.reporters);
323
324 let project_filter = ProjectFilter { project_names };
325 let module_filter = ModuleFilter { module_names };
326 let test_name_filter = TestNameFilter { test_names };
327 let test_ignore_filter = TestIgnoreFilter::default();
328
329 let start = std::time::Instant::now();
330 let handles: FuturesUnordered<_> = {
331 let semaphore = Arc::new(tokio::sync::Semaphore::new(
333 self.options.concurrency.unwrap_or(Semaphore::MAX_PERMITS),
334 ));
335
336 self.test_cases
337 .iter()
338 .flat_map(|(info, factory)| {
339 let projects = self.cfg.projects.clone();
340 let projects = if projects.is_empty() {
341 vec![ProjectConfig {
342 name: "default".into(),
343 ..Default::default()
344 }]
345 } else {
346 projects
347 };
348 projects
349 .into_iter()
350 .map(move |project| (project.clone(), info.clone(), factory.clone()))
351 })
352 .filter(move |(project, info, _)| test_name_filter.filter(project, info))
353 .filter(move |(project, info, _)| module_filter.filter(project, info))
354 .filter(move |(project, info, _)| project_filter.filter(project, info))
355 .filter(move |(project, info, _)| test_ignore_filter.filter(project, info))
356 .map(|(project, info, factory)| {
357 let semaphore = semaphore.clone();
358 tokio::spawn(async move {
359 let _permit = semaphore.acquire().await.unwrap();
360 config::PROJECT
361 .scope(project.clone(), async {
362 TEST_INFO
363 .scope(info.clone(), async {
364 let test_name = info.name.clone();
365 publish(EventBody::Start)?;
366
367 let retry_count =
368 AtomicUsize::new(project.retry.count.unwrap_or(0));
369 let f = || async {
370 let res = factory().await;
371
372 if res.is_err()
373 && retry_count.load(Ordering::SeqCst) > 0
374 {
375 publish(EventBody::Retry)?;
376 retry_count.fetch_sub(1, Ordering::SeqCst);
377 };
378 res
379 };
380 let started = std::time::Instant::now();
381 let fut = f.retry(project.retry.backoff());
382 let fut = std::panic::AssertUnwindSafe(fut).catch_unwind();
383 let res = fut.await;
384 let request_time = started.elapsed();
385
386 let result = match res {
387 Ok(Ok(_)) => {
388 debug!("{test_name} ok");
389 Ok(())
390 }
391 Ok(Err(e)) => {
392 debug!("{test_name} failed: {e:#}");
393 Err(Error::ErrorReturned(format!("{e:?}")))
394 }
395 Err(e) => {
396 let panic_message = if let Some(panic_message) =
397 e.downcast_ref::<&str>()
398 {
399 format!(
400 "{test_name} failed with message: {panic_message}"
401 )
402 } else if let Some(panic_message) =
403 e.downcast_ref::<String>()
404 {
405 format!(
406 "{test_name} failed with message: {panic_message}"
407 )
408 } else {
409 format!(
410 "{test_name} failed with unknown message"
411 )
412 };
413 let e = eyre::eyre!(panic_message);
414 Err(Error::Panicked(format!("{e:?}")))
415 }
416 };
417
418 let is_err = result.is_err();
419 publish(EventBody::End(Test {
420 info,
421 request_time,
422 result,
423 }))?;
424
425 eyre::ensure!(!is_err);
426 eyre::Ok(())
427 })
428 .await
429 })
430 .await
431 })
432 })
433 .collect()
434 };
435 debug!(
436 "created handles for {} test cases; took {}s",
437 handles.len(),
438 start.elapsed().as_secs_f32()
439 );
440
441 let reporters =
442 futures::future::join_all(reporters.iter_mut().map(|reporter| reporter.run().boxed()));
443
444 let mut has_any_error = false;
445 let options = self.options.clone();
446 let runner = async move {
447 let results = handles.collect::<Vec<_>>().await;
448 if results.is_empty() {
449 console::Term::stdout().write_line("no test cases found")?;
450 }
451 for result in results {
452 match result {
453 Ok(res) => {
454 if let Err(e) = res {
455 debug!("test case failed: {e:#}");
456 has_any_error = true;
457 }
458 }
459 Err(e) => {
460 if e.is_panic() {
461 error!("{e}");
463 has_any_error = true;
464 }
465 }
466 }
467 }
468 debug!("all test finished. sending stop signal to the background tasks.");
469
470 if options.terminate_channel {
471 let Ok(mut guard) = CHANNEL.lock() else {
472 eyre::bail!("failed to acquire runner channel lock");
473 };
474 guard.take(); }
476
477 if has_any_error {
478 eyre::bail!("one or more tests failed");
479 }
480
481 eyre::Ok(())
482 };
483
484 let (handles, reporters) = tokio::join!(runner, reporters);
485 for reporter in reporters {
486 if let Err(e) = reporter {
487 error!("reporter failed: {e:#}");
488 }
489 }
490
491 debug!("runner stopped");
492
493 handles
494 }
495
496 pub fn list(&self) -> Vec<&TestInfo> {
497 self.test_cases
498 .iter()
499 .map(|(meta, _test)| meta)
500 .collect::<Vec<_>>()
501 }
502}
503
504#[cfg(test)]
505mod test {
506 use super::*;
507 use crate::config::RetryConfig;
508
509 fn create_config() -> Config {
510 Config {
511 projects: vec![ProjectConfig {
512 name: "default".into(),
513 ..Default::default()
514 }],
515 ..Default::default()
516 }
517 }
518
519 fn create_config_with_retry() -> Config {
520 Config {
521 projects: vec![ProjectConfig {
522 name: "default".into(),
523 retry: RetryConfig {
524 count: Some(1),
525 ..Default::default()
526 },
527 ..Default::default()
528 }],
529 ..Default::default()
530 }
531 }
532
533 #[tokio::test]
534 async fn runner_fail_because_no_retry_configured() -> eyre::Result<()> {
535 let mut server = mockito::Server::new_async().await;
536 let m1 = server
537 .mock("GET", "/")
538 .with_status(500)
539 .expect(1)
540 .create_async()
541 .await;
542 let m2 = server
543 .mock("GET", "/")
544 .with_status(200)
545 .expect(0)
546 .create_async()
547 .await;
548
549 let factory: TestCaseFactory = Arc::new(move || {
550 let url = server.url();
551 Box::pin(async move {
552 let res = reqwest::get(url).await?;
553 if res.status().is_success() {
554 Ok(())
555 } else {
556 eyre::bail!("request failed")
557 }
558 })
559 });
560
561 let _runner_rx = subscribe()?;
562 let mut runner = Runner::with_config(create_config());
563 runner.add_test("retry_test", "module", factory);
564
565 let result = runner.run(&[], &[], &[]).await;
566 m1.assert_async().await;
567 m2.assert_async().await;
568
569 assert!(result.is_err());
570 Ok(())
571 }
572
573 #[tokio::test]
574 async fn runner_retry_successful_after_failure() -> eyre::Result<()> {
575 let mut server = mockito::Server::new_async().await;
576 let m1 = server
577 .mock("GET", "/")
578 .with_status(500)
579 .expect(1)
580 .create_async()
581 .await;
582 let m2 = server
583 .mock("GET", "/")
584 .with_status(200)
585 .expect(1)
586 .create_async()
587 .await;
588
589 let factory: TestCaseFactory = Arc::new(move || {
590 let url = server.url();
591 Box::pin(async move {
592 let res = reqwest::get(url).await?;
593 if res.status().is_success() {
594 Ok(())
595 } else {
596 eyre::bail!("request failed")
597 }
598 })
599 });
600
601 let _runner_rx = subscribe()?;
602 let mut runner = Runner::with_config(create_config_with_retry());
603 runner.add_test("retry_test", "module", factory);
604
605 let result = runner.run(&[], &[], &[]).await;
606 m1.assert_async().await;
607 m2.assert_async().await;
608
609 assert!(result.is_ok());
610
611 Ok(())
612 }
613}