1use chrono::{DateTime, Utc};
2use once_cell::sync::Lazy;
3use serde::{de::DeserializeOwned, Deserialize};
4use std::{collections::HashMap, io::Read, path::Path, time::Duration};
5use toml::Value as TomlValue;
6use tracing::*;
7
8use crate::{Error, Result};
9
10static CONFIG: Lazy<Config> = Lazy::new(|| {
11 let _ = dotenv::dotenv();
12 Config::load().expect("failed to load tanu.toml")
13});
14
15tokio::task_local! {
16 pub static PROJECT: ProjectConfig;
17}
18
19pub fn get_tanu_config() -> &'static Config {
20 &CONFIG
21}
22
23pub fn get_config() -> ProjectConfig {
26 PROJECT.get()
27}
28
29#[derive(Debug, Clone, Default, Deserialize)]
31pub struct Config {
32 pub projects: Vec<ProjectConfig>,
33}
34
35impl Config {
36 fn load_from(path: &Path) -> Result<Config> {
38 let Ok(mut file) = std::fs::File::open(path) else {
39 return Ok(Config::default());
40 };
41
42 let mut buf = String::new();
43 file.read_to_string(&mut buf)
44 .map_err(|e| Error::LoadError(e.to_string()))?;
45 let mut cfg: Config = toml::from_str(&buf).map_err(|e| {
46 Error::LoadError(format!(
47 "failed to deserialize tanu.toml into tanu::Config: {e}"
48 ))
49 })?;
50
51 debug!("tanu.toml was successfully loaded: {cfg:#?}");
52
53 cfg.load_env();
54
55 Ok(cfg)
56 }
57
58 fn load() -> Result<Config> {
60 Config::load_from(Path::new("tanu.toml"))
61 }
62
63 fn load_env(&mut self) {
73 static PREFIX: &str = "TANU";
74
75 let global_prefix = format!("{PREFIX}_");
76 let project_prefixes: Vec<_> = self
77 .projects
78 .iter()
79 .map(|p| format!("{PREFIX}_{}_", p.name.to_uppercase()))
80 .collect();
81 debug!("Loading global configuration from env");
82 let global_vars: HashMap<_, _> = std::env::vars()
83 .filter_map(|(k, v)| {
84 let is_project_var = project_prefixes.iter().any(|pp| k.contains(pp));
85 if is_project_var {
86 return None;
87 }
88
89 k.find(&global_prefix)?;
90 Some((
91 k[global_prefix.len()..].to_string().to_lowercase(),
92 TomlValue::String(v),
93 ))
94 })
95 .collect();
96
97 debug!("Loading project configuration from env");
98 for project in &mut self.projects {
99 let project_prefix = format!("{PREFIX}_{}_", project.name.to_uppercase());
100 let vars: HashMap<_, _> = std::env::vars()
101 .filter_map(|(k, v)| {
102 k.find(&project_prefix)?;
103 Some((
104 k[project_prefix.len()..].to_string().to_lowercase(),
105 TomlValue::String(v),
106 ))
107 })
108 .collect();
109 project.data.extend(vars);
110 project.data.extend(global_vars.clone());
111 }
112
113 debug!("tanu configuration loaded from env: {self:#?}");
114 }
115}
116
117#[derive(Debug, Clone, Default, Deserialize)]
119pub struct ProjectConfig {
120 pub name: String,
122 #[serde(flatten)]
124 pub data: HashMap<String, TomlValue>,
125 #[serde(default)]
127 pub test_ignore: Vec<String>,
128 #[serde(default)]
129 pub retry: RetryConfig,
130}
131
132impl ProjectConfig {
133 pub fn get(&self, key: impl AsRef<str>) -> Result<&TomlValue> {
134 let key = key.as_ref();
135 self.data
136 .get(key)
137 .ok_or_else(|| Error::ValueNotFound(key.to_string()))
138 }
139
140 pub fn get_str(&self, key: impl AsRef<str>) -> Result<&str> {
141 let key = key.as_ref();
142 self.get(key)?
143 .as_str()
144 .ok_or_else(|| Error::ValueNotFound(key.to_string()))
145 }
146
147 pub fn get_int(&self, key: impl AsRef<str>) -> Result<i64> {
148 self.get_str(key)?
149 .parse()
150 .map_err(|e| Error::ValueError(eyre::Error::from(e)))
151 }
152
153 pub fn get_float(&self, key: impl AsRef<str>) -> Result<f64> {
154 self.get_str(key)?
155 .parse()
156 .map_err(|e| Error::ValueError(eyre::Error::from(e)))
157 }
158
159 pub fn get_bool(&self, key: impl AsRef<str>) -> Result<bool> {
160 self.get_str(key)?
161 .parse()
162 .map_err(|e| Error::ValueError(eyre::Error::from(e)))
163 }
164
165 pub fn get_datetime(&self, key: impl AsRef<str>) -> Result<DateTime<Utc>> {
166 self.get_str(key)?
167 .parse::<DateTime<Utc>>()
168 .map_err(|e| Error::ValueError(eyre::Error::from(e)))
169 }
170
171 pub fn get_array<T: DeserializeOwned>(&self, key: impl AsRef<str>) -> Result<Vec<T>> {
172 serde_json::from_str(self.get_str(key)?)
173 .map_err(|e| Error::ValueError(eyre::Error::from(e)))
174 }
175
176 pub fn get_object<T: DeserializeOwned>(&self, key: impl AsRef<str>) -> Result<T> {
177 serde_json::from_str(self.get_str(key)?)
178 .map_err(|e| Error::ValueError(eyre::Error::from(e)))
179 }
180}
181
182#[derive(Debug, Clone, Deserialize)]
183pub struct RetryConfig {
184 pub count: Option<usize>,
186 pub factor: Option<f32>,
188 pub jitter: Option<bool>,
190 #[serde(with = "humantime_serde")]
192 pub min_delay: Option<Duration>,
193 #[serde(with = "humantime_serde")]
195 pub max_delay: Option<Duration>,
196}
197
198impl Default for RetryConfig {
199 fn default() -> Self {
200 RetryConfig {
201 count: Some(0),
202 factor: Some(2.0),
203 jitter: Some(false),
204 min_delay: Some(Duration::from_secs(1)),
205 max_delay: Some(Duration::from_secs(60)),
206 }
207 }
208}
209
210impl RetryConfig {
211 pub fn backoff(&self) -> backon::ExponentialBuilder {
212 let builder = backon::ExponentialBuilder::new()
213 .with_max_times(self.count.unwrap_or_default())
214 .with_factor(self.factor.unwrap_or(2.0))
215 .with_min_delay(self.min_delay.unwrap_or(Duration::from_secs(1)))
216 .with_max_delay(self.max_delay.unwrap_or(Duration::from_secs(60)));
217
218 if self.jitter.unwrap_or_default() {
219 builder.with_jitter()
220 } else {
221 builder
222 }
223 }
224}
225
226#[cfg(test)]
227mod test {
228 use super::*;
229 use pretty_assertions::assert_eq;
230 use std::{time::Duration, vec};
231 use test_case::test_case;
232
233 fn load_test_config() -> eyre::Result<Config> {
234 let manifest_dir = env!("CARGO_MANIFEST_DIR");
235 let config_path = Path::new(manifest_dir).join("../tanu-sample.toml");
236 Ok(super::Config::load_from(&config_path)?)
237 }
238
239 fn load_test_project_config() -> eyre::Result<ProjectConfig> {
240 Ok(load_test_config()?.projects.remove(0))
241 }
242
243 #[test]
244 fn load_config() -> eyre::Result<()> {
245 let cfg = load_test_config()?;
246 assert_eq!(cfg.projects.len(), 1);
247
248 let project = &cfg.projects[0];
249 assert_eq!(project.name, "default");
250 assert_eq!(project.test_ignore, Vec::<String>::new());
251 assert_eq!(project.retry.count, Some(0));
252 assert_eq!(project.retry.factor, Some(2.0));
253 assert_eq!(project.retry.jitter, Some(false));
254 assert_eq!(project.retry.min_delay, Some(Duration::from_secs(1)));
255 assert_eq!(project.retry.max_delay, Some(Duration::from_secs(60)));
256
257 Ok(())
258 }
259
260 #[test_case("TANU_DEFAULT_STR_KEY"; "project config")]
261 #[test_case("TANU_STR_KEY"; "global config")]
262 fn get_str(key: &str) -> eyre::Result<()> {
263 std::env::set_var(key, "example_string");
264 let project = load_test_project_config()?;
265 assert_eq!(project.get_str("str_key")?, "example_string");
266 Ok(())
267 }
268
269 #[test_case("TANU_DEFAULT_INT_KEY"; "project config")]
270 #[test_case("TANU_INT_KEY"; "global config")]
271 fn get_int(key: &str) -> eyre::Result<()> {
272 std::env::set_var(key, "42");
273 let project = load_test_project_config()?;
274 assert_eq!(project.get_int("int_key")?, 42);
275 Ok(())
276 }
277
278 #[test_case("TANU_DEFAULT"; "project config")]
279 #[test_case("TANU"; "global config")]
280 fn get_float(prefix: &str) -> eyre::Result<()> {
281 std::env::set_var(format!("{prefix}_FLOAT_KEY"), "5.5");
282 let project = load_test_project_config()?;
283 assert_eq!(project.get_float("float_key")?, 5.5);
284 Ok(())
285 }
286
287 #[test_case("TANU_DEFAULT_BOOL_KEY"; "project config")]
288 #[test_case("TANU_BOOL_KEY"; "global config")]
289 fn get_bool(key: &str) -> eyre::Result<()> {
290 std::env::set_var(key, "true");
291 let project = load_test_project_config()?;
292 assert_eq!(project.get_bool("bool_key")?, true);
293 Ok(())
294 }
295
296 #[test_case("TANU_DEFAULT_DATETIME_KEY"; "project config")]
297 #[test_case("TANU_DATETIME_KEY"; "global config")]
298 fn get_datetime(key: &str) -> eyre::Result<()> {
299 let datetime_str = "2025-03-08T12:00:00Z";
300 std::env::set_var(key, datetime_str);
301 let project = load_test_project_config()?;
302 assert_eq!(
303 project
304 .get_datetime("datetime_key")?
305 .to_rfc3339_opts(chrono::SecondsFormat::Secs, true),
306 datetime_str
307 );
308 Ok(())
309 }
310
311 #[test_case("TANU_DEFAULT_ARRAY_KEY"; "project config")]
312 #[test_case("TANU_ARRAY_KEY"; "global config")]
313 fn get_array(key: &str) -> eyre::Result<()> {
314 std::env::set_var(key, "[1, 2, 3]");
315 let project = load_test_project_config()?;
316 let array: Vec<i64> = project.get_array("array_key")?;
317 assert_eq!(array, vec![1, 2, 3]);
318 Ok(())
319 }
320
321 #[test_case("TANU_DEFAULT"; "project config")]
322 #[test_case("TANU"; "global config")]
323 fn get_object(prefix: &str) -> eyre::Result<()> {
324 #[derive(Debug, Deserialize, PartialEq)]
325 struct Foo {
326 foo: Vec<String>,
327 }
328 std::env::set_var(
329 format!("{prefix}_OBJECT_KEY"),
330 "{\"foo\": [\"bar\", \"baz\"]}",
331 );
332 let project = load_test_project_config()?;
333 let obj: Foo = project.get_object("object_key")?;
334 assert_eq!(obj.foo, vec!["bar", "baz"]);
335 Ok(())
336 }
337}