trae_agent_rs_core/config/
loader.rs

1//! Configuration loader for API providers
2
3use super::{ApiProvider, ApiProviderConfig, ConfigCache};
4use crate::error::{ConfigError, Result};
5use serde_json;
6use std::collections::HashMap;
7use std::path::{Path, PathBuf};
8use tokio::fs;
9use tracing::{debug, info, warn};
10
11/// Configuration loader that handles multiple API providers
12pub struct ConfigLoader {
13    /// Base directory to search for config files
14    config_dir: PathBuf,
15
16    /// Cache for user selections
17    cache: ConfigCache,
18
19    /// Cache file path
20    cache_path: PathBuf,
21}
22
23impl ConfigLoader {
24    /// Create a new configuration loader
25    pub fn new<P: AsRef<Path>>(config_dir: P) -> Self {
26        let config_dir = config_dir.as_ref().to_path_buf();
27        let cache_path = ConfigCache::default_cache_path();
28
29        Self {
30            config_dir,
31            cache: ConfigCache::new(),
32            cache_path,
33        }
34    }
35
36    /// Initialize the loader by loading the cache
37    pub async fn init(&mut self) -> Result<()> {
38        self.cache = ConfigCache::load(&self.cache_path).await?;
39        debug!("Loaded configuration cache from: {}", self.cache_path.display());
40        Ok(())
41    }
42
43    /// Discover available API provider configurations
44    pub async fn discover_configs(&self) -> Result<HashMap<ApiProvider, ApiProviderConfig>> {
45        let mut configs = HashMap::new();
46
47        // Check for JSON config files
48        let json_configs = self.load_json_configs().await?;
49        configs.extend(json_configs);
50
51        // Check for environment variables for providers that don't have JSON configs
52        let env_configs = self.load_env_configs(&configs).await?;
53        configs.extend(env_configs);
54
55        debug!("Discovered {} provider configurations", configs.len());
56        Ok(configs)
57    }
58
59    /// Load configurations from JSON files
60    async fn load_json_configs(&self) -> Result<HashMap<ApiProvider, ApiProviderConfig>> {
61        let mut configs = HashMap::new();
62
63        let providers = [
64            ApiProvider::OpenAI,
65            ApiProvider::Anthropic,
66            ApiProvider::Google,
67        ];
68
69        for provider in providers {
70            let config_path = self.config_dir.join(provider.config_filename());
71
72            if config_path.exists() {
73                match self.load_json_config(&config_path).await {
74                    Ok(config) => {
75                        info!("Loaded {} configuration from: {}", provider, config_path.display());
76                        configs.insert(provider, config);
77                    }
78                    Err(e) => {
79                        warn!("Failed to load {} configuration from {}: {}",
80                              provider, config_path.display(), e);
81                    }
82                }
83            }
84        }
85
86        Ok(configs)
87    }
88
89    /// Load a single JSON configuration file
90    async fn load_json_config(&self, path: &Path) -> Result<ApiProviderConfig> {
91        let content = fs::read_to_string(path).await?;
92        let config: ApiProviderConfig = serde_json::from_str(&content)
93            .map_err(|e| ConfigError::InvalidFormat)?;
94        Ok(config)
95    }
96
97    /// Load configurations from environment variables
98    async fn load_env_configs(&self, existing_configs: &HashMap<ApiProvider, ApiProviderConfig>)
99        -> Result<HashMap<ApiProvider, ApiProviderConfig>> {
100        let mut configs = HashMap::new();
101
102        let providers = [
103            ApiProvider::OpenAI,
104            ApiProvider::Anthropic,
105            ApiProvider::Google,
106        ];
107
108        for provider in providers {
109            // Skip if we already have a JSON config for this provider
110            if existing_configs.contains_key(&provider) {
111                continue;
112            }
113
114            if let Some(config) = self.load_env_config(&provider) {
115                info!("Loaded {} configuration from environment variables", provider);
116                configs.insert(provider, config);
117            }
118        }
119
120        Ok(configs)
121    }
122
123    /// Load configuration for a specific provider from environment variables
124    fn load_env_config(&self, provider: &ApiProvider) -> Option<ApiProviderConfig> {
125        let prefix = provider.env_prefix();
126
127        let base_url = std::env::var(format!("{}_BASE_URL", prefix)).ok();
128        let api_key = std::env::var(format!("{}_API_KEY", prefix)).ok();
129        let model = std::env::var(format!("{}_MODEL", prefix)).ok();
130
131        // Only create config if at least one environment variable is set
132        if base_url.is_some() || api_key.is_some() || model.is_some() {
133            Some(ApiProviderConfig {
134                base_url,
135                api_key,
136                model,
137                extra: HashMap::new(),
138            })
139        } else {
140            None
141        }
142    }
143
144    /// Select a configuration, handling multiple options
145    pub async fn select_config(&mut self, configs: HashMap<ApiProvider, ApiProviderConfig>)
146        -> Result<(ApiProvider, ApiProviderConfig)> {
147
148        if configs.is_empty() {
149            return Err(ConfigError::NoConfigFound.into());
150        }
151
152        // If only one config, use it
153        if configs.len() == 1 {
154            let (provider, config) = configs.into_iter().next().unwrap();
155            info!("Using single available configuration: {}", provider);
156            return Ok((provider, config));
157        }
158
159        // Check cache for previous selection
160        if let Some(cached_provider) = self.cache.get_selected_provider() {
161            if !self.cache.is_expired() {
162                if let Ok(provider) = cached_provider.parse::<ApiProvider>() {
163                    if let Some(config) = configs.get(&provider) {
164                        info!("Using cached provider selection: {}", provider);
165                        return Ok((provider, config.clone()));
166                    }
167                }
168            }
169        }
170
171        // Multiple configs available, need user selection
172        self.prompt_user_selection(configs).await
173    }
174
175    /// Prompt user to select from multiple configurations
176    async fn prompt_user_selection(&mut self, configs: HashMap<ApiProvider, ApiProviderConfig>)
177        -> Result<(ApiProvider, ApiProviderConfig)> {
178
179        println!("Multiple API provider configurations found:");
180
181        let providers: Vec<_> = configs.keys().collect();
182        for (i, provider) in providers.iter().enumerate() {
183            let config = configs.get(provider).unwrap();
184            let source = if config.api_key.is_some() { "configured" } else { "env vars" };
185            println!("  {}. {} ({})", i + 1, provider, source);
186        }
187
188        println!("Please select a provider (1-{}): ", providers.len());
189
190        // Read user input
191        let mut input = String::new();
192        std::io::stdin().read_line(&mut input)
193            .map_err(|_| ConfigError::InvalidValue {
194                field: "user_selection".to_string(),
195                value: "failed to read input".to_string(),
196            })?;
197
198        let selection: usize = input.trim().parse()
199            .map_err(|_| ConfigError::InvalidValue {
200                field: "user_selection".to_string(),
201                value: input.trim().to_string(),
202            })?;
203
204        if selection == 0 || selection > providers.len() {
205            return Err(ConfigError::InvalidValue {
206                field: "user_selection".to_string(),
207                value: selection.to_string(),
208            }.into());
209        }
210
211        let selected_provider = providers[selection - 1].clone();
212        let selected_config = configs.get(&selected_provider).unwrap().clone();
213
214        // Cache the selection
215        self.cache.set_selected_provider(selected_provider.to_string());
216        if let Err(e) = self.cache.save(&self.cache_path).await {
217            warn!("Failed to save configuration cache: {}", e);
218        }
219
220        info!("Selected provider: {}", selected_provider);
221        Ok((selected_provider, selected_config))
222    }
223
224    /// Load configuration with automatic discovery and selection
225    pub async fn load_config(&mut self) -> Result<(ApiProvider, ApiProviderConfig)> {
226        self.init().await?;
227        let configs = self.discover_configs().await?;
228        self.select_config(configs).await
229    }
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235    use tempfile::tempdir;
236
237    #[tokio::test]
238    async fn test_load_json_config() {
239        let temp_dir = tempdir().unwrap();
240        let config_path = temp_dir.path().join("openai.json");
241
242        let config_content = r#"{
243            "base_url": "https://api.openai.com/v1",
244            "api_key": "test-key",
245            "model": "gpt-4"
246        }"#;
247
248        fs::write(&config_path, config_content).await.unwrap();
249
250        let loader = ConfigLoader::new(temp_dir.path());
251        let config = loader.load_json_config(&config_path).await.unwrap();
252
253        assert_eq!(config.base_url, Some("https://api.openai.com/v1".to_string()));
254        assert_eq!(config.api_key, Some("test-key".to_string()));
255        assert_eq!(config.model, Some("gpt-4".to_string()));
256    }
257
258    #[test]
259    fn test_load_env_config() {
260        std::env::set_var("OPENAI_API_KEY", "test-env-key");
261        std::env::set_var("OPENAI_MODEL", "gpt-3.5-turbo");
262
263        let loader = ConfigLoader::new(".");
264        let config = loader.load_env_config(&ApiProvider::OpenAI).unwrap();
265
266        assert_eq!(config.api_key, Some("test-env-key".to_string()));
267        assert_eq!(config.model, Some("gpt-3.5-turbo".to_string()));
268
269        // Clean up
270        std::env::remove_var("OPENAI_API_KEY");
271        std::env::remove_var("OPENAI_MODEL");
272    }
273}
274
275
276    #[tokio::test]
277    async fn test_discover_configs_prefers_json_over_env() {
278        // Prepare: set env vars but also provide JSON to ensure JSON wins
279        std::env::set_var("OPENAI_API_KEY", "env-key");
280        std::env::set_var("OPENAI_MODEL", "env-model");
281
282        let temp_dir = tempfile::tempdir().unwrap();
283        let openai_json = temp_dir.path().join("openai.json");
284        let content = r#"{
285            "base_url": "https://api.openai.com/v1",
286            "api_key": "json-key",
287            "model": "gpt-4"
288        }"#;
289        fs::write(&openai_json, content).await.unwrap();
290
291        let loader = ConfigLoader::new(temp_dir.path());
292        let configs = loader.discover_configs().await.unwrap();
293
294        let openai = configs.get(&ApiProvider::OpenAI).expect("openai config missing");
295        assert_eq!(openai.api_key.as_deref(), Some("json-key"));
296        assert_eq!(openai.model.as_deref(), Some("gpt-4"));
297
298        // Clean env
299        std::env::remove_var("OPENAI_API_KEY");
300        std::env::remove_var("OPENAI_MODEL");
301    }
302
303    #[tokio::test]
304    async fn test_select_config_single_and_cache() {
305        let temp_dir = tempfile::tempdir().unwrap();
306        let mut loader = ConfigLoader::new(temp_dir.path());
307
308        // Single config should be selected directly
309        let mut single = HashMap::new();
310        single.insert(
311            ApiProvider::OpenAI,
312            ApiProviderConfig { base_url: Some("https://api.openai.com/v1".into()), api_key: Some("k".into()), model: Some("gpt-4".into()), extra: HashMap::new() }
313        );
314        let (prov, _cfg) = loader.select_config(single).await.unwrap();
315        assert!(matches!(prov, ApiProvider::OpenAI));
316
317        // Multiple with cache
318        let mut many = HashMap::new();
319        many.insert(
320            ApiProvider::OpenAI,
321            ApiProviderConfig { base_url: None, api_key: Some("k1".into()), model: Some("m1".into()), extra: HashMap::new() }
322        );
323        many.insert(
324            ApiProvider::Anthropic,
325            ApiProviderConfig { base_url: None, api_key: Some("k2".into()), model: Some("m2".into()), extra: HashMap::new() }
326        );
327
328        // Inject cache to avoid interactive prompt
329        loader.cache.set_selected_provider("openai".into());
330        let (prov2, _cfg2) = loader.select_config(many).await.unwrap();
331        assert!(matches!(prov2, ApiProvider::OpenAI));
332    }
333
334    #[tokio::test]
335    async fn test_load_config_uses_cache_file() {
336        let temp_dir = tempfile::tempdir().unwrap();
337
338        // Write openai.json
339        let openai_json = temp_dir.path().join("openai.json");
340        let content = r#"{
341            "base_url": "https://api.openai.com/v1",
342            "api_key": "json-key",
343            "model": "gpt-4"
344        }"#;
345        fs::write(&openai_json, content).await.unwrap();
346
347        // Prepare cache file with fresh timestamp
348        let cache_path = temp_dir.path().join("cache.json");
349        let mut cache = ConfigCache::new();
350        cache.set_selected_provider("openai".into());
351        cache.save(&cache_path).await.unwrap();
352
353        // Loader with overridden cache path
354        let mut loader = ConfigLoader::new(temp_dir.path());
355        loader.cache_path = cache_path; // same module, allowed in tests
356
357        let (prov, cfg) = loader.load_config().await.unwrap();
358        assert!(matches!(prov, ApiProvider::OpenAI));
359        assert_eq!(cfg.api_key.as_deref(), Some("json-key"));
360    }
361
362
363    #[tokio::test]
364    async fn test_load_config_no_config_returns_error() {
365        // Ensure env does not provide configs
366        for k in [
367            "OPENAI_API_KEY","OPENAI_BASE_URL","OPENAI_MODEL",
368            "ANTHROPIC_API_KEY","ANTHROPIC_BASE_URL","ANTHROPIC_MODEL",
369            "GOOGLE_API_KEY","GOOGLE_BASE_URL","GOOGLE_MODEL"
370        ] { let _ = std::env::remove_var(k); }
371
372        let temp_dir = tempfile::tempdir().unwrap();
373        let mut loader = ConfigLoader::new(temp_dir.path());
374        // Avoid using any real cache path
375        loader.cache_path = temp_dir.path().join("cache.json");
376
377        let err = loader.load_config().await.err().expect("expected error");
378        let msg = format!("{}", err);
379        assert!(msg.contains("No configuration"), "unexpected error: {}", msg);
380    }