1use crate::hash_with_indifferent_access::HashWithIndifferentAccess;
2use serde::de::DeserializeOwned;
3use serde_json::{Map, Number, Value};
4use std::env;
5use std::fs;
6use std::path::Path;
7
8#[derive(Debug, thiserror::Error)]
10pub enum ConfigError {
11 #[error("file not found: {0}")]
13 FileNotFound(String),
14 #[error("parse error: {0}")]
16 ParseError(String),
17 #[error("missing key: {0}")]
19 MissingKey(String),
20}
21
22#[derive(Debug, Clone, Default)]
24pub struct Config {
25 values: HashWithIndifferentAccess,
26}
27
28impl Config {
29 #[must_use]
31 pub fn new() -> Self {
32 Self {
33 values: HashWithIndifferentAccess::new(),
34 }
35 }
36
37 pub fn from_toml(content: &str) -> Result<Self, ConfigError> {
39 let parsed: toml::Value =
40 toml::from_str(content).map_err(|error| ConfigError::ParseError(error.to_string()))?;
41 let json = serde_json::to_value(parsed)
42 .map_err(|error| ConfigError::ParseError(error.to_string()))?;
43
44 match json {
45 Value::Object(map) => Ok(Self {
46 values: HashWithIndifferentAccess::from(map),
47 }),
48 _ => Err(ConfigError::ParseError(
49 "top-level TOML value must be a table".to_owned(),
50 )),
51 }
52 }
53
54 pub fn from_file(path: &Path) -> Result<Self, ConfigError> {
56 let content = fs::read_to_string(path).map_err(|error| {
57 if error.kind() == std::io::ErrorKind::NotFound {
58 ConfigError::FileNotFound(path.display().to_string())
59 } else {
60 ConfigError::ParseError(error.to_string())
61 }
62 })?;
63
64 Self::from_toml(&content)
65 }
66
67 #[must_use]
69 pub fn get(&self, key: &str) -> Option<&Value> {
70 get_path(self.values.as_index_map(), key)
71 }
72
73 #[must_use]
75 pub fn get_str(&self, key: &str) -> Option<&str> {
76 self.get(key).and_then(Value::as_str)
77 }
78
79 #[must_use]
81 pub fn get_i64(&self, key: &str) -> Option<i64> {
82 self.get(key).and_then(Value::as_i64)
83 }
84
85 #[must_use]
87 pub fn get_bool(&self, key: &str) -> Option<bool> {
88 self.get(key).and_then(Value::as_bool)
89 }
90
91 pub fn set(&mut self, key: impl Into<String>, value: Value) {
93 set_path(self.values.as_index_map_mut(), &key.into(), value);
94 }
95
96 pub fn apply_env_overrides(&mut self, prefix: &str) {
100 let prefix = format!("{prefix}_");
101 for (env_key, env_value) in env::vars() {
102 if !env_key.starts_with(&prefix) {
103 continue;
104 }
105
106 let suffix = &env_key[prefix.len()..];
107 if suffix.is_empty() {
108 continue;
109 }
110
111 let dotted_key = suffix
112 .split('_')
113 .filter(|segment| !segment.is_empty())
114 .map(str::to_ascii_lowercase)
115 .collect::<Vec<_>>()
116 .join(".");
117
118 if dotted_key.is_empty() {
119 continue;
120 }
121
122 self.set(dotted_key, parse_env_value(&env_value));
123 }
124 }
125
126 pub fn extract<T: DeserializeOwned>(&self) -> Result<T, ConfigError> {
128 let root = Value::Object(
129 self.values
130 .as_index_map()
131 .iter()
132 .map(|(key, value)| (key.clone(), value.clone()))
133 .collect::<Map<String, Value>>(),
134 );
135
136 serde_json::from_value(root).map_err(|error| {
137 match parse_missing_field(&error.to_string()) {
138 Some(field) => ConfigError::MissingKey(field),
139 None => ConfigError::ParseError(error.to_string()),
140 }
141 })
142 }
143}
144
145fn get_path<'a>(root: &'a indexmap::IndexMap<String, Value>, key: &str) -> Option<&'a Value> {
146 let mut segments = key.split('.');
147 let first = segments.next()?;
148 let mut current = root.get(first)?;
149
150 for segment in segments {
151 current = current.as_object()?.get(segment)?;
152 }
153
154 Some(current)
155}
156
157fn set_path(root: &mut indexmap::IndexMap<String, Value>, key: &str, value: Value) {
158 let parts: Vec<&str> = key
159 .split('.')
160 .filter(|segment| !segment.is_empty())
161 .collect();
162 if parts.is_empty() {
163 return;
164 }
165
166 if parts.len() == 1 {
167 root.insert(parts[0].to_owned(), value);
168 return;
169 }
170
171 let mut current = root
172 .entry(parts[0].to_owned())
173 .or_insert_with(|| Value::Object(Map::new()));
174
175 for segment in &parts[1..parts.len() - 1] {
176 match current {
177 Value::Object(map) => {
178 current = map
179 .entry((*segment).to_owned())
180 .or_insert_with(|| Value::Object(Map::new()));
181 }
182 _ => {
183 *current = Value::Object(Map::new());
184 if let Value::Object(map) = current {
185 current = map
186 .entry((*segment).to_owned())
187 .or_insert_with(|| Value::Object(Map::new()));
188 }
189 }
190 }
191 }
192
193 if let Value::Object(map) = current {
194 map.insert(parts[parts.len() - 1].to_owned(), value);
195 }
196}
197
198fn parse_env_value(value: &str) -> Value {
199 if value.eq_ignore_ascii_case("true") {
200 return Value::Bool(true);
201 }
202 if value.eq_ignore_ascii_case("false") {
203 return Value::Bool(false);
204 }
205 if let Ok(integer) = value.parse::<i64>() {
206 return Value::Number(integer.into());
207 }
208 if let Ok(unsigned) = value.parse::<u64>() {
209 return Value::Number(unsigned.into());
210 }
211 if let Ok(float) = value.parse::<f64>()
212 && let Some(number) = Number::from_f64(float)
213 {
214 return Value::Number(number);
215 }
216 if ((value.starts_with('[') && value.ends_with(']'))
217 || (value.starts_with('{') && value.ends_with('}')))
218 && let Ok(json) = serde_json::from_str::<Value>(value)
219 {
220 return json;
221 }
222
223 Value::String(value.to_owned())
224}
225
226fn parse_missing_field(message: &str) -> Option<String> {
227 let prefix = "missing field `";
228 let suffix = "`";
229 let rest = message.strip_prefix(prefix)?;
230 let end = rest.find(suffix)?;
231 Some(rest[..end].to_owned())
232}
233
234#[cfg(test)]
235mod tests {
236 use super::{Config, ConfigError};
237 use serde::Deserialize;
238 use serde_json::json;
239 use std::fs;
240 use std::time::{SystemTime, UNIX_EPOCH};
241
242 #[derive(Debug, Deserialize, PartialEq)]
243 struct AppConfig {
244 app_name: String,
245 port: i64,
246 debug: bool,
247 database: DatabaseConfig,
248 }
249
250 #[derive(Debug, Deserialize, PartialEq)]
251 struct DatabaseConfig {
252 host: String,
253 pool: i64,
254 }
255
256 #[test]
257 fn config_from_toml_reads_top_level_and_nested_values() {
258 let config = Config::from_toml(
259 r#"
260 app_name = "rustrails"
261 port = 3000
262 debug = true
263
264 [database]
265 host = "localhost"
266 pool = 5
267 "#,
268 )
269 .unwrap();
270
271 assert_eq!(config.get_str("app_name"), Some("rustrails"));
272 assert_eq!(config.get_i64("port"), Some(3000));
273 assert_eq!(config.get_bool("debug"), Some(true));
274 assert_eq!(config.get_str("database.host"), Some("localhost"));
275 assert_eq!(config.get_i64("database.pool"), Some(5));
276 }
277
278 #[test]
279 fn config_set_creates_nested_paths() {
280 let mut config = Config::new();
281 config.set("service.name", json!("api"));
282 config.set("service.enabled", json!(true));
283
284 assert_eq!(config.get_str("service.name"), Some("api"));
285 assert_eq!(config.get_bool("service.enabled"), Some(true));
286 }
287
288 #[test]
289 fn config_from_file_reads_toml() {
290 let unique = SystemTime::now()
291 .duration_since(UNIX_EPOCH)
292 .unwrap()
293 .as_nanos();
294 let path = std::env::temp_dir().join(format!("rustrails-support-config-{unique}.toml"));
295 fs::write(&path, "name = \"support\"\n").unwrap();
296
297 let config = Config::from_file(&path).unwrap();
298
299 fs::remove_file(&path).unwrap();
300 assert_eq!(config.get_str("name"), Some("support"));
301 }
302
303 #[test]
304 fn config_from_file_returns_not_found_error_for_missing_file() {
305 let path = std::env::temp_dir().join("rustrails-support-missing.toml");
306 let error = Config::from_file(&path).unwrap_err();
307
308 assert!(matches!(error, ConfigError::FileNotFound(_)));
309 }
310
311 #[test]
312 fn config_apply_env_overrides_replaces_existing_values() {
313 let mut config = Config::from_toml(
314 r#"
315 app_name = "rustrails"
316 port = 3000
317
318 [database]
319 host = "localhost"
320 pool = 5
321 "#,
322 )
323 .unwrap();
324
325 unsafe {
326 std::env::set_var("RUSTRAILS_PORT", "4000");
327 std::env::set_var("RUSTRAILS_DATABASE_HOST", "db.internal");
328 std::env::set_var("RUSTRAILS_DEBUG", "true");
329 }
330
331 config.apply_env_overrides("RUSTRAILS");
332
333 unsafe {
334 std::env::remove_var("RUSTRAILS_PORT");
335 std::env::remove_var("RUSTRAILS_DATABASE_HOST");
336 std::env::remove_var("RUSTRAILS_DEBUG");
337 }
338
339 assert_eq!(config.get_i64("port"), Some(4000));
340 assert_eq!(config.get_str("database.host"), Some("db.internal"));
341 assert_eq!(config.get_bool("debug"), Some(true));
342 }
343
344 #[test]
345 fn config_apply_env_overrides_parses_json_values() {
346 let mut config = Config::new();
347 unsafe {
348 std::env::set_var("RUSTRAILS_FEATURES", "[\"cache\",\"jobs\"]");
349 }
350
351 config.apply_env_overrides("RUSTRAILS");
352
353 unsafe {
354 std::env::remove_var("RUSTRAILS_FEATURES");
355 }
356
357 assert_eq!(config.get("features"), Some(&json!(["cache", "jobs"])));
358 }
359
360 #[test]
361 fn config_extract_deserializes_into_typed_struct() {
362 let config = Config::from_toml(
363 r#"
364 app_name = "rustrails"
365 port = 3000
366 debug = false
367
368 [database]
369 host = "db.internal"
370 pool = 7
371 "#,
372 )
373 .unwrap();
374
375 let extracted: AppConfig = config.extract().unwrap();
376
377 assert_eq!(
378 extracted,
379 AppConfig {
380 app_name: String::from("rustrails"),
381 port: 3000,
382 debug: false,
383 database: DatabaseConfig {
384 host: String::from("db.internal"),
385 pool: 7,
386 },
387 }
388 );
389 }
390
391 #[test]
392 fn config_extract_reports_missing_keys() {
393 let config = Config::from_toml("app_name = \"rustrails\"\n").unwrap();
394
395 let error = config.extract::<AppConfig>().unwrap_err();
396
397 assert!(matches!(error, ConfigError::MissingKey(key) if key == "port"));
398 }
399
400 #[test]
401 fn config_get_returns_none_for_missing_key() {
402 let config = Config::new();
403
404 assert_eq!(config.get("missing.key"), None);
405 assert_eq!(config.get_str("missing.key"), None);
406 }
407
408 #[test]
409 fn config_default_matches_new_for_missing_values() {
410 let config = Config::default();
411
412 assert_eq!(config.get("missing"), Config::new().get("missing"));
413 }
414
415 #[test]
416 fn config_from_toml_reads_array_values() {
417 let config = Config::from_toml(
418 r#"
419 features = ["cache", "jobs"]
420 "#,
421 )
422 .unwrap();
423
424 assert_eq!(config.get("features"), Some(&json!(["cache", "jobs"])));
425 }
426
427 #[test]
428 fn config_from_toml_reads_three_level_nested_values() {
429 let config = Config::from_toml(
430 r#"
431 [database]
432 [database.primary]
433 [database.primary.credentials]
434 user = "postgres"
435 "#,
436 )
437 .unwrap();
438
439 assert_eq!(
440 config.get_str("database.primary.credentials.user"),
441 Some("postgres")
442 );
443 }
444
445 #[test]
446 fn config_from_toml_reports_parse_error_for_invalid_syntax() {
447 let error = Config::from_toml("[database\nhost = \"localhost\"").unwrap_err();
448
449 assert!(matches!(error, ConfigError::ParseError(_)));
450 }
451
452 #[test]
453 fn config_from_toml_reports_parse_error_for_non_table_root() {
454 let error = Config::from_toml("\"support\"").unwrap_err();
455
456 assert!(matches!(error, ConfigError::ParseError(_)));
457 }
458
459 #[test]
460 fn config_from_file_reports_parse_error_for_invalid_toml() {
461 let unique = SystemTime::now()
462 .duration_since(UNIX_EPOCH)
463 .unwrap()
464 .as_nanos();
465 let path =
466 std::env::temp_dir().join(format!("rustrails-support-config-invalid-{unique}.toml"));
467 fs::write(&path, "[database\nhost = \"localhost\"").unwrap();
468
469 let error = Config::from_file(&path).unwrap_err();
470
471 fs::remove_file(&path).unwrap();
472 assert!(matches!(error, ConfigError::ParseError(_)));
473 }
474
475 #[test]
476 fn config_from_file_reads_nested_values() {
477 let unique = SystemTime::now()
478 .duration_since(UNIX_EPOCH)
479 .unwrap()
480 .as_nanos();
481 let path =
482 std::env::temp_dir().join(format!("rustrails-support-config-nested-{unique}.toml"));
483 fs::write(
484 &path,
485 "[database]\nhost = \"localhost\"\n[database.credentials]\nuser = \"postgres\"\n",
486 )
487 .unwrap();
488
489 let config = Config::from_file(&path).unwrap();
490
491 fs::remove_file(&path).unwrap();
492 assert_eq!(config.get_str("database.host"), Some("localhost"));
493 assert_eq!(
494 config.get_str("database.credentials.user"),
495 Some("postgres")
496 );
497 }
498
499 #[test]
500 fn config_set_preserves_existing_nested_siblings() {
501 let mut config = Config::new();
502 config.set("database.host", json!("localhost"));
503 config.set("database.pool", json!(5));
504
505 assert_eq!(config.get_str("database.host"), Some("localhost"));
506 assert_eq!(config.get_i64("database.pool"), Some(5));
507 }
508
509 #[test]
510 fn config_set_deeper_path_does_not_replace_existing_scalar() {
511 let mut config = Config::new();
512 config.set("database", json!("localhost"));
513
514 config.set("database.host", json!("db.internal"));
515
516 assert_eq!(config.get_str("database"), Some("localhost"));
517 assert_eq!(config.get_str("database.host"), None);
518 }
519
520 #[test]
521 fn config_set_ignores_empty_path() {
522 let mut config = Config::new();
523 config.set("", json!("ignored"));
524
525 assert_eq!(config.get(""), None);
526 assert!(config.get("anything").is_none());
527 }
528
529 #[test]
530 fn config_get_returns_none_when_descending_into_scalar() {
531 let config = Config::from_toml("port = 3000\n").unwrap();
532
533 assert_eq!(config.get("port.value"), None);
534 }
535
536 #[test]
537 fn config_get_str_returns_none_for_non_string_value() {
538 let config = Config::from_toml("port = 3000\n").unwrap();
539
540 assert_eq!(config.get_str("port"), None);
541 }
542
543 #[test]
544 fn config_get_i64_returns_none_for_non_integer_value() {
545 let config = Config::from_toml("debug = true\n").unwrap();
546
547 assert_eq!(config.get_i64("debug"), None);
548 }
549
550 #[test]
551 fn config_get_bool_returns_none_for_non_boolean_value() {
552 let config = Config::from_toml("app_name = \"rustrails\"\n").unwrap();
553
554 assert_eq!(config.get_bool("app_name"), None);
555 }
556
557 #[test]
558 fn config_apply_env_overrides_creates_nested_values() {
559 let mut config = Config::new();
560
561 unsafe {
562 std::env::set_var("RUSTRAILS_SUPPORT_CFG_CREATE_DATABASE_HOST", "db.internal");
563 }
564
565 config.apply_env_overrides("RUSTRAILS_SUPPORT_CFG_CREATE");
566
567 unsafe {
568 std::env::remove_var("RUSTRAILS_SUPPORT_CFG_CREATE_DATABASE_HOST");
569 }
570
571 assert_eq!(config.get_str("database.host"), Some("db.internal"));
572 }
573
574 #[test]
575 fn config_apply_env_overrides_preserves_unrelated_values() {
576 let mut config = Config::new();
577 config.set("database.pool", json!(5));
578
579 unsafe {
580 std::env::set_var("RUSTRAILS_SUPPORT_CFG_KEEP_DATABASE_HOST", "db.internal");
581 }
582
583 config.apply_env_overrides("RUSTRAILS_SUPPORT_CFG_KEEP");
584
585 unsafe {
586 std::env::remove_var("RUSTRAILS_SUPPORT_CFG_KEEP_DATABASE_HOST");
587 }
588
589 assert_eq!(config.get_i64("database.pool"), Some(5));
590 assert_eq!(config.get_str("database.host"), Some("db.internal"));
591 }
592
593 #[test]
594 fn config_apply_env_overrides_parses_false_values() {
595 let mut config = Config::new();
596
597 unsafe {
598 std::env::set_var("RUSTRAILS_SUPPORT_CFG_FALSE_DEBUG", "false");
599 }
600
601 config.apply_env_overrides("RUSTRAILS_SUPPORT_CFG_FALSE");
602
603 unsafe {
604 std::env::remove_var("RUSTRAILS_SUPPORT_CFG_FALSE_DEBUG");
605 }
606
607 assert_eq!(config.get_bool("debug"), Some(false));
608 }
609
610 #[test]
611 fn config_apply_env_overrides_parses_unsigned_integer_values() {
612 let mut config = Config::new();
613
614 unsafe {
615 std::env::set_var(
616 "RUSTRAILS_SUPPORT_CFG_UNSIGNED_LIMIT",
617 "18446744073709551615",
618 );
619 }
620
621 config.apply_env_overrides("RUSTRAILS_SUPPORT_CFG_UNSIGNED");
622
623 unsafe {
624 std::env::remove_var("RUSTRAILS_SUPPORT_CFG_UNSIGNED_LIMIT");
625 }
626
627 assert_eq!(config.get("limit"), Some(&json!(18446744073709551615_u64)));
628 }
629
630 #[test]
631 fn config_apply_env_overrides_parses_float_values() {
632 let mut config = Config::new();
633
634 unsafe {
635 std::env::set_var("RUSTRAILS_SUPPORT_CFG_FLOAT_RATIO", "3.5");
636 }
637
638 config.apply_env_overrides("RUSTRAILS_SUPPORT_CFG_FLOAT");
639
640 unsafe {
641 std::env::remove_var("RUSTRAILS_SUPPORT_CFG_FLOAT_RATIO");
642 }
643
644 assert_eq!(config.get("ratio"), Some(&json!(3.5)));
645 }
646
647 #[test]
648 fn config_apply_env_overrides_parses_object_json() {
649 let mut config = Config::new();
650
651 unsafe {
652 std::env::set_var(
653 "RUSTRAILS_SUPPORT_CFG_OBJECT_DATABASE",
654 r#"{"host":"db.internal"}"#,
655 );
656 }
657
658 config.apply_env_overrides("RUSTRAILS_SUPPORT_CFG_OBJECT");
659
660 unsafe {
661 std::env::remove_var("RUSTRAILS_SUPPORT_CFG_OBJECT_DATABASE");
662 }
663
664 assert_eq!(config.get_str("database.host"), Some("db.internal"));
665 }
666
667 #[test]
668 fn config_apply_env_overrides_leaves_invalid_json_as_string() {
669 let mut config = Config::new();
670
671 unsafe {
672 std::env::set_var("RUSTRAILS_SUPPORT_CFG_INVALID_PAYLOAD", "{not-json}");
673 }
674
675 config.apply_env_overrides("RUSTRAILS_SUPPORT_CFG_INVALID");
676
677 unsafe {
678 std::env::remove_var("RUSTRAILS_SUPPORT_CFG_INVALID_PAYLOAD");
679 }
680
681 assert_eq!(config.get_str("payload"), Some("{not-json}"));
682 }
683
684 #[test]
685 fn config_apply_env_overrides_ignores_empty_suffix_segments() {
686 let mut config = Config::new();
687
688 unsafe {
689 std::env::set_var(
690 "RUSTRAILS_SUPPORT_CFG_SEGMENTS__DATABASE__HOST",
691 "db.internal",
692 );
693 }
694
695 config.apply_env_overrides("RUSTRAILS_SUPPORT_CFG_SEGMENTS");
696
697 unsafe {
698 std::env::remove_var("RUSTRAILS_SUPPORT_CFG_SEGMENTS__DATABASE__HOST");
699 }
700
701 assert_eq!(config.get_str("database.host"), Some("db.internal"));
702 }
703
704 #[test]
705 fn config_clone_is_independent_of_original() {
706 let mut config = Config::new();
707 config.set("database.host", json!("localhost"));
708
709 let cloned = config.clone();
710 config.set("database.host", json!("db.internal"));
711
712 assert_eq!(cloned.get_str("database.host"), Some("localhost"));
713 assert_eq!(config.get_str("database.host"), Some("db.internal"));
714 }
715
716 #[test]
717 fn config_extract_supports_array_fields() {
718 #[derive(Debug, Deserialize, PartialEq)]
719 struct FeatureConfig {
720 features: Vec<String>,
721 }
722
723 let config = Config::from_toml(
724 r#"
725 features = ["cache", "jobs"]
726 "#,
727 )
728 .unwrap();
729
730 let extracted: FeatureConfig = config.extract().unwrap();
731
732 assert_eq!(
733 extracted,
734 FeatureConfig {
735 features: vec![String::from("cache"), String::from("jobs")],
736 }
737 );
738 }
739
740 #[test]
741 fn config_extract_supports_three_level_nested_values() {
742 #[derive(Debug, Deserialize, PartialEq)]
743 struct RootConfig {
744 database: LevelOne,
745 }
746
747 #[derive(Debug, Deserialize, PartialEq)]
748 struct LevelOne {
749 primary: LevelTwo,
750 }
751
752 #[derive(Debug, Deserialize, PartialEq)]
753 struct LevelTwo {
754 credentials: Credentials,
755 }
756
757 #[derive(Debug, Deserialize, PartialEq)]
758 struct Credentials {
759 username: String,
760 }
761
762 let config = Config::from_toml(
763 r#"
764 [database.primary.credentials]
765 username = "postgres"
766 "#,
767 )
768 .unwrap();
769
770 let extracted: RootConfig = config.extract().unwrap();
771
772 assert_eq!(
773 extracted,
774 RootConfig {
775 database: LevelOne {
776 primary: LevelTwo {
777 credentials: Credentials {
778 username: String::from("postgres"),
779 },
780 },
781 },
782 }
783 );
784 }
785
786 #[test]
787 fn config_extract_reports_parse_error_for_type_mismatch() {
788 let config = Config::from_toml(
789 r#"
790 app_name = "rustrails"
791 port = "not-a-number"
792 debug = false
793
794 [database]
795 host = "db.internal"
796 pool = 7
797 "#,
798 )
799 .unwrap();
800
801 let error = config.extract::<AppConfig>().unwrap_err();
802
803 assert!(matches!(error, ConfigError::ParseError(_)));
804 }
805
806 #[test]
807 fn config_extract_reports_missing_nested_field() {
808 #[derive(Debug, Deserialize, PartialEq)]
809 struct NestedConfig {
810 database: NestedDatabase,
811 }
812
813 #[derive(Debug, Deserialize, PartialEq)]
814 struct NestedDatabase {
815 credentials: NestedCredentials,
816 }
817
818 #[derive(Debug, Deserialize, PartialEq)]
819 struct NestedCredentials {
820 username: String,
821 }
822
823 let config = Config::from_toml(
824 r#"
825 [database.credentials]
826 "#,
827 )
828 .unwrap();
829
830 let error = config.extract::<NestedConfig>().unwrap_err();
831
832 assert!(matches!(error, ConfigError::MissingKey(key) if key == "username"));
833 }
834}