1use chrono::{DateTime, Utc};
40use once_cell::sync::Lazy;
41use serde::{de::DeserializeOwned, Deserialize};
42use std::{collections::HashMap, io::Read, path::Path, time::Duration};
43use toml::Value as TomlValue;
44use tracing::*;
45
46use crate::{Error, Result};
47
48static CONFIG: Lazy<Config> = Lazy::new(|| {
49 let _ = dotenv::dotenv();
50 Config::load().unwrap_or_default()
51});
52
53tokio::task_local! {
54 pub static PROJECT: ProjectConfig;
55}
56
57#[doc(hidden)]
58pub fn get_tanu_config() -> &'static Config {
59 &CONFIG
60}
61
62pub fn get_config() -> ProjectConfig {
65 PROJECT.get()
66}
67
68#[derive(Debug, Clone, Deserialize)]
70pub struct Config {
71 pub projects: Vec<ProjectConfig>,
72 #[serde(default)]
74 pub tui: Tui,
75}
76
77impl Default for Config {
78 fn default() -> Self {
79 Config {
80 projects: vec![ProjectConfig {
81 name: "default".to_string(),
82 ..Default::default()
83 }],
84 tui: Tui::default(),
85 }
86 }
87}
88
89#[derive(Debug, Clone, Default, Deserialize)]
91pub struct Tui {
92 #[serde(default)]
93 pub payload: Payload,
94}
95
96#[derive(Debug, Clone, Default, Deserialize)]
97pub struct Payload {
98 pub color_theme: Option<String>,
100}
101
102impl Config {
103 fn load_from(path: &Path) -> Result<Config> {
105 let Ok(mut file) = std::fs::File::open(path) else {
106 return Ok(Config::default());
107 };
108
109 let mut buf = String::new();
110 file.read_to_string(&mut buf)
111 .map_err(|e| Error::LoadError(e.to_string()))?;
112 let mut cfg: Config = toml::from_str(&buf).map_err(|e| {
113 Error::LoadError(format!(
114 "failed to deserialize tanu.toml into tanu::Config: {e}"
115 ))
116 })?;
117
118 debug!("tanu.toml was successfully loaded: {cfg:#?}");
119
120 cfg.load_env();
121
122 Ok(cfg)
123 }
124
125 fn load() -> Result<Config> {
127 Config::load_from(Path::new("tanu.toml"))
128 }
129
130 fn load_env(&mut self) {
140 static PREFIX: &str = "TANU";
141
142 let global_prefix = format!("{PREFIX}_");
143 let project_prefixes: Vec<_> = self
144 .projects
145 .iter()
146 .map(|p| format!("{PREFIX}_{}_", p.name.to_uppercase()))
147 .collect();
148 debug!("Loading global configuration from env");
149 let global_vars: HashMap<_, _> = std::env::vars()
150 .filter_map(|(k, v)| {
151 let is_project_var = project_prefixes.iter().any(|pp| k.contains(pp));
152 if is_project_var {
153 return None;
154 }
155
156 k.find(&global_prefix)?;
157 Some((
158 k[global_prefix.len()..].to_string().to_lowercase(),
159 TomlValue::String(v),
160 ))
161 })
162 .collect();
163
164 debug!("Loading project configuration from env");
165 for project in &mut self.projects {
166 let project_prefix = format!("{PREFIX}_{}_", project.name.to_uppercase());
167 let vars: HashMap<_, _> = std::env::vars()
168 .filter_map(|(k, v)| {
169 k.find(&project_prefix)?;
170 Some((
171 k[project_prefix.len()..].to_string().to_lowercase(),
172 TomlValue::String(v),
173 ))
174 })
175 .collect();
176 project.data.extend(vars);
177 project.data.extend(global_vars.clone());
178 }
179
180 debug!("tanu configuration loaded from env: {self:#?}");
181 }
182
183 pub fn color_theme(&self) -> Option<&str> {
185 self.tui.payload.color_theme.as_deref()
186 }
187}
188
189#[derive(Debug, Clone, Default, Deserialize)]
191pub struct ProjectConfig {
192 pub name: String,
194 #[serde(flatten)]
196 pub data: HashMap<String, TomlValue>,
197 #[serde(default)]
199 pub test_ignore: Vec<String>,
200 #[serde(default)]
201 pub retry: RetryConfig,
202}
203
204impl ProjectConfig {
205 pub fn get(&self, key: impl AsRef<str>) -> Result<&TomlValue> {
206 let key = key.as_ref();
207 self.data
208 .get(key)
209 .ok_or_else(|| Error::ValueNotFound(key.to_string()))
210 }
211
212 pub fn get_str(&self, key: impl AsRef<str>) -> Result<&str> {
213 let key = key.as_ref();
214 self.get(key)?
215 .as_str()
216 .ok_or_else(|| Error::ValueNotFound(key.to_string()))
217 }
218
219 pub fn get_int(&self, key: impl AsRef<str>) -> Result<i64> {
220 self.get_str(key)?
221 .parse()
222 .map_err(|e| Error::ValueError(eyre::Error::from(e)))
223 }
224
225 pub fn get_float(&self, key: impl AsRef<str>) -> Result<f64> {
226 self.get_str(key)?
227 .parse()
228 .map_err(|e| Error::ValueError(eyre::Error::from(e)))
229 }
230
231 pub fn get_bool(&self, key: impl AsRef<str>) -> Result<bool> {
232 self.get_str(key)?
233 .parse()
234 .map_err(|e| Error::ValueError(eyre::Error::from(e)))
235 }
236
237 pub fn get_datetime(&self, key: impl AsRef<str>) -> Result<DateTime<Utc>> {
238 self.get_str(key)?
239 .parse::<DateTime<Utc>>()
240 .map_err(|e| Error::ValueError(eyre::Error::from(e)))
241 }
242
243 pub fn get_array<T: DeserializeOwned>(&self, key: impl AsRef<str>) -> Result<Vec<T>> {
244 serde_json::from_str(self.get_str(key)?)
245 .map_err(|e| Error::ValueError(eyre::Error::from(e)))
246 }
247
248 pub fn get_object<T: DeserializeOwned>(&self, key: impl AsRef<str>) -> Result<T> {
249 serde_json::from_str(self.get_str(key)?)
250 .map_err(|e| Error::ValueError(eyre::Error::from(e)))
251 }
252}
253
254#[derive(Debug, Clone, Deserialize)]
255pub struct RetryConfig {
256 #[serde(default)]
258 pub count: Option<usize>,
259 #[serde(default)]
261 pub factor: Option<f32>,
262 #[serde(default)]
264 pub jitter: Option<bool>,
265 #[serde(default)]
267 #[serde(with = "humantime_serde")]
268 pub min_delay: Option<Duration>,
269 #[serde(default)]
271 #[serde(with = "humantime_serde")]
272 pub max_delay: Option<Duration>,
273}
274
275impl Default for RetryConfig {
276 fn default() -> Self {
277 RetryConfig {
278 count: Some(0),
279 factor: Some(2.0),
280 jitter: Some(false),
281 min_delay: Some(Duration::from_secs(1)),
282 max_delay: Some(Duration::from_secs(60)),
283 }
284 }
285}
286
287impl RetryConfig {
288 pub fn backoff(&self) -> backon::ExponentialBuilder {
289 let builder = backon::ExponentialBuilder::new()
290 .with_max_times(self.count.unwrap_or_default())
291 .with_factor(self.factor.unwrap_or(2.0))
292 .with_min_delay(self.min_delay.unwrap_or(Duration::from_secs(1)))
293 .with_max_delay(self.max_delay.unwrap_or(Duration::from_secs(60)));
294
295 if self.jitter.unwrap_or_default() {
296 builder.with_jitter()
297 } else {
298 builder
299 }
300 }
301}
302
303#[cfg(test)]
304mod test {
305 use super::*;
306 use pretty_assertions::assert_eq;
307 use std::{time::Duration, vec};
308 use test_case::test_case;
309
310 fn load_test_config() -> eyre::Result<Config> {
311 let manifest_dir = env!("CARGO_MANIFEST_DIR");
312 let config_path = Path::new(manifest_dir).join("../tanu-sample.toml");
313 Ok(super::Config::load_from(&config_path)?)
314 }
315
316 fn load_test_project_config() -> eyre::Result<ProjectConfig> {
317 Ok(load_test_config()?.projects.remove(0))
318 }
319
320 #[test]
321 fn load_config() -> eyre::Result<()> {
322 let cfg = load_test_config()?;
323 assert_eq!(cfg.projects.len(), 1);
324
325 let project = &cfg.projects[0];
326 assert_eq!(project.name, "default");
327 assert_eq!(project.test_ignore, Vec::<String>::new());
328 assert_eq!(project.retry.count, Some(0));
329 assert_eq!(project.retry.factor, Some(2.0));
330 assert_eq!(project.retry.jitter, Some(false));
331 assert_eq!(project.retry.min_delay, Some(Duration::from_secs(1)));
332 assert_eq!(project.retry.max_delay, Some(Duration::from_secs(60)));
333
334 Ok(())
335 }
336
337 #[test_case("TANU_DEFAULT_STR_KEY"; "project config")]
338 #[test_case("TANU_STR_KEY"; "global config")]
339 fn get_str(key: &str) -> eyre::Result<()> {
340 std::env::set_var(key, "example_string");
341 let project = load_test_project_config()?;
342 assert_eq!(project.get_str("str_key")?, "example_string");
343 Ok(())
344 }
345
346 #[test_case("TANU_DEFAULT_INT_KEY"; "project config")]
347 #[test_case("TANU_INT_KEY"; "global config")]
348 fn get_int(key: &str) -> eyre::Result<()> {
349 std::env::set_var(key, "42");
350 let project = load_test_project_config()?;
351 assert_eq!(project.get_int("int_key")?, 42);
352 Ok(())
353 }
354
355 #[test_case("TANU_DEFAULT"; "project config")]
356 #[test_case("TANU"; "global config")]
357 fn get_float(prefix: &str) -> eyre::Result<()> {
358 std::env::set_var(format!("{prefix}_FLOAT_KEY"), "5.5");
359 let project = load_test_project_config()?;
360 assert_eq!(project.get_float("float_key")?, 5.5);
361 Ok(())
362 }
363
364 #[test_case("TANU_DEFAULT_BOOL_KEY"; "project config")]
365 #[test_case("TANU_BOOL_KEY"; "global config")]
366 fn get_bool(key: &str) -> eyre::Result<()> {
367 std::env::set_var(key, "true");
368 let project = load_test_project_config()?;
369 assert_eq!(project.get_bool("bool_key")?, true);
370 Ok(())
371 }
372
373 #[test_case("TANU_DEFAULT_DATETIME_KEY"; "project config")]
374 #[test_case("TANU_DATETIME_KEY"; "global config")]
375 fn get_datetime(key: &str) -> eyre::Result<()> {
376 let datetime_str = "2025-03-08T12:00:00Z";
377 std::env::set_var(key, datetime_str);
378 let project = load_test_project_config()?;
379 assert_eq!(
380 project
381 .get_datetime("datetime_key")?
382 .to_rfc3339_opts(chrono::SecondsFormat::Secs, true),
383 datetime_str
384 );
385 Ok(())
386 }
387
388 #[test_case("TANU_DEFAULT_ARRAY_KEY"; "project config")]
389 #[test_case("TANU_ARRAY_KEY"; "global config")]
390 fn get_array(key: &str) -> eyre::Result<()> {
391 std::env::set_var(key, "[1, 2, 3]");
392 let project = load_test_project_config()?;
393 let array: Vec<i64> = project.get_array("array_key")?;
394 assert_eq!(array, vec![1, 2, 3]);
395 Ok(())
396 }
397
398 #[test_case("TANU_DEFAULT"; "project config")]
399 #[test_case("TANU"; "global config")]
400 fn get_object(prefix: &str) -> eyre::Result<()> {
401 #[derive(Debug, Deserialize, PartialEq)]
402 struct Foo {
403 foo: Vec<String>,
404 }
405 std::env::set_var(
406 format!("{prefix}_OBJECT_KEY"),
407 "{\"foo\": [\"bar\", \"baz\"]}",
408 );
409 let project = load_test_project_config()?;
410 let obj: Foo = project.get_object("object_key")?;
411 assert_eq!(obj.foo, vec!["bar", "baz"]);
412 Ok(())
413 }
414}