1use chrono::{DateTime, Utc};
40use once_cell::sync::Lazy;
41use serde::{de::DeserializeOwned, Deserialize};
42use std::{collections::HashMap, io::Read, path::Path, sync::Arc, time::Duration};
43use toml::Value as TomlValue;
44use tracing::*;
45
46use crate::{Error, Result};
47
48static CONFIG: Lazy<Config> = Lazy::new(|| {
49 let _ = dotenv::dotenv();
50 Config::load().unwrap_or_default()
51});
52
53tokio::task_local! {
54 pub static PROJECT: Arc<ProjectConfig>;
55}
56
57#[doc(hidden)]
58pub fn get_tanu_config() -> &'static Config {
59 &CONFIG
60}
61
62pub fn get_config() -> Arc<ProjectConfig> {
65 PROJECT.get()
66}
67
68#[derive(Debug, Clone)]
70pub struct Config {
71 pub projects: Vec<Arc<ProjectConfig>>,
72 pub tui: Tui,
74}
75
76impl Default for Config {
77 fn default() -> Self {
78 Config {
79 projects: vec![Arc::new(ProjectConfig {
80 name: "default".to_string(),
81 ..Default::default()
82 })],
83 tui: Tui::default(),
84 }
85 }
86}
87
88#[derive(Debug, Clone, Default, Deserialize)]
90pub struct Tui {
91 #[serde(default)]
92 pub payload: Payload,
93}
94
95#[derive(Debug, Clone, Default, Deserialize)]
96pub struct Payload {
97 pub color_theme: Option<String>,
99}
100
101impl Config {
102 fn load_from(path: &Path) -> Result<Config> {
104 let Ok(mut file) = std::fs::File::open(path) else {
105 return Ok(Config::default());
106 };
107
108 let mut buf = String::new();
109 file.read_to_string(&mut buf)
110 .map_err(|e| Error::LoadError(e.to_string()))?;
111
112 #[derive(Deserialize)]
113 struct ConfigHelper {
114 #[serde(default)]
115 projects: Vec<ProjectConfig>,
116 #[serde(default)]
117 tui: Tui,
118 }
119
120 let helper: ConfigHelper = toml::from_str(&buf).map_err(|e| {
121 Error::LoadError(format!(
122 "failed to deserialize tanu.toml into tanu::Config: {e}"
123 ))
124 })?;
125
126 let mut cfg = Config {
127 projects: helper.projects.into_iter().map(Arc::new).collect(),
128 tui: helper.tui,
129 };
130
131 debug!("tanu.toml was successfully loaded: {cfg:#?}");
132
133 cfg.load_env();
134
135 Ok(cfg)
136 }
137
138 fn load() -> Result<Config> {
140 Config::load_from(Path::new("tanu.toml"))
141 }
142
143 fn load_env(&mut self) {
153 static PREFIX: &str = "TANU";
154
155 let global_prefix = format!("{PREFIX}_");
156 let project_prefixes: Vec<_> = self
157 .projects
158 .iter()
159 .map(|p| format!("{PREFIX}_{}_", p.name.to_uppercase()))
160 .collect();
161 debug!("Loading global configuration from env");
162 let global_vars: HashMap<_, _> = std::env::vars()
163 .filter_map(|(k, v)| {
164 let is_project_var = project_prefixes.iter().any(|pp| k.contains(pp));
165 if is_project_var {
166 return None;
167 }
168
169 k.find(&global_prefix)?;
170 Some((
171 k[global_prefix.len()..].to_string().to_lowercase(),
172 TomlValue::String(v),
173 ))
174 })
175 .collect();
176
177 debug!("Loading project configuration from env");
178 for project_arc in &mut self.projects {
179 let project_prefix = format!("{PREFIX}_{}_", project_arc.name.to_uppercase());
180 let vars: HashMap<_, _> = std::env::vars()
181 .filter_map(|(k, v)| {
182 k.find(&project_prefix)?;
183 Some((
184 k[project_prefix.len()..].to_string().to_lowercase(),
185 TomlValue::String(v),
186 ))
187 })
188 .collect();
189 let project = Arc::make_mut(project_arc);
190 project.data.extend(vars);
191 project.data.extend(global_vars.clone());
192 }
193
194 debug!("tanu configuration loaded from env: {self:#?}");
195 }
196
197 pub fn color_theme(&self) -> Option<&str> {
199 self.tui.payload.color_theme.as_deref()
200 }
201}
202
203#[derive(Debug, Clone, Default, Deserialize)]
205pub struct ProjectConfig {
206 pub name: String,
208 #[serde(flatten)]
210 pub data: HashMap<String, TomlValue>,
211 #[serde(default)]
213 pub test_ignore: Vec<String>,
214 #[serde(default)]
215 pub retry: RetryConfig,
216}
217
218impl ProjectConfig {
219 pub fn get(&self, key: impl AsRef<str>) -> Result<&TomlValue> {
220 let key = key.as_ref();
221 self.data
222 .get(key)
223 .ok_or_else(|| Error::ValueNotFound(key.to_string()))
224 }
225
226 pub fn get_str(&self, key: impl AsRef<str>) -> Result<&str> {
227 let key = key.as_ref();
228 self.get(key)?
229 .as_str()
230 .ok_or_else(|| Error::ValueNotFound(key.to_string()))
231 }
232
233 pub fn get_int(&self, key: impl AsRef<str>) -> Result<i64> {
234 self.get_str(key)?
235 .parse()
236 .map_err(|e| Error::ValueError(eyre::Error::from(e)))
237 }
238
239 pub fn get_float(&self, key: impl AsRef<str>) -> Result<f64> {
240 self.get_str(key)?
241 .parse()
242 .map_err(|e| Error::ValueError(eyre::Error::from(e)))
243 }
244
245 pub fn get_bool(&self, key: impl AsRef<str>) -> Result<bool> {
246 self.get_str(key)?
247 .parse()
248 .map_err(|e| Error::ValueError(eyre::Error::from(e)))
249 }
250
251 pub fn get_datetime(&self, key: impl AsRef<str>) -> Result<DateTime<Utc>> {
252 self.get_str(key)?
253 .parse::<DateTime<Utc>>()
254 .map_err(|e| Error::ValueError(eyre::Error::from(e)))
255 }
256
257 pub fn get_array<T: DeserializeOwned>(&self, key: impl AsRef<str>) -> Result<Vec<T>> {
258 serde_json::from_str(self.get_str(key)?)
259 .map_err(|e| Error::ValueError(eyre::Error::from(e)))
260 }
261
262 pub fn get_object<T: DeserializeOwned>(&self, key: impl AsRef<str>) -> Result<T> {
263 serde_json::from_str(self.get_str(key)?)
264 .map_err(|e| Error::ValueError(eyre::Error::from(e)))
265 }
266}
267
268#[derive(Debug, Clone, Deserialize)]
269pub struct RetryConfig {
270 #[serde(default)]
272 pub count: Option<usize>,
273 #[serde(default)]
275 pub factor: Option<f32>,
276 #[serde(default)]
278 pub jitter: Option<bool>,
279 #[serde(default)]
281 #[serde(with = "humantime_serde")]
282 pub min_delay: Option<Duration>,
283 #[serde(default)]
285 #[serde(with = "humantime_serde")]
286 pub max_delay: Option<Duration>,
287}
288
289impl Default for RetryConfig {
290 fn default() -> Self {
291 RetryConfig {
292 count: Some(0),
293 factor: Some(2.0),
294 jitter: Some(false),
295 min_delay: Some(Duration::from_secs(1)),
296 max_delay: Some(Duration::from_secs(60)),
297 }
298 }
299}
300
301impl RetryConfig {
302 pub fn backoff(&self) -> backon::ExponentialBuilder {
303 let builder = backon::ExponentialBuilder::new()
304 .with_max_times(self.count.unwrap_or_default())
305 .with_factor(self.factor.unwrap_or(2.0))
306 .with_min_delay(self.min_delay.unwrap_or(Duration::from_secs(1)))
307 .with_max_delay(self.max_delay.unwrap_or(Duration::from_secs(60)));
308
309 if self.jitter.unwrap_or_default() {
310 builder.with_jitter()
311 } else {
312 builder
313 }
314 }
315}
316
317#[cfg(test)]
318mod test {
319 use super::*;
320 use pretty_assertions::assert_eq;
321 use std::{time::Duration, vec};
322 use test_case::test_case;
323
324 fn load_test_config() -> eyre::Result<Config> {
325 let manifest_dir = env!("CARGO_MANIFEST_DIR");
326 let config_path = Path::new(manifest_dir).join("../tanu-sample.toml");
327 Ok(super::Config::load_from(&config_path)?)
328 }
329
330 fn load_test_project_config() -> eyre::Result<ProjectConfig> {
331 Ok(Arc::try_unwrap(load_test_config()?.projects.remove(0)).unwrap())
332 }
333
334 #[test]
335 fn load_config() -> eyre::Result<()> {
336 let cfg = load_test_config()?;
337 assert_eq!(cfg.projects.len(), 1);
338
339 let project = &cfg.projects[0];
340 assert_eq!(project.name, "default");
341 assert_eq!(project.test_ignore, Vec::<String>::new());
342 assert_eq!(project.retry.count, Some(0));
343 assert_eq!(project.retry.factor, Some(2.0));
344 assert_eq!(project.retry.jitter, Some(false));
345 assert_eq!(project.retry.min_delay, Some(Duration::from_secs(1)));
346 assert_eq!(project.retry.max_delay, Some(Duration::from_secs(60)));
347
348 Ok(())
349 }
350
351 #[test_case("TANU_DEFAULT_STR_KEY"; "project config")]
352 #[test_case("TANU_STR_KEY"; "global config")]
353 fn get_str(key: &str) -> eyre::Result<()> {
354 std::env::set_var(key, "example_string");
355 let project = load_test_project_config()?;
356 assert_eq!(project.get_str("str_key")?, "example_string");
357 Ok(())
358 }
359
360 #[test_case("TANU_DEFAULT_INT_KEY"; "project config")]
361 #[test_case("TANU_INT_KEY"; "global config")]
362 fn get_int(key: &str) -> eyre::Result<()> {
363 std::env::set_var(key, "42");
364 let project = load_test_project_config()?;
365 assert_eq!(project.get_int("int_key")?, 42);
366 Ok(())
367 }
368
369 #[test_case("TANU_DEFAULT"; "project config")]
370 #[test_case("TANU"; "global config")]
371 fn get_float(prefix: &str) -> eyre::Result<()> {
372 std::env::set_var(format!("{prefix}_FLOAT_KEY"), "5.5");
373 let project = load_test_project_config()?;
374 assert_eq!(project.get_float("float_key")?, 5.5);
375 Ok(())
376 }
377
378 #[test_case("TANU_DEFAULT_BOOL_KEY"; "project config")]
379 #[test_case("TANU_BOOL_KEY"; "global config")]
380 fn get_bool(key: &str) -> eyre::Result<()> {
381 std::env::set_var(key, "true");
382 let project = load_test_project_config()?;
383 assert_eq!(project.get_bool("bool_key")?, true);
384 Ok(())
385 }
386
387 #[test_case("TANU_DEFAULT_DATETIME_KEY"; "project config")]
388 #[test_case("TANU_DATETIME_KEY"; "global config")]
389 fn get_datetime(key: &str) -> eyre::Result<()> {
390 let datetime_str = "2025-03-08T12:00:00Z";
391 std::env::set_var(key, datetime_str);
392 let project = load_test_project_config()?;
393 assert_eq!(
394 project
395 .get_datetime("datetime_key")?
396 .to_rfc3339_opts(chrono::SecondsFormat::Secs, true),
397 datetime_str
398 );
399 Ok(())
400 }
401
402 #[test_case("TANU_DEFAULT_ARRAY_KEY"; "project config")]
403 #[test_case("TANU_ARRAY_KEY"; "global config")]
404 fn get_array(key: &str) -> eyre::Result<()> {
405 std::env::set_var(key, "[1, 2, 3]");
406 let project = load_test_project_config()?;
407 let array: Vec<i64> = project.get_array("array_key")?;
408 assert_eq!(array, vec![1, 2, 3]);
409 Ok(())
410 }
411
412 #[test_case("TANU_DEFAULT"; "project config")]
413 #[test_case("TANU"; "global config")]
414 fn get_object(prefix: &str) -> eyre::Result<()> {
415 #[derive(Debug, Deserialize, PartialEq)]
416 struct Foo {
417 foo: Vec<String>,
418 }
419 std::env::set_var(
420 format!("{prefix}_OBJECT_KEY"),
421 "{\"foo\": [\"bar\", \"baz\"]}",
422 );
423 let project = load_test_project_config()?;
424 let obj: Foo = project.get_object("object_key")?;
425 assert_eq!(obj.foo, vec!["bar", "baz"]);
426 Ok(())
427 }
428}