Skip to main content

sh_layer1/
config_manager.rs

1//! 配置管理模块
2//!
3//! 多环境配置、热更新、验证、多提供商管理。
4
5use anyhow::{anyhow, Result};
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::path::{Path, PathBuf};
9
10/// 提供商配置
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct ProviderConfig {
13    /// API 密钥
14    pub api_key: String,
15    /// API 基础 URL
16    pub base_url: String,
17    /// 默认模型
18    pub model: String,
19    /// API 格式: "anthropic" | "openai" | "google"
20    #[serde(default = "default_api_format")]
21    pub api_format: String,
22    /// 默认最大 token 数
23    #[serde(default = "default_max_tokens")]
24    pub default_max_tokens: u32,
25    /// 默认温度
26    #[serde(default = "default_temperature")]
27    pub default_temperature: f32,
28}
29
30fn default_api_format() -> String {
31    "openai".to_string()
32}
33
34fn default_max_tokens() -> u32 {
35    4096
36}
37
38fn default_temperature() -> f32 {
39    0.7
40}
41
42impl Default for ProviderConfig {
43    fn default() -> Self {
44        Self {
45            api_key: String::new(),
46            base_url: String::new(),
47            model: "claude-sonnet-4-6".to_string(),
48            api_format: default_api_format(),
49            default_max_tokens: 4096,
50            default_temperature: 0.7,
51        }
52    }
53}
54
55/// 全局设置
56#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct GlobalSettings {
58    /// 会话自动保存
59    #[serde(default = "default_true")]
60    pub session_auto_save: bool,
61    /// 会话最大历史数
62    #[serde(default = "default_session_max_history")]
63    pub session_max_history: usize,
64    /// 检查点启用
65    #[serde(default = "default_true")]
66    pub checkpoint_enabled: bool,
67    /// 检查点间隔(秒)
68    #[serde(default = "default_checkpoint_interval")]
69    pub checkpoint_interval_sec: u32,
70    /// 审计日志启用
71    #[serde(default = "default_true")]
72    pub audit_enabled: bool,
73    /// MCP 启用
74    #[serde(default)]
75    pub mcp_enabled: bool,
76}
77
78impl Default for GlobalSettings {
79    fn default() -> Self {
80        Self {
81            session_auto_save: true,
82            session_max_history: 100,
83            checkpoint_enabled: true,
84            checkpoint_interval_sec: 60,
85            audit_enabled: true,
86            mcp_enabled: false,
87        }
88    }
89}
90
91fn default_true() -> bool {
92    true
93}
94
95fn default_session_max_history() -> usize {
96    100
97}
98
99fn default_checkpoint_interval() -> u32 {
100    60
101}
102
103/// 配置管理器
104#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct ConfigManager {
106    /// 当前激活的提供商
107    #[serde(default = "default_provider")]
108    pub active_provider: String,
109    /// 提供商注册表
110    #[serde(default)]
111    pub providers: HashMap<String, ProviderConfig>,
112    /// 全局设置
113    #[serde(default)]
114    pub settings: GlobalSettings,
115    /// 其他配置项
116    #[serde(default)]
117    pub extra: HashMap<String, String>,
118}
119
120fn default_provider() -> String {
121    "anthropic".to_string()
122}
123
124impl Default for ConfigManager {
125    fn default() -> Self {
126        Self {
127            active_provider: "anthropic".to_string(),
128            providers: HashMap::new(),
129            settings: GlobalSettings::default(),
130            extra: HashMap::new(),
131        }
132    }
133}
134
135impl ConfigManager {
136    /// 创建新的配置管理器
137    pub fn new() -> Self {
138        Self::default()
139    }
140
141    /// 从环境变量加载配置
142    pub fn from_env() -> Self {
143        let mut config = Self::default();
144
145        // 读取环境变量
146        if let Ok(provider) = std::env::var("CONTINUUM_PROVIDER") {
147            config.active_provider = provider;
148        }
149
150        // 读取 API 密钥并添加到当前提供商
151        if let Ok(api_key) = std::env::var("CONTINUUM_API_KEY") {
152            let provider_name = config.active_provider.clone();
153            let provider_config = config.providers.entry(provider_name).or_default();
154            provider_config.api_key = api_key;
155        }
156
157        // 读取基础 URL
158        if let Ok(base_url) = std::env::var("CONTINUUM_BASE_URL") {
159            let provider_name = config.active_provider.clone();
160            let provider_config = config.providers.entry(provider_name).or_default();
161            provider_config.base_url = base_url;
162        }
163
164        // 读取模型
165        if let Ok(model) = std::env::var("CONTINUUM_MODEL") {
166            let provider_name = config.active_provider.clone();
167            let provider_config = config.providers.entry(provider_name).or_default();
168            provider_config.model = model;
169        }
170
171        // 读取检查点配置
172        if let Ok(val) = std::env::var("CONTINUUM_CHECKPOINT_ENABLED") {
173            if let Ok(enabled) = val.parse::<bool>() {
174                config.settings.checkpoint_enabled = enabled;
175            }
176        }
177
178        if let Ok(val) = std::env::var("CONTINUUM_AUDIT_ENABLED") {
179            if let Ok(enabled) = val.parse::<bool>() {
180                config.settings.audit_enabled = enabled;
181            }
182        }
183
184        config
185    }
186
187    /// 从文件加载配置
188    pub async fn load_from_file(&mut self, path: &Path) -> Result<()> {
189        if !path.exists() {
190            return Ok(());
191        }
192
193        let content = tokio::fs::read_to_string(path).await?;
194        let loaded: ConfigManager = toml::from_str(&content)?;
195
196        // 合并配置
197        self.merge(loaded);
198        Ok(())
199    }
200
201    /// 从文件同步加载(用于非 async 环境)
202    pub fn load_from_file_sync(&mut self, path: &Path) -> Result<()> {
203        if !path.exists() {
204            return Ok(());
205        }
206
207        let content = std::fs::read_to_string(path)?;
208        let loaded: ConfigManager = toml::from_str(&content)?;
209        self.merge(loaded);
210        Ok(())
211    }
212
213    /// 合并配置 (优先级: env > file > default,所以 file 合入 self)
214    pub fn merge(&mut self, other: ConfigManager) {
215        // 合并提供商配置
216        for (name, provider) in other.providers {
217            // 只合并有 API 密钥的提供商
218            if !provider.api_key.is_empty() {
219                self.providers.insert(name, provider);
220            }
221        }
222
223        // 合并设置(保留已从环境变量读取的值)
224        if other.settings.session_max_history > 0 {
225            self.settings.session_max_history = other.settings.session_max_history;
226        }
227        if other.settings.checkpoint_interval_sec > 0 {
228            self.settings.checkpoint_interval_sec = other.settings.checkpoint_interval_sec;
229        }
230
231        // 合并其他配置
232        self.extra.extend(other.extra);
233
234        // 设置活跃提供商(如果指定的提供商存在)
235        if !other.active_provider.is_empty() && self.providers.contains_key(&other.active_provider)
236        {
237            self.active_provider = other.active_provider;
238        }
239    }
240
241    /// 获取默认配置路径
242    pub fn default_config_path() -> PathBuf {
243        // 用户级配置: ~/.continuum/config.toml
244        let home = dirs::home_dir().unwrap_or_else(|| PathBuf::from("."));
245        home.join(".continuum").join("config.toml")
246    }
247
248    /// 获取项目级配置路径
249    pub fn project_config_path() -> PathBuf {
250        PathBuf::from(".continuum").join("config.toml")
251    }
252
253    /// 加载完整配置(环境变量 + 项目级 + 用户级)
254    pub async fn load_full() -> Result<Self> {
255        // 1. 从默认配置开始
256        let mut config = Self::new();
257
258        // 2. 加载用户级配置
259        let user_path = Self::default_config_path();
260        config.load_from_file(&user_path).await?;
261
262        // 3. 加载项目级配置(覆盖用户级)
263        let project_path = Self::project_config_path();
264        config.load_from_file(&project_path).await?;
265
266        // 4. 从环境变量加载(最高优先级,覆盖文件配置)
267        let env_config = Self::from_env();
268        config.merge_env(env_config);
269
270        Ok(config)
271    }
272
273    /// 合并环境变量配置(最高优先级)
274    fn merge_env(&mut self, env: ConfigManager) {
275        // 环境变量配置优先级最高,直接覆盖
276        if !env.active_provider.is_empty() {
277            self.active_provider = env.active_provider;
278        }
279
280        // 合并提供商(环境变量的提供商直接覆盖)
281        for (name, provider) in env.providers {
282            self.providers.insert(name, provider);
283        }
284
285        // 合并设置
286        self.settings.audit_enabled = env.settings.audit_enabled;
287        self.settings.checkpoint_enabled = env.settings.checkpoint_enabled;
288    }
289
290    /// 切换提供商
291    pub fn use_provider(&mut self, name: &str) -> Result<()> {
292        if !self.providers.contains_key(name) {
293            return Err(anyhow!(
294                "Provider '{}' not found. Use 'config add-provider' first.",
295                name
296            ));
297        }
298        self.active_provider = name.to_string();
299        Ok(())
300    }
301
302    /// 获取当前提供商配置
303    pub fn current(&self) -> Result<&ProviderConfig> {
304        self.providers
305            .get(&self.active_provider)
306            .ok_or_else(|| anyhow!("No provider '{}' configured", self.active_provider))
307    }
308
309    /// 添加提供商
310    pub fn add_provider(&mut self, name: &str, config: ProviderConfig) {
311        self.providers.insert(name.to_string(), config);
312    }
313
314    /// 列出所有提供商
315    pub fn list_providers(&self) -> Vec<&String> {
316        self.providers.keys().collect()
317    }
318
319    /// 获取配置值
320    pub fn get(&self, key: &str) -> Option<&String> {
321        self.extra.get(key)
322    }
323
324    /// 设置配置值
325    pub fn set(&mut self, key: String, value: String) {
326        self.extra.insert(key, value);
327    }
328
329    /// 保存配置到文件
330    pub async fn save(&self, path: &Path) -> Result<()> {
331        // 确保父目录存在
332        if let Some(parent) = path.parent() {
333            tokio::fs::create_dir_all(parent).await?;
334        }
335
336        let content = toml::to_string_pretty(&self)?;
337        tokio::fs::write(path, content).await?;
338        Ok(())
339    }
340
341    /// 同步保存配置到文件
342    pub fn save_sync(&self, path: &Path) -> Result<()> {
343        if let Some(parent) = path.parent() {
344            std::fs::create_dir_all(parent)?;
345        }
346
347        let content = toml::to_string_pretty(&self)?;
348        std::fs::write(path, content)?;
349        Ok(())
350    }
351
352    /// 解析环境变量引用 ${VAR_NAME}
353    pub fn resolve_env_refs(&mut self) {
354        // 解析提供商配置中的环境变量引用
355        for provider in self.providers.values_mut() {
356            provider.api_key = Self::resolve_env_string(&provider.api_key);
357            provider.base_url = Self::resolve_env_string(&provider.base_url);
358            provider.model = Self::resolve_env_string(&provider.model);
359        }
360
361        // 解析其他配置中的环境变量引用
362        for value in self.extra.values_mut() {
363            *value = Self::resolve_env_string(value);
364        }
365    }
366
367    /// 解析单个字符串中的环境变量引用
368    fn resolve_env_string(s: &str) -> String {
369        let mut result = s.to_string();
370        // 查找 ${VAR_NAME} 并替换
371        while let Some(start) = result.find("${") {
372            if let Some(end) = result[start..].find('}') {
373                let var_name = &result[start + 2..start + end];
374                if let Ok(val) = std::env::var(var_name) {
375                    result.replace_range(start..start + end + 1, &val);
376                } else {
377                    // 环境变量不存在,移除引用标记
378                    result.replace_range(start..start + end + 1, "");
379                }
380            } else {
381                break;
382            }
383        }
384        result
385    }
386
387    /// 初始化默认配置文件
388    pub fn init_default_config(&self) -> Result<PathBuf> {
389        let path = Self::default_config_path();
390
391        if path.exists() {
392            return Err(anyhow!("Config file already exists at {:?}", path));
393        }
394
395        // 创建默认配置
396        let default_config = Self {
397            active_provider: "anthropic".to_string(),
398            providers: {
399                let mut map = HashMap::new();
400                map.insert(
401                    "anthropic".to_string(),
402                    ProviderConfig {
403                        api_key: "${ANTHROPIC_API_KEY}".to_string(),
404                        base_url: "https://api.anthropic.com/v1".to_string(),
405                        model: "claude-sonnet-4-6".to_string(),
406                        api_format: "anthropic".to_string(),
407                        default_max_tokens: 4096,
408                        default_temperature: 0.7,
409                    },
410                );
411                map.insert(
412                    "openai".to_string(),
413                    ProviderConfig {
414                        api_key: "${OPENAI_API_KEY}".to_string(),
415                        base_url: "https://api.openai.com/v1".to_string(),
416                        model: "gpt-4".to_string(),
417                        api_format: "openai".to_string(),
418                        default_max_tokens: 4096,
419                        default_temperature: 0.7,
420                    },
421                );
422                map.insert(
423                    "gemini".to_string(),
424                    ProviderConfig {
425                        api_key: "${GEMINI_API_KEY}".to_string(),
426                        base_url: "https://generativelanguage.googleapis.com/v1".to_string(),
427                        model: "gemini-pro".to_string(),
428                        api_format: "google".to_string(),
429                        default_max_tokens: 4096,
430                        default_temperature: 0.7,
431                    },
432                );
433                map
434            },
435            settings: GlobalSettings::default(),
436            extra: HashMap::new(),
437        };
438
439        default_config.save_sync(&path)?;
440        Ok(path)
441    }
442}
443
444#[cfg(test)]
445mod tests {
446    use super::*;
447
448    #[test]
449    fn test_config_manager_creation() {
450        let config = ConfigManager::new();
451        assert_eq!(config.active_provider, "anthropic");
452    }
453
454    #[test]
455    fn test_provider_config_default() {
456        let provider = ProviderConfig::default();
457        assert_eq!(provider.default_max_tokens, 4096);
458        assert_eq!(provider.default_temperature, 0.7);
459    }
460
461    #[test]
462    fn test_global_settings_default() {
463        let settings = GlobalSettings::default();
464        assert!(settings.session_auto_save);
465        assert!(settings.checkpoint_enabled);
466    }
467
468    #[test]
469    fn test_add_provider() {
470        let mut config = ConfigManager::new();
471        let provider = ProviderConfig {
472            api_key: "test_key".to_string(),
473            base_url: "https://test.api.com".to_string(),
474            model: "test-model".to_string(),
475            api_format: "openai".to_string(),
476            default_max_tokens: 8192,
477            default_temperature: 0.5,
478        };
479        config.add_provider("test", provider);
480        assert!(config.providers.contains_key("test"));
481    }
482
483    #[test]
484    fn test_use_provider() {
485        let mut config = ConfigManager::new();
486        let provider = ProviderConfig {
487            api_key: "test_key".to_string(),
488            base_url: "https://test.api.com".to_string(),
489            model: "test-model".to_string(),
490            api_format: "openai".to_string(),
491            default_max_tokens: 4096,
492            default_temperature: 0.7,
493        };
494        config.add_provider("test", provider);
495
496        config.use_provider("test").unwrap();
497        assert_eq!(config.active_provider, "test");
498    }
499
500    #[test]
501    fn test_use_provider_not_found() {
502        let mut config = ConfigManager::new();
503        let result = config.use_provider("nonexistent");
504        assert!(result.is_err());
505    }
506
507    #[test]
508    fn test_resolve_env_string() {
509        std::env::set_var("TEST_VAR", "test_value");
510        let resolved = ConfigManager::resolve_env_string("${TEST_VAR}");
511        assert_eq!(resolved, "test_value");
512        std::env::remove_var("TEST_VAR");
513    }
514
515    #[test]
516    fn test_set_get_config() {
517        let mut config = ConfigManager::new();
518        config.set("test_key".to_string(), "test_value".to_string());
519        assert_eq!(config.get("test_key"), Some(&"test_value".to_string()));
520    }
521
522    #[test]
523    fn test_list_providers() {
524        let mut config = ConfigManager::new();
525        let provider = ProviderConfig {
526            api_key: "key1".to_string(),
527            base_url: "url1".to_string(),
528            model: "model1".to_string(),
529            api_format: "openai".to_string(),
530            default_max_tokens: 4096,
531            default_temperature: 0.7,
532        };
533        config.add_provider("provider1", provider);
534
535        let list = config.list_providers();
536        assert!(list.contains(&&"provider1".to_string()));
537    }
538
539    #[test]
540    fn test_config_serialization() {
541        let config = ConfigManager::new();
542        let toml_str = toml::to_string(&config).unwrap();
543        assert!(toml_str.contains("active_provider"));
544    }
545}