1use backon::Retryable;
3use eyre::WrapErr;
4use futures::{stream::FuturesUnordered, FutureExt, StreamExt};
5use once_cell::sync::Lazy;
6use std::{
7 collections::HashMap,
8 ops::Deref,
9 pin::Pin,
10 sync::{Arc, Mutex},
11};
12use tokio::sync::broadcast;
13use tracing::*;
14
15use crate::{
16 config::{self, get_config, get_tanu_config, ProjectConfig},
17 http,
18 reporter::Reporter,
19 Config, ModuleName, ProjectName, TestName,
20};
21
22pub static CHANNEL: Lazy<Mutex<Option<broadcast::Sender<Message>>>> =
23 Lazy::new(|| Mutex::new(Some(broadcast::channel(1000).0)));
24
25pub fn publish(msg: Message) -> eyre::Result<()> {
26 let Ok(guard) = CHANNEL.lock() else {
27 eyre::bail!("failed to acquire runner channel lock");
28 };
29 let Some(tx) = guard.deref() else {
30 eyre::bail!("runner channel has been already closed");
31 };
32
33 tx.send(msg)
34 .wrap_err("failed to publish message to the runner channel")?;
35
36 Ok(())
37}
38
39pub fn subscribe() -> eyre::Result<broadcast::Receiver<Message>> {
41 let Ok(guard) = CHANNEL.lock() else {
42 eyre::bail!("failed to acquire runner channel lock");
43 };
44 let Some(tx) = guard.deref() else {
45 eyre::bail!("runner channel has been already closed");
46 };
47
48 Ok(tx.subscribe())
49}
50
51#[derive(Debug, Clone, thiserror::Error)]
52pub enum Error {
53 #[error("panic: {0}")]
54 Panicked(String),
55 #[error("error: {0}")]
56 ErrorReturned(String),
57}
58
59#[derive(Debug, Clone)]
60pub enum Message {
61 Start(ProjectName, ModuleName, TestName),
62 HttpLog(ProjectName, ModuleName, TestName, Box<http::Log>),
63 End(ProjectName, ModuleName, TestName, Test),
64}
65
66#[derive(Debug, Clone)]
67pub struct Test {
68 pub metadata: TestMetadata,
69 pub result: Result<(), Error>,
70}
71
72#[derive(Debug, Clone)]
73pub struct TestMetadata {
74 pub name: String,
75 pub module: String,
76}
77
78impl TestMetadata {
79 pub fn full_name(&self) -> String {
81 format!("{}::{}", self.module, self.name)
82 }
83}
84
85type TestCaseFactory = Arc<
86 dyn Fn() -> Pin<Box<dyn futures::Future<Output = eyre::Result<()>> + Send + 'static>>
87 + Sync
88 + Send
89 + 'static,
90>;
91
92#[derive(Debug, Clone, Default)]
93pub struct Options {
94 pub debug: bool,
95 pub capture_http: bool,
96 pub capture_rust: bool,
97 pub terminate_channel: bool,
98}
99
100pub trait Filter {
102 fn filter(&self, project: &ProjectConfig, metadata: &TestMetadata) -> bool;
103}
104
105pub struct ProjectFilter<'a> {
107 project_names: &'a [String],
108}
109
110impl Filter for ProjectFilter<'_> {
111 fn filter(&self, project: &ProjectConfig, _metadata: &TestMetadata) -> bool {
112 if self.project_names.is_empty() {
113 return true;
114 }
115
116 self.project_names
117 .iter()
118 .any(|project_name| &project.name == project_name)
119 }
120}
121
122pub struct ModuleFilter<'a> {
124 module_names: &'a [String],
125}
126
127impl Filter for ModuleFilter<'_> {
128 fn filter(&self, _project: &ProjectConfig, metadata: &TestMetadata) -> bool {
129 if self.module_names.is_empty() {
130 return true;
131 }
132
133 self.module_names
134 .iter()
135 .any(|module_name| &metadata.module == module_name)
136 }
137}
138
139pub struct TestNameFilter<'a> {
141 test_names: &'a [String],
142}
143
144impl Filter for TestNameFilter<'_> {
145 fn filter(&self, _project: &ProjectConfig, metadata: &TestMetadata) -> bool {
146 if self.test_names.is_empty() {
147 return true;
148 }
149
150 self.test_names
151 .iter()
152 .any(|test_name| &metadata.full_name() == test_name)
153 }
154}
155
156pub struct TestIgnoreFilter {
158 test_ignores: HashMap<String, Vec<String>>,
159}
160
161impl Default for TestIgnoreFilter {
162 fn default() -> TestIgnoreFilter {
163 TestIgnoreFilter {
164 test_ignores: get_tanu_config()
165 .projects
166 .iter()
167 .map(|proj| (proj.name.clone(), proj.test_ignore.clone()))
168 .collect(),
169 }
170 }
171}
172
173impl Filter for TestIgnoreFilter {
174 fn filter(&self, project: &ProjectConfig, metadata: &TestMetadata) -> bool {
175 let Some(test_ignore) = self.test_ignores.get(&project.name) else {
176 return true;
177 };
178
179 test_ignore
180 .iter()
181 .all(|test_name| &metadata.full_name() != test_name)
182 }
183}
184
185#[derive(Default)]
186pub struct Runner {
187 cfg: Config,
188 options: Options,
189 test_cases: Vec<(TestMetadata, TestCaseFactory)>,
190 reporters: Vec<Box<dyn Reporter + Send>>,
191}
192
193impl Runner {
194 pub fn new() -> Runner {
195 Runner::with_config(get_tanu_config().clone())
196 }
197
198 pub fn with_config(cfg: Config) -> Runner {
199 Runner {
200 cfg,
201 options: Options::default(),
202 test_cases: Vec::new(),
203 reporters: Vec::new(),
204 }
205 }
206
207 pub fn capture_http(&mut self) {
208 self.options.capture_http = true;
209 }
210
211 pub fn capture_rust(&mut self) {
212 self.options.capture_rust = true;
213 }
214
215 pub fn terminate_channel(&mut self) {
216 self.options.terminate_channel = true;
217 }
218
219 pub fn add_reporter(&mut self, reporter: impl Reporter + 'static + Send) {
220 self.reporters.push(Box::new(reporter));
221 }
222
223 pub fn add_test(&mut self, name: &str, module: &str, factory: TestCaseFactory) {
225 self.test_cases.push((
226 TestMetadata {
227 name: name.into(),
228 module: module.into(),
229 },
230 factory,
231 ));
232 }
233
234 #[allow(clippy::too_many_lines)]
236 pub async fn run(
237 &mut self,
238 project_names: &[String],
239 module_names: &[String],
240 test_names: &[String],
241 ) -> eyre::Result<()> {
242 if self.options.capture_rust {
243 tracing_subscriber::fmt::init();
244 }
245
246 let mut reporters = std::mem::take(&mut self.reporters);
247
248 let project_filter = ProjectFilter { project_names };
249 let module_filter = ModuleFilter { module_names };
250 let test_name_filter = TestNameFilter { test_names };
251 let test_ignore_filter = TestIgnoreFilter::default();
252
253 let handles: FuturesUnordered<_> = self
254 .test_cases
255 .iter()
256 .flat_map(|(metadata, factory)| {
257 let projects = self.cfg.projects.clone();
258 let projects = if projects.is_empty() {
259 vec![ProjectConfig {
260 name: "default".into(),
261 ..Default::default()
262 }]
263 } else {
264 projects
265 };
266 projects
267 .into_iter()
268 .map(move |project| {
269 (project.clone(), metadata.clone(), factory.clone())
270 })
271 })
272 .filter(move |(project, metadata, _)| {
273 test_name_filter.filter(project, metadata)
274 })
275 .filter(move |(project, metadata, _)| {
276 module_filter.filter(project, metadata)
277 })
278 .filter(move |(project, metadata, _)| {
279 project_filter.filter(project, metadata)
280 })
281 .filter(move |(project, metadata, _)| {
282 test_ignore_filter.filter(project, metadata)
283 })
284 .map(|(project, metadata, factory)| {
285 tokio::spawn(async move {
286 config::PROJECT
287 .scope(project.clone(), async {
288 http::CHANNEL
289 .scope(
290 Arc::new(Mutex::new(Some(broadcast::channel(1000).0))),
291 async {
292 let test_name = &metadata.name;
293 let mut http_rx = http::subscribe()?;
294
295 let f= || async {factory().await};
296 let fut = f.retry(project.retry.backoff());
297 let fut =
298 std::panic::AssertUnwindSafe(fut).catch_unwind();
299 let res = fut.await;
300
301 publish(Message::Start(project.name.clone(), metadata.module.clone(), test_name.to_string()))?;
302
303 let result = match res {
304 Ok(Ok(_)) => {
305 debug!("{test_name} ok");
306 Ok(())
307 }
308 Ok(Err(e)) => {
309 debug!("{test_name} failed: {e:#}");
310 Err(Error::ErrorReturned(format!("{e:?}")))
311 }
312 Err(e) => {
313 let panic_message =
314 if let Some(panic_message) = e.downcast_ref::<&str>() {
315 format!(
316 "{test_name} failed with message: {panic_message}"
317 )
318 } else if let Some(panic_message) =
319 e.downcast_ref::<String>()
320 {
321 format!(
322 "{test_name} failed with message: {panic_message}"
323 )
324 } else {
325 format!("{test_name} failed with unknown message")
326 };
327 let e = eyre::eyre!(panic_message);
328 Err(Error::Panicked(format!("{e:?}")))
329 }
330 };
331
332 while let Ok(log) = http_rx.try_recv() {
333 publish(Message::HttpLog(
334 project.name.clone(),
335 metadata.module.clone(),
336 test_name.clone(),
337 Box::new(log),
338 ))?;
339 }
340
341 let project = get_config();
342 publish(Message::End(
343 project.name,
344 metadata.module.clone(),
345 test_name.clone(),
346 Test { metadata, result },
347 ))?;
348
349 eyre::Ok(())
350 },
351 )
352 .await
353 })
354 .await
355 })
356 })
357 .collect();
358
359 let reporters =
360 futures::future::join_all(reporters.iter_mut().map(|reporter| reporter.run().boxed()));
361
362 let options = self.options.clone();
363 let runner = async move {
364 let results = handles.collect::<Vec<_>>().await;
365 for result in results {
366 if let Err(e) = result {
367 if e.is_panic() {
368 error!("{e}");
370 }
371 }
372 }
373 debug!("all test finished. sending stop signal to the background tasks.");
374
375 if options.terminate_channel {
376 let Ok(mut guard) = CHANNEL.lock() else {
377 eyre::bail!("failed to acquire runner channel lock");
378 };
379 guard.take(); }
381
382 eyre::Ok(())
383 };
384
385 let (_handles, _reporters) = tokio::join!(runner, reporters);
386
387 debug!("runner stopped");
388
389 Ok(())
390 }
391
392 pub fn list(&self) -> Vec<&TestMetadata> {
393 self.test_cases
394 .iter()
395 .map(|(meta, _test)| meta)
396 .collect::<Vec<_>>()
397 }
398}
399
400#[cfg(test)]
401mod test {
402 use super::*;
403 use crate::config::RetryConfig;
404
405 fn create_config() -> Config {
406 Config {
407 projects: vec![ProjectConfig {
408 name: "default".into(),
409 ..Default::default()
410 }],
411 }
412 }
413
414 fn create_config_with_retry() -> Config {
415 Config {
416 projects: vec![ProjectConfig {
417 name: "default".into(),
418 retry: RetryConfig {
419 count: Some(1),
420 ..Default::default()
421 },
422 ..Default::default()
423 }],
424 }
425 }
426
427 #[tokio::test]
428 async fn runner_fail_because_no_retry_configured() {
429 let mut server = mockito::Server::new_async().await;
430 let m1 = server
431 .mock("GET", "/")
432 .with_status(500)
433 .expect(1)
434 .create_async()
435 .await;
436 let m2 = server
437 .mock("GET", "/")
438 .with_status(200)
439 .expect(0)
440 .create_async()
441 .await;
442
443 let factory: TestCaseFactory = Arc::new(move || {
444 let url = server.url();
445 Box::pin(async move {
446 let res = reqwest::get(url).await?;
447 if res.status().is_success() {
448 Ok(())
449 } else {
450 eyre::bail!("request failed")
451 }
452 })
453 });
454
455 let mut runner = Runner::with_config(create_config());
456 runner.add_test("retry_test", "module", factory);
457
458 let result = runner.run(&[], &[], &[]).await;
459 m1.assert_async().await;
460 m2.assert_async().await;
461
462 assert!(result.is_ok());
463 }
464
465 #[tokio::test]
466 async fn runner_retry_after_failure() {
467 let mut server = mockito::Server::new_async().await;
468 let m1 = server
469 .mock("GET", "/")
470 .with_status(500)
471 .expect(1)
472 .create_async()
473 .await;
474 let m2 = server
475 .mock("GET", "/")
476 .with_status(200)
477 .expect(1)
478 .create_async()
479 .await;
480
481 let factory: TestCaseFactory = Arc::new(move || {
482 let url = server.url();
483 Box::pin(async move {
484 let res = reqwest::get(url).await?;
485 if res.status().is_success() {
486 Ok(())
487 } else {
488 eyre::bail!("request failed")
489 }
490 })
491 });
492
493 let mut runner = Runner::with_config(create_config_with_retry());
494 runner.add_test("retry_test", "module", factory);
495
496 let result = runner.run(&[], &[], &[]).await;
497 m1.assert_async().await;
498 m2.assert_async().await;
499
500 assert!(result.is_ok());
501 }
502}