Skip to main content

steer_core/config/
mod.rs

1use crate::auth::{
2    ApiKeyOrigin, AuthDirective, AuthMethod, AuthPluginRegistry, AuthSource, AuthStorage,
3    Credential,
4};
5use crate::config::provider::ProviderId;
6use crate::error::{Error, Result};
7use serde::{Deserialize, Serialize};
8use std::fs;
9use std::path::PathBuf;
10use std::sync::Arc;
11
12#[derive(Debug, Serialize, Deserialize, Default)]
13pub struct Config {
14    pub model: Option<String>,
15    pub history_size: Option<usize>,
16    pub system_prompt: Option<String>,
17    #[serde(skip_serializing_if = "Option::is_none")]
18    pub notifications: Option<NotificationSettings>,
19}
20
21#[derive(Debug, Serialize, Deserialize, Clone)]
22pub struct NotificationSettings {
23    pub enable_sound: Option<bool>,
24}
25
26impl Default for NotificationSettings {
27    fn default() -> Self {
28        Self {
29            enable_sound: Some(true),
30        }
31    }
32}
33
34impl Config {
35    fn new() -> Self {
36        Self {
37            model: Some(crate::config::model::builtin::default_model().id),
38            history_size: Some(10),
39            system_prompt: None,
40            notifications: Some(NotificationSettings::default()),
41        }
42    }
43}
44
45/// Get the path to the config file
46pub fn get_config_path() -> Result<PathBuf> {
47    let config_dir = dirs::config_dir()
48        .ok_or_else(|| Error::Configuration("Could not find config directory".to_string()))?
49        .join("steer");
50
51    fs::create_dir_all(&config_dir)
52        .map_err(|e| Error::Configuration(format!("Failed to create config directory: {e}")))?;
53
54    Ok(config_dir.join("config.json"))
55}
56
57/// Load the configuration
58pub fn load_config() -> Result<Config> {
59    let config_path = get_config_path()?;
60
61    if !config_path.exists() {
62        return Ok(Config::new());
63    }
64
65    let config_str = fs::read_to_string(&config_path)
66        .map_err(|e| Error::Configuration(format!("Failed to read config file: {e}")))?;
67
68    let config: Config = serde_json::from_str(&config_str)
69        .map_err(|e| Error::Configuration(format!("Failed to parse config file: {e}")))?;
70
71    Ok(config)
72}
73
74/// Initialize or update the configuration
75pub fn init_config(force: bool) -> Result<()> {
76    let config_path = get_config_path()?;
77
78    if config_path.exists() && !force {
79        return Err(Error::Configuration(
80            "Config file already exists. Use --force to overwrite.".to_string(),
81        ));
82    }
83
84    let config = Config::new();
85    let config_json = serde_json::to_string_pretty(&config)
86        .map_err(|e| Error::Configuration(format!("Failed to serialize config: {e}")))?;
87
88    fs::write(&config_path, config_json)
89        .map_err(|e| Error::Configuration(format!("Failed to write config file: {e}")))?;
90
91    Ok(())
92}
93
94/// Save the configuration
95pub fn save_config(config: &Config) -> Result<()> {
96    let config_path = get_config_path()?;
97    let config_json = serde_json::to_string_pretty(&config)
98        .map_err(|e| Error::Configuration(format!("Failed to serialize config: {e}")))?;
99
100    fs::write(&config_path, config_json)
101        .map_err(|e| Error::Configuration(format!("Failed to write config file: {e}")))?;
102
103    Ok(())
104}
105
106#[derive(Debug, Clone)]
107pub enum ApiAuth {
108    Key(String),
109    OAuth,
110}
111
112#[derive(Debug, Clone)]
113pub enum ResolvedAuth {
114    Plugin {
115        directive: AuthDirective,
116        source: AuthSource,
117    },
118    ApiKey {
119        credential: Credential,
120        source: AuthSource,
121    },
122    None,
123}
124
125impl ResolvedAuth {
126    pub fn source(&self) -> AuthSource {
127        match self {
128            ResolvedAuth::Plugin { source, .. } => source.clone(),
129            ResolvedAuth::ApiKey { source, .. } => source.clone(),
130            ResolvedAuth::None => AuthSource::None,
131        }
132    }
133
134    pub fn directive(&self) -> Option<&AuthDirective> {
135        match self {
136            ResolvedAuth::Plugin { directive, .. } => Some(directive),
137            _ => None,
138        }
139    }
140
141    pub fn credential(&self) -> Option<&Credential> {
142        match self {
143            ResolvedAuth::ApiKey { credential, .. } => Some(credential),
144            _ => None,
145        }
146    }
147}
148
149/// Provider for authentication credentials
150#[derive(Clone)]
151pub struct LlmConfigProvider {
152    storage: Arc<dyn AuthStorage>,
153    env_provider: Arc<dyn EnvProvider>,
154    plugin_registry: Arc<AuthPluginRegistry>,
155}
156
157impl std::fmt::Debug for LlmConfigProvider {
158    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
159        f.debug_struct("LlmConfigProvider").finish_non_exhaustive()
160    }
161}
162
163impl LlmConfigProvider {
164    /// Create a new LlmConfigProvider with the given auth storage and default plugins.
165    pub fn new(storage: Arc<dyn AuthStorage>) -> Result<Self> {
166        let plugin_registry = Arc::new(AuthPluginRegistry::with_defaults()?);
167        Ok(Self::new_with_plugins(storage, plugin_registry))
168    }
169
170    /// Create a new LlmConfigProvider with an explicit plugin registry.
171    pub fn new_with_plugins(
172        storage: Arc<dyn AuthStorage>,
173        plugin_registry: Arc<AuthPluginRegistry>,
174    ) -> Self {
175        Self {
176            storage,
177            env_provider: Arc::new(StdEnvProvider),
178            plugin_registry,
179        }
180    }
181
182    /// Create a new LlmConfigProvider with a custom env provider (useful for tests).
183    #[cfg(test)]
184    fn with_env_provider(
185        storage: Arc<dyn AuthStorage>,
186        env_provider: Arc<dyn EnvProvider>,
187    ) -> Self {
188        let plugin_registry =
189            Arc::new(AuthPluginRegistry::with_defaults().expect("default plugins"));
190        Self {
191            storage,
192            env_provider,
193            plugin_registry,
194        }
195    }
196
197    /// Get authentication for a specific provider ID (legacy API).
198    pub async fn get_auth_for_provider(&self, provider_id: &ProviderId) -> Result<Option<ApiAuth>> {
199        let resolved = self.resolve_auth_for_provider(provider_id).await?;
200        match resolved {
201            ResolvedAuth::Plugin { .. } => Ok(Some(ApiAuth::OAuth)),
202            ResolvedAuth::ApiKey { credential, .. } => match credential {
203                Credential::ApiKey { value } => Ok(Some(ApiAuth::Key(value.clone()))),
204                Credential::OAuth2(_) => Ok(None),
205            },
206            ResolvedAuth::None => Ok(None),
207        }
208    }
209
210    /// Resolve authentication source for a provider, including API key origin.
211    pub async fn resolve_auth_source(&self, provider_id: &ProviderId) -> Result<AuthSource> {
212        Ok(self.resolve_auth_for_provider(provider_id).await?.source())
213    }
214
215    /// Resolve authentication for a provider using server-side auto-selection.
216    pub async fn resolve_auth_for_provider(
217        &self,
218        provider_id: &ProviderId,
219    ) -> Result<ResolvedAuth> {
220        if let Some(plugin) = self.plugin_registry.get(provider_id)
221            && let Some(directive) = plugin.resolve_auth(self.storage.clone()).await?
222        {
223            return Ok(ResolvedAuth::Plugin {
224                directive,
225                source: AuthSource::Plugin {
226                    method: AuthMethod::OAuth,
227                },
228            });
229        }
230
231        if let Some((key, origin)) = self.resolve_api_key_for_provider(provider_id).await? {
232            return Ok(ResolvedAuth::ApiKey {
233                credential: Credential::ApiKey { value: key },
234                source: AuthSource::ApiKey { origin },
235            });
236        }
237
238        Ok(ResolvedAuth::None)
239    }
240
241    pub async fn resolve_api_key_for_provider(
242        &self,
243        provider_id: &ProviderId,
244    ) -> Result<Option<(String, ApiKeyOrigin)>> {
245        if provider_id.as_str() == self::provider::ANTHROPIC_ID {
246            let anthropic_key = self
247                .env_provider
248                .var("CLAUDE_API_KEY")
249                .or_else(|| self.env_provider.var("ANTHROPIC_API_KEY"));
250            if let Some(key) = anthropic_key {
251                Ok(Some((key, ApiKeyOrigin::Env)))
252            } else if let Some(crate::auth::Credential::ApiKey { value }) = self
253                .storage
254                .get_credential(
255                    &provider_id.storage_key(),
256                    crate::auth::CredentialType::ApiKey,
257                )
258                .await?
259            {
260                Ok(Some((value, ApiKeyOrigin::Stored)))
261            } else {
262                Ok(None)
263            }
264        } else if provider_id.as_str() == self::provider::OPENAI_ID {
265            if let Some(key) = self.env_provider.var("OPENAI_API_KEY") {
266                Ok(Some((key, ApiKeyOrigin::Env)))
267            } else if let Some(crate::auth::Credential::ApiKey { value }) = self
268                .storage
269                .get_credential(
270                    &provider_id.storage_key(),
271                    crate::auth::CredentialType::ApiKey,
272                )
273                .await?
274            {
275                Ok(Some((value, ApiKeyOrigin::Stored)))
276            } else {
277                Ok(None)
278            }
279        } else if provider_id.as_str() == self::provider::GOOGLE_ID {
280            if let Some(key) = self
281                .env_provider
282                .var("GEMINI_API_KEY")
283                .or_else(|| self.env_provider.var("GOOGLE_API_KEY"))
284            {
285                Ok(Some((key, ApiKeyOrigin::Env)))
286            } else if let Some(crate::auth::Credential::ApiKey { value }) = self
287                .storage
288                .get_credential(
289                    &provider_id.storage_key(),
290                    crate::auth::CredentialType::ApiKey,
291                )
292                .await?
293            {
294                Ok(Some((value, ApiKeyOrigin::Stored)))
295            } else {
296                Ok(None)
297            }
298        } else if provider_id.as_str() == self::provider::XAI_ID {
299            if let Some(key) = self
300                .env_provider
301                .var("XAI_API_KEY")
302                .or_else(|| self.env_provider.var("GROK_API_KEY"))
303            {
304                Ok(Some((key, ApiKeyOrigin::Env)))
305            } else if let Some(crate::auth::Credential::ApiKey { value }) = self
306                .storage
307                .get_credential(
308                    &provider_id.storage_key(),
309                    crate::auth::CredentialType::ApiKey,
310                )
311                .await?
312            {
313                Ok(Some((value, ApiKeyOrigin::Stored)))
314            } else {
315                Ok(None)
316            }
317        } else if let Some(crate::auth::Credential::ApiKey { value }) = self
318            .storage
319            .get_credential(
320                &provider_id.storage_key(),
321                crate::auth::CredentialType::ApiKey,
322            )
323            .await?
324        {
325            Ok(Some((value, ApiKeyOrigin::Stored)))
326        } else {
327            Ok(None)
328        }
329    }
330
331    /// Get the auth storage
332    pub fn auth_storage(&self) -> &Arc<dyn AuthStorage> {
333        &self.storage
334    }
335
336    pub fn plugin_registry(&self) -> &Arc<AuthPluginRegistry> {
337        &self.plugin_registry
338    }
339}
340
341pub mod model;
342pub mod provider;
343pub mod toml_types;
344
345trait EnvProvider: Send + Sync {
346    fn var(&self, key: &str) -> Option<String>;
347}
348
349#[derive(Clone)]
350struct StdEnvProvider;
351
352impl EnvProvider for StdEnvProvider {
353    fn var(&self, key: &str) -> Option<String> {
354        std::env::var(key).ok()
355    }
356}
357
358#[cfg(test)]
359mod tests {
360    use super::*;
361    use crate::auth::AuthTokens;
362    use crate::test_utils::InMemoryAuthStorage;
363    use std::collections::HashMap;
364    use std::sync::Arc;
365    use std::time::{Duration, SystemTime};
366
367    #[derive(Clone, Default)]
368    struct TestEnvProvider {
369        vars: HashMap<String, String>,
370    }
371
372    impl EnvProvider for TestEnvProvider {
373        fn var(&self, key: &str) -> Option<String> {
374            self.vars.get(key).cloned()
375        }
376    }
377
378    #[tokio::test]
379    async fn openai_oauth_takes_precedence() {
380        let storage = Arc::new(InMemoryAuthStorage::new());
381        storage
382            .set_credential(
383                "openai",
384                crate::auth::Credential::ApiKey {
385                    value: "stored-key".to_string(),
386                },
387            )
388            .await
389            .unwrap();
390        storage
391            .set_credential(
392                "openai",
393                crate::auth::Credential::OAuth2(AuthTokens {
394                    access_token: "token".to_string(),
395                    refresh_token: "refresh".to_string(),
396                    expires_at: SystemTime::now() + Duration::from_secs(3600),
397                    id_token: Some("id-token".to_string()),
398                }),
399            )
400            .await
401            .unwrap();
402
403        let mut env = TestEnvProvider::default();
404        env.vars
405            .insert("OPENAI_API_KEY".to_string(), "env-key".to_string());
406        let provider = LlmConfigProvider::with_env_provider(storage, Arc::new(env));
407        let auth = provider
408            .get_auth_for_provider(&provider::openai())
409            .await
410            .unwrap();
411
412        assert!(matches!(auth, Some(ApiAuth::OAuth)));
413    }
414
415    #[tokio::test]
416    async fn openai_env_takes_precedence_over_stored_key() {
417        let storage = Arc::new(InMemoryAuthStorage::new());
418        storage
419            .set_credential(
420                "openai",
421                crate::auth::Credential::ApiKey {
422                    value: "stored-key".to_string(),
423                },
424            )
425            .await
426            .unwrap();
427
428        let mut env = TestEnvProvider::default();
429        env.vars
430            .insert("OPENAI_API_KEY".to_string(), "env-key".to_string());
431        let provider = LlmConfigProvider::with_env_provider(storage, Arc::new(env));
432        let auth = provider
433            .get_auth_for_provider(&provider::openai())
434            .await
435            .unwrap();
436
437        match auth {
438            Some(ApiAuth::Key(key)) => assert_eq!(key, "env-key"),
439            _ => panic!("Expected env API key"),
440        }
441    }
442}