Skip to main content

saorsa_agent/config/
auth.rs

1//! Authentication configuration for LLM providers.
2
3use std::collections::HashMap;
4use std::path::Path;
5
6use serde::{Deserialize, Serialize};
7
8use crate::error::{Result, SaorsaAgentError};
9
10/// A single authentication entry describing how to obtain an API key.
11#[derive(Clone, Debug, Serialize, Deserialize)]
12#[serde(tag = "type", rename_all = "snake_case")]
13pub enum AuthEntry {
14    /// A raw API key stored directly in the config.
15    ApiKey {
16        /// The API key value.
17        key: String,
18    },
19    /// An environment variable containing the API key.
20    EnvVar {
21        /// The environment variable name.
22        name: String,
23    },
24    /// A shell command whose stdout is the API key.
25    Command {
26        /// The shell command to execute.
27        command: String,
28    },
29}
30
31/// Authentication configuration mapping provider names to auth entries.
32#[derive(Clone, Debug, Default, Serialize, Deserialize)]
33pub struct AuthConfig {
34    /// Provider name to authentication entry mapping.
35    #[serde(flatten)]
36    pub providers: HashMap<String, AuthEntry>,
37}
38
39/// Load authentication configuration from a JSON file.
40///
41/// Returns a default (empty) [`AuthConfig`] if the file does not exist.
42///
43/// # Errors
44///
45/// Returns [`SaorsaAgentError::ConfigIo`] on I/O failures or
46/// [`SaorsaAgentError::ConfigParse`] on JSON parse failures.
47pub fn load(path: &Path) -> Result<AuthConfig> {
48    if !path.exists() {
49        return Ok(AuthConfig::default());
50    }
51    let data = std::fs::read_to_string(path).map_err(SaorsaAgentError::ConfigIo)?;
52    let config: AuthConfig = serde_json::from_str(&data).map_err(SaorsaAgentError::ConfigParse)?;
53    Ok(config)
54}
55
56/// Save authentication configuration to a JSON file.
57///
58/// Creates parent directories if they do not exist.
59///
60/// # Errors
61///
62/// Returns [`SaorsaAgentError::ConfigIo`] on I/O failures or
63/// [`SaorsaAgentError::ConfigParse`] on serialization failures.
64pub fn save(config: &AuthConfig, path: &Path) -> Result<()> {
65    if let Some(parent) = path.parent() {
66        std::fs::create_dir_all(parent).map_err(SaorsaAgentError::ConfigIo)?;
67    }
68    let data = serde_json::to_string_pretty(config).map_err(SaorsaAgentError::ConfigParse)?;
69    std::fs::write(path, data).map_err(SaorsaAgentError::ConfigIo)?;
70    Ok(())
71}
72
73/// Resolve an [`AuthEntry`] to a concrete API key string.
74///
75/// - `ApiKey` returns the key directly.
76/// - `EnvVar` reads the named environment variable.
77/// - `Command` executes the shell command and returns trimmed stdout.
78///
79/// # Errors
80///
81/// Returns [`SaorsaAgentError::EnvVarNotFound`] when an environment variable
82/// is missing, or [`SaorsaAgentError::CommandFailed`] when a shell command
83/// exits with a non-zero status or fails to execute.
84pub fn resolve(entry: &AuthEntry) -> Result<String> {
85    match entry {
86        AuthEntry::ApiKey { key } => Ok(key.clone()),
87        AuthEntry::EnvVar { name } => {
88            std::env::var(name).map_err(|_| SaorsaAgentError::EnvVarNotFound { name: name.clone() })
89        }
90        AuthEntry::Command { command } => {
91            let output = std::process::Command::new("sh")
92                .arg("-c")
93                .arg(command)
94                .output()
95                .map_err(|e| SaorsaAgentError::CommandFailed(e.to_string()))?;
96            if !output.status.success() {
97                let stderr = String::from_utf8_lossy(&output.stderr);
98                return Err(SaorsaAgentError::CommandFailed(format!(
99                    "command exited with {}: {}",
100                    output.status,
101                    stderr.trim()
102                )));
103            }
104            Ok(String::from_utf8_lossy(&output.stdout).trim().to_string())
105        }
106    }
107}
108
109/// Look up and resolve the API key for a named provider.
110///
111/// # Errors
112///
113/// Returns [`SaorsaAgentError::EnvVarNotFound`] if the provider has no entry
114/// in the config (with the provider name as `name`), or any resolution error
115/// from [`resolve`].
116pub fn get_key(config: &AuthConfig, provider: &str) -> Result<String> {
117    let entry = config
118        .providers
119        .get(provider)
120        .ok_or_else(|| SaorsaAgentError::EnvVarNotFound {
121            name: provider.to_string(),
122        })?;
123    resolve(entry)
124}
125
126#[cfg(test)]
127#[allow(clippy::unwrap_used)]
128mod tests {
129    use super::*;
130
131    #[test]
132    fn roundtrip_auth_config() {
133        let tmp = tempfile::tempdir().unwrap();
134        let path = tmp.path().join("auth.json");
135
136        let mut config = AuthConfig::default();
137        config.providers.insert(
138            "anthropic".into(),
139            AuthEntry::ApiKey {
140                key: "sk-test-123".into(),
141            },
142        );
143        config.providers.insert(
144            "openai".into(),
145            AuthEntry::EnvVar {
146                name: "OPENAI_API_KEY".into(),
147            },
148        );
149
150        save(&config, &path).unwrap();
151        let loaded = load(&path).unwrap();
152
153        assert_eq!(loaded.providers.len(), 2);
154        assert!(loaded.providers.contains_key("anthropic"));
155        assert!(loaded.providers.contains_key("openai"));
156    }
157
158    #[test]
159    fn load_missing_file_returns_default() {
160        let tmp = tempfile::tempdir().unwrap();
161        let path = tmp.path().join("nonexistent.json");
162        let config = load(&path).unwrap();
163        assert!(config.providers.is_empty());
164    }
165
166    #[test]
167    fn resolve_api_key() {
168        let entry = AuthEntry::ApiKey {
169            key: "sk-direct".into(),
170        };
171        let resolved = resolve(&entry).unwrap();
172        assert_eq!(resolved, "sk-direct");
173    }
174
175    #[test]
176    fn resolve_env_var() {
177        // SAFETY: This test is single-threaded and the variable name is unique
178        // to this test, so no other thread observes it.
179        unsafe {
180            std::env::set_var("SAORSA_TEST_AUTH_KEY", "sk-from-env");
181        }
182        let entry = AuthEntry::EnvVar {
183            name: "SAORSA_TEST_AUTH_KEY".into(),
184        };
185        let resolved = resolve(&entry).unwrap();
186        assert_eq!(resolved, "sk-from-env");
187        // SAFETY: Same reasoning as above.
188        unsafe {
189            std::env::remove_var("SAORSA_TEST_AUTH_KEY");
190        }
191    }
192
193    #[test]
194    fn resolve_env_var_missing() {
195        let entry = AuthEntry::EnvVar {
196            name: "SAORSA_NONEXISTENT_VAR_12345".into(),
197        };
198        let err = resolve(&entry).unwrap_err();
199        assert!(matches!(err, SaorsaAgentError::EnvVarNotFound { .. }));
200    }
201
202    #[test]
203    fn resolve_command() {
204        let entry = AuthEntry::Command {
205            command: "echo sk-from-cmd".into(),
206        };
207        let resolved = resolve(&entry).unwrap();
208        assert_eq!(resolved, "sk-from-cmd");
209    }
210
211    #[test]
212    fn resolve_command_failure() {
213        let entry = AuthEntry::Command {
214            command: "exit 1".into(),
215        };
216        let err = resolve(&entry).unwrap_err();
217        assert!(matches!(err, SaorsaAgentError::CommandFailed(_)));
218    }
219
220    #[test]
221    fn get_key_found() {
222        let mut config = AuthConfig::default();
223        config.providers.insert(
224            "test".into(),
225            AuthEntry::ApiKey {
226                key: "sk-test".into(),
227            },
228        );
229        let key = get_key(&config, "test").unwrap();
230        assert_eq!(key, "sk-test");
231    }
232
233    #[test]
234    fn get_key_missing_provider() {
235        let config = AuthConfig::default();
236        let err = get_key(&config, "missing").unwrap_err();
237        assert!(matches!(err, SaorsaAgentError::EnvVarNotFound { .. }));
238    }
239
240    #[test]
241    fn save_creates_parent_dirs() {
242        let tmp = tempfile::tempdir().unwrap();
243        let path = tmp.path().join("nested").join("deep").join("auth.json");
244        let config = AuthConfig::default();
245        save(&config, &path).unwrap();
246        assert!(path.exists());
247    }
248}