steer_core/config/
mod.rs

1use crate::api::ProviderKind;
2use crate::auth::AuthStorage;
3use crate::error::{Error, Result};
4use serde::{Deserialize, Serialize};
5use std::fs;
6use std::path::PathBuf;
7use std::sync::Arc;
8
9#[derive(Debug, Serialize, Deserialize, Default)]
10pub struct Config {
11    pub model: Option<String>,
12    pub history_size: Option<usize>,
13    pub system_prompt: Option<String>,
14    #[serde(skip_serializing_if = "Option::is_none")]
15    pub notifications: Option<NotificationSettings>,
16}
17
18#[derive(Debug, Serialize, Deserialize, Clone)]
19pub struct NotificationSettings {
20    pub enable_sound: Option<bool>,
21    pub enable_desktop: Option<bool>,
22}
23
24impl Default for NotificationSettings {
25    fn default() -> Self {
26        Self {
27            enable_sound: Some(true),
28            enable_desktop: Some(true),
29        }
30    }
31}
32
33impl Config {
34    fn new() -> Self {
35        Self {
36            model: Some("claude-3-7-sonnet-20250219".to_string()),
37            history_size: Some(10),
38            system_prompt: None,
39            notifications: Some(NotificationSettings::default()),
40        }
41    }
42}
43
44/// Get the path to the config file
45pub fn get_config_path() -> Result<PathBuf> {
46    let config_dir = dirs::config_dir()
47        .ok_or_else(|| Error::Configuration("Could not find config directory".to_string()))?
48        .join("steer");
49
50    fs::create_dir_all(&config_dir)
51        .map_err(|e| Error::Configuration(format!("Failed to create config directory: {e}")))?;
52
53    Ok(config_dir.join("config.json"))
54}
55
56/// Load the configuration
57pub fn load_config() -> Result<Config> {
58    let config_path = get_config_path()?;
59
60    if !config_path.exists() {
61        return Ok(Config::new());
62    }
63
64    let config_str = fs::read_to_string(&config_path)
65        .map_err(|e| Error::Configuration(format!("Failed to read config file: {e}")))?;
66
67    let config: Config = serde_json::from_str(&config_str)
68        .map_err(|e| Error::Configuration(format!("Failed to parse config file: {e}")))?;
69
70    Ok(config)
71}
72
73/// Initialize or update the configuration
74pub fn init_config(force: bool) -> Result<()> {
75    let config_path = get_config_path()?;
76
77    if config_path.exists() && !force {
78        return Err(Error::Configuration(
79            "Config file already exists. Use --force to overwrite.".to_string(),
80        ));
81    }
82
83    let config = Config::new();
84    let config_json = serde_json::to_string_pretty(&config)
85        .map_err(|e| Error::Configuration(format!("Failed to serialize config: {e}")))?;
86
87    fs::write(&config_path, config_json)
88        .map_err(|e| Error::Configuration(format!("Failed to write config file: {e}")))?;
89
90    Ok(())
91}
92
93/// Save the configuration
94pub fn save_config(config: &Config) -> Result<()> {
95    let config_path = get_config_path()?;
96    let config_json = serde_json::to_string_pretty(&config)
97        .map_err(|e| Error::Configuration(format!("Failed to serialize config: {e}")))?;
98
99    fs::write(&config_path, config_json)
100        .map_err(|e| Error::Configuration(format!("Failed to write config file: {e}")))?;
101
102    Ok(())
103}
104
105#[derive(Debug, Clone)]
106pub enum ApiAuth {
107    Key(String),
108    OAuth,
109}
110
111/// Provider for authentication credentials
112#[derive(Clone)]
113pub struct LlmConfigProvider {
114    storage: Arc<dyn AuthStorage>,
115}
116
117impl std::fmt::Debug for LlmConfigProvider {
118    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119        f.debug_struct("LlmConfigProvider").finish_non_exhaustive()
120    }
121}
122
123impl LlmConfigProvider {
124    /// Create a new LlmConfigProvider with the given auth storage
125    pub fn new(storage: Arc<dyn AuthStorage>) -> Self {
126        Self { storage }
127    }
128
129    /// Get authentication for a specific provider
130    pub async fn get_auth_for_provider(&self, provider: ProviderKind) -> Result<Option<ApiAuth>> {
131        // For Anthropic: Check OAuth tokens first, then env vars, then stored API key
132        match provider {
133            ProviderKind::Anthropic => {
134                // API key via env var > OAuth > stored API key
135                let anthropic_key = std::env::var("CLAUDE_API_KEY")
136                    .ok()
137                    .or_else(|| std::env::var("ANTHROPIC_API_KEY").ok());
138                if let Some(key) = anthropic_key {
139                    Ok(Some(ApiAuth::Key(key)))
140                } else if self
141                    .storage
142                    .get_credential("anthropic", crate::auth::CredentialType::AuthTokens)
143                    .await?
144                    .is_some()
145                {
146                    Ok(Some(ApiAuth::OAuth))
147                } else {
148                    {
149                        // Fall back to stored API key in keyring
150                        if let Some(crate::auth::Credential::ApiKey { value }) = self
151                            .storage
152                            .get_credential("anthropic", crate::auth::CredentialType::ApiKey)
153                            .await?
154                        {
155                            Ok(Some(ApiAuth::Key(value)))
156                        } else {
157                            Ok(None)
158                        }
159                    }
160                }
161            }
162            ProviderKind::OpenAI => {
163                // API key via env var > stored API key
164                if let Ok(key) = std::env::var("OPENAI_API_KEY") {
165                    Ok(Some(ApiAuth::Key(key)))
166                } else if let Some(crate::auth::Credential::ApiKey { value }) = self
167                    .storage
168                    .get_credential("openai", crate::auth::CredentialType::ApiKey)
169                    .await?
170                {
171                    Ok(Some(ApiAuth::Key(value)))
172                } else {
173                    Ok(None)
174                }
175            }
176            ProviderKind::Google => {
177                // API key via env var > stored API key
178                if let Ok(key) =
179                    std::env::var("GEMINI_API_KEY").or_else(|_| std::env::var("GOOGLE_API_KEY"))
180                {
181                    Ok(Some(ApiAuth::Key(key)))
182                } else if let Some(crate::auth::Credential::ApiKey { value }) = self
183                    .storage
184                    .get_credential("google", crate::auth::CredentialType::ApiKey)
185                    .await?
186                {
187                    Ok(Some(ApiAuth::Key(value)))
188                } else {
189                    Ok(None)
190                }
191            }
192            ProviderKind::XAI => {
193                // API key via env var > stored API key
194                if let Ok(key) =
195                    std::env::var("XAI_API_KEY").or_else(|_| std::env::var("GROK_API_KEY"))
196                {
197                    Ok(Some(ApiAuth::Key(key)))
198                } else if let Some(crate::auth::Credential::ApiKey { value }) = self
199                    .storage
200                    .get_credential("xai", crate::auth::CredentialType::ApiKey)
201                    .await?
202                {
203                    Ok(Some(ApiAuth::Key(value)))
204                } else {
205                    Ok(None)
206                }
207            }
208        }
209    }
210
211    /// Get the auth storage
212    pub fn auth_storage(&self) -> &Arc<dyn AuthStorage> {
213        &self.storage
214    }
215
216    /// Return list of providers that have authentication configured
217    pub async fn available_providers(&self) -> Result<Vec<ProviderKind>> {
218        let mut providers = Vec::new();
219        if self
220            .get_auth_for_provider(ProviderKind::Anthropic)
221            .await?
222            .is_some()
223        {
224            providers.push(ProviderKind::Anthropic);
225        }
226        if self
227            .get_auth_for_provider(ProviderKind::OpenAI)
228            .await?
229            .is_some()
230        {
231            providers.push(ProviderKind::OpenAI);
232        }
233        if self
234            .get_auth_for_provider(ProviderKind::Google)
235            .await?
236            .is_some()
237        {
238            providers.push(ProviderKind::Google);
239        }
240        if self
241            .get_auth_for_provider(ProviderKind::XAI)
242            .await?
243            .is_some()
244        {
245            providers.push(ProviderKind::XAI);
246        }
247        Ok(providers)
248    }
249}
250
251#[cfg(test)]
252mod tests {
253    use super::*;
254    use crate::test_utils::InMemoryAuthStorage;
255
256    #[tokio::test]
257    async fn test_auth_changes_immediately_reflected() {
258        // Create a provider with in-memory storage
259        let storage = Arc::new(InMemoryAuthStorage::new());
260        let provider = LlmConfigProvider::new(storage.clone());
261
262        // Initially no auth
263        let auth = provider
264            .get_auth_for_provider(ProviderKind::Anthropic)
265            .await
266            .unwrap();
267        assert!(auth.is_none());
268
269        // Add API key
270        storage
271            .set_credential(
272                "anthropic",
273                crate::auth::Credential::ApiKey {
274                    value: "test-key".to_string(),
275                },
276            )
277            .await
278            .unwrap();
279
280        // Should immediately see the new key
281        let auth = provider
282            .get_auth_for_provider(ProviderKind::Anthropic)
283            .await
284            .unwrap();
285        assert!(matches!(auth, Some(ApiAuth::Key(key)) if key == "test-key"));
286
287        // Add OAuth tokens
288        storage
289            .set_credential(
290                "anthropic",
291                crate::auth::Credential::AuthTokens(crate::auth::storage::AuthTokens {
292                    access_token: "access".to_string(),
293                    refresh_token: "refresh".to_string(),
294                    expires_at: std::time::SystemTime::now() + std::time::Duration::from_secs(3600),
295                }),
296            )
297            .await
298            .unwrap();
299
300        // Should immediately prefer OAuth over API key
301        let auth = provider
302            .get_auth_for_provider(ProviderKind::Anthropic)
303            .await
304            .unwrap();
305        assert!(matches!(auth, Some(ApiAuth::OAuth)));
306
307        // Remove OAuth tokens
308        storage
309            .remove_credential("anthropic", crate::auth::CredentialType::AuthTokens)
310            .await
311            .unwrap();
312
313        // Should immediately fall back to API key
314        let auth = provider
315            .get_auth_for_provider(ProviderKind::Anthropic)
316            .await
317            .unwrap();
318        assert!(matches!(auth, Some(ApiAuth::Key(key)) if key == "test-key"));
319    }
320
321    #[tokio::test]
322    async fn test_available_providers_updates_immediately() {
323        let storage = Arc::new(InMemoryAuthStorage::new());
324        let provider = LlmConfigProvider::new(storage.clone());
325
326        // Initially no providers
327        let providers = provider.available_providers().await.unwrap();
328        assert!(providers.is_empty());
329
330        // Add Anthropic API key
331        storage
332            .set_credential(
333                "anthropic",
334                crate::auth::Credential::ApiKey {
335                    value: "test-key".to_string(),
336                },
337            )
338            .await
339            .unwrap();
340
341        // Should immediately show Anthropic
342        let providers = provider.available_providers().await.unwrap();
343        assert_eq!(providers, vec![ProviderKind::Anthropic]);
344
345        // Add OpenAI key
346        storage
347            .set_credential(
348                "openai",
349                crate::auth::Credential::ApiKey {
350                    value: "openai-key".to_string(),
351                },
352            )
353            .await
354            .unwrap();
355
356        // Should immediately show both
357        let providers = provider.available_providers().await.unwrap();
358        assert_eq!(
359            providers,
360            vec![ProviderKind::Anthropic, ProviderKind::OpenAI]
361        );
362    }
363}