1use chrono::{DateTime, Utc};
93use once_cell::sync::Lazy;
94use serde::{de::DeserializeOwned, Deserialize};
95use std::{collections::HashMap, io::Read, path::Path, sync::Arc, time::Duration};
96use toml::Value as TomlValue;
97use tracing::*;
98
99use crate::{Error, Result};
100
101#[derive(Debug, Clone, Default, PartialEq)]
103pub enum CaptureHttpMode {
104 Off,
106 All,
108 #[default]
110 OnFailure,
111}
112
113impl<'de> serde::Deserialize<'de> for CaptureHttpMode {
114 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
115 where
116 D: serde::Deserializer<'de>,
117 {
118 struct CaptureHttpModeVisitor;
119
120 impl<'de> serde::de::Visitor<'de> for CaptureHttpModeVisitor {
121 type Value = CaptureHttpMode;
122
123 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
124 formatter.write_str(r#"a boolean or one of "all", "on-failure""#)
125 }
126
127 fn visit_bool<E: serde::de::Error>(
128 self,
129 v: bool,
130 ) -> std::result::Result<Self::Value, E> {
131 Ok(if v {
132 CaptureHttpMode::All
133 } else {
134 CaptureHttpMode::Off
135 })
136 }
137
138 fn visit_str<E: serde::de::Error>(
139 self,
140 v: &str,
141 ) -> std::result::Result<Self::Value, E> {
142 match v {
143 "all" => Ok(CaptureHttpMode::All),
144 "on-failure" => Ok(CaptureHttpMode::OnFailure),
145 _ => Err(E::invalid_value(serde::de::Unexpected::Str(v), &self)),
146 }
147 }
148 }
149
150 deserializer.deserialize_any(CaptureHttpModeVisitor)
151 }
152}
153
154const TANU_CONFIG_ENV: &str = "TANU_CONFIG";
156
157static CONFIG: Lazy<Config> = Lazy::new(|| {
158 let _ = dotenv::dotenv();
159 Config::load().unwrap_or_default()
160});
161
162tokio::task_local! {
163 pub static PROJECT: Arc<ProjectConfig>;
164}
165
166#[doc(hidden)]
167pub fn get_tanu_config() -> &'static Config {
168 &CONFIG
169}
170
171pub fn get_config() -> Arc<ProjectConfig> {
174 PROJECT.get()
175}
176
177#[derive(Debug, Clone)]
179pub struct Config {
180 pub projects: Vec<Arc<ProjectConfig>>,
181 pub tui: Tui,
183 pub runner: Runner,
185}
186
187impl Default for Config {
188 fn default() -> Self {
189 Config {
190 projects: vec![Arc::new(ProjectConfig {
191 name: "default".to_string(),
192 ..Default::default()
193 })],
194 tui: Tui::default(),
195 runner: Runner::default(),
196 }
197 }
198}
199
200#[derive(Debug, Clone, Default, Deserialize)]
202pub struct Tui {
203 #[serde(default)]
204 pub payload: Payload,
205}
206
207#[derive(Debug, Clone, Default, Deserialize)]
208pub struct Payload {
209 pub color_theme: Option<String>,
211}
212
213#[derive(Debug, Clone, Default, Deserialize)]
215pub struct Runner {
216 #[serde(default)]
218 pub capture_http: Option<CaptureHttpMode>,
219 #[serde(default)]
221 pub capture_rust: Option<bool>,
222 #[serde(default)]
224 pub show_sensitive: Option<bool>,
225 #[serde(default)]
227 pub concurrency: Option<usize>,
228 #[serde(default)]
230 pub fail_fast: Option<bool>,
231}
232
233impl Config {
234 fn load_from(path: &Path) -> Result<Config> {
236 let Ok(mut file) = std::fs::File::open(path) else {
237 return Ok(Config::default());
238 };
239
240 let mut buf = String::new();
241 file.read_to_string(&mut buf)
242 .map_err(|e| Error::LoadError(e.to_string()))?;
243
244 #[derive(Deserialize)]
245 struct ConfigHelper {
246 #[serde(default)]
247 projects: Vec<ProjectConfig>,
248 #[serde(default)]
249 tui: Tui,
250 #[serde(default)]
251 runner: Runner,
252 }
253
254 let helper: ConfigHelper = toml::from_str(&buf).map_err(|e| {
255 Error::LoadError(format!(
256 "failed to deserialize tanu.toml into tanu::Config: {e}"
257 ))
258 })?;
259
260 let mut cfg = Config {
261 projects: helper.projects.into_iter().map(Arc::new).collect(),
262 tui: helper.tui,
263 runner: helper.runner,
264 };
265
266 debug!("tanu.toml was successfully loaded: {cfg:#?}");
267
268 cfg.load_env();
269
270 Ok(cfg)
271 }
272
273 fn load() -> Result<Config> {
279 match std::env::var(TANU_CONFIG_ENV) {
280 Ok(path) => {
281 let path = Path::new(&path);
282
283 if path.extension().is_none_or(|ext| ext != "toml")
285 && !path.to_string_lossy().contains(std::path::MAIN_SEPARATOR)
286 && !path.to_string_lossy().contains('/')
287 {
288 return Err(Error::LoadError(format!(
289 "{TANU_CONFIG_ENV} should be a path to a config file, not a config value. \
290 Got: {:?}. Use TANU_<KEY>=value for config values instead.",
291 path
292 )));
293 }
294
295 if !path.exists() {
296 return Err(Error::LoadError(format!(
297 "Config file specified by {TANU_CONFIG_ENV} not found: {:?}",
298 path
299 )));
300 }
301
302 debug!("Loading config from {TANU_CONFIG_ENV}={:?}", path);
303 Config::load_from(path)
304 }
305 Err(_) => Config::load_from(Path::new("tanu.toml")),
306 }
307 }
308
309 fn load_env(&mut self) {
319 static PREFIX: &str = "TANU";
320
321 let global_prefix = format!("{PREFIX}_");
322 let project_prefixes: Vec<_> = self
323 .projects
324 .iter()
325 .map(|p| format!("{PREFIX}_{}_", p.name.to_uppercase()))
326 .collect();
327 debug!("Loading global configuration from env");
328 let global_vars: HashMap<_, _> = std::env::vars()
329 .filter_map(|(k, v)| {
330 if k == TANU_CONFIG_ENV {
332 let path = Path::new(&v);
334 if path.extension().is_none_or(|ext| ext != "toml")
335 && !v.contains(std::path::MAIN_SEPARATOR)
336 && !v.contains('/')
337 {
338 error!(
339 "{TANU_CONFIG_ENV} is reserved for specifying the config file path, \
340 not a config value. Use TANU_<KEY>=value for config values instead. \
341 Got: {TANU_CONFIG_ENV}={v:?}"
342 );
343 }
344 return None;
345 }
346
347 let is_project_var = project_prefixes.iter().any(|pp| k.contains(pp));
348 if is_project_var {
349 return None;
350 }
351
352 k.find(&global_prefix)?;
353 Some((
354 k[global_prefix.len()..].to_string().to_lowercase(),
355 TomlValue::String(v),
356 ))
357 })
358 .collect();
359
360 debug!("Loading project configuration from env");
361 for project_arc in &mut self.projects {
362 let project_prefix = format!("{PREFIX}_{}_", project_arc.name.to_uppercase());
363 let vars: HashMap<_, _> = std::env::vars()
364 .filter_map(|(k, v)| {
365 k.find(&project_prefix)?;
366 Some((
367 k[project_prefix.len()..].to_string().to_lowercase(),
368 TomlValue::String(v),
369 ))
370 })
371 .collect();
372 let project = Arc::make_mut(project_arc);
373 project.data.extend(vars);
374 project.data.extend(global_vars.clone());
375 }
376
377 debug!("tanu configuration loaded from env: {self:#?}");
378 }
379
380 pub fn color_theme(&self) -> Option<&str> {
382 self.tui.payload.color_theme.as_deref()
383 }
384}
385
386#[derive(Debug, Clone, Default, Deserialize)]
388pub struct ProjectConfig {
389 pub name: String,
391 #[serde(flatten)]
393 pub data: HashMap<String, TomlValue>,
394 #[serde(default)]
396 pub test_ignore: Vec<String>,
397 #[serde(default)]
398 pub retry: RetryConfig,
399}
400
401impl ProjectConfig {
402 pub fn get(&self, key: impl AsRef<str>) -> Result<&TomlValue> {
403 let key = key.as_ref();
404 self.data
405 .get(key)
406 .ok_or_else(|| Error::ValueNotFound(key.to_string()))
407 }
408
409 pub fn get_str(&self, key: impl AsRef<str>) -> Result<&str> {
410 let key = key.as_ref();
411 self.get(key)?
412 .as_str()
413 .ok_or_else(|| Error::ValueNotFound(key.to_string()))
414 }
415
416 pub fn get_int(&self, key: impl AsRef<str>) -> Result<i64> {
417 self.get_str(key)?
418 .parse()
419 .map_err(|e| Error::ValueError(eyre::Error::from(e)))
420 }
421
422 pub fn get_float(&self, key: impl AsRef<str>) -> Result<f64> {
423 self.get_str(key)?
424 .parse()
425 .map_err(|e| Error::ValueError(eyre::Error::from(e)))
426 }
427
428 pub fn get_bool(&self, key: impl AsRef<str>) -> Result<bool> {
429 self.get_str(key)?
430 .parse()
431 .map_err(|e| Error::ValueError(eyre::Error::from(e)))
432 }
433
434 pub fn get_datetime(&self, key: impl AsRef<str>) -> Result<DateTime<Utc>> {
435 self.get_str(key)?
436 .parse::<DateTime<Utc>>()
437 .map_err(|e| Error::ValueError(eyre::Error::from(e)))
438 }
439
440 pub fn get_array<T: DeserializeOwned>(&self, key: impl AsRef<str>) -> Result<Vec<T>> {
441 serde_json::from_str(self.get_str(key)?)
442 .map_err(|e| Error::ValueError(eyre::Error::from(e)))
443 }
444
445 pub fn get_object<T: DeserializeOwned>(&self, key: impl AsRef<str>) -> Result<T> {
446 serde_json::from_str(self.get_str(key)?)
447 .map_err(|e| Error::ValueError(eyre::Error::from(e)))
448 }
449}
450
451#[derive(Debug, Clone, Deserialize)]
452pub struct RetryConfig {
453 #[serde(default)]
455 pub count: Option<usize>,
456 #[serde(default)]
458 pub factor: Option<f32>,
459 #[serde(default)]
461 pub jitter: Option<bool>,
462 #[serde(default)]
464 #[serde(with = "humantime_serde")]
465 pub min_delay: Option<Duration>,
466 #[serde(default)]
468 #[serde(with = "humantime_serde")]
469 pub max_delay: Option<Duration>,
470}
471
472impl Default for RetryConfig {
473 fn default() -> Self {
474 RetryConfig {
475 count: Some(0),
476 factor: Some(2.0),
477 jitter: Some(false),
478 min_delay: Some(Duration::from_secs(1)),
479 max_delay: Some(Duration::from_secs(60)),
480 }
481 }
482}
483
484impl RetryConfig {
485 pub fn backoff(&self) -> backon::ExponentialBuilder {
486 let builder = backon::ExponentialBuilder::new()
487 .with_max_times(self.count.unwrap_or_default())
488 .with_factor(self.factor.unwrap_or(2.0))
489 .with_min_delay(self.min_delay.unwrap_or(Duration::from_secs(1)))
490 .with_max_delay(self.max_delay.unwrap_or(Duration::from_secs(60)));
491
492 if self.jitter.unwrap_or_default() {
493 builder.with_jitter()
494 } else {
495 builder
496 }
497 }
498}
499
500#[cfg(test)]
501mod test {
502 use super::*;
503 use pretty_assertions::assert_eq;
504 use std::{time::Duration, vec};
505 use test_case::test_case;
506
507 mod capture_http_mode {
508 use super::CaptureHttpMode;
509 use pretty_assertions::assert_eq;
510
511 fn from_toml(s: &str) -> Result<CaptureHttpMode, toml::de::Error> {
512 #[derive(serde::Deserialize)]
513 struct Wrapper {
514 mode: CaptureHttpMode,
515 }
516 let w: Wrapper = toml::from_str(&format!("mode = {s}"))?;
517 Ok(w.mode)
518 }
519
520 #[test]
521 fn bool_true_maps_to_all() {
522 assert_eq!(from_toml("true").unwrap(), CaptureHttpMode::All);
523 }
524
525 #[test]
526 fn bool_false_maps_to_off() {
527 assert_eq!(from_toml("false").unwrap(), CaptureHttpMode::Off);
528 }
529
530 #[test]
531 fn string_all_maps_to_all() {
532 assert_eq!(from_toml(r#""all""#).unwrap(), CaptureHttpMode::All);
533 }
534
535 #[test]
536 fn string_on_failure_maps_to_on_failure() {
537 assert_eq!(
538 from_toml(r#""on-failure""#).unwrap(),
539 CaptureHttpMode::OnFailure
540 );
541 }
542
543 #[test]
544 fn invalid_string_returns_error() {
545 assert!(from_toml(r#""invalid""#).is_err());
546 }
547
548 #[test]
549 fn runner_capture_http_field_accepts_bool() {
550 #[derive(serde::Deserialize)]
551 struct R {
552 capture_http: Option<CaptureHttpMode>,
553 }
554 let r: R = toml::from_str("capture_http = true").unwrap();
555 assert_eq!(r.capture_http, Some(CaptureHttpMode::All));
556
557 let r: R = toml::from_str("capture_http = false").unwrap();
558 assert_eq!(r.capture_http, Some(CaptureHttpMode::Off));
559 }
560
561 #[test]
562 fn runner_capture_http_field_accepts_string() {
563 #[derive(serde::Deserialize)]
564 struct R {
565 capture_http: Option<CaptureHttpMode>,
566 }
567 let r: R = toml::from_str(r#"capture_http = "all""#).unwrap();
568 assert_eq!(r.capture_http, Some(CaptureHttpMode::All));
569
570 let r: R = toml::from_str(r#"capture_http = "on-failure""#).unwrap();
571 assert_eq!(r.capture_http, Some(CaptureHttpMode::OnFailure));
572 }
573
574 #[test]
575 fn default_is_on_failure() {
576 assert_eq!(CaptureHttpMode::default(), CaptureHttpMode::OnFailure);
577 }
578 }
579
580 fn load_test_config() -> eyre::Result<Config> {
581 let manifest_dir = env!("CARGO_MANIFEST_DIR");
582 let config_path = Path::new(manifest_dir).join("../tanu-sample.toml");
583 Ok(super::Config::load_from(&config_path)?)
584 }
585
586 fn load_test_project_config() -> eyre::Result<ProjectConfig> {
587 Ok(Arc::try_unwrap(load_test_config()?.projects.remove(0)).unwrap())
588 }
589
590 #[test]
591 fn load_config() -> eyre::Result<()> {
592 let cfg = load_test_config()?;
593 assert_eq!(cfg.projects.len(), 1);
594
595 let project = &cfg.projects[0];
596 assert_eq!(project.name, "default");
597 assert_eq!(project.test_ignore, Vec::<String>::new());
598 assert_eq!(project.retry.count, Some(0));
599 assert_eq!(project.retry.factor, Some(2.0));
600 assert_eq!(project.retry.jitter, Some(false));
601 assert_eq!(project.retry.min_delay, Some(Duration::from_secs(1)));
602 assert_eq!(project.retry.max_delay, Some(Duration::from_secs(60)));
603
604 Ok(())
605 }
606
607 #[test_case("TANU_DEFAULT_STR_KEY"; "project config")]
608 #[test_case("TANU_STR_KEY"; "global config")]
609 fn get_str(key: &str) -> eyre::Result<()> {
610 std::env::set_var(key, "example_string");
611 let project = load_test_project_config()?;
612 assert_eq!(project.get_str("str_key")?, "example_string");
613 Ok(())
614 }
615
616 #[test_case("TANU_DEFAULT_INT_KEY"; "project config")]
617 #[test_case("TANU_INT_KEY"; "global config")]
618 fn get_int(key: &str) -> eyre::Result<()> {
619 std::env::set_var(key, "42");
620 let project = load_test_project_config()?;
621 assert_eq!(project.get_int("int_key")?, 42);
622 Ok(())
623 }
624
625 #[test_case("TANU_DEFAULT"; "project config")]
626 #[test_case("TANU"; "global config")]
627 fn get_float(prefix: &str) -> eyre::Result<()> {
628 std::env::set_var(format!("{prefix}_FLOAT_KEY"), "5.5");
629 let project = load_test_project_config()?;
630 assert_eq!(project.get_float("float_key")?, 5.5);
631 Ok(())
632 }
633
634 #[test_case("TANU_DEFAULT_BOOL_KEY"; "project config")]
635 #[test_case("TANU_BOOL_KEY"; "global config")]
636 fn get_bool(key: &str) -> eyre::Result<()> {
637 std::env::set_var(key, "true");
638 let project = load_test_project_config()?;
639 assert_eq!(project.get_bool("bool_key")?, true);
640 Ok(())
641 }
642
643 #[test_case("TANU_DEFAULT_DATETIME_KEY"; "project config")]
644 #[test_case("TANU_DATETIME_KEY"; "global config")]
645 fn get_datetime(key: &str) -> eyre::Result<()> {
646 let datetime_str = "2025-03-08T12:00:00Z";
647 std::env::set_var(key, datetime_str);
648 let project = load_test_project_config()?;
649 assert_eq!(
650 project
651 .get_datetime("datetime_key")?
652 .to_rfc3339_opts(chrono::SecondsFormat::Secs, true),
653 datetime_str
654 );
655 Ok(())
656 }
657
658 #[test_case("TANU_DEFAULT_ARRAY_KEY"; "project config")]
659 #[test_case("TANU_ARRAY_KEY"; "global config")]
660 fn get_array(key: &str) -> eyre::Result<()> {
661 std::env::set_var(key, "[1, 2, 3]");
662 let project = load_test_project_config()?;
663 let array: Vec<i64> = project.get_array("array_key")?;
664 assert_eq!(array, vec![1, 2, 3]);
665 Ok(())
666 }
667
668 #[test_case("TANU_DEFAULT"; "project config")]
669 #[test_case("TANU"; "global config")]
670 fn get_object(prefix: &str) -> eyre::Result<()> {
671 #[derive(Debug, Deserialize, PartialEq)]
672 struct Foo {
673 foo: Vec<String>,
674 }
675 std::env::set_var(
676 format!("{prefix}_OBJECT_KEY"),
677 "{\"foo\": [\"bar\", \"baz\"]}",
678 );
679 let project = load_test_project_config()?;
680 let obj: Foo = project.get_object("object_key")?;
681 assert_eq!(obj.foo, vec!["bar", "baz"]);
682 Ok(())
683 }
684
685 mod tanu_config_env {
686 use super::{Config, Path, TANU_CONFIG_ENV};
687 use pretty_assertions::assert_eq;
688 use test_case::test_case;
689
690 #[test]
691 fn load_from_tanu_config_env() {
692 let manifest_dir = env!("CARGO_MANIFEST_DIR");
693 let config_path = Path::new(manifest_dir).join("../tanu-sample.toml");
694
695 std::env::set_var(TANU_CONFIG_ENV, config_path.to_str().unwrap());
696 let cfg = Config::load().unwrap();
697 std::env::remove_var(TANU_CONFIG_ENV);
698
699 assert_eq!(cfg.projects.len(), 1);
700 assert_eq!(cfg.projects[0].name, "default");
701 }
702
703 #[test]
704 fn error_when_file_not_found() {
705 std::env::set_var(TANU_CONFIG_ENV, "/nonexistent/path/tanu.toml");
706 let result = Config::load();
707 std::env::remove_var(TANU_CONFIG_ENV);
708
709 assert!(result.is_err());
710 let err = result.unwrap_err().to_string();
711 assert!(
712 err.contains("not found"),
713 "error should mention file not found: {err}"
714 );
715 }
716
717 #[test_case("true"; "boolean value")]
718 #[test_case("123"; "numeric value")]
719 #[test_case("some_value"; "string value")]
720 fn error_when_value_looks_like_config_value(value: &str) {
721 std::env::set_var(TANU_CONFIG_ENV, value);
722 let result = Config::load();
723 std::env::remove_var(TANU_CONFIG_ENV);
724
725 assert!(result.is_err());
726 let err = result.unwrap_err().to_string();
727 assert!(
728 err.contains("should be a path"),
729 "error should guide user: {err}"
730 );
731 }
732
733 #[test_case("config.toml"; "toml extension")]
734 #[test_case("./tanu.toml"; "relative path with dot")]
735 #[test_case("configs/tanu.toml"; "path with separator")]
736 fn accepts_valid_path_patterns(value: &str) {
737 std::env::set_var(TANU_CONFIG_ENV, value);
738 let result = Config::load();
739 std::env::remove_var(TANU_CONFIG_ENV);
740
741 assert!(result.is_err());
743 let err = result.unwrap_err().to_string();
744 assert!(
745 err.contains("not found"),
746 "valid path pattern should fail with 'not found', not path validation: {err}"
747 );
748 }
749 }
750}