Skip to main content

sqlite_graphrag/
config.rs

1//! XDG-based API key management for OpenRouter and other providers.
2//!
3//! Stores keys in `$XDG_CONFIG_HOME/sqlite-graphrag/config.toml` with
4//! atomic write, symlink-attack defense and Unix permission hardening.
5
6use crate::errors::AppError;
7use directories::ProjectDirs;
8use secrecy::SecretBox;
9use serde::{Deserialize, Serialize};
10use std::path::PathBuf;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct AppConfig {
14    pub schema_version: u32,
15    #[serde(default)]
16    pub keys: Vec<ApiKeyEntry>,
17}
18
19#[derive(Clone, Serialize, Deserialize)]
20pub struct ApiKeyEntry {
21    pub provider: String,
22    pub value: String,
23    pub added_at: String,
24    pub fingerprint: String,
25}
26
27impl std::fmt::Debug for ApiKeyEntry {
28    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29        f.debug_struct("ApiKeyEntry")
30            .field("provider", &self.provider)
31            .field("value", &mask_key(&self.value))
32            .field("added_at", &self.added_at)
33            .field("fingerprint", &self.fingerprint)
34            .finish()
35    }
36}
37
38impl Default for AppConfig {
39    fn default() -> Self {
40        Self {
41            schema_version: 1,
42            keys: vec![],
43        }
44    }
45}
46
47pub struct ResolvedKey {
48    pub value: SecretBox<String>,
49    pub source: &'static str,
50}
51
52pub fn config_file_path() -> Result<PathBuf, AppError> {
53    let proj = ProjectDirs::from("", "", "sqlite-graphrag").ok_or_else(|| {
54        AppError::Io(std::io::Error::other(
55            "could not determine home directory for config",
56        ))
57    })?;
58    Ok(proj.config_dir().join("config.toml"))
59}
60
61pub fn load_config() -> Result<AppConfig, AppError> {
62    let path = config_file_path()?;
63
64    if !path.exists() {
65        return Ok(AppConfig::default());
66    }
67
68    let meta = std::fs::symlink_metadata(&path)?;
69    if meta.file_type().is_symlink() {
70        return Err(AppError::Validation(format!(
71            "config file is a symlink (potential attack): {}",
72            path.display()
73        )));
74    }
75
76    #[cfg(unix)]
77    {
78        use std::os::unix::fs::PermissionsExt;
79        let mode = meta.permissions().mode() & 0o777;
80        if mode > 0o600 {
81            tracing::warn!(
82                path = %path.display(),
83                mode = format!("{mode:o}"),
84                "config file permissions are too open; recommend chmod 600"
85            );
86        }
87    }
88
89    let content = std::fs::read_to_string(&path)?;
90    toml::from_str(&content)
91        .map_err(|e| AppError::Validation(format!("config parse error in {}: {e}", path.display())))
92}
93
94pub fn save_config(config: &AppConfig) -> Result<(), AppError> {
95    let path = config_file_path()?;
96    let dir = path.parent().ok_or_else(|| {
97        AppError::Validation(format!("config path has no parent: {}", path.display()))
98    })?;
99
100    std::fs::create_dir_all(dir)?;
101
102    #[cfg(unix)]
103    {
104        use std::os::unix::fs::PermissionsExt;
105        std::fs::set_permissions(dir, std::fs::Permissions::from_mode(0o700))?;
106    }
107
108    #[cfg(unix)]
109    if path.exists() {
110        use std::os::unix::fs::MetadataExt;
111        let meta = std::fs::metadata(&path)?;
112        let file_uid = meta.uid();
113        let my_uid = unsafe { libc::getuid() };
114        if file_uid != my_uid {
115            return Err(AppError::Validation(format!(
116                "config file {} owned by uid {file_uid}, not current uid {my_uid}; refusing to overwrite",
117                path.display()
118            )));
119        }
120    }
121
122    let serialized =
123        toml::to_string_pretty(config).map_err(|e| AppError::Validation(e.to_string()))?;
124
125    #[cfg(unix)]
126    let old_umask = unsafe { libc::umask(0o077) };
127
128    use std::io::Write;
129    let mut tmp = tempfile::NamedTempFile::new_in(dir)?;
130    tmp.write_all(serialized.as_bytes())?;
131    tmp.as_file().sync_all()?;
132
133    #[cfg(unix)]
134    {
135        use std::os::unix::fs::PermissionsExt;
136        std::fs::set_permissions(tmp.path(), std::fs::Permissions::from_mode(0o600))?;
137    }
138
139    tmp.persist(&path)
140        .map_err(|e| AppError::Io(std::io::Error::other(format!("atomic persist failed: {e}"))))?;
141
142    #[cfg(unix)]
143    unsafe {
144        libc::umask(old_umask);
145    }
146
147    // fsync parent dir for crash consistency
148    #[cfg(unix)]
149    {
150        let dir_file = std::fs::File::open(dir)?;
151        dir_file.sync_all()?;
152    }
153
154    Ok(())
155}
156
157pub fn resolve_api_key(provider: &str, cli_key: Option<&str>) -> Option<ResolvedKey> {
158    let env_name = match provider {
159        "openrouter" => "OPENROUTER_API_KEY",
160        other => {
161            let upper = other.to_uppercase().replace('-', "_");
162            let owned = format!("{upper}_API_KEY");
163            return resolve_api_key_inner(provider, cli_key, &owned);
164        }
165    };
166    resolve_api_key_inner(provider, cli_key, env_name)
167}
168
169fn resolve_api_key_inner(
170    provider: &str,
171    cli_key: Option<&str>,
172    env_name: &str,
173) -> Option<ResolvedKey> {
174    if let Ok(val) = std::env::var(env_name) {
175        if !val.is_empty() {
176            return Some(ResolvedKey {
177                value: SecretBox::new(Box::new(val)),
178                source: "env",
179            });
180        }
181    }
182
183    if let Ok(cfg) = load_config() {
184        if let Some(entry) = cfg.keys.iter().find(|k| k.provider == provider) {
185            return Some(ResolvedKey {
186                value: SecretBox::new(Box::new(entry.value.clone())),
187                source: "config",
188            });
189        }
190    }
191
192    cli_key.map(|k| ResolvedKey {
193        value: SecretBox::new(Box::new(k.to_owned())),
194        source: "cli",
195    })
196}
197
198pub fn compute_fingerprint(key: &str) -> String {
199    let hash = blake3::hash(key.as_bytes());
200    hash.to_hex()[..16].to_string()
201}
202
203pub fn mask_key(key: &str) -> String {
204    if key.len() <= 8 {
205        return "****".to_string();
206    }
207    format!("{}...{}", &key[..4], &key[key.len() - 4..])
208}
209
210#[cfg(test)]
211mod tests {
212    use super::*;
213    use secrecy::ExposeSecret;
214    use serial_test::serial;
215    use tempfile::TempDir;
216
217    #[test]
218    fn compute_fingerprint_deterministic() {
219        let fp1 = compute_fingerprint("sk-or-v1-test-key-12345");
220        let fp2 = compute_fingerprint("sk-or-v1-test-key-12345");
221        assert_eq!(fp1, fp2);
222        assert_eq!(fp1.len(), 16);
223    }
224
225    #[test]
226    fn compute_fingerprint_differs_for_different_keys() {
227        let fp1 = compute_fingerprint("key-a");
228        let fp2 = compute_fingerprint("key-b");
229        assert_ne!(fp1, fp2);
230    }
231
232    #[test]
233    fn mask_key_short() {
234        assert_eq!(mask_key("abcd"), "****");
235        assert_eq!(mask_key("12345678"), "****");
236        assert_eq!(mask_key(""), "****");
237    }
238
239    #[test]
240    fn mask_key_normal() {
241        assert_eq!(mask_key("sk-or-v1-abcdef1234"), "sk-o...1234");
242    }
243
244    #[test]
245    fn load_config_missing_file_returns_default() {
246        let tmp = TempDir::new().unwrap();
247        let nonexistent = tmp.path().join("does-not-exist.toml");
248        assert!(!nonexistent.exists());
249        let cfg = AppConfig::default();
250        assert_eq!(cfg.schema_version, 1);
251        assert!(cfg.keys.is_empty());
252    }
253
254    #[test]
255    fn save_and_load_roundtrip() {
256        let tmp = TempDir::new().unwrap();
257        let config_path = tmp.path().join("config.toml");
258
259        let mut cfg = AppConfig::default();
260        cfg.keys.push(ApiKeyEntry {
261            provider: "openrouter".to_string(),
262            value: "sk-test-key".to_string(),
263            added_at: "2026-01-01T00:00:00Z".to_string(),
264            fingerprint: compute_fingerprint("sk-test-key"),
265        });
266
267        let serialized = toml::to_string_pretty(&cfg).unwrap();
268        std::fs::write(&config_path, &serialized).unwrap();
269
270        let content = std::fs::read_to_string(&config_path).unwrap();
271        let loaded: AppConfig = toml::from_str(&content).unwrap();
272
273        assert_eq!(loaded.schema_version, 1);
274        assert_eq!(loaded.keys.len(), 1);
275        assert_eq!(loaded.keys[0].provider, "openrouter");
276        assert_eq!(loaded.keys[0].value, "sk-test-key");
277    }
278
279    #[test]
280    #[serial]
281    fn resolve_api_key_env_takes_precedence() {
282        unsafe {
283            std::env::set_var("OPENROUTER_API_KEY", "env-key-value");
284        }
285
286        let resolved = resolve_api_key("openrouter", Some("cli-key-value"));
287        assert!(resolved.is_some());
288        let r = resolved.unwrap();
289        assert_eq!(r.source, "env");
290        assert_eq!(r.value.expose_secret(), "env-key-value");
291
292        unsafe {
293            std::env::remove_var("OPENROUTER_API_KEY");
294        }
295    }
296
297    #[test]
298    #[serial]
299    fn resolve_api_key_cli_fallback() {
300        unsafe {
301            std::env::remove_var("OPENROUTER_API_KEY");
302        }
303
304        let resolved = resolve_api_key("nonexistent-provider", Some("cli-key"));
305        assert!(resolved.is_some());
306        let r = resolved.unwrap();
307        assert_eq!(r.source, "cli");
308        assert_eq!(r.value.expose_secret(), "cli-key");
309    }
310
311    #[test]
312    fn resolve_api_key_none_when_nothing_available() {
313        let resolved = resolve_api_key("totally-unknown-provider-xyz", None);
314        // May return None or config match depending on user env
315        // This test verifies no panic
316        let _ = resolved;
317    }
318}