1use crate::auth::AuthStorage;
2use crate::config::provider::ProviderId;
3use crate::error::{Error, Result};
4use serde::{Deserialize, Serialize};
5use std::fs;
6use std::path::PathBuf;
7use std::sync::Arc;
8
9#[derive(Debug, Serialize, Deserialize, Default)]
10pub struct Config {
11 pub model: Option<String>,
12 pub history_size: Option<usize>,
13 pub system_prompt: Option<String>,
14 #[serde(skip_serializing_if = "Option::is_none")]
15 pub notifications: Option<NotificationSettings>,
16}
17
18#[derive(Debug, Serialize, Deserialize, Clone)]
19pub struct NotificationSettings {
20 pub enable_sound: Option<bool>,
21 pub enable_desktop: Option<bool>,
22}
23
24impl Default for NotificationSettings {
25 fn default() -> Self {
26 Self {
27 enable_sound: Some(true),
28 enable_desktop: Some(true),
29 }
30 }
31}
32
33impl Config {
34 fn new() -> Self {
35 Self {
36 model: Some(crate::config::model::builtin::opus().1),
37 history_size: Some(10),
38 system_prompt: None,
39 notifications: Some(NotificationSettings::default()),
40 }
41 }
42}
43
44pub fn get_config_path() -> Result<PathBuf> {
46 let config_dir = dirs::config_dir()
47 .ok_or_else(|| Error::Configuration("Could not find config directory".to_string()))?
48 .join("steer");
49
50 fs::create_dir_all(&config_dir)
51 .map_err(|e| Error::Configuration(format!("Failed to create config directory: {e}")))?;
52
53 Ok(config_dir.join("config.json"))
54}
55
56pub fn load_config() -> Result<Config> {
58 let config_path = get_config_path()?;
59
60 if !config_path.exists() {
61 return Ok(Config::new());
62 }
63
64 let config_str = fs::read_to_string(&config_path)
65 .map_err(|e| Error::Configuration(format!("Failed to read config file: {e}")))?;
66
67 let config: Config = serde_json::from_str(&config_str)
68 .map_err(|e| Error::Configuration(format!("Failed to parse config file: {e}")))?;
69
70 Ok(config)
71}
72
73pub fn init_config(force: bool) -> Result<()> {
75 let config_path = get_config_path()?;
76
77 if config_path.exists() && !force {
78 return Err(Error::Configuration(
79 "Config file already exists. Use --force to overwrite.".to_string(),
80 ));
81 }
82
83 let config = Config::new();
84 let config_json = serde_json::to_string_pretty(&config)
85 .map_err(|e| Error::Configuration(format!("Failed to serialize config: {e}")))?;
86
87 fs::write(&config_path, config_json)
88 .map_err(|e| Error::Configuration(format!("Failed to write config file: {e}")))?;
89
90 Ok(())
91}
92
93pub fn save_config(config: &Config) -> Result<()> {
95 let config_path = get_config_path()?;
96 let config_json = serde_json::to_string_pretty(&config)
97 .map_err(|e| Error::Configuration(format!("Failed to serialize config: {e}")))?;
98
99 fs::write(&config_path, config_json)
100 .map_err(|e| Error::Configuration(format!("Failed to write config file: {e}")))?;
101
102 Ok(())
103}
104
105#[derive(Debug, Clone)]
106pub enum ApiAuth {
107 Key(String),
108 OAuth,
109}
110
111#[derive(Clone)]
113pub struct LlmConfigProvider {
114 storage: Arc<dyn AuthStorage>,
115}
116
117impl std::fmt::Debug for LlmConfigProvider {
118 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119 f.debug_struct("LlmConfigProvider").finish_non_exhaustive()
120 }
121}
122
123impl LlmConfigProvider {
124 pub fn new(storage: Arc<dyn AuthStorage>) -> Self {
126 Self { storage }
127 }
128
129 pub async fn get_auth_for_provider(&self, provider_id: &ProviderId) -> Result<Option<ApiAuth>> {
131 if provider_id.as_str() == self::provider::ANTHROPIC_ID {
132 let anthropic_key = std::env::var("CLAUDE_API_KEY")
134 .ok()
135 .or_else(|| std::env::var("ANTHROPIC_API_KEY").ok());
136 if let Some(key) = anthropic_key {
137 Ok(Some(ApiAuth::Key(key)))
138 } else if self
139 .storage
140 .get_credential(
141 &provider_id.storage_key(),
142 crate::auth::CredentialType::OAuth2,
143 )
144 .await?
145 .is_some()
146 {
147 Ok(Some(ApiAuth::OAuth))
148 } else {
149 if let Some(crate::auth::Credential::ApiKey { value }) = self
151 .storage
152 .get_credential(
153 &provider_id.storage_key(),
154 crate::auth::CredentialType::ApiKey,
155 )
156 .await?
157 {
158 Ok(Some(ApiAuth::Key(value)))
159 } else {
160 Ok(None)
161 }
162 }
163 } else if provider_id.as_str() == self::provider::OPENAI_ID {
164 if let Ok(key) = std::env::var("OPENAI_API_KEY") {
166 Ok(Some(ApiAuth::Key(key)))
167 } else if let Some(crate::auth::Credential::ApiKey { value }) = self
168 .storage
169 .get_credential(
170 &provider_id.storage_key(),
171 crate::auth::CredentialType::ApiKey,
172 )
173 .await?
174 {
175 Ok(Some(ApiAuth::Key(value)))
176 } else {
177 Ok(None)
178 }
179 } else if provider_id.as_str() == self::provider::GOOGLE_ID {
180 if let Ok(key) =
182 std::env::var("GEMINI_API_KEY").or_else(|_| std::env::var("GOOGLE_API_KEY"))
183 {
184 Ok(Some(ApiAuth::Key(key)))
185 } else if let Some(crate::auth::Credential::ApiKey { value }) = self
186 .storage
187 .get_credential(
188 &provider_id.storage_key(),
189 crate::auth::CredentialType::ApiKey,
190 )
191 .await?
192 {
193 Ok(Some(ApiAuth::Key(value)))
194 } else {
195 Ok(None)
196 }
197 } else if provider_id.as_str() == self::provider::XAI_ID {
198 if let Ok(key) = std::env::var("XAI_API_KEY").or_else(|_| std::env::var("GROK_API_KEY"))
200 {
201 Ok(Some(ApiAuth::Key(key)))
202 } else if let Some(crate::auth::Credential::ApiKey { value }) = self
203 .storage
204 .get_credential(
205 &provider_id.storage_key(),
206 crate::auth::CredentialType::ApiKey,
207 )
208 .await?
209 {
210 Ok(Some(ApiAuth::Key(value)))
211 } else {
212 Ok(None)
213 }
214 } else {
215 if let Some(crate::auth::Credential::ApiKey { value }) = self
217 .storage
218 .get_credential(
219 &provider_id.storage_key(),
220 crate::auth::CredentialType::ApiKey,
221 )
222 .await?
223 {
224 Ok(Some(ApiAuth::Key(value)))
225 } else {
226 Ok(None)
227 }
228 }
229 }
230
231 pub fn auth_storage(&self) -> &Arc<dyn AuthStorage> {
233 &self.storage
234 }
235}
236
237pub mod model;
238pub mod provider;
239pub mod toml_types;