trae_agent_rs_core/config/
loader.rs1use super::{ApiProvider, ApiProviderConfig, ConfigCache};
4use crate::error::{ConfigError, Result};
5use serde_json;
6use std::collections::HashMap;
7use std::path::{Path, PathBuf};
8use tokio::fs;
9use tracing::{debug, info, warn};
10
11pub struct ConfigLoader {
13 config_dir: PathBuf,
15
16 cache: ConfigCache,
18
19 cache_path: PathBuf,
21}
22
23impl ConfigLoader {
24 pub fn new<P: AsRef<Path>>(config_dir: P) -> Self {
26 let config_dir = config_dir.as_ref().to_path_buf();
27 let cache_path = ConfigCache::default_cache_path();
28
29 Self {
30 config_dir,
31 cache: ConfigCache::new(),
32 cache_path,
33 }
34 }
35
36 pub async fn init(&mut self) -> Result<()> {
38 self.cache = ConfigCache::load(&self.cache_path).await?;
39 debug!("Loaded configuration cache from: {}", self.cache_path.display());
40 Ok(())
41 }
42
43 pub async fn discover_configs(&self) -> Result<HashMap<ApiProvider, ApiProviderConfig>> {
45 let mut configs = HashMap::new();
46
47 let json_configs = self.load_json_configs().await?;
49 configs.extend(json_configs);
50
51 let env_configs = self.load_env_configs(&configs).await?;
53 configs.extend(env_configs);
54
55 debug!("Discovered {} provider configurations", configs.len());
56 Ok(configs)
57 }
58
59 async fn load_json_configs(&self) -> Result<HashMap<ApiProvider, ApiProviderConfig>> {
61 let mut configs = HashMap::new();
62
63 let providers = [
64 ApiProvider::OpenAI,
65 ApiProvider::Anthropic,
66 ApiProvider::Google,
67 ];
68
69 for provider in providers {
70 let config_path = self.config_dir.join(provider.config_filename());
71
72 if config_path.exists() {
73 match self.load_json_config(&config_path).await {
74 Ok(config) => {
75 info!("Loaded {} configuration from: {}", provider, config_path.display());
76 configs.insert(provider, config);
77 }
78 Err(e) => {
79 warn!("Failed to load {} configuration from {}: {}",
80 provider, config_path.display(), e);
81 }
82 }
83 }
84 }
85
86 Ok(configs)
87 }
88
89 async fn load_json_config(&self, path: &Path) -> Result<ApiProviderConfig> {
91 let content = fs::read_to_string(path).await?;
92 let config: ApiProviderConfig = serde_json::from_str(&content)
93 .map_err(|e| ConfigError::InvalidFormat)?;
94 Ok(config)
95 }
96
97 async fn load_env_configs(&self, existing_configs: &HashMap<ApiProvider, ApiProviderConfig>)
99 -> Result<HashMap<ApiProvider, ApiProviderConfig>> {
100 let mut configs = HashMap::new();
101
102 let providers = [
103 ApiProvider::OpenAI,
104 ApiProvider::Anthropic,
105 ApiProvider::Google,
106 ];
107
108 for provider in providers {
109 if existing_configs.contains_key(&provider) {
111 continue;
112 }
113
114 if let Some(config) = self.load_env_config(&provider) {
115 info!("Loaded {} configuration from environment variables", provider);
116 configs.insert(provider, config);
117 }
118 }
119
120 Ok(configs)
121 }
122
123 fn load_env_config(&self, provider: &ApiProvider) -> Option<ApiProviderConfig> {
125 let prefix = provider.env_prefix();
126
127 let base_url = std::env::var(format!("{}_BASE_URL", prefix)).ok();
128 let api_key = std::env::var(format!("{}_API_KEY", prefix)).ok();
129 let model = std::env::var(format!("{}_MODEL", prefix)).ok();
130
131 if base_url.is_some() || api_key.is_some() || model.is_some() {
133 Some(ApiProviderConfig {
134 base_url,
135 api_key,
136 model,
137 extra: HashMap::new(),
138 })
139 } else {
140 None
141 }
142 }
143
144 pub async fn select_config(&mut self, configs: HashMap<ApiProvider, ApiProviderConfig>)
146 -> Result<(ApiProvider, ApiProviderConfig)> {
147
148 if configs.is_empty() {
149 return Err(ConfigError::NoConfigFound.into());
150 }
151
152 if configs.len() == 1 {
154 let (provider, config) = configs.into_iter().next().unwrap();
155 info!("Using single available configuration: {}", provider);
156 return Ok((provider, config));
157 }
158
159 if let Some(cached_provider) = self.cache.get_selected_provider() {
161 if !self.cache.is_expired() {
162 if let Ok(provider) = cached_provider.parse::<ApiProvider>() {
163 if let Some(config) = configs.get(&provider) {
164 info!("Using cached provider selection: {}", provider);
165 return Ok((provider, config.clone()));
166 }
167 }
168 }
169 }
170
171 self.prompt_user_selection(configs).await
173 }
174
175 async fn prompt_user_selection(&mut self, configs: HashMap<ApiProvider, ApiProviderConfig>)
177 -> Result<(ApiProvider, ApiProviderConfig)> {
178
179 println!("Multiple API provider configurations found:");
180
181 let providers: Vec<_> = configs.keys().collect();
182 for (i, provider) in providers.iter().enumerate() {
183 let config = configs.get(provider).unwrap();
184 let source = if config.api_key.is_some() { "configured" } else { "env vars" };
185 println!(" {}. {} ({})", i + 1, provider, source);
186 }
187
188 println!("Please select a provider (1-{}): ", providers.len());
189
190 let mut input = String::new();
192 std::io::stdin().read_line(&mut input)
193 .map_err(|_| ConfigError::InvalidValue {
194 field: "user_selection".to_string(),
195 value: "failed to read input".to_string(),
196 })?;
197
198 let selection: usize = input.trim().parse()
199 .map_err(|_| ConfigError::InvalidValue {
200 field: "user_selection".to_string(),
201 value: input.trim().to_string(),
202 })?;
203
204 if selection == 0 || selection > providers.len() {
205 return Err(ConfigError::InvalidValue {
206 field: "user_selection".to_string(),
207 value: selection.to_string(),
208 }.into());
209 }
210
211 let selected_provider = providers[selection - 1].clone();
212 let selected_config = configs.get(&selected_provider).unwrap().clone();
213
214 self.cache.set_selected_provider(selected_provider.to_string());
216 if let Err(e) = self.cache.save(&self.cache_path).await {
217 warn!("Failed to save configuration cache: {}", e);
218 }
219
220 info!("Selected provider: {}", selected_provider);
221 Ok((selected_provider, selected_config))
222 }
223
224 pub async fn load_config(&mut self) -> Result<(ApiProvider, ApiProviderConfig)> {
226 self.init().await?;
227 let configs = self.discover_configs().await?;
228 self.select_config(configs).await
229 }
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235 use tempfile::tempdir;
236
237 #[tokio::test]
238 async fn test_load_json_config() {
239 let temp_dir = tempdir().unwrap();
240 let config_path = temp_dir.path().join("openai.json");
241
242 let config_content = r#"{
243 "base_url": "https://api.openai.com/v1",
244 "api_key": "test-key",
245 "model": "gpt-4"
246 }"#;
247
248 fs::write(&config_path, config_content).await.unwrap();
249
250 let loader = ConfigLoader::new(temp_dir.path());
251 let config = loader.load_json_config(&config_path).await.unwrap();
252
253 assert_eq!(config.base_url, Some("https://api.openai.com/v1".to_string()));
254 assert_eq!(config.api_key, Some("test-key".to_string()));
255 assert_eq!(config.model, Some("gpt-4".to_string()));
256 }
257
258 #[test]
259 fn test_load_env_config() {
260 std::env::set_var("OPENAI_API_KEY", "test-env-key");
261 std::env::set_var("OPENAI_MODEL", "gpt-3.5-turbo");
262
263 let loader = ConfigLoader::new(".");
264 let config = loader.load_env_config(&ApiProvider::OpenAI).unwrap();
265
266 assert_eq!(config.api_key, Some("test-env-key".to_string()));
267 assert_eq!(config.model, Some("gpt-3.5-turbo".to_string()));
268
269 std::env::remove_var("OPENAI_API_KEY");
271 std::env::remove_var("OPENAI_MODEL");
272 }
273}
274
275
276 #[tokio::test]
277 async fn test_discover_configs_prefers_json_over_env() {
278 std::env::set_var("OPENAI_API_KEY", "env-key");
280 std::env::set_var("OPENAI_MODEL", "env-model");
281
282 let temp_dir = tempfile::tempdir().unwrap();
283 let openai_json = temp_dir.path().join("openai.json");
284 let content = r#"{
285 "base_url": "https://api.openai.com/v1",
286 "api_key": "json-key",
287 "model": "gpt-4"
288 }"#;
289 fs::write(&openai_json, content).await.unwrap();
290
291 let loader = ConfigLoader::new(temp_dir.path());
292 let configs = loader.discover_configs().await.unwrap();
293
294 let openai = configs.get(&ApiProvider::OpenAI).expect("openai config missing");
295 assert_eq!(openai.api_key.as_deref(), Some("json-key"));
296 assert_eq!(openai.model.as_deref(), Some("gpt-4"));
297
298 std::env::remove_var("OPENAI_API_KEY");
300 std::env::remove_var("OPENAI_MODEL");
301 }
302
303 #[tokio::test]
304 async fn test_select_config_single_and_cache() {
305 let temp_dir = tempfile::tempdir().unwrap();
306 let mut loader = ConfigLoader::new(temp_dir.path());
307
308 let mut single = HashMap::new();
310 single.insert(
311 ApiProvider::OpenAI,
312 ApiProviderConfig { base_url: Some("https://api.openai.com/v1".into()), api_key: Some("k".into()), model: Some("gpt-4".into()), extra: HashMap::new() }
313 );
314 let (prov, _cfg) = loader.select_config(single).await.unwrap();
315 assert!(matches!(prov, ApiProvider::OpenAI));
316
317 let mut many = HashMap::new();
319 many.insert(
320 ApiProvider::OpenAI,
321 ApiProviderConfig { base_url: None, api_key: Some("k1".into()), model: Some("m1".into()), extra: HashMap::new() }
322 );
323 many.insert(
324 ApiProvider::Anthropic,
325 ApiProviderConfig { base_url: None, api_key: Some("k2".into()), model: Some("m2".into()), extra: HashMap::new() }
326 );
327
328 loader.cache.set_selected_provider("openai".into());
330 let (prov2, _cfg2) = loader.select_config(many).await.unwrap();
331 assert!(matches!(prov2, ApiProvider::OpenAI));
332 }
333
334 #[tokio::test]
335 async fn test_load_config_uses_cache_file() {
336 let temp_dir = tempfile::tempdir().unwrap();
337
338 let openai_json = temp_dir.path().join("openai.json");
340 let content = r#"{
341 "base_url": "https://api.openai.com/v1",
342 "api_key": "json-key",
343 "model": "gpt-4"
344 }"#;
345 fs::write(&openai_json, content).await.unwrap();
346
347 let cache_path = temp_dir.path().join("cache.json");
349 let mut cache = ConfigCache::new();
350 cache.set_selected_provider("openai".into());
351 cache.save(&cache_path).await.unwrap();
352
353 let mut loader = ConfigLoader::new(temp_dir.path());
355 loader.cache_path = cache_path; let (prov, cfg) = loader.load_config().await.unwrap();
358 assert!(matches!(prov, ApiProvider::OpenAI));
359 assert_eq!(cfg.api_key.as_deref(), Some("json-key"));
360 }
361
362
363 #[tokio::test]
364 async fn test_load_config_no_config_returns_error() {
365 for k in [
367 "OPENAI_API_KEY","OPENAI_BASE_URL","OPENAI_MODEL",
368 "ANTHROPIC_API_KEY","ANTHROPIC_BASE_URL","ANTHROPIC_MODEL",
369 "GOOGLE_API_KEY","GOOGLE_BASE_URL","GOOGLE_MODEL"
370 ] { let _ = std::env::remove_var(k); }
371
372 let temp_dir = tempfile::tempdir().unwrap();
373 let mut loader = ConfigLoader::new(temp_dir.path());
374 loader.cache_path = temp_dir.path().join("cache.json");
376
377 let err = loader.load_config().await.err().expect("expected error");
378 let msg = format!("{}", err);
379 assert!(msg.contains("No configuration"), "unexpected error: {}", msg);
380 }