steer_core/config/
mod.rs

1use crate::auth::AuthStorage;
2use crate::config::provider::ProviderId;
3use crate::error::{Error, Result};
4use serde::{Deserialize, Serialize};
5use std::fs;
6use std::path::PathBuf;
7use std::sync::Arc;
8
9#[derive(Debug, Serialize, Deserialize, Default)]
10pub struct Config {
11    pub model: Option<String>,
12    pub history_size: Option<usize>,
13    pub system_prompt: Option<String>,
14    #[serde(skip_serializing_if = "Option::is_none")]
15    pub notifications: Option<NotificationSettings>,
16}
17
18#[derive(Debug, Serialize, Deserialize, Clone)]
19pub struct NotificationSettings {
20    pub enable_sound: Option<bool>,
21    pub enable_desktop: Option<bool>,
22}
23
24impl Default for NotificationSettings {
25    fn default() -> Self {
26        Self {
27            enable_sound: Some(true),
28            enable_desktop: Some(true),
29        }
30    }
31}
32
33impl Config {
34    fn new() -> Self {
35        Self {
36            model: Some(crate::config::model::builtin::opus().1),
37            history_size: Some(10),
38            system_prompt: None,
39            notifications: Some(NotificationSettings::default()),
40        }
41    }
42}
43
44/// Get the path to the config file
45pub fn get_config_path() -> Result<PathBuf> {
46    let config_dir = dirs::config_dir()
47        .ok_or_else(|| Error::Configuration("Could not find config directory".to_string()))?
48        .join("steer");
49
50    fs::create_dir_all(&config_dir)
51        .map_err(|e| Error::Configuration(format!("Failed to create config directory: {e}")))?;
52
53    Ok(config_dir.join("config.json"))
54}
55
56/// Load the configuration
57pub fn load_config() -> Result<Config> {
58    let config_path = get_config_path()?;
59
60    if !config_path.exists() {
61        return Ok(Config::new());
62    }
63
64    let config_str = fs::read_to_string(&config_path)
65        .map_err(|e| Error::Configuration(format!("Failed to read config file: {e}")))?;
66
67    let config: Config = serde_json::from_str(&config_str)
68        .map_err(|e| Error::Configuration(format!("Failed to parse config file: {e}")))?;
69
70    Ok(config)
71}
72
73/// Initialize or update the configuration
74pub fn init_config(force: bool) -> Result<()> {
75    let config_path = get_config_path()?;
76
77    if config_path.exists() && !force {
78        return Err(Error::Configuration(
79            "Config file already exists. Use --force to overwrite.".to_string(),
80        ));
81    }
82
83    let config = Config::new();
84    let config_json = serde_json::to_string_pretty(&config)
85        .map_err(|e| Error::Configuration(format!("Failed to serialize config: {e}")))?;
86
87    fs::write(&config_path, config_json)
88        .map_err(|e| Error::Configuration(format!("Failed to write config file: {e}")))?;
89
90    Ok(())
91}
92
93/// Save the configuration
94pub fn save_config(config: &Config) -> Result<()> {
95    let config_path = get_config_path()?;
96    let config_json = serde_json::to_string_pretty(&config)
97        .map_err(|e| Error::Configuration(format!("Failed to serialize config: {e}")))?;
98
99    fs::write(&config_path, config_json)
100        .map_err(|e| Error::Configuration(format!("Failed to write config file: {e}")))?;
101
102    Ok(())
103}
104
105#[derive(Debug, Clone)]
106pub enum ApiAuth {
107    Key(String),
108    OAuth,
109}
110
111/// Provider for authentication credentials
112#[derive(Clone)]
113pub struct LlmConfigProvider {
114    storage: Arc<dyn AuthStorage>,
115}
116
117impl std::fmt::Debug for LlmConfigProvider {
118    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119        f.debug_struct("LlmConfigProvider").finish_non_exhaustive()
120    }
121}
122
123impl LlmConfigProvider {
124    /// Create a new LlmConfigProvider with the given auth storage
125    pub fn new(storage: Arc<dyn AuthStorage>) -> Self {
126        Self { storage }
127    }
128
129    /// Get authentication for a specific provider ID
130    pub async fn get_auth_for_provider(&self, provider_id: &ProviderId) -> Result<Option<ApiAuth>> {
131        if provider_id.as_str() == self::provider::ANTHROPIC_ID {
132            // API key via env var > OAuth > stored API key
133            let anthropic_key = std::env::var("CLAUDE_API_KEY")
134                .ok()
135                .or_else(|| std::env::var("ANTHROPIC_API_KEY").ok());
136            if let Some(key) = anthropic_key {
137                Ok(Some(ApiAuth::Key(key)))
138            } else if self
139                .storage
140                .get_credential(
141                    &provider_id.storage_key(),
142                    crate::auth::CredentialType::OAuth2,
143                )
144                .await?
145                .is_some()
146            {
147                Ok(Some(ApiAuth::OAuth))
148            } else {
149                // Fall back to stored API key in keyring
150                if let Some(crate::auth::Credential::ApiKey { value }) = self
151                    .storage
152                    .get_credential(
153                        &provider_id.storage_key(),
154                        crate::auth::CredentialType::ApiKey,
155                    )
156                    .await?
157                {
158                    Ok(Some(ApiAuth::Key(value)))
159                } else {
160                    Ok(None)
161                }
162            }
163        } else if provider_id.as_str() == self::provider::OPENAI_ID {
164            // API key via env var > stored API key
165            if let Ok(key) = std::env::var("OPENAI_API_KEY") {
166                Ok(Some(ApiAuth::Key(key)))
167            } else if let Some(crate::auth::Credential::ApiKey { value }) = self
168                .storage
169                .get_credential(
170                    &provider_id.storage_key(),
171                    crate::auth::CredentialType::ApiKey,
172                )
173                .await?
174            {
175                Ok(Some(ApiAuth::Key(value)))
176            } else {
177                Ok(None)
178            }
179        } else if provider_id.as_str() == self::provider::GOOGLE_ID {
180            // API key via env var > stored API key
181            if let Ok(key) =
182                std::env::var("GEMINI_API_KEY").or_else(|_| std::env::var("GOOGLE_API_KEY"))
183            {
184                Ok(Some(ApiAuth::Key(key)))
185            } else if let Some(crate::auth::Credential::ApiKey { value }) = self
186                .storage
187                .get_credential(
188                    &provider_id.storage_key(),
189                    crate::auth::CredentialType::ApiKey,
190                )
191                .await?
192            {
193                Ok(Some(ApiAuth::Key(value)))
194            } else {
195                Ok(None)
196            }
197        } else if provider_id.as_str() == self::provider::XAI_ID {
198            // API key via env var > stored API key
199            if let Ok(key) = std::env::var("XAI_API_KEY").or_else(|_| std::env::var("GROK_API_KEY"))
200            {
201                Ok(Some(ApiAuth::Key(key)))
202            } else if let Some(crate::auth::Credential::ApiKey { value }) = self
203                .storage
204                .get_credential(
205                    &provider_id.storage_key(),
206                    crate::auth::CredentialType::ApiKey,
207                )
208                .await?
209            {
210                Ok(Some(ApiAuth::Key(value)))
211            } else {
212                Ok(None)
213            }
214        } else {
215            // Custom providers - check for stored API key
216            if let Some(crate::auth::Credential::ApiKey { value }) = self
217                .storage
218                .get_credential(
219                    &provider_id.storage_key(),
220                    crate::auth::CredentialType::ApiKey,
221                )
222                .await?
223            {
224                Ok(Some(ApiAuth::Key(value)))
225            } else {
226                Ok(None)
227            }
228        }
229    }
230
231    /// Get the auth storage
232    pub fn auth_storage(&self) -> &Arc<dyn AuthStorage> {
233        &self.storage
234    }
235}
236
237pub mod model;
238pub mod provider;
239pub mod toml_types;