1use chrono::{DateTime, Utc};
2use once_cell::sync::Lazy;
3use serde::{de::DeserializeOwned, Deserialize};
4use std::{collections::HashMap, io::Read, path::Path, time::Duration};
5use toml::Value as TomlValue;
6use tracing::*;
7
8use crate::{Error, Result};
9
10static CONFIG: Lazy<Config> = Lazy::new(|| {
11 let _ = dotenv::dotenv();
12 Config::load().unwrap_or_default()
13});
14
15tokio::task_local! {
16 pub static PROJECT: ProjectConfig;
17}
18
19pub fn get_tanu_config() -> &'static Config {
20 &CONFIG
21}
22
23pub fn get_config() -> ProjectConfig {
26 PROJECT.get()
27}
28
29#[derive(Debug, Clone, Deserialize)]
31pub struct Config {
32 pub projects: Vec<ProjectConfig>,
33 #[serde(default)]
35 pub tui: Tui,
36}
37
38impl Default for Config {
39 fn default() -> Self {
40 Config {
41 projects: vec![ProjectConfig {
42 name: "default".to_string(),
43 ..Default::default()
44 }],
45 tui: Tui::default(),
46 }
47 }
48}
49
50#[derive(Debug, Clone, Default, Deserialize)]
52pub struct Tui {
53 #[serde(default)]
54 pub payload: Payload,
55}
56
57#[derive(Debug, Clone, Default, Deserialize)]
58pub struct Payload {
59 pub color_theme: Option<String>,
61}
62
63impl Config {
64 fn load_from(path: &Path) -> Result<Config> {
66 let Ok(mut file) = std::fs::File::open(path) else {
67 return Ok(Config::default());
68 };
69
70 let mut buf = String::new();
71 file.read_to_string(&mut buf)
72 .map_err(|e| Error::LoadError(e.to_string()))?;
73 let mut cfg: Config = toml::from_str(&buf).map_err(|e| {
74 Error::LoadError(format!(
75 "failed to deserialize tanu.toml into tanu::Config: {e}"
76 ))
77 })?;
78
79 debug!("tanu.toml was successfully loaded: {cfg:#?}");
80
81 cfg.load_env();
82
83 Ok(cfg)
84 }
85
86 fn load() -> Result<Config> {
88 Config::load_from(Path::new("tanu.toml"))
89 }
90
91 fn load_env(&mut self) {
101 static PREFIX: &str = "TANU";
102
103 let global_prefix = format!("{PREFIX}_");
104 let project_prefixes: Vec<_> = self
105 .projects
106 .iter()
107 .map(|p| format!("{PREFIX}_{}_", p.name.to_uppercase()))
108 .collect();
109 debug!("Loading global configuration from env");
110 let global_vars: HashMap<_, _> = std::env::vars()
111 .filter_map(|(k, v)| {
112 let is_project_var = project_prefixes.iter().any(|pp| k.contains(pp));
113 if is_project_var {
114 return None;
115 }
116
117 k.find(&global_prefix)?;
118 Some((
119 k[global_prefix.len()..].to_string().to_lowercase(),
120 TomlValue::String(v),
121 ))
122 })
123 .collect();
124
125 debug!("Loading project configuration from env");
126 for project in &mut self.projects {
127 let project_prefix = format!("{PREFIX}_{}_", project.name.to_uppercase());
128 let vars: HashMap<_, _> = std::env::vars()
129 .filter_map(|(k, v)| {
130 k.find(&project_prefix)?;
131 Some((
132 k[project_prefix.len()..].to_string().to_lowercase(),
133 TomlValue::String(v),
134 ))
135 })
136 .collect();
137 project.data.extend(vars);
138 project.data.extend(global_vars.clone());
139 }
140
141 debug!("tanu configuration loaded from env: {self:#?}");
142 }
143
144 pub fn color_theme(&self) -> Option<&str> {
146 self.tui.payload.color_theme.as_deref()
147 }
148}
149
150#[derive(Debug, Clone, Default, Deserialize)]
152pub struct ProjectConfig {
153 pub name: String,
155 #[serde(flatten)]
157 pub data: HashMap<String, TomlValue>,
158 #[serde(default)]
160 pub test_ignore: Vec<String>,
161 #[serde(default)]
162 pub retry: RetryConfig,
163}
164
165impl ProjectConfig {
166 pub fn get(&self, key: impl AsRef<str>) -> Result<&TomlValue> {
167 let key = key.as_ref();
168 self.data
169 .get(key)
170 .ok_or_else(|| Error::ValueNotFound(key.to_string()))
171 }
172
173 pub fn get_str(&self, key: impl AsRef<str>) -> Result<&str> {
174 let key = key.as_ref();
175 self.get(key)?
176 .as_str()
177 .ok_or_else(|| Error::ValueNotFound(key.to_string()))
178 }
179
180 pub fn get_int(&self, key: impl AsRef<str>) -> Result<i64> {
181 self.get_str(key)?
182 .parse()
183 .map_err(|e| Error::ValueError(eyre::Error::from(e)))
184 }
185
186 pub fn get_float(&self, key: impl AsRef<str>) -> Result<f64> {
187 self.get_str(key)?
188 .parse()
189 .map_err(|e| Error::ValueError(eyre::Error::from(e)))
190 }
191
192 pub fn get_bool(&self, key: impl AsRef<str>) -> Result<bool> {
193 self.get_str(key)?
194 .parse()
195 .map_err(|e| Error::ValueError(eyre::Error::from(e)))
196 }
197
198 pub fn get_datetime(&self, key: impl AsRef<str>) -> Result<DateTime<Utc>> {
199 self.get_str(key)?
200 .parse::<DateTime<Utc>>()
201 .map_err(|e| Error::ValueError(eyre::Error::from(e)))
202 }
203
204 pub fn get_array<T: DeserializeOwned>(&self, key: impl AsRef<str>) -> Result<Vec<T>> {
205 serde_json::from_str(self.get_str(key)?)
206 .map_err(|e| Error::ValueError(eyre::Error::from(e)))
207 }
208
209 pub fn get_object<T: DeserializeOwned>(&self, key: impl AsRef<str>) -> Result<T> {
210 serde_json::from_str(self.get_str(key)?)
211 .map_err(|e| Error::ValueError(eyre::Error::from(e)))
212 }
213}
214
215#[derive(Debug, Clone, Deserialize)]
216pub struct RetryConfig {
217 #[serde(default)]
219 pub count: Option<usize>,
220 #[serde(default)]
222 pub factor: Option<f32>,
223 #[serde(default)]
225 pub jitter: Option<bool>,
226 #[serde(default)]
228 #[serde(with = "humantime_serde")]
229 pub min_delay: Option<Duration>,
230 #[serde(default)]
232 #[serde(with = "humantime_serde")]
233 pub max_delay: Option<Duration>,
234}
235
236impl Default for RetryConfig {
237 fn default() -> Self {
238 RetryConfig {
239 count: Some(0),
240 factor: Some(2.0),
241 jitter: Some(false),
242 min_delay: Some(Duration::from_secs(1)),
243 max_delay: Some(Duration::from_secs(60)),
244 }
245 }
246}
247
248impl RetryConfig {
249 pub fn backoff(&self) -> backon::ExponentialBuilder {
250 let builder = backon::ExponentialBuilder::new()
251 .with_max_times(self.count.unwrap_or_default())
252 .with_factor(self.factor.unwrap_or(2.0))
253 .with_min_delay(self.min_delay.unwrap_or(Duration::from_secs(1)))
254 .with_max_delay(self.max_delay.unwrap_or(Duration::from_secs(60)));
255
256 if self.jitter.unwrap_or_default() {
257 builder.with_jitter()
258 } else {
259 builder
260 }
261 }
262}
263
264#[cfg(test)]
265mod test {
266 use super::*;
267 use pretty_assertions::assert_eq;
268 use std::{time::Duration, vec};
269 use test_case::test_case;
270
271 fn load_test_config() -> eyre::Result<Config> {
272 let manifest_dir = env!("CARGO_MANIFEST_DIR");
273 let config_path = Path::new(manifest_dir).join("../tanu-sample.toml");
274 Ok(super::Config::load_from(&config_path)?)
275 }
276
277 fn load_test_project_config() -> eyre::Result<ProjectConfig> {
278 Ok(load_test_config()?.projects.remove(0))
279 }
280
281 #[test]
282 fn load_config() -> eyre::Result<()> {
283 let cfg = load_test_config()?;
284 assert_eq!(cfg.projects.len(), 1);
285
286 let project = &cfg.projects[0];
287 assert_eq!(project.name, "default");
288 assert_eq!(project.test_ignore, Vec::<String>::new());
289 assert_eq!(project.retry.count, Some(0));
290 assert_eq!(project.retry.factor, Some(2.0));
291 assert_eq!(project.retry.jitter, Some(false));
292 assert_eq!(project.retry.min_delay, Some(Duration::from_secs(1)));
293 assert_eq!(project.retry.max_delay, Some(Duration::from_secs(60)));
294
295 Ok(())
296 }
297
298 #[test_case("TANU_DEFAULT_STR_KEY"; "project config")]
299 #[test_case("TANU_STR_KEY"; "global config")]
300 fn get_str(key: &str) -> eyre::Result<()> {
301 std::env::set_var(key, "example_string");
302 let project = load_test_project_config()?;
303 assert_eq!(project.get_str("str_key")?, "example_string");
304 Ok(())
305 }
306
307 #[test_case("TANU_DEFAULT_INT_KEY"; "project config")]
308 #[test_case("TANU_INT_KEY"; "global config")]
309 fn get_int(key: &str) -> eyre::Result<()> {
310 std::env::set_var(key, "42");
311 let project = load_test_project_config()?;
312 assert_eq!(project.get_int("int_key")?, 42);
313 Ok(())
314 }
315
316 #[test_case("TANU_DEFAULT"; "project config")]
317 #[test_case("TANU"; "global config")]
318 fn get_float(prefix: &str) -> eyre::Result<()> {
319 std::env::set_var(format!("{prefix}_FLOAT_KEY"), "5.5");
320 let project = load_test_project_config()?;
321 assert_eq!(project.get_float("float_key")?, 5.5);
322 Ok(())
323 }
324
325 #[test_case("TANU_DEFAULT_BOOL_KEY"; "project config")]
326 #[test_case("TANU_BOOL_KEY"; "global config")]
327 fn get_bool(key: &str) -> eyre::Result<()> {
328 std::env::set_var(key, "true");
329 let project = load_test_project_config()?;
330 assert_eq!(project.get_bool("bool_key")?, true);
331 Ok(())
332 }
333
334 #[test_case("TANU_DEFAULT_DATETIME_KEY"; "project config")]
335 #[test_case("TANU_DATETIME_KEY"; "global config")]
336 fn get_datetime(key: &str) -> eyre::Result<()> {
337 let datetime_str = "2025-03-08T12:00:00Z";
338 std::env::set_var(key, datetime_str);
339 let project = load_test_project_config()?;
340 assert_eq!(
341 project
342 .get_datetime("datetime_key")?
343 .to_rfc3339_opts(chrono::SecondsFormat::Secs, true),
344 datetime_str
345 );
346 Ok(())
347 }
348
349 #[test_case("TANU_DEFAULT_ARRAY_KEY"; "project config")]
350 #[test_case("TANU_ARRAY_KEY"; "global config")]
351 fn get_array(key: &str) -> eyre::Result<()> {
352 std::env::set_var(key, "[1, 2, 3]");
353 let project = load_test_project_config()?;
354 let array: Vec<i64> = project.get_array("array_key")?;
355 assert_eq!(array, vec![1, 2, 3]);
356 Ok(())
357 }
358
359 #[test_case("TANU_DEFAULT"; "project config")]
360 #[test_case("TANU"; "global config")]
361 fn get_object(prefix: &str) -> eyre::Result<()> {
362 #[derive(Debug, Deserialize, PartialEq)]
363 struct Foo {
364 foo: Vec<String>,
365 }
366 std::env::set_var(
367 format!("{prefix}_OBJECT_KEY"),
368 "{\"foo\": [\"bar\", \"baz\"]}",
369 );
370 let project = load_test_project_config()?;
371 let obj: Foo = project.get_object("object_key")?;
372 assert_eq!(obj.foo, vec!["bar", "baz"]);
373 Ok(())
374 }
375}