Skip to main content

sh_layer1/
config_manager.rs

1//! 配置管理模块
2//!
3//! 多环境配置、热更新、验证、多提供商管理。
4
5use anyhow::{anyhow, Result};
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::path::{Path, PathBuf};
9
10/// 提供商配置
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct ProviderConfig {
13    /// API 密钥
14    pub api_key: String,
15    /// API 基础 URL
16    pub base_url: String,
17    /// 默认模型
18    pub model: String,
19    /// 默认最大 token 数
20    #[serde(default = "default_max_tokens")]
21    pub default_max_tokens: u32,
22    /// 默认温度
23    #[serde(default = "default_temperature")]
24    pub default_temperature: f32,
25}
26
27fn default_max_tokens() -> u32 {
28    4096
29}
30
31fn default_temperature() -> f32 {
32    0.7
33}
34
35impl Default for ProviderConfig {
36    fn default() -> Self {
37        Self {
38            api_key: String::new(),
39            base_url: String::new(),
40            model: "claude-sonnet-4-6".to_string(),
41            default_max_tokens: 4096,
42            default_temperature: 0.7,
43        }
44    }
45}
46
47/// 全局设置
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct GlobalSettings {
50    /// 会话自动保存
51    #[serde(default = "default_true")]
52    pub session_auto_save: bool,
53    /// 会话最大历史数
54    #[serde(default = "default_session_max_history")]
55    pub session_max_history: usize,
56    /// 检查点启用
57    #[serde(default = "default_true")]
58    pub checkpoint_enabled: bool,
59    /// 检查点间隔(秒)
60    #[serde(default = "default_checkpoint_interval")]
61    pub checkpoint_interval_sec: u32,
62    /// 审计日志启用
63    #[serde(default = "default_true")]
64    pub audit_enabled: bool,
65    /// MCP 启用
66    #[serde(default)]
67    pub mcp_enabled: bool,
68}
69
70impl Default for GlobalSettings {
71    fn default() -> Self {
72        Self {
73            session_auto_save: true,
74            session_max_history: 100,
75            checkpoint_enabled: true,
76            checkpoint_interval_sec: 60,
77            audit_enabled: true,
78            mcp_enabled: false,
79        }
80    }
81}
82
83fn default_true() -> bool {
84    true
85}
86
87fn default_session_max_history() -> usize {
88    100
89}
90
91fn default_checkpoint_interval() -> u32 {
92    60
93}
94
95/// 配置管理器
96#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct ConfigManager {
98    /// 当前激活的提供商
99    #[serde(default = "default_provider")]
100    pub active_provider: String,
101    /// 提供商注册表
102    #[serde(default)]
103    pub providers: HashMap<String, ProviderConfig>,
104    /// 全局设置
105    #[serde(default)]
106    pub settings: GlobalSettings,
107    /// 其他配置项
108    #[serde(default)]
109    pub extra: HashMap<String, String>,
110}
111
112fn default_provider() -> String {
113    "anthropic".to_string()
114}
115
116impl Default for ConfigManager {
117    fn default() -> Self {
118        Self {
119            active_provider: "anthropic".to_string(),
120            providers: HashMap::new(),
121            settings: GlobalSettings::default(),
122            extra: HashMap::new(),
123        }
124    }
125}
126
127impl ConfigManager {
128    /// 创建新的配置管理器
129    pub fn new() -> Self {
130        Self::default()
131    }
132
133    /// 从环境变量加载配置
134    pub fn from_env() -> Self {
135        let mut config = Self::default();
136
137        // 读取环境变量
138        if let Ok(provider) = std::env::var("CONTINUUM_PROVIDER") {
139            config.active_provider = provider;
140        }
141
142        // 读取 API 密钥并添加到当前提供商
143        if let Ok(api_key) = std::env::var("CONTINUUM_API_KEY") {
144            let provider_name = config.active_provider.clone();
145            let provider_config = config.providers.entry(provider_name).or_default();
146            provider_config.api_key = api_key;
147        }
148
149        // 读取基础 URL
150        if let Ok(base_url) = std::env::var("CONTINUUM_BASE_URL") {
151            let provider_name = config.active_provider.clone();
152            let provider_config = config.providers.entry(provider_name).or_default();
153            provider_config.base_url = base_url;
154        }
155
156        // 读取模型
157        if let Ok(model) = std::env::var("CONTINUUM_MODEL") {
158            let provider_name = config.active_provider.clone();
159            let provider_config = config.providers.entry(provider_name).or_default();
160            provider_config.model = model;
161        }
162
163        // 读取检查点配置
164        if let Ok(val) = std::env::var("CONTINUUM_CHECKPOINT_ENABLED") {
165            if let Ok(enabled) = val.parse::<bool>() {
166                config.settings.checkpoint_enabled = enabled;
167            }
168        }
169
170        if let Ok(val) = std::env::var("CONTINUUM_AUDIT_ENABLED") {
171            if let Ok(enabled) = val.parse::<bool>() {
172                config.settings.audit_enabled = enabled;
173            }
174        }
175
176        config
177    }
178
179    /// 从文件加载配置
180    pub async fn load_from_file(&mut self, path: &Path) -> Result<()> {
181        if !path.exists() {
182            return Ok(());
183        }
184
185        let content = tokio::fs::read_to_string(path).await?;
186        let loaded: ConfigManager = toml::from_str(&content)?;
187
188        // 合并配置
189        self.merge(loaded);
190        Ok(())
191    }
192
193    /// 从文件同步加载(用于非 async 环境)
194    pub fn load_from_file_sync(&mut self, path: &Path) -> Result<()> {
195        if !path.exists() {
196            return Ok(());
197        }
198
199        let content = std::fs::read_to_string(path)?;
200        let loaded: ConfigManager = toml::from_str(&content)?;
201        self.merge(loaded);
202        Ok(())
203    }
204
205    /// 合并配置 (优先级: env > file > default,所以 file 合入 self)
206    pub fn merge(&mut self, other: ConfigManager) {
207        // 合并提供商配置
208        for (name, provider) in other.providers {
209            // 只合并有 API 密钥的提供商
210            if !provider.api_key.is_empty() {
211                self.providers.insert(name, provider);
212            }
213        }
214
215        // 合并设置(保留已从环境变量读取的值)
216        if other.settings.session_max_history > 0 {
217            self.settings.session_max_history = other.settings.session_max_history;
218        }
219        if other.settings.checkpoint_interval_sec > 0 {
220            self.settings.checkpoint_interval_sec = other.settings.checkpoint_interval_sec;
221        }
222
223        // 合并其他配置
224        self.extra.extend(other.extra);
225
226        // 设置活跃提供商(如果指定的提供商存在)
227        if !other.active_provider.is_empty() && self.providers.contains_key(&other.active_provider)
228        {
229            self.active_provider = other.active_provider;
230        }
231    }
232
233    /// 获取默认配置路径
234    pub fn default_config_path() -> PathBuf {
235        // 用户级配置: ~/.continuum/config.toml
236        let home = dirs::home_dir().unwrap_or_else(|| PathBuf::from("."));
237        home.join(".continuum").join("config.toml")
238    }
239
240    /// 获取项目级配置路径
241    pub fn project_config_path() -> PathBuf {
242        PathBuf::from(".continuum").join("config.toml")
243    }
244
245    /// 加载完整配置(环境变量 + 项目级 + 用户级)
246    pub async fn load_full() -> Result<Self> {
247        // 1. 从默认配置开始
248        let mut config = Self::new();
249
250        // 2. 加载用户级配置
251        let user_path = Self::default_config_path();
252        config.load_from_file(&user_path).await?;
253
254        // 3. 加载项目级配置(覆盖用户级)
255        let project_path = Self::project_config_path();
256        config.load_from_file(&project_path).await?;
257
258        // 4. 从环境变量加载(最高优先级,覆盖文件配置)
259        let env_config = Self::from_env();
260        config.merge_env(env_config);
261
262        Ok(config)
263    }
264
265    /// 合并环境变量配置(最高优先级)
266    fn merge_env(&mut self, env: ConfigManager) {
267        // 环境变量配置优先级最高,直接覆盖
268        if !env.active_provider.is_empty() {
269            self.active_provider = env.active_provider;
270        }
271
272        // 合并提供商(环境变量的提供商直接覆盖)
273        for (name, provider) in env.providers {
274            self.providers.insert(name, provider);
275        }
276
277        // 合并设置
278        self.settings.audit_enabled = env.settings.audit_enabled;
279        self.settings.checkpoint_enabled = env.settings.checkpoint_enabled;
280    }
281
282    /// 切换提供商
283    pub fn use_provider(&mut self, name: &str) -> Result<()> {
284        if !self.providers.contains_key(name) {
285            return Err(anyhow!(
286                "Provider '{}' not found. Use 'config add-provider' first.",
287                name
288            ));
289        }
290        self.active_provider = name.to_string();
291        Ok(())
292    }
293
294    /// 获取当前提供商配置
295    pub fn current(&self) -> Result<&ProviderConfig> {
296        self.providers
297            .get(&self.active_provider)
298            .ok_or_else(|| anyhow!("No provider '{}' configured", self.active_provider))
299    }
300
301    /// 添加提供商
302    pub fn add_provider(&mut self, name: &str, config: ProviderConfig) {
303        self.providers.insert(name.to_string(), config);
304    }
305
306    /// 列出所有提供商
307    pub fn list_providers(&self) -> Vec<&String> {
308        self.providers.keys().collect()
309    }
310
311    /// 获取配置值
312    pub fn get(&self, key: &str) -> Option<&String> {
313        self.extra.get(key)
314    }
315
316    /// 设置配置值
317    pub fn set(&mut self, key: String, value: String) {
318        self.extra.insert(key, value);
319    }
320
321    /// 保存配置到文件
322    pub async fn save(&self, path: &Path) -> Result<()> {
323        // 确保父目录存在
324        if let Some(parent) = path.parent() {
325            tokio::fs::create_dir_all(parent).await?;
326        }
327
328        let content = toml::to_string_pretty(&self)?;
329        tokio::fs::write(path, content).await?;
330        Ok(())
331    }
332
333    /// 同步保存配置到文件
334    pub fn save_sync(&self, path: &Path) -> Result<()> {
335        if let Some(parent) = path.parent() {
336            std::fs::create_dir_all(parent)?;
337        }
338
339        let content = toml::to_string_pretty(&self)?;
340        std::fs::write(path, content)?;
341        Ok(())
342    }
343
344    /// 解析环境变量引用 ${VAR_NAME}
345    pub fn resolve_env_refs(&mut self) {
346        // 解析提供商配置中的环境变量引用
347        for provider in self.providers.values_mut() {
348            provider.api_key = Self::resolve_env_string(&provider.api_key);
349            provider.base_url = Self::resolve_env_string(&provider.base_url);
350            provider.model = Self::resolve_env_string(&provider.model);
351        }
352
353        // 解析其他配置中的环境变量引用
354        for value in self.extra.values_mut() {
355            *value = Self::resolve_env_string(value);
356        }
357    }
358
359    /// 解析单个字符串中的环境变量引用
360    fn resolve_env_string(s: &str) -> String {
361        let mut result = s.to_string();
362        // 查找 ${VAR_NAME} 并替换
363        while let Some(start) = result.find("${") {
364            if let Some(end) = result[start..].find('}') {
365                let var_name = &result[start + 2..start + end];
366                if let Ok(val) = std::env::var(var_name) {
367                    result.replace_range(start..start + end + 1, &val);
368                } else {
369                    // 环境变量不存在,移除引用标记
370                    result.replace_range(start..start + end + 1, "");
371                }
372            } else {
373                break;
374            }
375        }
376        result
377    }
378
379    /// 初始化默认配置文件
380    pub fn init_default_config(&self) -> Result<PathBuf> {
381        let path = Self::default_config_path();
382
383        if path.exists() {
384            return Err(anyhow!("Config file already exists at {:?}", path));
385        }
386
387        // 创建默认配置
388        let default_config = Self {
389            active_provider: "anthropic".to_string(),
390            providers: {
391                let mut map = HashMap::new();
392                map.insert(
393                    "anthropic".to_string(),
394                    ProviderConfig {
395                        api_key: "${ANTHROPIC_API_KEY}".to_string(),
396                        base_url: "https://api.anthropic.com/v1".to_string(),
397                        model: "claude-sonnet-4-6".to_string(),
398                        default_max_tokens: 4096,
399                        default_temperature: 0.7,
400                    },
401                );
402                map.insert(
403                    "openai".to_string(),
404                    ProviderConfig {
405                        api_key: "${OPENAI_API_KEY}".to_string(),
406                        base_url: "https://api.openai.com/v1".to_string(),
407                        model: "gpt-4".to_string(),
408                        default_max_tokens: 4096,
409                        default_temperature: 0.7,
410                    },
411                );
412                map.insert(
413                    "gemini".to_string(),
414                    ProviderConfig {
415                        api_key: "${GEMINI_API_KEY}".to_string(),
416                        base_url: "https://generativelanguage.googleapis.com/v1".to_string(),
417                        model: "gemini-pro".to_string(),
418                        default_max_tokens: 4096,
419                        default_temperature: 0.7,
420                    },
421                );
422                map
423            },
424            settings: GlobalSettings::default(),
425            extra: HashMap::new(),
426        };
427
428        default_config.save_sync(&path)?;
429        Ok(path)
430    }
431}
432
433#[cfg(test)]
434mod tests {
435    use super::*;
436
437    #[test]
438    fn test_config_manager_creation() {
439        let config = ConfigManager::new();
440        assert_eq!(config.active_provider, "anthropic");
441    }
442
443    #[test]
444    fn test_provider_config_default() {
445        let provider = ProviderConfig::default();
446        assert_eq!(provider.default_max_tokens, 4096);
447        assert_eq!(provider.default_temperature, 0.7);
448    }
449
450    #[test]
451    fn test_global_settings_default() {
452        let settings = GlobalSettings::default();
453        assert!(settings.session_auto_save);
454        assert!(settings.checkpoint_enabled);
455    }
456
457    #[test]
458    fn test_add_provider() {
459        let mut config = ConfigManager::new();
460        let provider = ProviderConfig {
461            api_key: "test_key".to_string(),
462            base_url: "https://test.api.com".to_string(),
463            model: "test-model".to_string(),
464            default_max_tokens: 8192,
465            default_temperature: 0.5,
466        };
467        config.add_provider("test", provider);
468        assert!(config.providers.contains_key("test"));
469    }
470
471    #[test]
472    fn test_use_provider() {
473        let mut config = ConfigManager::new();
474        let provider = ProviderConfig {
475            api_key: "test_key".to_string(),
476            base_url: "https://test.api.com".to_string(),
477            model: "test-model".to_string(),
478            default_max_tokens: 4096,
479            default_temperature: 0.7,
480        };
481        config.add_provider("test", provider);
482
483        config.use_provider("test").unwrap();
484        assert_eq!(config.active_provider, "test");
485    }
486
487    #[test]
488    fn test_use_provider_not_found() {
489        let mut config = ConfigManager::new();
490        let result = config.use_provider("nonexistent");
491        assert!(result.is_err());
492    }
493
494    #[test]
495    fn test_resolve_env_string() {
496        std::env::set_var("TEST_VAR", "test_value");
497        let resolved = ConfigManager::resolve_env_string("${TEST_VAR}");
498        assert_eq!(resolved, "test_value");
499        std::env::remove_var("TEST_VAR");
500    }
501
502    #[test]
503    fn test_set_get_config() {
504        let mut config = ConfigManager::new();
505        config.set("test_key".to_string(), "test_value".to_string());
506        assert_eq!(config.get("test_key"), Some(&"test_value".to_string()));
507    }
508
509    #[test]
510    fn test_list_providers() {
511        let mut config = ConfigManager::new();
512        let provider = ProviderConfig {
513            api_key: "key1".to_string(),
514            base_url: "url1".to_string(),
515            model: "model1".to_string(),
516            default_max_tokens: 4096,
517            default_temperature: 0.7,
518        };
519        config.add_provider("provider1", provider);
520
521        let list = config.list_providers();
522        assert!(list.contains(&&"provider1".to_string()));
523    }
524
525    #[test]
526    fn test_config_serialization() {
527        let config = ConfigManager::new();
528        let toml_str = toml::to_string(&config).unwrap();
529        assert!(toml_str.contains("active_provider"));
530    }
531}