1use crate::auth::{
2 ApiKeyOrigin, AuthDirective, AuthMethod, AuthPluginRegistry, AuthSource, AuthStorage,
3 Credential,
4};
5use crate::config::provider::ProviderId;
6use crate::error::{Error, Result};
7use serde::{Deserialize, Serialize};
8use std::fs;
9use std::path::PathBuf;
10use std::sync::Arc;
11
12#[derive(Debug, Serialize, Deserialize, Default)]
13pub struct Config {
14 pub model: Option<String>,
15 pub history_size: Option<usize>,
16 pub system_prompt: Option<String>,
17 #[serde(skip_serializing_if = "Option::is_none")]
18 pub notifications: Option<NotificationSettings>,
19}
20
21#[derive(Debug, Serialize, Deserialize, Clone)]
22pub struct NotificationSettings {
23 pub enable_sound: Option<bool>,
24}
25
26impl Default for NotificationSettings {
27 fn default() -> Self {
28 Self {
29 enable_sound: Some(true),
30 }
31 }
32}
33
34impl Config {
35 fn new() -> Self {
36 Self {
37 model: Some(crate::config::model::builtin::default_model().id),
38 history_size: Some(10),
39 system_prompt: None,
40 notifications: Some(NotificationSettings::default()),
41 }
42 }
43}
44
45pub fn get_config_path() -> Result<PathBuf> {
47 let config_dir = dirs::config_dir()
48 .ok_or_else(|| Error::Configuration("Could not find config directory".to_string()))?
49 .join("steer");
50
51 fs::create_dir_all(&config_dir)
52 .map_err(|e| Error::Configuration(format!("Failed to create config directory: {e}")))?;
53
54 Ok(config_dir.join("config.json"))
55}
56
57pub fn load_config() -> Result<Config> {
59 let config_path = get_config_path()?;
60
61 if !config_path.exists() {
62 return Ok(Config::new());
63 }
64
65 let config_str = fs::read_to_string(&config_path)
66 .map_err(|e| Error::Configuration(format!("Failed to read config file: {e}")))?;
67
68 let config: Config = serde_json::from_str(&config_str)
69 .map_err(|e| Error::Configuration(format!("Failed to parse config file: {e}")))?;
70
71 Ok(config)
72}
73
74pub fn init_config(force: bool) -> Result<()> {
76 let config_path = get_config_path()?;
77
78 if config_path.exists() && !force {
79 return Err(Error::Configuration(
80 "Config file already exists. Use --force to overwrite.".to_string(),
81 ));
82 }
83
84 let config = Config::new();
85 let config_json = serde_json::to_string_pretty(&config)
86 .map_err(|e| Error::Configuration(format!("Failed to serialize config: {e}")))?;
87
88 fs::write(&config_path, config_json)
89 .map_err(|e| Error::Configuration(format!("Failed to write config file: {e}")))?;
90
91 Ok(())
92}
93
94pub fn save_config(config: &Config) -> Result<()> {
96 let config_path = get_config_path()?;
97 let config_json = serde_json::to_string_pretty(&config)
98 .map_err(|e| Error::Configuration(format!("Failed to serialize config: {e}")))?;
99
100 fs::write(&config_path, config_json)
101 .map_err(|e| Error::Configuration(format!("Failed to write config file: {e}")))?;
102
103 Ok(())
104}
105
106#[derive(Debug, Clone)]
107pub enum ApiAuth {
108 Key(String),
109 OAuth,
110}
111
112#[derive(Debug, Clone)]
113pub enum ResolvedAuth {
114 Plugin {
115 directive: AuthDirective,
116 source: AuthSource,
117 },
118 ApiKey {
119 credential: Credential,
120 source: AuthSource,
121 },
122 None,
123}
124
125impl ResolvedAuth {
126 pub fn source(&self) -> AuthSource {
127 match self {
128 ResolvedAuth::Plugin { source, .. } => source.clone(),
129 ResolvedAuth::ApiKey { source, .. } => source.clone(),
130 ResolvedAuth::None => AuthSource::None,
131 }
132 }
133
134 pub fn directive(&self) -> Option<&AuthDirective> {
135 match self {
136 ResolvedAuth::Plugin { directive, .. } => Some(directive),
137 _ => None,
138 }
139 }
140
141 pub fn credential(&self) -> Option<&Credential> {
142 match self {
143 ResolvedAuth::ApiKey { credential, .. } => Some(credential),
144 _ => None,
145 }
146 }
147}
148
149#[derive(Clone)]
151pub struct LlmConfigProvider {
152 storage: Arc<dyn AuthStorage>,
153 env_provider: Arc<dyn EnvProvider>,
154 plugin_registry: Arc<AuthPluginRegistry>,
155}
156
157impl std::fmt::Debug for LlmConfigProvider {
158 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
159 f.debug_struct("LlmConfigProvider").finish_non_exhaustive()
160 }
161}
162
163impl LlmConfigProvider {
164 pub fn new(storage: Arc<dyn AuthStorage>) -> Result<Self> {
166 let plugin_registry = Arc::new(AuthPluginRegistry::with_defaults()?);
167 Ok(Self::new_with_plugins(storage, plugin_registry))
168 }
169
170 pub fn new_with_plugins(
172 storage: Arc<dyn AuthStorage>,
173 plugin_registry: Arc<AuthPluginRegistry>,
174 ) -> Self {
175 Self {
176 storage,
177 env_provider: Arc::new(StdEnvProvider),
178 plugin_registry,
179 }
180 }
181
182 #[cfg(test)]
184 fn with_env_provider(
185 storage: Arc<dyn AuthStorage>,
186 env_provider: Arc<dyn EnvProvider>,
187 ) -> Self {
188 let plugin_registry =
189 Arc::new(AuthPluginRegistry::with_defaults().expect("default plugins"));
190 Self {
191 storage,
192 env_provider,
193 plugin_registry,
194 }
195 }
196
197 pub async fn get_auth_for_provider(&self, provider_id: &ProviderId) -> Result<Option<ApiAuth>> {
199 let resolved = self.resolve_auth_for_provider(provider_id).await?;
200 match resolved {
201 ResolvedAuth::Plugin { .. } => Ok(Some(ApiAuth::OAuth)),
202 ResolvedAuth::ApiKey { credential, .. } => match credential {
203 Credential::ApiKey { value } => Ok(Some(ApiAuth::Key(value.clone()))),
204 Credential::OAuth2(_) => Ok(None),
205 },
206 ResolvedAuth::None => Ok(None),
207 }
208 }
209
210 pub async fn resolve_auth_source(&self, provider_id: &ProviderId) -> Result<AuthSource> {
212 Ok(self.resolve_auth_for_provider(provider_id).await?.source())
213 }
214
215 pub async fn resolve_auth_for_provider(
217 &self,
218 provider_id: &ProviderId,
219 ) -> Result<ResolvedAuth> {
220 if let Some(plugin) = self.plugin_registry.get(provider_id)
221 && let Some(directive) = plugin.resolve_auth(self.storage.clone()).await?
222 {
223 return Ok(ResolvedAuth::Plugin {
224 directive,
225 source: AuthSource::Plugin {
226 method: AuthMethod::OAuth,
227 },
228 });
229 }
230
231 if let Some((key, origin)) = self.resolve_api_key_for_provider(provider_id).await? {
232 return Ok(ResolvedAuth::ApiKey {
233 credential: Credential::ApiKey { value: key },
234 source: AuthSource::ApiKey { origin },
235 });
236 }
237
238 Ok(ResolvedAuth::None)
239 }
240
241 pub async fn resolve_api_key_for_provider(
242 &self,
243 provider_id: &ProviderId,
244 ) -> Result<Option<(String, ApiKeyOrigin)>> {
245 if provider_id.as_str() == self::provider::ANTHROPIC_ID {
246 let anthropic_key = self
247 .env_provider
248 .var("CLAUDE_API_KEY")
249 .or_else(|| self.env_provider.var("ANTHROPIC_API_KEY"));
250 if let Some(key) = anthropic_key {
251 Ok(Some((key, ApiKeyOrigin::Env)))
252 } else if let Some(crate::auth::Credential::ApiKey { value }) = self
253 .storage
254 .get_credential(
255 &provider_id.storage_key(),
256 crate::auth::CredentialType::ApiKey,
257 )
258 .await?
259 {
260 Ok(Some((value, ApiKeyOrigin::Stored)))
261 } else {
262 Ok(None)
263 }
264 } else if provider_id.as_str() == self::provider::OPENAI_ID {
265 if let Some(key) = self.env_provider.var("OPENAI_API_KEY") {
266 Ok(Some((key, ApiKeyOrigin::Env)))
267 } else if let Some(crate::auth::Credential::ApiKey { value }) = self
268 .storage
269 .get_credential(
270 &provider_id.storage_key(),
271 crate::auth::CredentialType::ApiKey,
272 )
273 .await?
274 {
275 Ok(Some((value, ApiKeyOrigin::Stored)))
276 } else {
277 Ok(None)
278 }
279 } else if provider_id.as_str() == self::provider::GOOGLE_ID {
280 if let Some(key) = self
281 .env_provider
282 .var("GEMINI_API_KEY")
283 .or_else(|| self.env_provider.var("GOOGLE_API_KEY"))
284 {
285 Ok(Some((key, ApiKeyOrigin::Env)))
286 } else if let Some(crate::auth::Credential::ApiKey { value }) = self
287 .storage
288 .get_credential(
289 &provider_id.storage_key(),
290 crate::auth::CredentialType::ApiKey,
291 )
292 .await?
293 {
294 Ok(Some((value, ApiKeyOrigin::Stored)))
295 } else {
296 Ok(None)
297 }
298 } else if provider_id.as_str() == self::provider::XAI_ID {
299 if let Some(key) = self
300 .env_provider
301 .var("XAI_API_KEY")
302 .or_else(|| self.env_provider.var("GROK_API_KEY"))
303 {
304 Ok(Some((key, ApiKeyOrigin::Env)))
305 } else if let Some(crate::auth::Credential::ApiKey { value }) = self
306 .storage
307 .get_credential(
308 &provider_id.storage_key(),
309 crate::auth::CredentialType::ApiKey,
310 )
311 .await?
312 {
313 Ok(Some((value, ApiKeyOrigin::Stored)))
314 } else {
315 Ok(None)
316 }
317 } else if let Some(crate::auth::Credential::ApiKey { value }) = self
318 .storage
319 .get_credential(
320 &provider_id.storage_key(),
321 crate::auth::CredentialType::ApiKey,
322 )
323 .await?
324 {
325 Ok(Some((value, ApiKeyOrigin::Stored)))
326 } else {
327 Ok(None)
328 }
329 }
330
331 pub fn auth_storage(&self) -> &Arc<dyn AuthStorage> {
333 &self.storage
334 }
335
336 pub fn plugin_registry(&self) -> &Arc<AuthPluginRegistry> {
337 &self.plugin_registry
338 }
339}
340
341pub mod model;
342pub mod provider;
343pub mod toml_types;
344
345trait EnvProvider: Send + Sync {
346 fn var(&self, key: &str) -> Option<String>;
347}
348
349#[derive(Clone)]
350struct StdEnvProvider;
351
352impl EnvProvider for StdEnvProvider {
353 fn var(&self, key: &str) -> Option<String> {
354 std::env::var(key).ok()
355 }
356}
357
358#[cfg(test)]
359mod tests {
360 use super::*;
361 use crate::auth::AuthTokens;
362 use crate::test_utils::InMemoryAuthStorage;
363 use std::collections::HashMap;
364 use std::sync::Arc;
365 use std::time::{Duration, SystemTime};
366
367 #[derive(Clone, Default)]
368 struct TestEnvProvider {
369 vars: HashMap<String, String>,
370 }
371
372 impl EnvProvider for TestEnvProvider {
373 fn var(&self, key: &str) -> Option<String> {
374 self.vars.get(key).cloned()
375 }
376 }
377
378 #[tokio::test]
379 async fn openai_oauth_takes_precedence() {
380 let storage = Arc::new(InMemoryAuthStorage::new());
381 storage
382 .set_credential(
383 "openai",
384 crate::auth::Credential::ApiKey {
385 value: "stored-key".to_string(),
386 },
387 )
388 .await
389 .unwrap();
390 storage
391 .set_credential(
392 "openai",
393 crate::auth::Credential::OAuth2(AuthTokens {
394 access_token: "token".to_string(),
395 refresh_token: "refresh".to_string(),
396 expires_at: SystemTime::now() + Duration::from_secs(3600),
397 id_token: Some("id-token".to_string()),
398 }),
399 )
400 .await
401 .unwrap();
402
403 let mut env = TestEnvProvider::default();
404 env.vars
405 .insert("OPENAI_API_KEY".to_string(), "env-key".to_string());
406 let provider = LlmConfigProvider::with_env_provider(storage, Arc::new(env));
407 let auth = provider
408 .get_auth_for_provider(&provider::openai())
409 .await
410 .unwrap();
411
412 assert!(matches!(auth, Some(ApiAuth::OAuth)));
413 }
414
415 #[tokio::test]
416 async fn openai_env_takes_precedence_over_stored_key() {
417 let storage = Arc::new(InMemoryAuthStorage::new());
418 storage
419 .set_credential(
420 "openai",
421 crate::auth::Credential::ApiKey {
422 value: "stored-key".to_string(),
423 },
424 )
425 .await
426 .unwrap();
427
428 let mut env = TestEnvProvider::default();
429 env.vars
430 .insert("OPENAI_API_KEY".to_string(), "env-key".to_string());
431 let provider = LlmConfigProvider::with_env_provider(storage, Arc::new(env));
432 let auth = provider
433 .get_auth_for_provider(&provider::openai())
434 .await
435 .unwrap();
436
437 match auth {
438 Some(ApiAuth::Key(key)) => assert_eq!(key, "env-key"),
439 _ => panic!("Expected env API key"),
440 }
441 }
442}