rust_config_tree/
cli_overrides.rs1use figment::{
4 Metadata, Profile, Provider,
5 value::{Dict, Map, Value},
6};
7use serde::Serialize;
8
9use crate::config::ConfigResult;
10
11#[derive(Debug, Clone, Default)]
13pub struct ConfigOverrideProvider {
14 values: Dict,
15}
16
17impl ConfigOverrideProvider {
18 pub fn new() -> Self {
20 Self::default()
21 }
22
23 pub fn is_empty(&self) -> bool {
25 self.values.is_empty()
26 }
27
28 pub fn insert<T>(&mut self, path: &str, value: &T) -> ConfigResult<()>
30 where
31 T: Serialize + ?Sized,
32 {
33 if path.is_empty() || path.split('.').any(str::is_empty) {
34 return Err(figment::Error::from(format!(
35 "config override path `{path}` must not be empty"
36 ))
37 .into());
38 }
39
40 let value = Value::serialize(value)?;
41 let nested = figment::util::nest(path, value)
42 .into_dict()
43 .ok_or_else(|| {
44 figment::Error::from(format!(
45 "config override path `{path}` must produce a dictionary"
46 ))
47 })?;
48 merge_dict(&mut self.values, nested);
49 Ok(())
50 }
51}
52
53impl Provider for ConfigOverrideProvider {
54 fn metadata(&self) -> Metadata {
55 Metadata::named("CLI overrides")
56 }
57
58 fn data(&self) -> Result<Map<Profile, Dict>, figment::Error> {
59 Ok(Profile::Default.collect(self.values.clone()))
60 }
61}
62
63pub trait ConfigOverrides {
65 fn config_overrides(&self) -> ConfigResult<ConfigOverrideProvider>;
67}
68
69fn merge_dict(target: &mut Dict, source: Dict) {
70 for (key, value) in source {
71 match (target.remove(&key), value) {
72 (Some(Value::Dict(tag, mut target_child)), Value::Dict(_, source_child)) => {
73 merge_dict(&mut target_child, source_child);
74 target.insert(key, Value::Dict(tag, target_child));
75 }
76 (_, value) => {
77 target.insert(key, value);
78 }
79 }
80 }
81}
82
83#[cfg(test)]
84mod tests {
85 use super::*;
86
87 #[test]
88 fn provider_nests_override_values_by_dot_path() {
89 let mut provider = ConfigOverrideProvider::new();
90 provider.insert("server.port", &9000u16).unwrap();
91 provider.insert("log.level", "debug").unwrap();
92
93 let data = provider.data().unwrap();
94 let values = data.get(&Profile::Default).unwrap();
95
96 assert_eq!(
97 values
98 .get("server")
99 .unwrap()
100 .find_ref("port")
101 .unwrap()
102 .to_u128(),
103 Some(9000)
104 );
105 assert_eq!(
106 values
107 .get("log")
108 .unwrap()
109 .find_ref("level")
110 .unwrap()
111 .as_str(),
112 Some("debug")
113 );
114 }
115
116 #[test]
117 fn provider_merges_sibling_values_under_same_parent() {
118 let mut provider = ConfigOverrideProvider::new();
119 provider.insert("server.bind", "0.0.0.0").unwrap();
120 provider.insert("server.port", &9000u16).unwrap();
121
122 let data = provider.data().unwrap();
123 let server = data.get(&Profile::Default).unwrap().get("server").unwrap();
124
125 assert_eq!(server.find_ref("bind").unwrap().as_str(), Some("0.0.0.0"));
126 assert_eq!(server.find_ref("port").unwrap().to_u128(), Some(9000));
127 }
128
129 #[test]
130 fn provider_rejects_empty_override_path_segments() {
131 let mut provider = ConfigOverrideProvider::new();
132
133 assert!(provider.insert("", "debug").is_err());
134 assert!(provider.insert("log..level", "debug").is_err());
135 }
136}