tanu_core/
config.rs

1//! # Configuration Module
2//!
3//! Handles loading and managing tanu configuration from `tanu.toml` files.
4//! Supports project-specific configurations, environment variables, and
5//! various test execution settings.
6//!
7//! ## Configuration Structure
8//!
9//! Tanu uses TOML configuration files with the following structure:
10//!
11//! ```toml
12//! [[projects]]
13//! name = "staging"
14//! base_url = "https://staging.api.example.com"
15//! timeout = 30000
16//! retry.count = 3
17//! retry.factor = 2.0
18//! test_ignore = ["slow_test", "flaky_test"]
19//!
20//! [[projects]]
21//! name = "production"
22//! base_url = "https://api.example.com"
23//! timeout = 10000
24//! ```
25//!
26//! ## Usage
27//!
28//! ```rust,ignore
29//! use tanu::{get_config, get_tanu_config};
30//!
31//! // Get global configuration
32//! let config = get_tanu_config();
33//!
34//! // Get current project configuration (within test context)
35//! let project_config = get_config();
36//! let base_url = project_config.get_str("base_url").unwrap_or_default();
37//! ```
38
39use chrono::{DateTime, Utc};
40use once_cell::sync::Lazy;
41use serde::{de::DeserializeOwned, Deserialize};
42use std::{collections::HashMap, io::Read, path::Path, sync::Arc, time::Duration};
43use toml::Value as TomlValue;
44use tracing::*;
45
46use crate::{Error, Result};
47
48static CONFIG: Lazy<Config> = Lazy::new(|| {
49    let _ = dotenv::dotenv();
50    Config::load().unwrap_or_default()
51});
52
53tokio::task_local! {
54    pub static PROJECT: Arc<ProjectConfig>;
55}
56
57#[doc(hidden)]
58pub fn get_tanu_config() -> &'static Config {
59    &CONFIG
60}
61
62/// Get configuration for the current project. This function has to be called in the tokio
63/// task created by tanu runner. Otherwise, calling this function will panic.
64pub fn get_config() -> Arc<ProjectConfig> {
65    PROJECT.get()
66}
67
68/// tanu's configuration.
69#[derive(Debug, Clone)]
70pub struct Config {
71    pub projects: Vec<Arc<ProjectConfig>>,
72    /// Global tanu configuration
73    pub tui: Tui,
74}
75
76impl Default for Config {
77    fn default() -> Self {
78        Config {
79            projects: vec![Arc::new(ProjectConfig {
80                name: "default".to_string(),
81                ..Default::default()
82            })],
83            tui: Tui::default(),
84        }
85    }
86}
87
88/// Global tanu configuration
89#[derive(Debug, Clone, Default, Deserialize)]
90pub struct Tui {
91    #[serde(default)]
92    pub payload: Payload,
93}
94
95#[derive(Debug, Clone, Default, Deserialize)]
96pub struct Payload {
97    /// Optional color theme for terminal output
98    pub color_theme: Option<String>,
99}
100
101impl Config {
102    /// Load tanu configuration from path.
103    fn load_from(path: &Path) -> Result<Config> {
104        let Ok(mut file) = std::fs::File::open(path) else {
105            return Ok(Config::default());
106        };
107
108        let mut buf = String::new();
109        file.read_to_string(&mut buf)
110            .map_err(|e| Error::LoadError(e.to_string()))?;
111
112        #[derive(Deserialize)]
113        struct ConfigHelper {
114            #[serde(default)]
115            projects: Vec<ProjectConfig>,
116            #[serde(default)]
117            tui: Tui,
118        }
119
120        let helper: ConfigHelper = toml::from_str(&buf).map_err(|e| {
121            Error::LoadError(format!(
122                "failed to deserialize tanu.toml into tanu::Config: {e}"
123            ))
124        })?;
125
126        let mut cfg = Config {
127            projects: helper.projects.into_iter().map(Arc::new).collect(),
128            tui: helper.tui,
129        };
130
131        debug!("tanu.toml was successfully loaded: {cfg:#?}");
132
133        cfg.load_env();
134
135        Ok(cfg)
136    }
137
138    /// Load tanu configuration from tanu.toml in the current directory.
139    fn load() -> Result<Config> {
140        Config::load_from(Path::new("tanu.toml"))
141    }
142
143    /// Load tanu configuration from environment variables.
144    ///
145    /// Global environment variables: tanu automatically detects environment variables prefixed
146    /// with tanu_XXX and maps them to the corresponding configuration variable as "xxx". This
147    /// global configuration can be accessed in any project.
148    ///
149    /// Project environment variables: tanu automatically detects environment variables prefixed
150    /// with tanu_PROJECT_ZZZ_XXX and maps them to the corresponding configuration variable as
151    /// "xxx" for project "ZZZ". This configuration is isolated within the project.
152    fn load_env(&mut self) {
153        static PREFIX: &str = "TANU";
154
155        let global_prefix = format!("{PREFIX}_");
156        let project_prefixes: Vec<_> = self
157            .projects
158            .iter()
159            .map(|p| format!("{PREFIX}_{}_", p.name.to_uppercase()))
160            .collect();
161        debug!("Loading global configuration from env");
162        let global_vars: HashMap<_, _> = std::env::vars()
163            .filter_map(|(k, v)| {
164                let is_project_var = project_prefixes.iter().any(|pp| k.contains(pp));
165                if is_project_var {
166                    return None;
167                }
168
169                k.find(&global_prefix)?;
170                Some((
171                    k[global_prefix.len()..].to_string().to_lowercase(),
172                    TomlValue::String(v),
173                ))
174            })
175            .collect();
176
177        debug!("Loading project configuration from env");
178        for project_arc in &mut self.projects {
179            let project_prefix = format!("{PREFIX}_{}_", project_arc.name.to_uppercase());
180            let vars: HashMap<_, _> = std::env::vars()
181                .filter_map(|(k, v)| {
182                    k.find(&project_prefix)?;
183                    Some((
184                        k[project_prefix.len()..].to_string().to_lowercase(),
185                        TomlValue::String(v),
186                    ))
187                })
188                .collect();
189            let project = Arc::make_mut(project_arc);
190            project.data.extend(vars);
191            project.data.extend(global_vars.clone());
192        }
193
194        debug!("tanu configuration loaded from env: {self:#?}");
195    }
196
197    /// Get the current color theme
198    pub fn color_theme(&self) -> Option<&str> {
199        self.tui.payload.color_theme.as_deref()
200    }
201}
202
203/// tanu's project configuration.
204#[derive(Debug, Clone, Default, Deserialize)]
205pub struct ProjectConfig {
206    /// Project name specified by user.
207    pub name: String,
208    /// Keys and values specified by user.
209    #[serde(flatten)]
210    pub data: HashMap<String, TomlValue>,
211    /// List of files to ignore in the project.
212    #[serde(default)]
213    pub test_ignore: Vec<String>,
214    #[serde(default)]
215    pub retry: RetryConfig,
216}
217
218impl ProjectConfig {
219    pub fn get(&self, key: impl AsRef<str>) -> Result<&TomlValue> {
220        let key = key.as_ref();
221        self.data
222            .get(key)
223            .ok_or_else(|| Error::ValueNotFound(key.to_string()))
224    }
225
226    pub fn get_str(&self, key: impl AsRef<str>) -> Result<&str> {
227        let key = key.as_ref();
228        self.get(key)?
229            .as_str()
230            .ok_or_else(|| Error::ValueNotFound(key.to_string()))
231    }
232
233    pub fn get_int(&self, key: impl AsRef<str>) -> Result<i64> {
234        self.get_str(key)?
235            .parse()
236            .map_err(|e| Error::ValueError(eyre::Error::from(e)))
237    }
238
239    pub fn get_float(&self, key: impl AsRef<str>) -> Result<f64> {
240        self.get_str(key)?
241            .parse()
242            .map_err(|e| Error::ValueError(eyre::Error::from(e)))
243    }
244
245    pub fn get_bool(&self, key: impl AsRef<str>) -> Result<bool> {
246        self.get_str(key)?
247            .parse()
248            .map_err(|e| Error::ValueError(eyre::Error::from(e)))
249    }
250
251    pub fn get_datetime(&self, key: impl AsRef<str>) -> Result<DateTime<Utc>> {
252        self.get_str(key)?
253            .parse::<DateTime<Utc>>()
254            .map_err(|e| Error::ValueError(eyre::Error::from(e)))
255    }
256
257    pub fn get_array<T: DeserializeOwned>(&self, key: impl AsRef<str>) -> Result<Vec<T>> {
258        serde_json::from_str(self.get_str(key)?)
259            .map_err(|e| Error::ValueError(eyre::Error::from(e)))
260    }
261
262    pub fn get_object<T: DeserializeOwned>(&self, key: impl AsRef<str>) -> Result<T> {
263        serde_json::from_str(self.get_str(key)?)
264            .map_err(|e| Error::ValueError(eyre::Error::from(e)))
265    }
266}
267
268#[derive(Debug, Clone, Deserialize)]
269pub struct RetryConfig {
270    /// Number of retries.
271    #[serde(default)]
272    pub count: Option<usize>,
273    /// Factor to multiply the delay between retries.
274    #[serde(default)]
275    pub factor: Option<f32>,
276    /// Whether to add jitter to the delay between retries.
277    #[serde(default)]
278    pub jitter: Option<bool>,
279    /// Minimum delay between retries.
280    #[serde(default)]
281    #[serde(with = "humantime_serde")]
282    pub min_delay: Option<Duration>,
283    /// Maximum delay between retries.
284    #[serde(default)]
285    #[serde(with = "humantime_serde")]
286    pub max_delay: Option<Duration>,
287}
288
289impl Default for RetryConfig {
290    fn default() -> Self {
291        RetryConfig {
292            count: Some(0),
293            factor: Some(2.0),
294            jitter: Some(false),
295            min_delay: Some(Duration::from_secs(1)),
296            max_delay: Some(Duration::from_secs(60)),
297        }
298    }
299}
300
301impl RetryConfig {
302    pub fn backoff(&self) -> backon::ExponentialBuilder {
303        let builder = backon::ExponentialBuilder::new()
304            .with_max_times(self.count.unwrap_or_default())
305            .with_factor(self.factor.unwrap_or(2.0))
306            .with_min_delay(self.min_delay.unwrap_or(Duration::from_secs(1)))
307            .with_max_delay(self.max_delay.unwrap_or(Duration::from_secs(60)));
308
309        if self.jitter.unwrap_or_default() {
310            builder.with_jitter()
311        } else {
312            builder
313        }
314    }
315}
316
317#[cfg(test)]
318mod test {
319    use super::*;
320    use pretty_assertions::assert_eq;
321    use std::{time::Duration, vec};
322    use test_case::test_case;
323
324    fn load_test_config() -> eyre::Result<Config> {
325        let manifest_dir = env!("CARGO_MANIFEST_DIR");
326        let config_path = Path::new(manifest_dir).join("../tanu-sample.toml");
327        Ok(super::Config::load_from(&config_path)?)
328    }
329
330    fn load_test_project_config() -> eyre::Result<ProjectConfig> {
331        Ok(Arc::try_unwrap(load_test_config()?.projects.remove(0)).unwrap())
332    }
333
334    #[test]
335    fn load_config() -> eyre::Result<()> {
336        let cfg = load_test_config()?;
337        assert_eq!(cfg.projects.len(), 1);
338
339        let project = &cfg.projects[0];
340        assert_eq!(project.name, "default");
341        assert_eq!(project.test_ignore, Vec::<String>::new());
342        assert_eq!(project.retry.count, Some(0));
343        assert_eq!(project.retry.factor, Some(2.0));
344        assert_eq!(project.retry.jitter, Some(false));
345        assert_eq!(project.retry.min_delay, Some(Duration::from_secs(1)));
346        assert_eq!(project.retry.max_delay, Some(Duration::from_secs(60)));
347
348        Ok(())
349    }
350
351    #[test_case("TANU_DEFAULT_STR_KEY"; "project config")]
352    #[test_case("TANU_STR_KEY"; "global config")]
353    fn get_str(key: &str) -> eyre::Result<()> {
354        std::env::set_var(key, "example_string");
355        let project = load_test_project_config()?;
356        assert_eq!(project.get_str("str_key")?, "example_string");
357        Ok(())
358    }
359
360    #[test_case("TANU_DEFAULT_INT_KEY"; "project config")]
361    #[test_case("TANU_INT_KEY"; "global config")]
362    fn get_int(key: &str) -> eyre::Result<()> {
363        std::env::set_var(key, "42");
364        let project = load_test_project_config()?;
365        assert_eq!(project.get_int("int_key")?, 42);
366        Ok(())
367    }
368
369    #[test_case("TANU_DEFAULT"; "project config")]
370    #[test_case("TANU"; "global config")]
371    fn get_float(prefix: &str) -> eyre::Result<()> {
372        std::env::set_var(format!("{prefix}_FLOAT_KEY"), "5.5");
373        let project = load_test_project_config()?;
374        assert_eq!(project.get_float("float_key")?, 5.5);
375        Ok(())
376    }
377
378    #[test_case("TANU_DEFAULT_BOOL_KEY"; "project config")]
379    #[test_case("TANU_BOOL_KEY"; "global config")]
380    fn get_bool(key: &str) -> eyre::Result<()> {
381        std::env::set_var(key, "true");
382        let project = load_test_project_config()?;
383        assert_eq!(project.get_bool("bool_key")?, true);
384        Ok(())
385    }
386
387    #[test_case("TANU_DEFAULT_DATETIME_KEY"; "project config")]
388    #[test_case("TANU_DATETIME_KEY"; "global config")]
389    fn get_datetime(key: &str) -> eyre::Result<()> {
390        let datetime_str = "2025-03-08T12:00:00Z";
391        std::env::set_var(key, datetime_str);
392        let project = load_test_project_config()?;
393        assert_eq!(
394            project
395                .get_datetime("datetime_key")?
396                .to_rfc3339_opts(chrono::SecondsFormat::Secs, true),
397            datetime_str
398        );
399        Ok(())
400    }
401
402    #[test_case("TANU_DEFAULT_ARRAY_KEY"; "project config")]
403    #[test_case("TANU_ARRAY_KEY"; "global config")]
404    fn get_array(key: &str) -> eyre::Result<()> {
405        std::env::set_var(key, "[1, 2, 3]");
406        let project = load_test_project_config()?;
407        let array: Vec<i64> = project.get_array("array_key")?;
408        assert_eq!(array, vec![1, 2, 3]);
409        Ok(())
410    }
411
412    #[test_case("TANU_DEFAULT"; "project config")]
413    #[test_case("TANU"; "global config")]
414    fn get_object(prefix: &str) -> eyre::Result<()> {
415        #[derive(Debug, Deserialize, PartialEq)]
416        struct Foo {
417            foo: Vec<String>,
418        }
419        std::env::set_var(
420            format!("{prefix}_OBJECT_KEY"),
421            "{\"foo\": [\"bar\", \"baz\"]}",
422        );
423        let project = load_test_project_config()?;
424        let obj: Foo = project.get_object("object_key")?;
425        assert_eq!(obj.foo, vec!["bar", "baz"]);
426        Ok(())
427    }
428}