Skip to main content

upstream_rs/services/storage/
config_storage.rs

1use anyhow::{Context, Result, anyhow};
2use serde::de::DeserializeOwned;
3use std::collections::HashMap;
4use std::fs;
5use std::path::{Path, PathBuf};
6use toml;
7
8use crate::models::upstream::AppConfig;
9
10pub struct ConfigStorage {
11    config: AppConfig,
12    config_file: PathBuf,
13}
14
15impl ConfigStorage {
16    pub fn new(config_file: &Path) -> Result<Self> {
17        let mut storage = Self {
18            config: AppConfig::default(),
19            config_file: config_file.to_path_buf(),
20        };
21
22        storage.load_config()?;
23        Ok(storage)
24    }
25
26    /// Loads configuration from config.toml, or creates default if it doesn't exist.
27    pub fn load_config(&mut self) -> Result<()> {
28        if !self.config_file.exists() {
29            return self.save_config();
30        }
31
32        let toml_str =
33            fs::read_to_string(&self.config_file).context("Failed to load config file")?;
34
35        self.config = toml::from_str(&toml_str).context("Tried to parse an invalid config")?;
36        Ok(())
37    }
38
39    /// Saves the current configuration to config.toml.
40    pub fn save_config(&self) -> Result<()> {
41        let toml = toml::to_string_pretty(&self.config).context("Failed to serialize config")?;
42
43        fs::write(&self.config_file, toml).with_context(|| {
44            format!("Failed to save config to '{}'", self.config_file.display())
45        })?;
46
47        #[cfg(unix)]
48        {
49            use std::os::unix::fs::PermissionsExt;
50            fs::set_permissions(&self.config_file, fs::Permissions::from_mode(0o600))?;
51        }
52
53        Ok(())
54    }
55
56    pub fn get_config(&self) -> &AppConfig {
57        &self.config
58    }
59
60    /// Sets a configuration value at the given key path (e.g., "github.api_token").
61    pub fn try_set_value(&mut self, key_path: &str, value: &str) -> Result<()> {
62        if key_path.trim().is_empty() {
63            return Err(anyhow!("Key path cannot be empty"));
64        }
65
66        let mut root = toml::Value::try_from(&self.config).context("Failed to serialize config")?;
67
68        let keys: Vec<&str> = key_path.split('.').collect();
69        let (path, final_key) = keys.split_at(keys.len() - 1);
70
71        let mut current = root
72            .as_table_mut()
73            .ok_or_else(|| anyhow!("Config root is not a table"))?;
74
75        for key in path {
76            current = current
77                .get_mut(*key)
78                .and_then(toml::Value::as_table_mut)
79                .ok_or_else(|| anyhow!("Key path not found: {}", key_path))?;
80        }
81
82        let parsed_value = self.convert_value(value)?;
83        current.insert(final_key[0].to_string(), parsed_value);
84
85        self.config = root.try_into().context("Failed to update config")?;
86
87        self.save_config().context("Failed to save config")
88    }
89
90    /// Gets a configuration value at the given key path.
91    pub fn try_get_value<T>(&self, key_path: &str) -> Result<T>
92    where
93        T: DeserializeOwned,
94    {
95        let value = self.get_value(key_path)?;
96        value
97            .clone()
98            .try_into()
99            .with_context(|| format!("Failed to deserialize '{}'", key_path))
100    }
101
102    fn get_value(&self, key_path: &str) -> Result<toml::Value> {
103        let root = toml::Value::try_from(&self.config).context("Failed to serialize config")?;
104
105        let mut current = &root;
106        for key in key_path.split('.') {
107            current = current
108                .get(key)
109                .ok_or_else(|| anyhow!("Key path not found: {}", key_path))?;
110        }
111
112        Ok(current.clone())
113    }
114
115    /// Gets all configuration keys and values as flattened dot-notation paths.
116    pub fn get_flattened_config(&self) -> HashMap<String, String> {
117        let root =
118            toml::Value::try_from(&self.config).unwrap_or(toml::Value::Table(Default::default()));
119        Self::flatten_value(&root, "", 10, 0)
120    }
121
122    /// Resets all configuration to defaults.
123    pub fn reset_to_defaults(&mut self) -> Result<()> {
124        self.config = AppConfig::default();
125        self.save_config()
126    }
127
128    fn flatten_value(
129        value: &toml::Value,
130        prefix: &str,
131        max_depth: usize,
132        current_depth: usize,
133    ) -> HashMap<String, String> {
134        let mut result = HashMap::new();
135
136        if current_depth >= max_depth {
137            return result;
138        }
139
140        match value {
141            toml::Value::String(s) => {
142                result.insert(prefix.to_string(), s.clone());
143            }
144            toml::Value::Integer(i) => {
145                result.insert(prefix.to_string(), i.to_string());
146            }
147            toml::Value::Float(f) => {
148                result.insert(prefix.to_string(), f.to_string());
149            }
150            toml::Value::Boolean(b) => {
151                result.insert(prefix.to_string(), b.to_string());
152            }
153            toml::Value::Table(table) => {
154                for (key, val) in table {
155                    let new_prefix = if prefix.is_empty() {
156                        key.clone()
157                    } else {
158                        format!("{}.{}", prefix, key)
159                    };
160                    result.extend(Self::flatten_value(
161                        val,
162                        &new_prefix,
163                        max_depth,
164                        current_depth + 1,
165                    ));
166                }
167            }
168            _ => {}
169        }
170
171        result
172    }
173
174    fn convert_value(&self, value: &str) -> Result<toml::Value> {
175        // Try TOML literal first
176        if let Ok(parsed) = value.parse::<toml::Value>() {
177            return Ok(parsed);
178        }
179
180        // Fallback to string
181        Ok(toml::Value::String(value.to_string()))
182    }
183}
184
185#[cfg(test)]
186mod tests {
187    use super::ConfigStorage;
188    use std::path::{Path, PathBuf};
189    use std::time::{SystemTime, UNIX_EPOCH};
190    use std::{fs, io};
191
192    fn temp_config_file(name: &str) -> PathBuf {
193        let nanos = SystemTime::now()
194            .duration_since(UNIX_EPOCH)
195            .map(|d| d.as_nanos())
196            .unwrap_or(0);
197        std::env::temp_dir()
198            .join(format!("upstream-config-test-{name}-{nanos}"))
199            .join("config.toml")
200    }
201
202    fn cleanup(path: &Path) -> io::Result<()> {
203        if let Some(parent) = path.parent() {
204            fs::remove_dir_all(parent)?;
205        }
206        Ok(())
207    }
208
209    #[test]
210    fn new_creates_default_config_file_when_missing() {
211        let path = temp_config_file("new-default");
212        if let Some(parent) = path.parent() {
213            fs::create_dir_all(parent).expect("create parent");
214        }
215
216        let storage = ConfigStorage::new(&path).expect("create storage");
217        assert!(path.exists());
218        assert_eq!(storage.get_config().github.rate_limit, 5000);
219        assert_eq!(storage.get_config().gitlab.rate_limit, 5000);
220
221        cleanup(&path).expect("cleanup");
222    }
223
224    #[test]
225    fn set_and_get_nested_values_updates_config() {
226        let path = temp_config_file("set-get");
227        if let Some(parent) = path.parent() {
228            fs::create_dir_all(parent).expect("create parent");
229        }
230        let mut storage = ConfigStorage::new(&path).expect("create storage");
231
232        storage
233            .try_set_value("github.rate_limit", "1234")
234            .expect("set integer");
235        storage
236            .try_set_value("gitlab.api_token", "\"abc\"")
237            .expect("set string literal");
238
239        let rate_limit: u32 = storage
240            .try_get_value("github.rate_limit")
241            .expect("read rate limit");
242        let token: Option<String> = storage
243            .try_get_value("gitlab.api_token")
244            .expect("read token");
245
246        assert_eq!(rate_limit, 1234);
247        assert_eq!(token.as_deref(), Some("abc"));
248
249        cleanup(&path).expect("cleanup");
250    }
251
252    #[test]
253    fn flattened_config_contains_dot_notation_keys() {
254        let path = temp_config_file("flatten");
255        if let Some(parent) = path.parent() {
256            fs::create_dir_all(parent).expect("create parent");
257        }
258        let storage = ConfigStorage::new(&path).expect("create storage");
259        let flat = storage.get_flattened_config();
260
261        assert_eq!(flat.get("github.rate_limit"), Some(&"5000".to_string()));
262        assert_eq!(flat.get("gitlab.rate_limit"), Some(&"5000".to_string()));
263
264        cleanup(&path).expect("cleanup");
265    }
266
267    #[test]
268    fn set_value_rejects_unknown_paths() {
269        let path = temp_config_file("bad-path");
270        if let Some(parent) = path.parent() {
271            fs::create_dir_all(parent).expect("create parent");
272        }
273        let mut storage = ConfigStorage::new(&path).expect("create storage");
274        let err = storage
275            .try_set_value("github.missing.field", "1")
276            .expect_err("must reject unknown path");
277        assert!(err.to_string().contains("Key path not found"));
278
279        cleanup(&path).expect("cleanup");
280    }
281
282    #[test]
283    fn reset_to_defaults_restores_default_values() {
284        let path = temp_config_file("reset");
285        if let Some(parent) = path.parent() {
286            fs::create_dir_all(parent).expect("create parent");
287        }
288        let mut storage = ConfigStorage::new(&path).expect("create storage");
289        storage
290            .try_set_value("github.rate_limit", "99")
291            .expect("set override");
292        storage.reset_to_defaults().expect("reset defaults");
293
294        let rate_limit: u32 = storage
295            .try_get_value("github.rate_limit")
296            .expect("read reset value");
297        assert_eq!(rate_limit, 5000);
298
299        cleanup(&path).expect("cleanup");
300    }
301}