tanu_core/
config.rs

1use chrono::{DateTime, Utc};
2use once_cell::sync::Lazy;
3use serde::{de::DeserializeOwned, Deserialize};
4use std::{collections::HashMap, io::Read, path::Path, time::Duration};
5use toml::Value as TomlValue;
6use tracing::*;
7
8use crate::{Error, Result};
9
10static CONFIG: Lazy<Config> = Lazy::new(|| {
11    let _ = dotenv::dotenv();
12    Config::load().expect("failed to load tanu.toml")
13});
14
15tokio::task_local! {
16    pub static PROJECT: ProjectConfig;
17}
18
19pub fn get_tanu_config() -> &'static Config {
20    &CONFIG
21}
22
23/// Get configuration for the current project. This function has to be called in the tokio
24/// task created by tanu runner. Otherwise, calling this function will panic.
25pub fn get_config() -> ProjectConfig {
26    PROJECT.get()
27}
28
29/// tanu's configuration.
30#[derive(Debug, Clone, Default, Deserialize)]
31pub struct Config {
32    pub projects: Vec<ProjectConfig>,
33    /// Global tanu configuration
34    #[serde(default)]
35    pub tui: Tui,
36}
37
38/// Global tanu configuration
39#[derive(Debug, Clone, Default, Deserialize)]
40pub struct Tui {
41    #[serde(default)]
42    pub payload: Payload,
43}
44
45#[derive(Debug, Clone, Default, Deserialize)]
46pub struct Payload {
47    /// Optional color theme for terminal output
48    pub color_theme: Option<String>,
49}
50
51impl Config {
52    /// Load tanu configuration from path.
53    fn load_from(path: &Path) -> Result<Config> {
54        let Ok(mut file) = std::fs::File::open(path) else {
55            return Ok(Config::default());
56        };
57
58        let mut buf = String::new();
59        file.read_to_string(&mut buf)
60            .map_err(|e| Error::LoadError(e.to_string()))?;
61        let mut cfg: Config = toml::from_str(&buf).map_err(|e| {
62            Error::LoadError(format!(
63                "failed to deserialize tanu.toml into tanu::Config: {e}"
64            ))
65        })?;
66
67        debug!("tanu.toml was successfully loaded: {cfg:#?}");
68
69        cfg.load_env();
70
71        Ok(cfg)
72    }
73
74    /// Load tanu configuration from tanu.toml in the current directory.
75    fn load() -> Result<Config> {
76        Config::load_from(Path::new("tanu.toml"))
77    }
78
79    /// Load tanu configuration from environment variables.
80    ///
81    /// Global environment variables: tanu automatically detects environment variables prefixed
82    /// with tanu_XXX and maps them to the corresponding configuration variable as "xxx". This
83    /// global configuration can be accessed in any project.
84    ///
85    /// Project environment variables: tanu automatically detects environment variables prefixed
86    /// with tanu_PROJECT_ZZZ_XXX and maps them to the corresponding configuration variable as
87    /// "xxx" for project "ZZZ". This configuration is isolated within the project.
88    fn load_env(&mut self) {
89        static PREFIX: &str = "TANU";
90
91        let global_prefix = format!("{PREFIX}_");
92        let project_prefixes: Vec<_> = self
93            .projects
94            .iter()
95            .map(|p| format!("{PREFIX}_{}_", p.name.to_uppercase()))
96            .collect();
97        debug!("Loading global configuration from env");
98        let global_vars: HashMap<_, _> = std::env::vars()
99            .filter_map(|(k, v)| {
100                let is_project_var = project_prefixes.iter().any(|pp| k.contains(pp));
101                if is_project_var {
102                    return None;
103                }
104
105                k.find(&global_prefix)?;
106                Some((
107                    k[global_prefix.len()..].to_string().to_lowercase(),
108                    TomlValue::String(v),
109                ))
110            })
111            .collect();
112
113        debug!("Loading project configuration from env");
114        for project in &mut self.projects {
115            let project_prefix = format!("{PREFIX}_{}_", project.name.to_uppercase());
116            let vars: HashMap<_, _> = std::env::vars()
117                .filter_map(|(k, v)| {
118                    k.find(&project_prefix)?;
119                    Some((
120                        k[project_prefix.len()..].to_string().to_lowercase(),
121                        TomlValue::String(v),
122                    ))
123                })
124                .collect();
125            project.data.extend(vars);
126            project.data.extend(global_vars.clone());
127        }
128
129        debug!("tanu configuration loaded from env: {self:#?}");
130    }
131
132    /// Get the current color theme
133    pub fn color_theme(&self) -> Option<&str> {
134        self.tui.payload.color_theme.as_deref()
135    }
136}
137
138/// tanu's project configuration.
139#[derive(Debug, Clone, Default, Deserialize)]
140pub struct ProjectConfig {
141    /// Project name specified by user.
142    pub name: String,
143    /// Keys and values specified by user.
144    #[serde(flatten)]
145    pub data: HashMap<String, TomlValue>,
146    /// List of files to ignore in the project.
147    #[serde(default)]
148    pub test_ignore: Vec<String>,
149    #[serde(default)]
150    pub retry: RetryConfig,
151}
152
153impl ProjectConfig {
154    pub fn get(&self, key: impl AsRef<str>) -> Result<&TomlValue> {
155        let key = key.as_ref();
156        self.data
157            .get(key)
158            .ok_or_else(|| Error::ValueNotFound(key.to_string()))
159    }
160
161    pub fn get_str(&self, key: impl AsRef<str>) -> Result<&str> {
162        let key = key.as_ref();
163        self.get(key)?
164            .as_str()
165            .ok_or_else(|| Error::ValueNotFound(key.to_string()))
166    }
167
168    pub fn get_int(&self, key: impl AsRef<str>) -> Result<i64> {
169        self.get_str(key)?
170            .parse()
171            .map_err(|e| Error::ValueError(eyre::Error::from(e)))
172    }
173
174    pub fn get_float(&self, key: impl AsRef<str>) -> Result<f64> {
175        self.get_str(key)?
176            .parse()
177            .map_err(|e| Error::ValueError(eyre::Error::from(e)))
178    }
179
180    pub fn get_bool(&self, key: impl AsRef<str>) -> Result<bool> {
181        self.get_str(key)?
182            .parse()
183            .map_err(|e| Error::ValueError(eyre::Error::from(e)))
184    }
185
186    pub fn get_datetime(&self, key: impl AsRef<str>) -> Result<DateTime<Utc>> {
187        self.get_str(key)?
188            .parse::<DateTime<Utc>>()
189            .map_err(|e| Error::ValueError(eyre::Error::from(e)))
190    }
191
192    pub fn get_array<T: DeserializeOwned>(&self, key: impl AsRef<str>) -> Result<Vec<T>> {
193        serde_json::from_str(self.get_str(key)?)
194            .map_err(|e| Error::ValueError(eyre::Error::from(e)))
195    }
196
197    pub fn get_object<T: DeserializeOwned>(&self, key: impl AsRef<str>) -> Result<T> {
198        serde_json::from_str(self.get_str(key)?)
199            .map_err(|e| Error::ValueError(eyre::Error::from(e)))
200    }
201}
202
203#[derive(Debug, Clone, Deserialize)]
204pub struct RetryConfig {
205    /// Number of retries.
206    pub count: Option<usize>,
207    /// Factor to multiply the delay between retries.
208    pub factor: Option<f32>,
209    /// Whether to add jitter to the delay between retries.
210    pub jitter: Option<bool>,
211    /// Minimum delay between retries.
212    #[serde(with = "humantime_serde")]
213    pub min_delay: Option<Duration>,
214    /// Maximum delay between retries.
215    #[serde(with = "humantime_serde")]
216    pub max_delay: Option<Duration>,
217}
218
219impl Default for RetryConfig {
220    fn default() -> Self {
221        RetryConfig {
222            count: Some(0),
223            factor: Some(2.0),
224            jitter: Some(false),
225            min_delay: Some(Duration::from_secs(1)),
226            max_delay: Some(Duration::from_secs(60)),
227        }
228    }
229}
230
231impl RetryConfig {
232    pub fn backoff(&self) -> backon::ExponentialBuilder {
233        let builder = backon::ExponentialBuilder::new()
234            .with_max_times(self.count.unwrap_or_default())
235            .with_factor(self.factor.unwrap_or(2.0))
236            .with_min_delay(self.min_delay.unwrap_or(Duration::from_secs(1)))
237            .with_max_delay(self.max_delay.unwrap_or(Duration::from_secs(60)));
238
239        if self.jitter.unwrap_or_default() {
240            builder.with_jitter()
241        } else {
242            builder
243        }
244    }
245}
246
247#[cfg(test)]
248mod test {
249    use super::*;
250    use pretty_assertions::assert_eq;
251    use std::{time::Duration, vec};
252    use test_case::test_case;
253
254    fn load_test_config() -> eyre::Result<Config> {
255        let manifest_dir = env!("CARGO_MANIFEST_DIR");
256        let config_path = Path::new(manifest_dir).join("../tanu-sample.toml");
257        Ok(super::Config::load_from(&config_path)?)
258    }
259
260    fn load_test_project_config() -> eyre::Result<ProjectConfig> {
261        Ok(load_test_config()?.projects.remove(0))
262    }
263
264    #[test]
265    fn load_config() -> eyre::Result<()> {
266        let cfg = load_test_config()?;
267        assert_eq!(cfg.projects.len(), 1);
268
269        let project = &cfg.projects[0];
270        assert_eq!(project.name, "default");
271        assert_eq!(project.test_ignore, Vec::<String>::new());
272        assert_eq!(project.retry.count, Some(0));
273        assert_eq!(project.retry.factor, Some(2.0));
274        assert_eq!(project.retry.jitter, Some(false));
275        assert_eq!(project.retry.min_delay, Some(Duration::from_secs(1)));
276        assert_eq!(project.retry.max_delay, Some(Duration::from_secs(60)));
277
278        Ok(())
279    }
280
281    #[test_case("TANU_DEFAULT_STR_KEY"; "project config")]
282    #[test_case("TANU_STR_KEY"; "global config")]
283    fn get_str(key: &str) -> eyre::Result<()> {
284        std::env::set_var(key, "example_string");
285        let project = load_test_project_config()?;
286        assert_eq!(project.get_str("str_key")?, "example_string");
287        Ok(())
288    }
289
290    #[test_case("TANU_DEFAULT_INT_KEY"; "project config")]
291    #[test_case("TANU_INT_KEY"; "global config")]
292    fn get_int(key: &str) -> eyre::Result<()> {
293        std::env::set_var(key, "42");
294        let project = load_test_project_config()?;
295        assert_eq!(project.get_int("int_key")?, 42);
296        Ok(())
297    }
298
299    #[test_case("TANU_DEFAULT"; "project config")]
300    #[test_case("TANU"; "global config")]
301    fn get_float(prefix: &str) -> eyre::Result<()> {
302        std::env::set_var(format!("{prefix}_FLOAT_KEY"), "5.5");
303        let project = load_test_project_config()?;
304        assert_eq!(project.get_float("float_key")?, 5.5);
305        Ok(())
306    }
307
308    #[test_case("TANU_DEFAULT_BOOL_KEY"; "project config")]
309    #[test_case("TANU_BOOL_KEY"; "global config")]
310    fn get_bool(key: &str) -> eyre::Result<()> {
311        std::env::set_var(key, "true");
312        let project = load_test_project_config()?;
313        assert_eq!(project.get_bool("bool_key")?, true);
314        Ok(())
315    }
316
317    #[test_case("TANU_DEFAULT_DATETIME_KEY"; "project config")]
318    #[test_case("TANU_DATETIME_KEY"; "global config")]
319    fn get_datetime(key: &str) -> eyre::Result<()> {
320        let datetime_str = "2025-03-08T12:00:00Z";
321        std::env::set_var(key, datetime_str);
322        let project = load_test_project_config()?;
323        assert_eq!(
324            project
325                .get_datetime("datetime_key")?
326                .to_rfc3339_opts(chrono::SecondsFormat::Secs, true),
327            datetime_str
328        );
329        Ok(())
330    }
331
332    #[test_case("TANU_DEFAULT_ARRAY_KEY"; "project config")]
333    #[test_case("TANU_ARRAY_KEY"; "global config")]
334    fn get_array(key: &str) -> eyre::Result<()> {
335        std::env::set_var(key, "[1, 2, 3]");
336        let project = load_test_project_config()?;
337        let array: Vec<i64> = project.get_array("array_key")?;
338        assert_eq!(array, vec![1, 2, 3]);
339        Ok(())
340    }
341
342    #[test_case("TANU_DEFAULT"; "project config")]
343    #[test_case("TANU"; "global config")]
344    fn get_object(prefix: &str) -> eyre::Result<()> {
345        #[derive(Debug, Deserialize, PartialEq)]
346        struct Foo {
347            foo: Vec<String>,
348        }
349        std::env::set_var(
350            format!("{prefix}_OBJECT_KEY"),
351            "{\"foo\": [\"bar\", \"baz\"]}",
352        );
353        let project = load_test_project_config()?;
354        let obj: Foo = project.get_object("object_key")?;
355        assert_eq!(obj.foo, vec!["bar", "baz"]);
356        Ok(())
357    }
358}