tanu_core/
config.rs

1use chrono::{DateTime, Utc};
2use once_cell::sync::Lazy;
3use serde::{de::DeserializeOwned, Deserialize};
4use std::{collections::HashMap, io::Read, path::Path, time::Duration};
5use toml::Value as TomlValue;
6use tracing::*;
7
8use crate::{Error, Result};
9
10static CONFIG: Lazy<Config> = Lazy::new(|| {
11    let _ = dotenv::dotenv();
12    Config::load().expect("failed to load tanu.toml")
13});
14
15tokio::task_local! {
16    pub static PROJECT: ProjectConfig;
17}
18
19pub fn get_tanu_config() -> &'static Config {
20    &CONFIG
21}
22
23/// Get configuration for the current project. This function has to be called in the tokio
24/// task created by tanu runner. Otherwise, calling this function will panic.
25pub fn get_config() -> ProjectConfig {
26    PROJECT.get()
27}
28
29/// tanu's configuration.
30#[derive(Debug, Clone, Default, Deserialize)]
31pub struct Config {
32    pub projects: Vec<ProjectConfig>,
33}
34
35impl Config {
36    /// Load tanu configuration from path.
37    fn load_from(path: &Path) -> Result<Config> {
38        let Ok(mut file) = std::fs::File::open(path) else {
39            return Ok(Config::default());
40        };
41
42        let mut buf = String::new();
43        file.read_to_string(&mut buf)
44            .map_err(|e| Error::LoadError(e.to_string()))?;
45        let mut cfg: Config = toml::from_str(&buf).map_err(|e| {
46            Error::LoadError(format!(
47                "failed to deserialize tanu.toml into tanu::Config: {e}"
48            ))
49        })?;
50
51        debug!("tanu.toml was successfully loaded: {cfg:#?}");
52
53        cfg.load_env();
54
55        Ok(cfg)
56    }
57
58    /// Load tanu configuration from tanu.toml in the current directory.
59    fn load() -> Result<Config> {
60        Config::load_from(Path::new("tanu.toml"))
61    }
62
63    /// Load tanu configuration from environment variables.
64    ///
65    /// Global environment variables: tanu automatically detects environment variables prefixed
66    /// with tanu_XXX and maps them to the corresponding configuration variable as "xxx". This
67    /// global configuration can be accessed in any project.
68    ///
69    /// Project environment variables: tanu automatically detects environment variables prefixed
70    /// with tanu_PROJECT_ZZZ_XXX and maps them to the corresponding configuration variable as
71    /// "xxx" for project "ZZZ". This configuration is isolated within the project.
72    fn load_env(&mut self) {
73        static PREFIX: &str = "TANU";
74
75        let global_prefix = format!("{PREFIX}_");
76        let project_prefixes: Vec<_> = self
77            .projects
78            .iter()
79            .map(|p| format!("{PREFIX}_{}_", p.name.to_uppercase()))
80            .collect();
81        debug!("Loading global configuration from env");
82        let global_vars: HashMap<_, _> = std::env::vars()
83            .filter_map(|(k, v)| {
84                let is_project_var = project_prefixes.iter().any(|pp| k.contains(pp));
85                if is_project_var {
86                    return None;
87                }
88
89                k.find(&global_prefix)?;
90                Some((
91                    k[global_prefix.len()..].to_string().to_lowercase(),
92                    TomlValue::String(v),
93                ))
94            })
95            .collect();
96
97        debug!("Loading project configuration from env");
98        for project in &mut self.projects {
99            let project_prefix = format!("{PREFIX}_{}_", project.name.to_uppercase());
100            let vars: HashMap<_, _> = std::env::vars()
101                .filter_map(|(k, v)| {
102                    k.find(&project_prefix)?;
103                    Some((
104                        k[project_prefix.len()..].to_string().to_lowercase(),
105                        TomlValue::String(v),
106                    ))
107                })
108                .collect();
109            project.data.extend(vars);
110            project.data.extend(global_vars.clone());
111        }
112
113        debug!("tanu configuration loaded from env: {self:#?}");
114    }
115}
116
117/// tanu's project configuration.
118#[derive(Debug, Clone, Default, Deserialize)]
119pub struct ProjectConfig {
120    /// Project name specified by user.
121    pub name: String,
122    /// Keys and values specified by user.
123    #[serde(flatten)]
124    pub data: HashMap<String, TomlValue>,
125    /// List of files to ignore in the project.
126    #[serde(default)]
127    pub test_ignore: Vec<String>,
128    #[serde(default)]
129    pub retry: RetryConfig,
130}
131
132impl ProjectConfig {
133    pub fn get(&self, key: impl AsRef<str>) -> Result<&TomlValue> {
134        let key = key.as_ref();
135        self.data
136            .get(key)
137            .ok_or_else(|| Error::ValueNotFound(key.to_string()))
138    }
139
140    pub fn get_str(&self, key: impl AsRef<str>) -> Result<&str> {
141        let key = key.as_ref();
142        self.get(key)?
143            .as_str()
144            .ok_or_else(|| Error::ValueNotFound(key.to_string()))
145    }
146
147    pub fn get_int(&self, key: impl AsRef<str>) -> Result<i64> {
148        self.get_str(key)?
149            .parse()
150            .map_err(|e| Error::ValueError(eyre::Error::from(e)))
151    }
152
153    pub fn get_float(&self, key: impl AsRef<str>) -> Result<f64> {
154        self.get_str(key)?
155            .parse()
156            .map_err(|e| Error::ValueError(eyre::Error::from(e)))
157    }
158
159    pub fn get_bool(&self, key: impl AsRef<str>) -> Result<bool> {
160        self.get_str(key)?
161            .parse()
162            .map_err(|e| Error::ValueError(eyre::Error::from(e)))
163    }
164
165    pub fn get_datetime(&self, key: impl AsRef<str>) -> Result<DateTime<Utc>> {
166        self.get_str(key)?
167            .parse::<DateTime<Utc>>()
168            .map_err(|e| Error::ValueError(eyre::Error::from(e)))
169    }
170
171    pub fn get_array<T: DeserializeOwned>(&self, key: impl AsRef<str>) -> Result<Vec<T>> {
172        serde_json::from_str(self.get_str(key)?)
173            .map_err(|e| Error::ValueError(eyre::Error::from(e)))
174    }
175
176    pub fn get_object<T: DeserializeOwned>(&self, key: impl AsRef<str>) -> Result<T> {
177        serde_json::from_str(self.get_str(key)?)
178            .map_err(|e| Error::ValueError(eyre::Error::from(e)))
179    }
180}
181
182#[derive(Debug, Clone, Deserialize)]
183pub struct RetryConfig {
184    /// Number of retries.
185    pub count: Option<usize>,
186    /// Factor to multiply the delay between retries.
187    pub factor: Option<f32>,
188    /// Whether to add jitter to the delay between retries.
189    pub jitter: Option<bool>,
190    /// Minimum delay between retries.
191    #[serde(with = "humantime_serde")]
192    pub min_delay: Option<Duration>,
193    /// Maximum delay between retries.
194    #[serde(with = "humantime_serde")]
195    pub max_delay: Option<Duration>,
196}
197
198impl Default for RetryConfig {
199    fn default() -> Self {
200        RetryConfig {
201            count: Some(0),
202            factor: Some(2.0),
203            jitter: Some(false),
204            min_delay: Some(Duration::from_secs(1)),
205            max_delay: Some(Duration::from_secs(60)),
206        }
207    }
208}
209
210impl RetryConfig {
211    pub fn backoff(&self) -> backon::ExponentialBuilder {
212        let builder = backon::ExponentialBuilder::new()
213            .with_max_times(self.count.unwrap_or_default())
214            .with_factor(self.factor.unwrap_or(2.0))
215            .with_min_delay(self.min_delay.unwrap_or(Duration::from_secs(1)))
216            .with_max_delay(self.max_delay.unwrap_or(Duration::from_secs(60)));
217
218        if self.jitter.unwrap_or_default() {
219            builder.with_jitter()
220        } else {
221            builder
222        }
223    }
224}
225
226#[cfg(test)]
227mod test {
228    use super::*;
229    use pretty_assertions::assert_eq;
230    use std::{time::Duration, vec};
231    use test_case::test_case;
232
233    fn load_test_config() -> eyre::Result<Config> {
234        let manifest_dir = env!("CARGO_MANIFEST_DIR");
235        let config_path = Path::new(manifest_dir).join("../tanu-sample.toml");
236        Ok(super::Config::load_from(&config_path)?)
237    }
238
239    fn load_test_project_config() -> eyre::Result<ProjectConfig> {
240        Ok(load_test_config()?.projects.remove(0))
241    }
242
243    #[test]
244    fn load_config() -> eyre::Result<()> {
245        let cfg = load_test_config()?;
246        assert_eq!(cfg.projects.len(), 1);
247
248        let project = &cfg.projects[0];
249        assert_eq!(project.name, "default");
250        assert_eq!(project.test_ignore, Vec::<String>::new());
251        assert_eq!(project.retry.count, Some(0));
252        assert_eq!(project.retry.factor, Some(2.0));
253        assert_eq!(project.retry.jitter, Some(false));
254        assert_eq!(project.retry.min_delay, Some(Duration::from_secs(1)));
255        assert_eq!(project.retry.max_delay, Some(Duration::from_secs(60)));
256
257        Ok(())
258    }
259
260    #[test_case("TANU_DEFAULT_STR_KEY"; "project config")]
261    #[test_case("TANU_STR_KEY"; "global config")]
262    fn get_str(key: &str) -> eyre::Result<()> {
263        std::env::set_var(key, "example_string");
264        let project = load_test_project_config()?;
265        assert_eq!(project.get_str("str_key")?, "example_string");
266        Ok(())
267    }
268
269    #[test_case("TANU_DEFAULT_INT_KEY"; "project config")]
270    #[test_case("TANU_INT_KEY"; "global config")]
271    fn get_int(key: &str) -> eyre::Result<()> {
272        std::env::set_var(key, "42");
273        let project = load_test_project_config()?;
274        assert_eq!(project.get_int("int_key")?, 42);
275        Ok(())
276    }
277
278    #[test_case("TANU_DEFAULT"; "project config")]
279    #[test_case("TANU"; "global config")]
280    fn get_float(prefix: &str) -> eyre::Result<()> {
281        std::env::set_var(format!("{prefix}_FLOAT_KEY"), "5.5");
282        let project = load_test_project_config()?;
283        assert_eq!(project.get_float("float_key")?, 5.5);
284        Ok(())
285    }
286
287    #[test_case("TANU_DEFAULT_BOOL_KEY"; "project config")]
288    #[test_case("TANU_BOOL_KEY"; "global config")]
289    fn get_bool(key: &str) -> eyre::Result<()> {
290        std::env::set_var(key, "true");
291        let project = load_test_project_config()?;
292        assert_eq!(project.get_bool("bool_key")?, true);
293        Ok(())
294    }
295
296    #[test_case("TANU_DEFAULT_DATETIME_KEY"; "project config")]
297    #[test_case("TANU_DATETIME_KEY"; "global config")]
298    fn get_datetime(key: &str) -> eyre::Result<()> {
299        let datetime_str = "2025-03-08T12:00:00Z";
300        std::env::set_var(key, datetime_str);
301        let project = load_test_project_config()?;
302        assert_eq!(
303            project
304                .get_datetime("datetime_key")?
305                .to_rfc3339_opts(chrono::SecondsFormat::Secs, true),
306            datetime_str
307        );
308        Ok(())
309    }
310
311    #[test_case("TANU_DEFAULT_ARRAY_KEY"; "project config")]
312    #[test_case("TANU_ARRAY_KEY"; "global config")]
313    fn get_array(key: &str) -> eyre::Result<()> {
314        std::env::set_var(key, "[1, 2, 3]");
315        let project = load_test_project_config()?;
316        let array: Vec<i64> = project.get_array("array_key")?;
317        assert_eq!(array, vec![1, 2, 3]);
318        Ok(())
319    }
320
321    #[test_case("TANU_DEFAULT"; "project config")]
322    #[test_case("TANU"; "global config")]
323    fn get_object(prefix: &str) -> eyre::Result<()> {
324        #[derive(Debug, Deserialize, PartialEq)]
325        struct Foo {
326            foo: Vec<String>,
327        }
328        std::env::set_var(
329            format!("{prefix}_OBJECT_KEY"),
330            "{\"foo\": [\"bar\", \"baz\"]}",
331        );
332        let project = load_test_project_config()?;
333        let obj: Foo = project.get_object("object_key")?;
334        assert_eq!(obj.foo, vec!["bar", "baz"]);
335        Ok(())
336    }
337}