tanu_core/
config.rs

1use chrono::{DateTime, Utc};
2use once_cell::sync::Lazy;
3use serde::{de::DeserializeOwned, Deserialize};
4use std::{collections::HashMap, io::Read, path::Path, time::Duration};
5use toml::Value as TomlValue;
6use tracing::*;
7
8use crate::{Error, Result};
9
10static CONFIG: Lazy<Config> = Lazy::new(|| {
11    let _ = dotenv::dotenv();
12    Config::load().unwrap_or_default()
13});
14
15tokio::task_local! {
16    pub static PROJECT: ProjectConfig;
17}
18
19pub fn get_tanu_config() -> &'static Config {
20    &CONFIG
21}
22
23/// Get configuration for the current project. This function has to be called in the tokio
24/// task created by tanu runner. Otherwise, calling this function will panic.
25pub fn get_config() -> ProjectConfig {
26    PROJECT.get()
27}
28
29/// tanu's configuration.
30#[derive(Debug, Clone, Deserialize)]
31pub struct Config {
32    pub projects: Vec<ProjectConfig>,
33    /// Global tanu configuration
34    #[serde(default)]
35    pub tui: Tui,
36}
37
38impl Default for Config {
39    fn default() -> Self {
40        Config {
41            projects: vec![ProjectConfig {
42                name: "default".to_string(),
43                ..Default::default()
44            }],
45            tui: Tui::default(),
46        }
47    }
48}
49
50/// Global tanu configuration
51#[derive(Debug, Clone, Default, Deserialize)]
52pub struct Tui {
53    #[serde(default)]
54    pub payload: Payload,
55}
56
57#[derive(Debug, Clone, Default, Deserialize)]
58pub struct Payload {
59    /// Optional color theme for terminal output
60    pub color_theme: Option<String>,
61}
62
63impl Config {
64    /// Load tanu configuration from path.
65    fn load_from(path: &Path) -> Result<Config> {
66        let Ok(mut file) = std::fs::File::open(path) else {
67            return Ok(Config::default());
68        };
69
70        let mut buf = String::new();
71        file.read_to_string(&mut buf)
72            .map_err(|e| Error::LoadError(e.to_string()))?;
73        let mut cfg: Config = toml::from_str(&buf).map_err(|e| {
74            Error::LoadError(format!(
75                "failed to deserialize tanu.toml into tanu::Config: {e}"
76            ))
77        })?;
78
79        debug!("tanu.toml was successfully loaded: {cfg:#?}");
80
81        cfg.load_env();
82
83        Ok(cfg)
84    }
85
86    /// Load tanu configuration from tanu.toml in the current directory.
87    fn load() -> Result<Config> {
88        Config::load_from(Path::new("tanu.toml"))
89    }
90
91    /// Load tanu configuration from environment variables.
92    ///
93    /// Global environment variables: tanu automatically detects environment variables prefixed
94    /// with tanu_XXX and maps them to the corresponding configuration variable as "xxx". This
95    /// global configuration can be accessed in any project.
96    ///
97    /// Project environment variables: tanu automatically detects environment variables prefixed
98    /// with tanu_PROJECT_ZZZ_XXX and maps them to the corresponding configuration variable as
99    /// "xxx" for project "ZZZ". This configuration is isolated within the project.
100    fn load_env(&mut self) {
101        static PREFIX: &str = "TANU";
102
103        let global_prefix = format!("{PREFIX}_");
104        let project_prefixes: Vec<_> = self
105            .projects
106            .iter()
107            .map(|p| format!("{PREFIX}_{}_", p.name.to_uppercase()))
108            .collect();
109        debug!("Loading global configuration from env");
110        let global_vars: HashMap<_, _> = std::env::vars()
111            .filter_map(|(k, v)| {
112                let is_project_var = project_prefixes.iter().any(|pp| k.contains(pp));
113                if is_project_var {
114                    return None;
115                }
116
117                k.find(&global_prefix)?;
118                Some((
119                    k[global_prefix.len()..].to_string().to_lowercase(),
120                    TomlValue::String(v),
121                ))
122            })
123            .collect();
124
125        debug!("Loading project configuration from env");
126        for project in &mut self.projects {
127            let project_prefix = format!("{PREFIX}_{}_", project.name.to_uppercase());
128            let vars: HashMap<_, _> = std::env::vars()
129                .filter_map(|(k, v)| {
130                    k.find(&project_prefix)?;
131                    Some((
132                        k[project_prefix.len()..].to_string().to_lowercase(),
133                        TomlValue::String(v),
134                    ))
135                })
136                .collect();
137            project.data.extend(vars);
138            project.data.extend(global_vars.clone());
139        }
140
141        debug!("tanu configuration loaded from env: {self:#?}");
142    }
143
144    /// Get the current color theme
145    pub fn color_theme(&self) -> Option<&str> {
146        self.tui.payload.color_theme.as_deref()
147    }
148}
149
150/// tanu's project configuration.
151#[derive(Debug, Clone, Default, Deserialize)]
152pub struct ProjectConfig {
153    /// Project name specified by user.
154    pub name: String,
155    /// Keys and values specified by user.
156    #[serde(flatten)]
157    pub data: HashMap<String, TomlValue>,
158    /// List of files to ignore in the project.
159    #[serde(default)]
160    pub test_ignore: Vec<String>,
161    #[serde(default)]
162    pub retry: RetryConfig,
163}
164
165impl ProjectConfig {
166    pub fn get(&self, key: impl AsRef<str>) -> Result<&TomlValue> {
167        let key = key.as_ref();
168        self.data
169            .get(key)
170            .ok_or_else(|| Error::ValueNotFound(key.to_string()))
171    }
172
173    pub fn get_str(&self, key: impl AsRef<str>) -> Result<&str> {
174        let key = key.as_ref();
175        self.get(key)?
176            .as_str()
177            .ok_or_else(|| Error::ValueNotFound(key.to_string()))
178    }
179
180    pub fn get_int(&self, key: impl AsRef<str>) -> Result<i64> {
181        self.get_str(key)?
182            .parse()
183            .map_err(|e| Error::ValueError(eyre::Error::from(e)))
184    }
185
186    pub fn get_float(&self, key: impl AsRef<str>) -> Result<f64> {
187        self.get_str(key)?
188            .parse()
189            .map_err(|e| Error::ValueError(eyre::Error::from(e)))
190    }
191
192    pub fn get_bool(&self, key: impl AsRef<str>) -> Result<bool> {
193        self.get_str(key)?
194            .parse()
195            .map_err(|e| Error::ValueError(eyre::Error::from(e)))
196    }
197
198    pub fn get_datetime(&self, key: impl AsRef<str>) -> Result<DateTime<Utc>> {
199        self.get_str(key)?
200            .parse::<DateTime<Utc>>()
201            .map_err(|e| Error::ValueError(eyre::Error::from(e)))
202    }
203
204    pub fn get_array<T: DeserializeOwned>(&self, key: impl AsRef<str>) -> Result<Vec<T>> {
205        serde_json::from_str(self.get_str(key)?)
206            .map_err(|e| Error::ValueError(eyre::Error::from(e)))
207    }
208
209    pub fn get_object<T: DeserializeOwned>(&self, key: impl AsRef<str>) -> Result<T> {
210        serde_json::from_str(self.get_str(key)?)
211            .map_err(|e| Error::ValueError(eyre::Error::from(e)))
212    }
213}
214
215#[derive(Debug, Clone, Deserialize)]
216pub struct RetryConfig {
217    /// Number of retries.
218    #[serde(default)]
219    pub count: Option<usize>,
220    /// Factor to multiply the delay between retries.
221    #[serde(default)]
222    pub factor: Option<f32>,
223    /// Whether to add jitter to the delay between retries.
224    #[serde(default)]
225    pub jitter: Option<bool>,
226    /// Minimum delay between retries.
227    #[serde(default)]
228    #[serde(with = "humantime_serde")]
229    pub min_delay: Option<Duration>,
230    /// Maximum delay between retries.
231    #[serde(default)]
232    #[serde(with = "humantime_serde")]
233    pub max_delay: Option<Duration>,
234}
235
236impl Default for RetryConfig {
237    fn default() -> Self {
238        RetryConfig {
239            count: Some(0),
240            factor: Some(2.0),
241            jitter: Some(false),
242            min_delay: Some(Duration::from_secs(1)),
243            max_delay: Some(Duration::from_secs(60)),
244        }
245    }
246}
247
248impl RetryConfig {
249    pub fn backoff(&self) -> backon::ExponentialBuilder {
250        let builder = backon::ExponentialBuilder::new()
251            .with_max_times(self.count.unwrap_or_default())
252            .with_factor(self.factor.unwrap_or(2.0))
253            .with_min_delay(self.min_delay.unwrap_or(Duration::from_secs(1)))
254            .with_max_delay(self.max_delay.unwrap_or(Duration::from_secs(60)));
255
256        if self.jitter.unwrap_or_default() {
257            builder.with_jitter()
258        } else {
259            builder
260        }
261    }
262}
263
264#[cfg(test)]
265mod test {
266    use super::*;
267    use pretty_assertions::assert_eq;
268    use std::{time::Duration, vec};
269    use test_case::test_case;
270
271    fn load_test_config() -> eyre::Result<Config> {
272        let manifest_dir = env!("CARGO_MANIFEST_DIR");
273        let config_path = Path::new(manifest_dir).join("../tanu-sample.toml");
274        Ok(super::Config::load_from(&config_path)?)
275    }
276
277    fn load_test_project_config() -> eyre::Result<ProjectConfig> {
278        Ok(load_test_config()?.projects.remove(0))
279    }
280
281    #[test]
282    fn load_config() -> eyre::Result<()> {
283        let cfg = load_test_config()?;
284        assert_eq!(cfg.projects.len(), 1);
285
286        let project = &cfg.projects[0];
287        assert_eq!(project.name, "default");
288        assert_eq!(project.test_ignore, Vec::<String>::new());
289        assert_eq!(project.retry.count, Some(0));
290        assert_eq!(project.retry.factor, Some(2.0));
291        assert_eq!(project.retry.jitter, Some(false));
292        assert_eq!(project.retry.min_delay, Some(Duration::from_secs(1)));
293        assert_eq!(project.retry.max_delay, Some(Duration::from_secs(60)));
294
295        Ok(())
296    }
297
298    #[test_case("TANU_DEFAULT_STR_KEY"; "project config")]
299    #[test_case("TANU_STR_KEY"; "global config")]
300    fn get_str(key: &str) -> eyre::Result<()> {
301        std::env::set_var(key, "example_string");
302        let project = load_test_project_config()?;
303        assert_eq!(project.get_str("str_key")?, "example_string");
304        Ok(())
305    }
306
307    #[test_case("TANU_DEFAULT_INT_KEY"; "project config")]
308    #[test_case("TANU_INT_KEY"; "global config")]
309    fn get_int(key: &str) -> eyre::Result<()> {
310        std::env::set_var(key, "42");
311        let project = load_test_project_config()?;
312        assert_eq!(project.get_int("int_key")?, 42);
313        Ok(())
314    }
315
316    #[test_case("TANU_DEFAULT"; "project config")]
317    #[test_case("TANU"; "global config")]
318    fn get_float(prefix: &str) -> eyre::Result<()> {
319        std::env::set_var(format!("{prefix}_FLOAT_KEY"), "5.5");
320        let project = load_test_project_config()?;
321        assert_eq!(project.get_float("float_key")?, 5.5);
322        Ok(())
323    }
324
325    #[test_case("TANU_DEFAULT_BOOL_KEY"; "project config")]
326    #[test_case("TANU_BOOL_KEY"; "global config")]
327    fn get_bool(key: &str) -> eyre::Result<()> {
328        std::env::set_var(key, "true");
329        let project = load_test_project_config()?;
330        assert_eq!(project.get_bool("bool_key")?, true);
331        Ok(())
332    }
333
334    #[test_case("TANU_DEFAULT_DATETIME_KEY"; "project config")]
335    #[test_case("TANU_DATETIME_KEY"; "global config")]
336    fn get_datetime(key: &str) -> eyre::Result<()> {
337        let datetime_str = "2025-03-08T12:00:00Z";
338        std::env::set_var(key, datetime_str);
339        let project = load_test_project_config()?;
340        assert_eq!(
341            project
342                .get_datetime("datetime_key")?
343                .to_rfc3339_opts(chrono::SecondsFormat::Secs, true),
344            datetime_str
345        );
346        Ok(())
347    }
348
349    #[test_case("TANU_DEFAULT_ARRAY_KEY"; "project config")]
350    #[test_case("TANU_ARRAY_KEY"; "global config")]
351    fn get_array(key: &str) -> eyre::Result<()> {
352        std::env::set_var(key, "[1, 2, 3]");
353        let project = load_test_project_config()?;
354        let array: Vec<i64> = project.get_array("array_key")?;
355        assert_eq!(array, vec![1, 2, 3]);
356        Ok(())
357    }
358
359    #[test_case("TANU_DEFAULT"; "project config")]
360    #[test_case("TANU"; "global config")]
361    fn get_object(prefix: &str) -> eyre::Result<()> {
362        #[derive(Debug, Deserialize, PartialEq)]
363        struct Foo {
364            foo: Vec<String>,
365        }
366        std::env::set_var(
367            format!("{prefix}_OBJECT_KEY"),
368            "{\"foo\": [\"bar\", \"baz\"]}",
369        );
370        let project = load_test_project_config()?;
371        let obj: Foo = project.get_object("object_key")?;
372        assert_eq!(obj.foo, vec!["bar", "baz"]);
373        Ok(())
374    }
375}