tanu_core/
runner.rs

1/// tanu's test runner
2use backon::Retryable;
3use eyre::WrapErr;
4use futures::{stream::FuturesUnordered, FutureExt, StreamExt};
5use once_cell::sync::Lazy;
6use std::{
7    collections::HashMap,
8    ops::Deref,
9    pin::Pin,
10    sync::{Arc, Mutex},
11};
12use tokio::sync::broadcast;
13use tracing::*;
14
15use crate::{
16    config::{self, get_config, get_tanu_config, ProjectConfig},
17    http,
18    reporter::Reporter,
19    Config, ModuleName, ProjectName, TestName,
20};
21
22pub static CHANNEL: Lazy<Mutex<Option<broadcast::Sender<Message>>>> =
23    Lazy::new(|| Mutex::new(Some(broadcast::channel(1000).0)));
24
25pub fn publish(msg: Message) -> eyre::Result<()> {
26    let Ok(guard) = CHANNEL.lock() else {
27        eyre::bail!("failed to acquire runner channel lock");
28    };
29    let Some(tx) = guard.deref() else {
30        eyre::bail!("runner channel has been already closed");
31    };
32
33    tx.send(msg)
34        .wrap_err("failed to publish message to the runner channel")?;
35
36    Ok(())
37}
38
39/// Subscribe to the channel to see the real-time test execution events.
40pub fn subscribe() -> eyre::Result<broadcast::Receiver<Message>> {
41    let Ok(guard) = CHANNEL.lock() else {
42        eyre::bail!("failed to acquire runner channel lock");
43    };
44    let Some(tx) = guard.deref() else {
45        eyre::bail!("runner channel has been already closed");
46    };
47
48    Ok(tx.subscribe())
49}
50
51#[derive(Debug, Clone, thiserror::Error)]
52pub enum Error {
53    #[error("panic: {0}")]
54    Panicked(String),
55    #[error("error: {0}")]
56    ErrorReturned(String),
57}
58
59#[derive(Debug, Clone)]
60pub enum Message {
61    Start(ProjectName, ModuleName, TestName),
62    HttpLog(ProjectName, ModuleName, TestName, Box<http::Log>),
63    End(ProjectName, ModuleName, TestName, Test),
64}
65
66#[derive(Debug, Clone)]
67pub struct Test {
68    pub metadata: TestMetadata,
69    pub result: Result<(), Error>,
70}
71
72#[derive(Debug, Clone)]
73pub struct TestMetadata {
74    pub name: String,
75    pub module: String,
76}
77
78impl TestMetadata {
79    /// Full test name including module
80    pub fn full_name(&self) -> String {
81        format!("{}::{}", self.module, self.name)
82    }
83}
84
85type TestCaseFactory = Arc<
86    dyn Fn() -> Pin<Box<dyn futures::Future<Output = eyre::Result<()>> + Send + 'static>>
87        + Sync
88        + Send
89        + 'static,
90>;
91
92#[derive(Debug, Clone, Default)]
93pub struct Options {
94    pub debug: bool,
95    pub capture_http: bool,
96    pub capture_rust: bool,
97    pub terminate_channel: bool,
98}
99
100/// Test case filter trait.
101pub trait Filter {
102    fn filter(&self, project: &ProjectConfig, metadata: &TestMetadata) -> bool;
103}
104
105/// Filter test cases by project name.
106pub struct ProjectFilter<'a> {
107    project_names: &'a [String],
108}
109
110impl Filter for ProjectFilter<'_> {
111    fn filter(&self, project: &ProjectConfig, _metadata: &TestMetadata) -> bool {
112        if self.project_names.is_empty() {
113            return true;
114        }
115
116        self.project_names
117            .iter()
118            .any(|project_name| &project.name == project_name)
119    }
120}
121
122/// Filter test cases by module name.
123pub struct ModuleFilter<'a> {
124    module_names: &'a [String],
125}
126
127impl Filter for ModuleFilter<'_> {
128    fn filter(&self, _project: &ProjectConfig, metadata: &TestMetadata) -> bool {
129        if self.module_names.is_empty() {
130            return true;
131        }
132
133        self.module_names
134            .iter()
135            .any(|module_name| &metadata.module == module_name)
136    }
137}
138
139/// Filter test cases by test name.
140pub struct TestNameFilter<'a> {
141    test_names: &'a [String],
142}
143
144impl Filter for TestNameFilter<'_> {
145    fn filter(&self, _project: &ProjectConfig, metadata: &TestMetadata) -> bool {
146        if self.test_names.is_empty() {
147            return true;
148        }
149
150        self.test_names
151            .iter()
152            .any(|test_name| &metadata.full_name() == test_name)
153    }
154}
155
156/// Filter test cases by test ignore config.
157pub struct TestIgnoreFilter {
158    test_ignores: HashMap<String, Vec<String>>,
159}
160
161impl Default for TestIgnoreFilter {
162    fn default() -> TestIgnoreFilter {
163        TestIgnoreFilter {
164            test_ignores: get_tanu_config()
165                .projects
166                .iter()
167                .map(|proj| (proj.name.clone(), proj.test_ignore.clone()))
168                .collect(),
169        }
170    }
171}
172
173impl Filter for TestIgnoreFilter {
174    fn filter(&self, project: &ProjectConfig, metadata: &TestMetadata) -> bool {
175        let Some(test_ignore) = self.test_ignores.get(&project.name) else {
176            return true;
177        };
178
179        test_ignore
180            .iter()
181            .all(|test_name| &metadata.full_name() != test_name)
182    }
183}
184
185#[derive(Default)]
186pub struct Runner {
187    cfg: Config,
188    options: Options,
189    test_cases: Vec<(TestMetadata, TestCaseFactory)>,
190    reporters: Vec<Box<dyn Reporter + Send>>,
191}
192
193impl Runner {
194    pub fn new() -> Runner {
195        Runner::with_config(get_tanu_config().clone())
196    }
197
198    pub fn with_config(cfg: Config) -> Runner {
199        Runner {
200            cfg,
201            options: Options::default(),
202            test_cases: Vec::new(),
203            reporters: Vec::new(),
204        }
205    }
206
207    pub fn capture_http(&mut self) {
208        self.options.capture_http = true;
209    }
210
211    pub fn capture_rust(&mut self) {
212        self.options.capture_rust = true;
213    }
214
215    pub fn terminate_channel(&mut self) {
216        self.options.terminate_channel = true;
217    }
218
219    pub fn add_reporter(&mut self, reporter: impl Reporter + 'static + Send) {
220        self.reporters.push(Box::new(reporter));
221    }
222
223    /// Add a test case to the runner.
224    pub fn add_test(&mut self, name: &str, module: &str, factory: TestCaseFactory) {
225        self.test_cases.push((
226            TestMetadata {
227                name: name.into(),
228                module: module.into(),
229            },
230            factory,
231        ));
232    }
233
234    /// Run tanu runner.
235    #[allow(clippy::too_many_lines)]
236    pub async fn run(
237        &mut self,
238        project_names: &[String],
239        module_names: &[String],
240        test_names: &[String],
241    ) -> eyre::Result<()> {
242        if self.options.capture_rust {
243            tracing_subscriber::fmt::init();
244        }
245
246        let mut reporters = std::mem::take(&mut self.reporters);
247
248        let project_filter = ProjectFilter { project_names };
249        let module_filter = ModuleFilter { module_names };
250        let test_name_filter = TestNameFilter { test_names };
251        let test_ignore_filter = TestIgnoreFilter::default();
252
253        let handles: FuturesUnordered<_> = self
254                .test_cases
255                .iter()
256                .flat_map(|(metadata, factory)| {
257                    let projects = self.cfg.projects.clone();
258                    let projects = if projects.is_empty() {
259                        vec![ProjectConfig {
260                            name: "default".into(),
261                            ..Default::default()
262                        }]
263                    } else {
264                        projects
265                    };
266                    projects
267                        .into_iter()
268                        .map(move |project| {
269                            (project.clone(), metadata.clone(), factory.clone())
270                        })
271                })
272                .filter(move |(project, metadata, _)| {
273                    test_name_filter.filter(project, metadata)
274                })
275                .filter(move |(project, metadata, _)| {
276                    module_filter.filter(project, metadata)
277                })
278                .filter(move |(project, metadata, _)| {
279                    project_filter.filter(project, metadata)
280                })
281                .filter(move |(project, metadata, _)| {
282                    test_ignore_filter.filter(project, metadata)
283                })
284                .map(|(project, metadata, factory)| {
285                    tokio::spawn(async move {
286                        config::PROJECT
287                            .scope(project.clone(), async {
288                                http::CHANNEL
289                                    .scope(
290                                        Arc::new(Mutex::new(Some(broadcast::channel(1000).0))),
291                                        async {
292                                            let test_name = &metadata.name;
293                                            let mut http_rx = http::subscribe()?;
294
295                                            let f= || async {factory().await};
296                                            let fut = f.retry(project.retry.backoff());
297                                            let fut =
298                                                std::panic::AssertUnwindSafe(fut).catch_unwind();
299                                            let res = fut.await;
300
301                                            publish(Message::Start(project.name.clone(), metadata.module.clone(), test_name.to_string()))?;
302
303                                            let result = match res {
304                                                Ok(Ok(_)) => {
305                                                    debug!("{test_name} ok");
306                                                    Ok(())
307                                                }
308                                                Ok(Err(e)) => {
309                                                    debug!("{test_name} failed: {e:#}");
310                                                    Err(Error::ErrorReturned(format!("{e:?}")))
311                                                }
312                                                Err(e) => {
313                                                    let panic_message =
314                                                        if let Some(panic_message) = e.downcast_ref::<&str>() {
315                                                            format!(
316                                                            "{test_name} failed with message: {panic_message}"
317                                                        )
318                                                        } else if let Some(panic_message) =
319                                                            e.downcast_ref::<String>()
320                                                        {
321                                                            format!(
322                                                            "{test_name} failed with message: {panic_message}"
323                                                        )
324                                                        } else {
325                                                            format!("{test_name} failed with unknown message")
326                                                        };
327                                                    let e = eyre::eyre!(panic_message);
328                                                    Err(Error::Panicked(format!("{e:?}")))
329                                                }
330                                            };
331
332                                            while let Ok(log) = http_rx.try_recv() {
333                                                publish(Message::HttpLog(
334                                                    project.name.clone(),
335                                                    metadata.module.clone(),
336                                                    test_name.clone(),
337                                                    Box::new(log),
338                                                ))?;
339                                            }
340
341                                            let project = get_config();
342                                            publish(Message::End(
343                                                project.name,
344                                                metadata.module.clone(),
345                                                test_name.clone(),
346                                                Test { metadata, result },
347                                            ))?;
348
349                                            eyre::Ok(())
350                                        },
351                                    )
352                                    .await
353                            })
354                            .await
355                    })
356                })
357                .collect();
358
359        let reporters =
360            futures::future::join_all(reporters.iter_mut().map(|reporter| reporter.run().boxed()));
361
362        let options = self.options.clone();
363        let runner = async move {
364            let results = handles.collect::<Vec<_>>().await;
365            for result in results {
366                if let Err(e) = result {
367                    if e.is_panic() {
368                        // Resume the panic on the main task
369                        error!("{e}");
370                    }
371                }
372            }
373            debug!("all test finished. sending stop signal to the background tasks.");
374
375            if options.terminate_channel {
376                let Ok(mut guard) = CHANNEL.lock() else {
377                    eyre::bail!("failed to acquire runner channel lock");
378                };
379                guard.take(); // closing the runner channel.
380            }
381
382            eyre::Ok(())
383        };
384
385        let (_handles, _reporters) = tokio::join!(runner, reporters);
386
387        debug!("runner stopped");
388
389        Ok(())
390    }
391
392    pub fn list(&self) -> Vec<&TestMetadata> {
393        self.test_cases
394            .iter()
395            .map(|(meta, _test)| meta)
396            .collect::<Vec<_>>()
397    }
398}
399
400#[cfg(test)]
401mod test {
402    use super::*;
403    use crate::config::RetryConfig;
404
405    fn create_config() -> Config {
406        Config {
407            projects: vec![ProjectConfig {
408                name: "default".into(),
409                ..Default::default()
410            }],
411        }
412    }
413
414    fn create_config_with_retry() -> Config {
415        Config {
416            projects: vec![ProjectConfig {
417                name: "default".into(),
418                retry: RetryConfig {
419                    count: Some(1),
420                    ..Default::default()
421                },
422                ..Default::default()
423            }],
424        }
425    }
426
427    #[tokio::test]
428    async fn runner_fail_because_no_retry_configured() {
429        let mut server = mockito::Server::new_async().await;
430        let m1 = server
431            .mock("GET", "/")
432            .with_status(500)
433            .expect(1)
434            .create_async()
435            .await;
436        let m2 = server
437            .mock("GET", "/")
438            .with_status(200)
439            .expect(0)
440            .create_async()
441            .await;
442
443        let factory: TestCaseFactory = Arc::new(move || {
444            let url = server.url();
445            Box::pin(async move {
446                let res = reqwest::get(url).await?;
447                if res.status().is_success() {
448                    Ok(())
449                } else {
450                    eyre::bail!("request failed")
451                }
452            })
453        });
454
455        let mut runner = Runner::with_config(create_config());
456        runner.add_test("retry_test", "module", factory);
457
458        let result = runner.run(&[], &[], &[]).await;
459        m1.assert_async().await;
460        m2.assert_async().await;
461
462        assert!(result.is_ok());
463    }
464
465    #[tokio::test]
466    async fn runner_retry_after_failure() {
467        let mut server = mockito::Server::new_async().await;
468        let m1 = server
469            .mock("GET", "/")
470            .with_status(500)
471            .expect(1)
472            .create_async()
473            .await;
474        let m2 = server
475            .mock("GET", "/")
476            .with_status(200)
477            .expect(1)
478            .create_async()
479            .await;
480
481        let factory: TestCaseFactory = Arc::new(move || {
482            let url = server.url();
483            Box::pin(async move {
484                let res = reqwest::get(url).await?;
485                if res.status().is_success() {
486                    Ok(())
487                } else {
488                    eyre::bail!("request failed")
489                }
490            })
491        });
492
493        let mut runner = Runner::with_config(create_config_with_retry());
494        runner.add_test("retry_test", "module", factory);
495
496        let result = runner.run(&[], &[], &[]).await;
497        m1.assert_async().await;
498        m2.assert_async().await;
499
500        assert!(result.is_ok());
501    }
502}