Skip to main content

rust_config_tree/
cli_overrides.rs

1//! Runtime provider for CLI field override values.
2
3use figment::{
4    Metadata, Profile, Provider,
5    value::{Dict, Map, Value},
6};
7use serde::Serialize;
8
9use crate::config::ConfigResult;
10
11/// Sparse override provider built from CLI fields.
12#[derive(Debug, Clone, Default)]
13pub struct ConfigOverrideProvider {
14    values: Dict,
15}
16
17impl ConfigOverrideProvider {
18    /// Creates an empty override provider.
19    pub fn new() -> Self {
20        Self::default()
21    }
22
23    /// Returns whether this provider has no override values.
24    pub fn is_empty(&self) -> bool {
25        self.values.is_empty()
26    }
27
28    /// Inserts one override value at a dotted config path.
29    pub fn insert<T>(&mut self, path: &str, value: &T) -> ConfigResult<()>
30    where
31        T: Serialize + ?Sized,
32    {
33        if path.is_empty() || path.split('.').any(str::is_empty) {
34            return Err(figment::Error::from(format!(
35                "config override path `{path}` must not be empty"
36            ))
37            .into());
38        }
39
40        let value = Value::serialize(value)?;
41        let nested = figment::util::nest(path, value)
42            .into_dict()
43            .ok_or_else(|| {
44                figment::Error::from(format!(
45                    "config override path `{path}` must produce a dictionary"
46                ))
47            })?;
48        merge_dict(&mut self.values, nested);
49        Ok(())
50    }
51}
52
53impl Provider for ConfigOverrideProvider {
54    fn metadata(&self) -> Metadata {
55        Metadata::named("CLI overrides")
56    }
57
58    fn data(&self) -> Result<Map<Profile, Dict>, figment::Error> {
59        Ok(Profile::Default.collect(self.values.clone()))
60    }
61}
62
63/// Builds config override values from parsed CLI input.
64pub trait ConfigOverrides {
65    /// Builds an override provider that can be merged into Figment.
66    fn config_overrides(&self) -> ConfigResult<ConfigOverrideProvider>;
67}
68
69fn merge_dict(target: &mut Dict, source: Dict) {
70    for (key, value) in source {
71        match (target.remove(&key), value) {
72            (Some(Value::Dict(tag, mut target_child)), Value::Dict(_, source_child)) => {
73                merge_dict(&mut target_child, source_child);
74                target.insert(key, Value::Dict(tag, target_child));
75            }
76            (_, value) => {
77                target.insert(key, value);
78            }
79        }
80    }
81}
82
83#[cfg(test)]
84mod tests {
85    use super::*;
86
87    #[test]
88    fn provider_nests_override_values_by_dot_path() {
89        let mut provider = ConfigOverrideProvider::new();
90        provider.insert("server.port", &9000u16).unwrap();
91        provider.insert("log.level", "debug").unwrap();
92
93        let data = provider.data().unwrap();
94        let values = data.get(&Profile::Default).unwrap();
95
96        assert_eq!(
97            values
98                .get("server")
99                .unwrap()
100                .find_ref("port")
101                .unwrap()
102                .to_u128(),
103            Some(9000)
104        );
105        assert_eq!(
106            values
107                .get("log")
108                .unwrap()
109                .find_ref("level")
110                .unwrap()
111                .as_str(),
112            Some("debug")
113        );
114    }
115
116    #[test]
117    fn provider_merges_sibling_values_under_same_parent() {
118        let mut provider = ConfigOverrideProvider::new();
119        provider.insert("server.bind", "0.0.0.0").unwrap();
120        provider.insert("server.port", &9000u16).unwrap();
121
122        let data = provider.data().unwrap();
123        let server = data.get(&Profile::Default).unwrap().get("server").unwrap();
124
125        assert_eq!(server.find_ref("bind").unwrap().as_str(), Some("0.0.0.0"));
126        assert_eq!(server.find_ref("port").unwrap().to_u128(), Some(9000));
127    }
128
129    #[test]
130    fn provider_rejects_empty_override_path_segments() {
131        let mut provider = ConfigOverrideProvider::new();
132
133        assert!(provider.insert("", "debug").is_err());
134        assert!(provider.insert("log..level", "debug").is_err());
135    }
136}