tanu_core/
runner.rs

1/// tanu's test runner
2use backon::Retryable;
3use eyre::WrapErr;
4use futures::{stream::FuturesUnordered, FutureExt, StreamExt};
5use once_cell::sync::Lazy;
6use std::{
7    collections::HashMap,
8    ops::Deref,
9    pin::Pin,
10    sync::{
11        atomic::{AtomicUsize, Ordering},
12        Arc, Mutex,
13    },
14    time::Duration,
15};
16use tokio::sync::{broadcast, Semaphore};
17use tracing::*;
18
19use crate::{
20    config::{self, get_tanu_config, ProjectConfig},
21    http,
22    reporter::Reporter,
23    Config, ModuleName, ProjectName,
24};
25
26tokio::task_local! {
27    pub(crate) static TEST_INFO: TestInfo;
28}
29
30pub(crate) fn get_test_info() -> TestInfo {
31    TEST_INFO.with(|info| info.clone())
32}
33
34// NOTE: Keep the runner receiver alive here so that sender never fails to send.
35#[allow(clippy::type_complexity)]
36pub(crate) static CHANNEL: Lazy<
37    Mutex<Option<(broadcast::Sender<Event>, broadcast::Receiver<Event>)>>,
38> = Lazy::new(|| Mutex::new(Some(broadcast::channel(1000))));
39
40pub fn publish(e: impl Into<Event>) -> eyre::Result<()> {
41    let Ok(guard) = CHANNEL.lock() else {
42        eyre::bail!("failed to acquire runner channel lock");
43    };
44    let Some((tx, _)) = guard.deref() else {
45        eyre::bail!("runner channel has been already closed");
46    };
47
48    tx.send(e.into())
49        .wrap_err("failed to publish message to the runner channel")?;
50
51    Ok(())
52}
53
54/// Subscribe to the channel to see the real-time test execution events.
55pub fn subscribe() -> eyre::Result<broadcast::Receiver<Event>> {
56    let Ok(guard) = CHANNEL.lock() else {
57        eyre::bail!("failed to acquire runner channel lock");
58    };
59    let Some((tx, _)) = guard.deref() else {
60        eyre::bail!("runner channel has been already closed");
61    };
62
63    Ok(tx.subscribe())
64}
65
66#[derive(Debug, Clone, thiserror::Error)]
67pub enum Error {
68    #[error("panic: {0}")]
69    Panicked(String),
70    #[error("error: {0}")]
71    ErrorReturned(String),
72}
73
74#[derive(Debug, Clone)]
75pub struct Check {
76    pub result: bool,
77    pub expr: String,
78}
79
80impl Check {
81    pub fn success(expr: impl Into<String>) -> Check {
82        Check {
83            result: true,
84            expr: expr.into(),
85        }
86    }
87
88    pub fn error(expr: impl Into<String>) -> Check {
89        Check {
90            result: false,
91            expr: expr.into(),
92        }
93    }
94}
95
96/// Runner event represents a test event that is published to the channel.
97#[derive(Debug, Clone)]
98pub struct Event {
99    pub project: ProjectName,
100    pub module: ModuleName,
101    pub test: ModuleName,
102    pub body: EventBody,
103}
104
105#[derive(Debug, Clone)]
106pub enum EventBody {
107    Start,
108    Check(Box<Check>),
109    Http(Box<http::Log>),
110    Retry,
111    End(Test),
112}
113
114impl From<EventBody> for Event {
115    fn from(body: EventBody) -> Self {
116        let project = crate::config::get_config();
117        let test_info = crate::runner::get_test_info();
118        Event {
119            project: project.name,
120            module: test_info.module,
121            test: test_info.name,
122            body,
123        }
124    }
125}
126
127#[derive(Debug, Clone)]
128pub struct Test {
129    pub info: TestInfo,
130    pub request_time: Duration,
131    pub result: Result<(), Error>,
132}
133
134#[derive(Debug, Clone, Default)]
135pub struct TestInfo {
136    pub module: String,
137    pub name: String,
138}
139
140impl TestInfo {
141    /// Full test name including module
142    pub fn full_name(&self) -> String {
143        format!("{}::{}", self.module, self.name)
144    }
145
146    /// Unique test name including project and module names
147    pub fn unique_name(&self, project: &str) -> String {
148        format!("{project}::{}::{}", self.module, self.name)
149    }
150}
151
152type TestCaseFactory = Arc<
153    dyn Fn() -> Pin<Box<dyn futures::Future<Output = eyre::Result<()>> + Send + 'static>>
154        + Sync
155        + Send
156        + 'static,
157>;
158
159#[derive(Debug, Clone, Default)]
160pub struct Options {
161    pub debug: bool,
162    pub capture_http: bool,
163    pub capture_rust: bool,
164    pub terminate_channel: bool,
165    pub concurrency: Option<usize>,
166}
167
168/// Test case filter trait.
169pub trait Filter {
170    fn filter(&self, project: &ProjectConfig, info: &TestInfo) -> bool;
171}
172
173/// Filter test cases by project name.
174pub struct ProjectFilter<'a> {
175    project_names: &'a [String],
176}
177
178impl Filter for ProjectFilter<'_> {
179    fn filter(&self, project: &ProjectConfig, _info: &TestInfo) -> bool {
180        if self.project_names.is_empty() {
181            return true;
182        }
183
184        self.project_names
185            .iter()
186            .any(|project_name| &project.name == project_name)
187    }
188}
189
190/// Filter test cases by module name.
191pub struct ModuleFilter<'a> {
192    module_names: &'a [String],
193}
194
195impl Filter for ModuleFilter<'_> {
196    fn filter(&self, _project: &ProjectConfig, info: &TestInfo) -> bool {
197        if self.module_names.is_empty() {
198            return true;
199        }
200
201        self.module_names
202            .iter()
203            .any(|module_name| &info.module == module_name)
204    }
205}
206
207/// Filter test cases by test name.
208pub struct TestNameFilter<'a> {
209    test_names: &'a [String],
210}
211
212impl Filter for TestNameFilter<'_> {
213    fn filter(&self, _project: &ProjectConfig, info: &TestInfo) -> bool {
214        if self.test_names.is_empty() {
215            return true;
216        }
217
218        self.test_names
219            .iter()
220            .any(|test_name| &info.full_name() == test_name)
221    }
222}
223
224/// Filter test cases by test ignore config.
225pub struct TestIgnoreFilter {
226    test_ignores: HashMap<String, Vec<String>>,
227}
228
229impl Default for TestIgnoreFilter {
230    fn default() -> TestIgnoreFilter {
231        TestIgnoreFilter {
232            test_ignores: get_tanu_config()
233                .projects
234                .iter()
235                .map(|proj| (proj.name.clone(), proj.test_ignore.clone()))
236                .collect(),
237        }
238    }
239}
240
241impl Filter for TestIgnoreFilter {
242    fn filter(&self, project: &ProjectConfig, info: &TestInfo) -> bool {
243        let Some(test_ignore) = self.test_ignores.get(&project.name) else {
244            return true;
245        };
246
247        test_ignore
248            .iter()
249            .all(|test_name| &info.full_name() != test_name)
250    }
251}
252
253#[derive(Default)]
254pub struct Runner {
255    cfg: Config,
256    options: Options,
257    test_cases: Vec<(TestInfo, TestCaseFactory)>,
258    reporters: Vec<Box<dyn Reporter + Send>>,
259}
260
261impl Runner {
262    pub fn new() -> Runner {
263        Runner::with_config(get_tanu_config().clone())
264    }
265
266    pub fn with_config(cfg: Config) -> Runner {
267        Runner {
268            cfg,
269            options: Options::default(),
270            test_cases: Vec::new(),
271            reporters: Vec::new(),
272        }
273    }
274
275    pub fn capture_http(&mut self) {
276        self.options.capture_http = true;
277    }
278
279    pub fn capture_rust(&mut self) {
280        self.options.capture_rust = true;
281    }
282
283    pub fn terminate_channel(&mut self) {
284        self.options.terminate_channel = true;
285    }
286
287    pub fn add_reporter(&mut self, reporter: impl Reporter + 'static + Send) {
288        self.reporters.push(Box::new(reporter));
289    }
290
291    pub fn add_boxed_reporter(&mut self, reporter: Box<dyn Reporter + 'static + Send>) {
292        self.reporters.push(reporter);
293    }
294
295    /// Add a test case to the runner.
296    pub fn add_test(&mut self, name: &str, module: &str, factory: TestCaseFactory) {
297        self.test_cases.push((
298            TestInfo {
299                name: name.into(),
300                module: module.into(),
301            },
302            factory,
303        ));
304    }
305
306    pub fn set_concurrency(&mut self, concurrency: usize) {
307        self.options.concurrency = Some(concurrency);
308    }
309
310    /// Run tanu runner.
311    #[allow(clippy::too_many_lines)]
312    pub async fn run(
313        &mut self,
314        project_names: &[String],
315        module_names: &[String],
316        test_names: &[String],
317    ) -> eyre::Result<()> {
318        if self.options.capture_rust {
319            tracing_subscriber::fmt::init();
320        }
321
322        let mut reporters = std::mem::take(&mut self.reporters);
323
324        let project_filter = ProjectFilter { project_names };
325        let module_filter = ModuleFilter { module_names };
326        let test_name_filter = TestNameFilter { test_names };
327        let test_ignore_filter = TestIgnoreFilter::default();
328
329        let start = std::time::Instant::now();
330        let handles: FuturesUnordered<_> = {
331            // Create a semaphore to limit concurrency if specified
332            let semaphore = Arc::new(tokio::sync::Semaphore::new(
333                self.options.concurrency.unwrap_or(Semaphore::MAX_PERMITS),
334            ));
335
336            self.test_cases
337                .iter()
338                .flat_map(|(info, factory)| {
339                    let projects = self.cfg.projects.clone();
340                    let projects = if projects.is_empty() {
341                        vec![ProjectConfig {
342                            name: "default".into(),
343                            ..Default::default()
344                        }]
345                    } else {
346                        projects
347                    };
348                    projects
349                        .into_iter()
350                        .map(move |project| (project.clone(), info.clone(), factory.clone()))
351                })
352                .filter(move |(project, info, _)| test_name_filter.filter(project, info))
353                .filter(move |(project, info, _)| module_filter.filter(project, info))
354                .filter(move |(project, info, _)| project_filter.filter(project, info))
355                .filter(move |(project, info, _)| test_ignore_filter.filter(project, info))
356                .map(|(project, info, factory)| {
357                    let semaphore = semaphore.clone();
358                    tokio::spawn(async move {
359                        let _permit = semaphore.acquire().await.unwrap();
360                        config::PROJECT
361                            .scope(project.clone(), async {
362                                TEST_INFO
363                                    .scope(info.clone(), async {
364                                        let test_name = info.name.clone();
365                                        publish(EventBody::Start)?;
366
367                                        let retry_count =
368                                            AtomicUsize::new(project.retry.count.unwrap_or(0));
369                                        let f = || async {
370                                            let res = factory().await;
371
372                                            if res.is_err()
373                                                && retry_count.load(Ordering::SeqCst) > 0
374                                            {
375                                                publish(EventBody::Retry)?;
376                                                retry_count.fetch_sub(1, Ordering::SeqCst);
377                                            };
378                                            res
379                                        };
380                                        let started = std::time::Instant::now();
381                                        let fut = f.retry(project.retry.backoff());
382                                        let fut = std::panic::AssertUnwindSafe(fut).catch_unwind();
383                                        let res = fut.await;
384                                        let request_time = started.elapsed();
385
386                                        let result = match res {
387                                            Ok(Ok(_)) => {
388                                                debug!("{test_name} ok");
389                                                Ok(())
390                                            }
391                                            Ok(Err(e)) => {
392                                                debug!("{test_name} failed: {e:#}");
393                                                Err(Error::ErrorReturned(format!("{e:?}")))
394                                            }
395                                            Err(e) => {
396                                                let panic_message = if let Some(panic_message) =
397                                                    e.downcast_ref::<&str>()
398                                                {
399                                                    format!(
400                                                "{test_name} failed with message: {panic_message}"
401                                            )
402                                                } else if let Some(panic_message) =
403                                                    e.downcast_ref::<String>()
404                                                {
405                                                    format!(
406                                                "{test_name} failed with message: {panic_message}"
407                                            )
408                                                } else {
409                                                    format!(
410                                                        "{test_name} failed with unknown message"
411                                                    )
412                                                };
413                                                let e = eyre::eyre!(panic_message);
414                                                Err(Error::Panicked(format!("{e:?}")))
415                                            }
416                                        };
417
418                                        let is_err = result.is_err();
419                                        publish(EventBody::End(Test {
420                                            info,
421                                            request_time,
422                                            result,
423                                        }))?;
424
425                                        eyre::ensure!(!is_err);
426                                        eyre::Ok(())
427                                    })
428                                    .await
429                            })
430                            .await
431                    })
432                })
433                .collect()
434        };
435        debug!(
436            "created handles for {} test cases; took {}s",
437            handles.len(),
438            start.elapsed().as_secs_f32()
439        );
440
441        let reporters =
442            futures::future::join_all(reporters.iter_mut().map(|reporter| reporter.run().boxed()));
443
444        let mut has_any_error = false;
445        let options = self.options.clone();
446        let runner = async move {
447            let results = handles.collect::<Vec<_>>().await;
448            if results.is_empty() {
449                console::Term::stdout().write_line("no test cases found")?;
450            }
451            for result in results {
452                match result {
453                    Ok(res) => {
454                        if let Err(e) = res {
455                            debug!("test case failed: {e:#}");
456                            has_any_error = true;
457                        }
458                    }
459                    Err(e) => {
460                        if e.is_panic() {
461                            // Resume the panic on the main task
462                            error!("{e}");
463                            has_any_error = true;
464                        }
465                    }
466                }
467            }
468            debug!("all test finished. sending stop signal to the background tasks.");
469
470            if options.terminate_channel {
471                let Ok(mut guard) = CHANNEL.lock() else {
472                    eyre::bail!("failed to acquire runner channel lock");
473                };
474                guard.take(); // closing the runner channel.
475            }
476
477            if has_any_error {
478                eyre::bail!("one or more tests failed");
479            }
480
481            eyre::Ok(())
482        };
483
484        let (handles, reporters) = tokio::join!(runner, reporters);
485        for reporter in reporters {
486            if let Err(e) = reporter {
487                error!("reporter failed: {e:#}");
488            }
489        }
490
491        debug!("runner stopped");
492
493        handles
494    }
495
496    pub fn list(&self) -> Vec<&TestInfo> {
497        self.test_cases
498            .iter()
499            .map(|(meta, _test)| meta)
500            .collect::<Vec<_>>()
501    }
502}
503
504#[cfg(test)]
505mod test {
506    use super::*;
507    use crate::config::RetryConfig;
508
509    fn create_config() -> Config {
510        Config {
511            projects: vec![ProjectConfig {
512                name: "default".into(),
513                ..Default::default()
514            }],
515            ..Default::default()
516        }
517    }
518
519    fn create_config_with_retry() -> Config {
520        Config {
521            projects: vec![ProjectConfig {
522                name: "default".into(),
523                retry: RetryConfig {
524                    count: Some(1),
525                    ..Default::default()
526                },
527                ..Default::default()
528            }],
529            ..Default::default()
530        }
531    }
532
533    #[tokio::test]
534    async fn runner_fail_because_no_retry_configured() -> eyre::Result<()> {
535        let mut server = mockito::Server::new_async().await;
536        let m1 = server
537            .mock("GET", "/")
538            .with_status(500)
539            .expect(1)
540            .create_async()
541            .await;
542        let m2 = server
543            .mock("GET", "/")
544            .with_status(200)
545            .expect(0)
546            .create_async()
547            .await;
548
549        let factory: TestCaseFactory = Arc::new(move || {
550            let url = server.url();
551            Box::pin(async move {
552                let res = reqwest::get(url).await?;
553                if res.status().is_success() {
554                    Ok(())
555                } else {
556                    eyre::bail!("request failed")
557                }
558            })
559        });
560
561        let _runner_rx = subscribe()?;
562        let mut runner = Runner::with_config(create_config());
563        runner.add_test("retry_test", "module", factory);
564
565        let result = runner.run(&[], &[], &[]).await;
566        m1.assert_async().await;
567        m2.assert_async().await;
568
569        assert!(result.is_err());
570        Ok(())
571    }
572
573    #[tokio::test]
574    async fn runner_retry_successful_after_failure() -> eyre::Result<()> {
575        let mut server = mockito::Server::new_async().await;
576        let m1 = server
577            .mock("GET", "/")
578            .with_status(500)
579            .expect(1)
580            .create_async()
581            .await;
582        let m2 = server
583            .mock("GET", "/")
584            .with_status(200)
585            .expect(1)
586            .create_async()
587            .await;
588
589        let factory: TestCaseFactory = Arc::new(move || {
590            let url = server.url();
591            Box::pin(async move {
592                let res = reqwest::get(url).await?;
593                if res.status().is_success() {
594                    Ok(())
595                } else {
596                    eyre::bail!("request failed")
597                }
598            })
599        });
600
601        let _runner_rx = subscribe()?;
602        let mut runner = Runner::with_config(create_config_with_retry());
603        runner.add_test("retry_test", "module", factory);
604
605        let result = runner.run(&[], &[], &[]).await;
606        m1.assert_async().await;
607        m2.assert_async().await;
608
609        assert!(result.is_ok());
610
611        Ok(())
612    }
613}