tanu_core/
runner.rs

1/// tanu's test runner
2use backon::Retryable;
3use eyre::WrapErr;
4use futures::{stream::FuturesUnordered, FutureExt, StreamExt};
5use once_cell::sync::Lazy;
6use std::{
7    collections::HashMap,
8    ops::Deref,
9    pin::Pin,
10    sync::{Arc, Mutex},
11};
12use tokio::sync::broadcast;
13use tracing::*;
14
15use crate::{
16    config::{self, get_config, get_tanu_config, ProjectConfig},
17    http,
18    reporter::Reporter,
19    Config, ModuleName, ProjectName, TestName,
20};
21
22pub static CHANNEL: Lazy<Mutex<Option<broadcast::Sender<Message>>>> =
23    Lazy::new(|| Mutex::new(Some(broadcast::channel(1000).0)));
24
25pub fn publish(msg: Message) -> eyre::Result<()> {
26    let Ok(guard) = CHANNEL.lock() else {
27        eyre::bail!("failed to acquire runner channel lock");
28    };
29    let Some(tx) = guard.deref() else {
30        eyre::bail!("runner channel has been already closed");
31    };
32
33    tx.send(msg)
34        .wrap_err("failed to publish message to the runner channel")?;
35
36    Ok(())
37}
38
39/// Subscribe to the channel to see the real-time test execution events.
40pub fn subscribe() -> eyre::Result<broadcast::Receiver<Message>> {
41    let Ok(guard) = CHANNEL.lock() else {
42        eyre::bail!("failed to acquire runner channel lock");
43    };
44    let Some(tx) = guard.deref() else {
45        eyre::bail!("runner channel has been already closed");
46    };
47
48    Ok(tx.subscribe())
49}
50
51#[derive(Debug, Clone, thiserror::Error)]
52pub enum Error {
53    #[error("panic: {0}")]
54    Panicked(String),
55    #[error("error: {0}")]
56    ErrorReturned(String),
57}
58
59#[derive(Debug, Clone)]
60pub enum Message {
61    Start(ProjectName, ModuleName, TestName),
62    HttpLog(ProjectName, ModuleName, TestName, Box<http::Log>),
63    End(ProjectName, ModuleName, TestName, Test),
64}
65
66#[derive(Debug, Clone)]
67pub struct Test {
68    pub info: TestInfo,
69    pub result: Result<(), Error>,
70}
71
72#[derive(Debug, Clone)]
73pub struct TestInfo {
74    pub module: String,
75    pub name: String,
76}
77
78impl TestInfo {
79    /// Full test name including module
80    pub fn full_name(&self) -> String {
81        format!("{}::{}", self.module, self.name)
82    }
83
84    /// Unique test name including project and module names
85    pub fn unique_name(&self, project: &str) -> String {
86        format!("{project}::{}::{}", self.module, self.name)
87    }
88}
89
90type TestCaseFactory = Arc<
91    dyn Fn() -> Pin<Box<dyn futures::Future<Output = eyre::Result<()>> + Send + 'static>>
92        + Sync
93        + Send
94        + 'static,
95>;
96
97#[derive(Debug, Clone, Default)]
98pub struct Options {
99    pub debug: bool,
100    pub capture_http: bool,
101    pub capture_rust: bool,
102    pub terminate_channel: bool,
103}
104
105/// Test case filter trait.
106pub trait Filter {
107    fn filter(&self, project: &ProjectConfig, info: &TestInfo) -> bool;
108}
109
110/// Filter test cases by project name.
111pub struct ProjectFilter<'a> {
112    project_names: &'a [String],
113}
114
115impl Filter for ProjectFilter<'_> {
116    fn filter(&self, project: &ProjectConfig, _info: &TestInfo) -> bool {
117        if self.project_names.is_empty() {
118            return true;
119        }
120
121        self.project_names
122            .iter()
123            .any(|project_name| &project.name == project_name)
124    }
125}
126
127/// Filter test cases by module name.
128pub struct ModuleFilter<'a> {
129    module_names: &'a [String],
130}
131
132impl Filter for ModuleFilter<'_> {
133    fn filter(&self, _project: &ProjectConfig, info: &TestInfo) -> bool {
134        if self.module_names.is_empty() {
135            return true;
136        }
137
138        self.module_names
139            .iter()
140            .any(|module_name| &info.module == module_name)
141    }
142}
143
144/// Filter test cases by test name.
145pub struct TestNameFilter<'a> {
146    test_names: &'a [String],
147}
148
149impl Filter for TestNameFilter<'_> {
150    fn filter(&self, _project: &ProjectConfig, info: &TestInfo) -> bool {
151        if self.test_names.is_empty() {
152            return true;
153        }
154
155        self.test_names
156            .iter()
157            .any(|test_name| &info.full_name() == test_name)
158    }
159}
160
161/// Filter test cases by test ignore config.
162pub struct TestIgnoreFilter {
163    test_ignores: HashMap<String, Vec<String>>,
164}
165
166impl Default for TestIgnoreFilter {
167    fn default() -> TestIgnoreFilter {
168        TestIgnoreFilter {
169            test_ignores: get_tanu_config()
170                .projects
171                .iter()
172                .map(|proj| (proj.name.clone(), proj.test_ignore.clone()))
173                .collect(),
174        }
175    }
176}
177
178impl Filter for TestIgnoreFilter {
179    fn filter(&self, project: &ProjectConfig, info: &TestInfo) -> bool {
180        let Some(test_ignore) = self.test_ignores.get(&project.name) else {
181            return true;
182        };
183
184        test_ignore
185            .iter()
186            .all(|test_name| &info.full_name() != test_name)
187    }
188}
189
190#[derive(Default)]
191pub struct Runner {
192    cfg: Config,
193    options: Options,
194    test_cases: Vec<(TestInfo, TestCaseFactory)>,
195    reporters: Vec<Box<dyn Reporter + Send>>,
196}
197
198impl Runner {
199    pub fn new() -> Runner {
200        Runner::with_config(get_tanu_config().clone())
201    }
202
203    pub fn with_config(cfg: Config) -> Runner {
204        Runner {
205            cfg,
206            options: Options::default(),
207            test_cases: Vec::new(),
208            reporters: Vec::new(),
209        }
210    }
211
212    pub fn capture_http(&mut self) {
213        self.options.capture_http = true;
214    }
215
216    pub fn capture_rust(&mut self) {
217        self.options.capture_rust = true;
218    }
219
220    pub fn terminate_channel(&mut self) {
221        self.options.terminate_channel = true;
222    }
223
224    pub fn add_reporter(&mut self, reporter: impl Reporter + 'static + Send) {
225        self.reporters.push(Box::new(reporter));
226    }
227
228    /// Add a test case to the runner.
229    pub fn add_test(&mut self, name: &str, module: &str, factory: TestCaseFactory) {
230        self.test_cases.push((
231            TestInfo {
232                name: name.into(),
233                module: module.into(),
234            },
235            factory,
236        ));
237    }
238
239    /// Run tanu runner.
240    #[allow(clippy::too_many_lines)]
241    pub async fn run(
242        &mut self,
243        project_names: &[String],
244        module_names: &[String],
245        test_names: &[String],
246    ) -> eyre::Result<()> {
247        if self.options.capture_rust {
248            tracing_subscriber::fmt::init();
249        }
250
251        let mut reporters = std::mem::take(&mut self.reporters);
252
253        let project_filter = ProjectFilter { project_names };
254        let module_filter = ModuleFilter { module_names };
255        let test_name_filter = TestNameFilter { test_names };
256        let test_ignore_filter = TestIgnoreFilter::default();
257
258        let handles: FuturesUnordered<_> = self
259                .test_cases
260                .iter()
261                .flat_map(|(info, factory)| {
262                    let projects = self.cfg.projects.clone();
263                    let projects = if projects.is_empty() {
264                        vec![ProjectConfig {
265                            name: "default".into(),
266                            ..Default::default()
267                        }]
268                    } else {
269                        projects
270                    };
271                    projects
272                        .into_iter()
273                        .map(move |project| {
274                            (project.clone(), info.clone(), factory.clone())
275                        })
276                })
277                .filter(move |(project, info, _)| {
278                    test_name_filter.filter(project, info)
279                })
280                .filter(move |(project, info, _)| {
281                    module_filter.filter(project, info)
282                })
283                .filter(move |(project, info, _)| {
284                    project_filter.filter(project, info)
285                })
286                .filter(move |(project, info, _)| {
287                    test_ignore_filter.filter(project, info)
288                })
289                .map(|(project, info, factory)| {
290                    tokio::spawn(async move {
291                        config::PROJECT
292                            .scope(project.clone(), async {
293                                http::CHANNEL
294                                    .scope(
295                                        Arc::new(Mutex::new(Some(broadcast::channel(1000).0))),
296                                        async {
297                                            let test_name = &info.name;
298                                            let mut http_rx = http::subscribe()?;
299
300                                            let f= || async {factory().await};
301                                            let fut = f.retry(project.retry.backoff());
302                                            let fut =
303                                                std::panic::AssertUnwindSafe(fut).catch_unwind();
304                                            let res = fut.await;
305
306                                            publish(Message::Start(project.name.clone(), info.module.clone(), test_name.to_string())).wrap_err("failed to send Message::Start to the channel")?;
307
308                                            let result = match res {
309                                                Ok(Ok(_)) => {
310                                                    debug!("{test_name} ok");
311                                                    Ok(())
312                                                }
313                                                Ok(Err(e)) => {
314                                                    debug!("{test_name} failed: {e:#}");
315                                                    Err(Error::ErrorReturned(format!("{e:?}")))
316                                                }
317                                                Err(e) => {
318                                                    let panic_message =
319                                                        if let Some(panic_message) = e.downcast_ref::<&str>() {
320                                                            format!(
321                                                            "{test_name} failed with message: {panic_message}"
322                                                        )
323                                                        } else if let Some(panic_message) =
324                                                            e.downcast_ref::<String>()
325                                                        {
326                                                            format!(
327                                                            "{test_name} failed with message: {panic_message}"
328                                                        )
329                                                        } else {
330                                                            format!("{test_name} failed with unknown message")
331                                                        };
332                                                    let e = eyre::eyre!(panic_message);
333                                                    Err(Error::Panicked(format!("{e:?}")))
334                                                }
335                                            };
336
337                                            while let Ok(log) = http_rx.try_recv() {
338                                                publish(Message::HttpLog(project.name.clone(), info.module.clone(), test_name.clone(), Box::new(log))).wrap_err("failed to send Message::HttpLog to the channel")?;
339                                            }
340
341                                            let project = get_config();
342                                            let is_err = result.is_err();
343                                            publish(Message::End(project.name, info.module.clone(), test_name.clone(), Test { info, result })).wrap_err("failed to send Message::End to the channel")?;
344
345                                            eyre::ensure!(!is_err);
346                                            eyre::Ok(())
347                                        },
348                                    )
349                                    .await
350                            })
351                            .await
352                    })
353                })
354                .collect();
355
356        let reporters =
357            futures::future::join_all(reporters.iter_mut().map(|reporter| reporter.run().boxed()));
358
359        let mut has_any_error = false;
360        let options = self.options.clone();
361        let runner = async move {
362            let results = handles.collect::<Vec<_>>().await;
363            for result in results {
364                match result {
365                    Ok(res) => {
366                        if res.is_err() {
367                            has_any_error = true;
368                        }
369                    }
370                    Err(e) => {
371                        if e.is_panic() {
372                            // Resume the panic on the main task
373                            error!("{e}");
374                            has_any_error = true;
375                            println!("e={e:?}");
376                        }
377                    }
378                }
379            }
380            debug!("all test finished. sending stop signal to the background tasks.");
381
382            if options.terminate_channel {
383                let Ok(mut guard) = CHANNEL.lock() else {
384                    eyre::bail!("failed to acquire runner channel lock");
385                };
386                guard.take(); // closing the runner channel.
387            }
388
389            if has_any_error {
390                eyre::bail!("one or more tests failed");
391            }
392
393            eyre::Ok(())
394        };
395
396        let (handles, _reporters) = tokio::join!(runner, reporters);
397
398        debug!("runner stopped");
399
400        handles
401    }
402
403    pub fn list(&self) -> Vec<&TestInfo> {
404        self.test_cases
405            .iter()
406            .map(|(meta, _test)| meta)
407            .collect::<Vec<_>>()
408    }
409}
410
411#[cfg(test)]
412mod test {
413    use super::*;
414    use crate::config::RetryConfig;
415
416    fn create_config() -> Config {
417        Config {
418            projects: vec![ProjectConfig {
419                name: "default".into(),
420                ..Default::default()
421            }],
422            ..Default::default()
423        }
424    }
425
426    fn create_config_with_retry() -> Config {
427        Config {
428            projects: vec![ProjectConfig {
429                name: "default".into(),
430                retry: RetryConfig {
431                    count: Some(1),
432                    ..Default::default()
433                },
434                ..Default::default()
435            }],
436            ..Default::default()
437        }
438    }
439
440    #[tokio::test]
441    async fn runner_fail_because_no_retry_configured() -> eyre::Result<()> {
442        let mut server = mockito::Server::new_async().await;
443        let m1 = server
444            .mock("GET", "/")
445            .with_status(500)
446            .expect(1)
447            .create_async()
448            .await;
449        let m2 = server
450            .mock("GET", "/")
451            .with_status(200)
452            .expect(0)
453            .create_async()
454            .await;
455
456        let factory: TestCaseFactory = Arc::new(move || {
457            let url = server.url();
458            Box::pin(async move {
459                let res = reqwest::get(url).await?;
460                if res.status().is_success() {
461                    Ok(())
462                } else {
463                    eyre::bail!("request failed")
464                }
465            })
466        });
467
468        let _runner_rx = subscribe()?;
469        let mut runner = Runner::with_config(create_config());
470        runner.add_test("retry_test", "module", factory);
471
472        let result = runner.run(&[], &[], &[]).await;
473        m1.assert_async().await;
474        m2.assert_async().await;
475
476        assert!(result.is_err());
477        Ok(())
478    }
479
480    #[tokio::test]
481    async fn runner_retry_successful_after_failure() -> eyre::Result<()> {
482        let mut server = mockito::Server::new_async().await;
483        let m1 = server
484            .mock("GET", "/")
485            .with_status(500)
486            .expect(1)
487            .create_async()
488            .await;
489        let m2 = server
490            .mock("GET", "/")
491            .with_status(200)
492            .expect(1)
493            .create_async()
494            .await;
495
496        let factory: TestCaseFactory = Arc::new(move || {
497            let url = server.url();
498            Box::pin(async move {
499                let res = reqwest::get(url).await?;
500                if res.status().is_success() {
501                    Ok(())
502                } else {
503                    eyre::bail!("request failed")
504                }
505            })
506        });
507
508        let _runner_rx = subscribe()?;
509        let mut runner = Runner::with_config(create_config_with_retry());
510        runner.add_test("retry_test", "module", factory);
511
512        let result = runner.run(&[], &[], &[]).await;
513        m1.assert_async().await;
514        m2.assert_async().await;
515
516        assert!(result.is_ok());
517
518        Ok(())
519    }
520}