1use crate::api::ProviderKind;
2use crate::auth::AuthStorage;
3use crate::error::{Error, Result};
4use serde::{Deserialize, Serialize};
5use std::fs;
6use std::path::PathBuf;
7use std::sync::Arc;
8
9#[derive(Debug, Serialize, Deserialize, Default)]
10pub struct Config {
11 pub model: Option<String>,
12 pub history_size: Option<usize>,
13 pub system_prompt: Option<String>,
14 #[serde(skip_serializing_if = "Option::is_none")]
15 pub notifications: Option<NotificationSettings>,
16}
17
18#[derive(Debug, Serialize, Deserialize, Clone)]
19pub struct NotificationSettings {
20 pub enable_sound: Option<bool>,
21 pub enable_desktop: Option<bool>,
22}
23
24impl Default for NotificationSettings {
25 fn default() -> Self {
26 Self {
27 enable_sound: Some(true),
28 enable_desktop: Some(true),
29 }
30 }
31}
32
33impl Config {
34 fn new() -> Self {
35 Self {
36 model: Some("claude-3-7-sonnet-20250219".to_string()),
37 history_size: Some(10),
38 system_prompt: None,
39 notifications: Some(NotificationSettings::default()),
40 }
41 }
42}
43
44pub fn get_config_path() -> Result<PathBuf> {
46 let config_dir = dirs::config_dir()
47 .ok_or_else(|| Error::Configuration("Could not find config directory".to_string()))?
48 .join("steer");
49
50 fs::create_dir_all(&config_dir)
51 .map_err(|e| Error::Configuration(format!("Failed to create config directory: {e}")))?;
52
53 Ok(config_dir.join("config.json"))
54}
55
56pub fn load_config() -> Result<Config> {
58 let config_path = get_config_path()?;
59
60 if !config_path.exists() {
61 return Ok(Config::new());
62 }
63
64 let config_str = fs::read_to_string(&config_path)
65 .map_err(|e| Error::Configuration(format!("Failed to read config file: {e}")))?;
66
67 let config: Config = serde_json::from_str(&config_str)
68 .map_err(|e| Error::Configuration(format!("Failed to parse config file: {e}")))?;
69
70 Ok(config)
71}
72
73pub fn init_config(force: bool) -> Result<()> {
75 let config_path = get_config_path()?;
76
77 if config_path.exists() && !force {
78 return Err(Error::Configuration(
79 "Config file already exists. Use --force to overwrite.".to_string(),
80 ));
81 }
82
83 let config = Config::new();
84 let config_json = serde_json::to_string_pretty(&config)
85 .map_err(|e| Error::Configuration(format!("Failed to serialize config: {e}")))?;
86
87 fs::write(&config_path, config_json)
88 .map_err(|e| Error::Configuration(format!("Failed to write config file: {e}")))?;
89
90 Ok(())
91}
92
93pub fn save_config(config: &Config) -> Result<()> {
95 let config_path = get_config_path()?;
96 let config_json = serde_json::to_string_pretty(&config)
97 .map_err(|e| Error::Configuration(format!("Failed to serialize config: {e}")))?;
98
99 fs::write(&config_path, config_json)
100 .map_err(|e| Error::Configuration(format!("Failed to write config file: {e}")))?;
101
102 Ok(())
103}
104
105#[derive(Debug, Clone)]
106pub enum ApiAuth {
107 Key(String),
108 OAuth,
109}
110
111#[derive(Clone)]
113pub struct LlmConfigProvider {
114 storage: Arc<dyn AuthStorage>,
115}
116
117impl std::fmt::Debug for LlmConfigProvider {
118 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119 f.debug_struct("LlmConfigProvider").finish_non_exhaustive()
120 }
121}
122
123impl LlmConfigProvider {
124 pub fn new(storage: Arc<dyn AuthStorage>) -> Self {
126 Self { storage }
127 }
128
129 pub async fn get_auth_for_provider(&self, provider: ProviderKind) -> Result<Option<ApiAuth>> {
131 match provider {
133 ProviderKind::Anthropic => {
134 let anthropic_key = std::env::var("CLAUDE_API_KEY")
136 .ok()
137 .or_else(|| std::env::var("ANTHROPIC_API_KEY").ok());
138 if let Some(key) = anthropic_key {
139 Ok(Some(ApiAuth::Key(key)))
140 } else if self
141 .storage
142 .get_credential("anthropic", crate::auth::CredentialType::AuthTokens)
143 .await?
144 .is_some()
145 {
146 Ok(Some(ApiAuth::OAuth))
147 } else {
148 {
149 if let Some(crate::auth::Credential::ApiKey { value }) = self
151 .storage
152 .get_credential("anthropic", crate::auth::CredentialType::ApiKey)
153 .await?
154 {
155 Ok(Some(ApiAuth::Key(value)))
156 } else {
157 Ok(None)
158 }
159 }
160 }
161 }
162 ProviderKind::OpenAI => {
163 if let Ok(key) = std::env::var("OPENAI_API_KEY") {
165 Ok(Some(ApiAuth::Key(key)))
166 } else if let Some(crate::auth::Credential::ApiKey { value }) = self
167 .storage
168 .get_credential("openai", crate::auth::CredentialType::ApiKey)
169 .await?
170 {
171 Ok(Some(ApiAuth::Key(value)))
172 } else {
173 Ok(None)
174 }
175 }
176 ProviderKind::Google => {
177 if let Ok(key) =
179 std::env::var("GEMINI_API_KEY").or_else(|_| std::env::var("GOOGLE_API_KEY"))
180 {
181 Ok(Some(ApiAuth::Key(key)))
182 } else if let Some(crate::auth::Credential::ApiKey { value }) = self
183 .storage
184 .get_credential("google", crate::auth::CredentialType::ApiKey)
185 .await?
186 {
187 Ok(Some(ApiAuth::Key(value)))
188 } else {
189 Ok(None)
190 }
191 }
192 ProviderKind::XAI => {
193 if let Ok(key) =
195 std::env::var("XAI_API_KEY").or_else(|_| std::env::var("GROK_API_KEY"))
196 {
197 Ok(Some(ApiAuth::Key(key)))
198 } else if let Some(crate::auth::Credential::ApiKey { value }) = self
199 .storage
200 .get_credential("xai", crate::auth::CredentialType::ApiKey)
201 .await?
202 {
203 Ok(Some(ApiAuth::Key(value)))
204 } else {
205 Ok(None)
206 }
207 }
208 }
209 }
210
211 pub fn auth_storage(&self) -> &Arc<dyn AuthStorage> {
213 &self.storage
214 }
215
216 pub async fn available_providers(&self) -> Result<Vec<ProviderKind>> {
218 let mut providers = Vec::new();
219 if self
220 .get_auth_for_provider(ProviderKind::Anthropic)
221 .await?
222 .is_some()
223 {
224 providers.push(ProviderKind::Anthropic);
225 }
226 if self
227 .get_auth_for_provider(ProviderKind::OpenAI)
228 .await?
229 .is_some()
230 {
231 providers.push(ProviderKind::OpenAI);
232 }
233 if self
234 .get_auth_for_provider(ProviderKind::Google)
235 .await?
236 .is_some()
237 {
238 providers.push(ProviderKind::Google);
239 }
240 if self
241 .get_auth_for_provider(ProviderKind::XAI)
242 .await?
243 .is_some()
244 {
245 providers.push(ProviderKind::XAI);
246 }
247 Ok(providers)
248 }
249}
250
251#[cfg(test)]
252mod tests {
253 use super::*;
254 use crate::test_utils::InMemoryAuthStorage;
255
256 #[tokio::test]
257 async fn test_auth_changes_immediately_reflected() {
258 let storage = Arc::new(InMemoryAuthStorage::new());
260 let provider = LlmConfigProvider::new(storage.clone());
261
262 let auth = provider
264 .get_auth_for_provider(ProviderKind::Anthropic)
265 .await
266 .unwrap();
267 assert!(auth.is_none());
268
269 storage
271 .set_credential(
272 "anthropic",
273 crate::auth::Credential::ApiKey {
274 value: "test-key".to_string(),
275 },
276 )
277 .await
278 .unwrap();
279
280 let auth = provider
282 .get_auth_for_provider(ProviderKind::Anthropic)
283 .await
284 .unwrap();
285 assert!(matches!(auth, Some(ApiAuth::Key(key)) if key == "test-key"));
286
287 storage
289 .set_credential(
290 "anthropic",
291 crate::auth::Credential::AuthTokens(crate::auth::storage::AuthTokens {
292 access_token: "access".to_string(),
293 refresh_token: "refresh".to_string(),
294 expires_at: std::time::SystemTime::now() + std::time::Duration::from_secs(3600),
295 }),
296 )
297 .await
298 .unwrap();
299
300 let auth = provider
302 .get_auth_for_provider(ProviderKind::Anthropic)
303 .await
304 .unwrap();
305 assert!(matches!(auth, Some(ApiAuth::OAuth)));
306
307 storage
309 .remove_credential("anthropic", crate::auth::CredentialType::AuthTokens)
310 .await
311 .unwrap();
312
313 let auth = provider
315 .get_auth_for_provider(ProviderKind::Anthropic)
316 .await
317 .unwrap();
318 assert!(matches!(auth, Some(ApiAuth::Key(key)) if key == "test-key"));
319 }
320
321 #[tokio::test]
322 async fn test_available_providers_updates_immediately() {
323 let storage = Arc::new(InMemoryAuthStorage::new());
324 let provider = LlmConfigProvider::new(storage.clone());
325
326 let providers = provider.available_providers().await.unwrap();
328 assert!(providers.is_empty());
329
330 storage
332 .set_credential(
333 "anthropic",
334 crate::auth::Credential::ApiKey {
335 value: "test-key".to_string(),
336 },
337 )
338 .await
339 .unwrap();
340
341 let providers = provider.available_providers().await.unwrap();
343 assert_eq!(providers, vec![ProviderKind::Anthropic]);
344
345 storage
347 .set_credential(
348 "openai",
349 crate::auth::Credential::ApiKey {
350 value: "openai-key".to_string(),
351 },
352 )
353 .await
354 .unwrap();
355
356 let providers = provider.available_providers().await.unwrap();
358 assert_eq!(
359 providers,
360 vec![ProviderKind::Anthropic, ProviderKind::OpenAI]
361 );
362 }
363}