1use super::{IntegrationConfig, IntegrationError, MemoryStrategy, PrecisionLevel};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::path::Path;
11#[allow(unused_imports)]
12pub struct ConfigManager {
14 config: GlobalConfig,
16 sources: Vec<ConfigSource>,
18 cache: HashMap<String, ConfigValue>,
20 watch_enabled: bool,
22}
23
24impl ConfigManager {
25 pub fn new() -> Self {
27 Self {
28 config: GlobalConfig::default(),
29 sources: Vec::new(),
30 cache: HashMap::new(),
31 watch_enabled: false,
32 }
33 }
34
35 pub fn load_from_file<P: AsRef<Path>>(&mut self, path: P) -> Result<(), IntegrationError> {
37 let content = std::fs::read_to_string(path.as_ref()).map_err(|e| {
38 IntegrationError::ConfigMismatch(format!("Failed to read config file: {e}"))
39 })?;
40
41 let file_config: FileConfig = toml::from_str(&content).map_err(|e| {
42 IntegrationError::ConfigMismatch(format!("Failed to parse _config file: {e}"))
43 })?;
44
45 self.merge_file_config(file_config)?;
47
48 self.sources
50 .push(ConfigSource::File(path.as_ref().to_path_buf()));
51
52 Ok(())
53 }
54
55 pub fn load_from_env(&mut self) -> Result<(), IntegrationError> {
57 let mut env_config = HashMap::new();
58
59 for (key, value) in std::env::vars() {
61 if key.starts_with("SCIRS2_") {
62 let config_key = key
63 .strip_prefix("SCIRS2_")
64 .expect("Operation failed")
65 .to_lowercase();
66 env_config.insert(config_key, value);
67 }
68 }
69
70 self.merge_env_config(env_config)?;
71 self.sources.push(ConfigSource::Environment);
72
73 Ok(())
74 }
75
76 pub fn set<T: Into<ConfigValue>>(&mut self, key: &str, value: T) {
78 self.cache.insert(key.to_string(), value.into());
79 self.apply_cached_values();
80 }
81
82 pub fn get(&self, key: &str) -> Option<&ConfigValue> {
84 self.cache.get(key)
85 }
86
87 pub fn get_or_default<T: From<ConfigValue> + Default>(&self, key: &str) -> T {
89 self.cache
90 .get(key)
91 .cloned()
92 .map(T::from)
93 .unwrap_or_default()
94 }
95
96 pub fn integration_config(&self) -> &IntegrationConfig {
98 &self.config.integration
99 }
100
101 pub fn update_integration_config(&mut self, config: IntegrationConfig) {
103 self.config.integration = config;
104 }
105
106 pub fn module_config(&self, module_name: &str) -> Option<&ModuleConfig> {
108 self.config.modules.get(module_name)
109 }
110
111 pub fn set_module_config(&mut self, module_name: String, config: ModuleConfig) {
113 self.config.modules.insert(module_name, config);
114 }
115
116 pub fn enable_watch(&mut self) -> Result<(), IntegrationError> {
118 self.watch_enabled = true;
119 Ok(())
121 }
122
123 pub fn validate(&self) -> Result<(), IntegrationError> {
125 if self.config.integration.strict_compatibility {
127 for (module_name, module_config) in &self.config.modules {
129 if !module_config.enabled {
130 continue;
131 }
132
133 if let Some(required_version) = &module_config.required_version {
135 self.validate_module_version(module_name, required_version)?;
137 }
138 }
139 }
140
141 self.check_conflicting_settings()?;
143
144 Ok(())
145 }
146
147 pub fn export_to_file<P: AsRef<Path>>(&self, path: P) -> Result<(), IntegrationError> {
149 let file_config = self.to_file_config();
150 let content = toml::to_string_pretty(&file_config).map_err(|e| {
151 IntegrationError::ConfigMismatch(format!("Failed to serialize config: {e}"))
152 })?;
153
154 std::fs::write(path.as_ref(), content).map_err(|e| {
155 IntegrationError::ConfigMismatch(format!("Failed to write config file: {e}"))
156 })?;
157
158 Ok(())
159 }
160
161 pub fn reset_to_defaults(&mut self) {
163 self.config = GlobalConfig::default();
164 self.cache.clear();
165 }
166
167 pub fn summary(&self) -> ConfigSummary {
169 ConfigSummary {
170 total_modules: self.config.modules.len(),
171 enabled_modules: self.config.modules.values().filter(|m| m.enabled).count(),
172 precision_level: self.config.integration.default_precision,
173 memory_strategy: self.config.integration.memory_strategy,
174 sources: self.sources.clone(),
175 cache_size: self.cache.len(),
176 }
177 }
178
179 fn merge_file_config(&mut self, file_config: FileConfig) -> Result<(), IntegrationError> {
181 if let Some(integration) = file_config.integration {
183 self.config.integration = self.merge_integration_config(integration)?;
184 }
185
186 for (name, module_config) in file_config.modules.unwrap_or_default() {
188 self.config.modules.insert(name, module_config);
189 }
190
191 Ok(())
192 }
193
194 fn merge_env_config(
195 &mut self,
196 env_config: HashMap<String, String>,
197 ) -> Result<(), IntegrationError> {
198 for (key, value) in env_config {
199 match key.as_str() {
200 "auto_convert_tensors" => {
201 let val = value.parse::<bool>().map_err(|_| {
202 IntegrationError::ConfigMismatch(format!(
203 "Invalid boolean value for {key}: {value}"
204 ))
205 })?;
206 self.config.integration.auto_convert_tensors = val;
207 }
208 "strict_compatibility" => {
209 let val = value.parse::<bool>().map_err(|_| {
210 IntegrationError::ConfigMismatch(format!(
211 "Invalid boolean value for {key}: {value}"
212 ))
213 })?;
214 self.config.integration.strict_compatibility = val;
215 }
216 "default_precision" => {
217 self.config.integration.default_precision = match value.as_str() {
218 "float32" => PrecisionLevel::Float32,
219 "float64" => PrecisionLevel::Float64,
220 "mixed" => PrecisionLevel::Mixed,
221 _ => {
222 return Err(IntegrationError::ConfigMismatch(format!(
223 "Invalid precision level: {value}"
224 )))
225 }
226 };
227 }
228 "memory_strategy" => {
229 self.config.integration.memory_strategy = match value.as_str() {
230 "shared" => MemoryStrategy::Shared,
231 "copy" => MemoryStrategy::Copy,
232 "memory_mapped" => MemoryStrategy::MemoryMapped,
233 _ => {
234 return Err(IntegrationError::ConfigMismatch(format!(
235 "Invalid memory strategy: {value}"
236 )))
237 }
238 };
239 }
240 _ => {
241 self.cache.insert(key, ConfigValue::String(value));
243 }
244 }
245 }
246
247 Ok(())
248 }
249
250 fn merge_integration_config(
251 &self,
252 file_integration: FileIntegrationConfig,
253 ) -> Result<IntegrationConfig, IntegrationError> {
254 let mut config = self.config.integration.clone();
255
256 if let Some(val) = file_integration.auto_convert_tensors {
257 config.auto_convert_tensors = val;
258 }
259
260 if let Some(val) = file_integration.strict_compatibility {
261 config.strict_compatibility = val;
262 }
263
264 if let Some(precision) = file_integration.default_precision {
265 config.default_precision = match precision.as_str() {
266 "float32" => PrecisionLevel::Float32,
267 "float64" => PrecisionLevel::Float64,
268 "mixed" => PrecisionLevel::Mixed,
269 _ => {
270 return Err(IntegrationError::ConfigMismatch(format!(
271 "Invalid precision level: {precision}"
272 )))
273 }
274 };
275 }
276
277 if let Some(strategy) = file_integration.memory_strategy {
278 config.memory_strategy = match strategy.as_str() {
279 "shared" => MemoryStrategy::Shared,
280 "copy" => MemoryStrategy::Copy,
281 "memory_mapped" => MemoryStrategy::MemoryMapped,
282 _ => {
283 return Err(IntegrationError::ConfigMismatch(format!(
284 "Invalid memory strategy: {strategy}"
285 )))
286 }
287 };
288 }
289
290 Ok(config)
291 }
292
293 fn apply_cached_values(&mut self) {
294 for (key, value) in &self.cache {
296 match key.as_str() {
297 "auto_convert_tensors" => {
298 if let ConfigValue::Bool(val) = value {
299 self.config.integration.auto_convert_tensors = *val;
300 }
301 }
302 "strict_compatibility" => {
303 if let ConfigValue::Bool(val) = value {
304 self.config.integration.strict_compatibility = *val;
305 }
306 }
307 _ => {} }
309 }
310 }
311
312 fn validate_module_version(
313 &self,
314 _module_name: &str,
315 _version: &str,
316 ) -> Result<(), IntegrationError> {
317 Ok(())
319 }
320
321 fn check_conflicting_settings(&self) -> Result<(), IntegrationError> {
322 if self.config.integration.memory_strategy == MemoryStrategy::Shared
324 && !self.config.integration.auto_convert_tensors
325 {
326 return Err(IntegrationError::ConfigMismatch(
327 "Shared memory strategy requires auto tensor conversion".to_string(),
328 ));
329 }
330
331 Ok(())
332 }
333
334 fn to_file_config(&self) -> FileConfig {
335 let integration = FileIntegrationConfig {
336 auto_convert_tensors: Some(self.config.integration.auto_convert_tensors),
337 strict_compatibility: Some(self.config.integration.strict_compatibility),
338 default_precision: Some(
339 format!("{:?}", self.config.integration.default_precision).to_lowercase(),
340 ),
341 memory_strategy: Some(
342 format!("{:?}", self.config.integration.memory_strategy).to_lowercase(),
343 ),
344 error_mode: Some(format!("{:?}", self.config.integration.error_mode).to_lowercase()),
345 };
346
347 FileConfig {
348 integration: Some(integration),
349 modules: Some(self.config.modules.clone()),
350 }
351 }
352}
353
354impl Default for ConfigManager {
355 fn default() -> Self {
356 Self::new()
357 }
358}
359
360#[derive(Debug, Clone)]
362pub struct GlobalConfig {
363 pub integration: IntegrationConfig,
365 pub modules: HashMap<String, ModuleConfig>,
367 pub performance: PerformanceConfig,
369 pub logging: LoggingConfig,
371}
372
373impl Default for GlobalConfig {
374 fn default() -> Self {
375 let mut modules = HashMap::new();
376
377 modules.insert("scirs2-neural".to_string(), ModuleConfig::default_neural());
379 modules.insert("scirs2-optim".to_string(), ModuleConfig::default_optim());
380 modules.insert("scirs2-linalg".to_string(), ModuleConfig::default_linalg());
381
382 Self {
383 integration: IntegrationConfig::default(),
384 modules,
385 performance: PerformanceConfig::default(),
386 logging: LoggingConfig::default(),
387 }
388 }
389}
390
391#[derive(Debug, Clone, Serialize, Deserialize)]
393pub struct ModuleConfig {
394 pub enabled: bool,
396 pub required_version: Option<String>,
398 pub settings: HashMap<String, ConfigValue>,
400 pub features: Vec<String>,
402 pub resource_limits: ResourceLimits,
404}
405
406impl ModuleConfig {
407 pub fn new() -> Self {
409 Self {
410 enabled: true,
411 required_version: None,
412 settings: HashMap::new(),
413 features: Vec::new(),
414 resource_limits: ResourceLimits::default(),
415 }
416 }
417
418 pub fn default_neural() -> Self {
420 let mut config = Self::new();
421 config.features = vec![
422 "automatic_differentiation".to_string(),
423 "gradient_checkpointing".to_string(),
424 "mixed_precision".to_string(),
425 ];
426 config
427 .settings
428 .insert("batch_size".to_string(), ConfigValue::Int(32));
429 config
430 .settings
431 .insert("learning_rate".to_string(), ConfigValue::Float(0.001));
432 config
433 }
434
435 pub fn default_optim() -> Self {
437 let mut config = Self::new();
438 config.features = vec![
439 "adaptive_optimizers".to_string(),
440 "learning_rate_scheduling".to_string(),
441 "gradient_clipping".to_string(),
442 ];
443 config.settings.insert(
444 "default_optimizer".to_string(),
445 ConfigValue::String("adam".to_string()),
446 );
447 config
448 .settings
449 .insert("weight_decay".to_string(), ConfigValue::Float(1e-4));
450 config
451 }
452
453 pub fn default_linalg() -> Self {
455 let mut config = Self::new();
456 config.features = vec![
457 "blas_acceleration".to_string(),
458 "gpu_support".to_string(),
459 "numerical_stability".to_string(),
460 ];
461 config
462 .settings
463 .insert("use_blas".to_string(), ConfigValue::Bool(true));
464 config
465 .settings
466 .insert("pivot_threshold".to_string(), ConfigValue::Float(1e-3));
467 config
468 }
469}
470
471impl Default for ModuleConfig {
472 fn default() -> Self {
473 Self::new()
474 }
475}
476
477#[derive(Debug, Clone, Default, Serialize, Deserialize)]
479pub struct ResourceLimits {
480 pub max_memory: Option<usize>,
482 pub max_threads: Option<usize>,
484 pub max_compute_time: Option<f64>,
486 pub max_gpu_memory: Option<usize>,
488}
489
490#[derive(Debug, Clone)]
492pub struct PerformanceConfig {
493 pub enable_simd: bool,
495 pub num_threads: Option<usize>,
497 pub cache_size: usize,
499 pub enable_gpu: bool,
501}
502
503impl Default for PerformanceConfig {
504 fn default() -> Self {
505 Self {
506 enable_simd: true,
507 num_threads: None, cache_size: 1024 * 1024, enable_gpu: false, }
511 }
512}
513
514#[derive(Debug, Clone)]
516pub struct LoggingConfig {
517 pub level: LogLevel,
519 pub module_logging: bool,
521 pub log_file: Option<String>,
523 pub performance_logging: bool,
525}
526
527impl Default for LoggingConfig {
528 fn default() -> Self {
529 Self {
530 level: LogLevel::Info,
531 module_logging: true,
532 log_file: None,
533 performance_logging: false,
534 }
535 }
536}
537
538#[derive(Debug, Clone, Copy, PartialEq)]
540pub enum LogLevel {
541 Error,
542 Warn,
543 Info,
544 Debug,
545 Trace,
546}
547
548#[derive(Debug, Clone, Serialize, Deserialize)]
550pub enum ConfigValue {
551 Bool(bool),
552 Int(i64),
553 Float(f64),
554 String(String),
555 Array(Vec<ConfigValue>),
556 Object(HashMap<String, ConfigValue>),
557}
558
559impl From<bool> for ConfigValue {
560 fn from(value: bool) -> Self {
561 ConfigValue::Bool(value)
562 }
563}
564
565impl From<i64> for ConfigValue {
566 fn from(value: i64) -> Self {
567 ConfigValue::Int(value)
568 }
569}
570
571impl From<f64> for ConfigValue {
572 fn from(value: f64) -> Self {
573 ConfigValue::Float(value)
574 }
575}
576
577impl From<String> for ConfigValue {
578 fn from(value: String) -> Self {
579 ConfigValue::String(value)
580 }
581}
582
583impl From<&str> for ConfigValue {
584 fn from(value: &str) -> Self {
585 ConfigValue::String(value.to_string())
586 }
587}
588
589impl Default for ConfigValue {
590 fn default() -> Self {
591 ConfigValue::String(String::new())
592 }
593}
594
595#[derive(Debug, Clone)]
597pub enum ConfigSource {
598 File(std::path::PathBuf),
599 Environment,
600 Runtime,
601 Default,
602}
603
604#[derive(Debug, Clone)]
606pub struct ConfigSummary {
607 pub total_modules: usize,
608 pub enabled_modules: usize,
609 pub precision_level: PrecisionLevel,
610 pub memory_strategy: MemoryStrategy,
611 pub sources: Vec<ConfigSource>,
612 pub cache_size: usize,
613}
614
615#[derive(Debug, Clone, Serialize, Deserialize)]
617pub struct FileConfig {
618 pub integration: Option<FileIntegrationConfig>,
619 pub modules: Option<HashMap<String, ModuleConfig>>,
620}
621
622#[derive(Debug, Clone, Serialize, Deserialize)]
624pub struct FileIntegrationConfig {
625 pub auto_convert_tensors: Option<bool>,
626 pub strict_compatibility: Option<bool>,
627 pub default_precision: Option<String>,
628 pub memory_strategy: Option<String>,
629 pub error_mode: Option<String>,
630}
631
632static GLOBAL_CONFIG_MANAGER: std::sync::OnceLock<std::sync::Mutex<ConfigManager>> =
634 std::sync::OnceLock::new();
635
636#[allow(dead_code)]
638pub fn init_config_manager() -> &'static std::sync::Mutex<ConfigManager> {
639 GLOBAL_CONFIG_MANAGER.get_or_init(|| {
640 let mut manager = ConfigManager::new();
641
642 let _ = manager.load_from_env();
644
645 if let Ok(config_path) = std::env::var("SCIRS2_CONFIG_PATH") {
647 let _ = manager.load_from_file(config_path);
648 }
649
650 std::sync::Mutex::new(manager)
651 })
652}
653
654#[allow(dead_code)]
656pub fn get_config_value(key: &str) -> Result<Option<ConfigValue>, IntegrationError> {
657 let manager = init_config_manager();
658 let manager_guard = manager.lock().map_err(|_| {
659 IntegrationError::ConfigMismatch("Failed to acquire config lock".to_string())
660 })?;
661 Ok(manager_guard.get(key).cloned())
662}
663
664#[allow(dead_code)]
666pub fn set_config_value<T: Into<ConfigValue>>(key: &str, value: T) -> Result<(), IntegrationError> {
667 let manager = init_config_manager();
668 let mut manager_guard = manager.lock().map_err(|_| {
669 IntegrationError::ConfigMismatch("Failed to acquire config lock".to_string())
670 })?;
671 manager_guard.set(key, value);
672 Ok(())
673}
674
675#[allow(dead_code)]
677pub fn get_module_config(modulename: &str) -> Result<Option<ModuleConfig>, IntegrationError> {
678 let manager = init_config_manager();
679 let manager_guard = manager.lock().map_err(|_| {
680 IntegrationError::ConfigMismatch("Failed to acquire config lock".to_string())
681 })?;
682 Ok(manager_guard.module_config(modulename).cloned())
683}
684
685#[allow(dead_code)]
687pub fn update_integration_config(config: IntegrationConfig) -> Result<(), IntegrationError> {
688 let manager = init_config_manager();
689 let mut manager_guard = manager.lock().map_err(|_| {
690 IntegrationError::ConfigMismatch("Failed to acquire _config lock".to_string())
691 })?;
692 manager_guard.update_integration_config(config);
693 Ok(())
694}
695
696#[allow(dead_code)]
698pub fn load_config_from_file<P: AsRef<Path>>(path: P) -> Result<(), IntegrationError> {
699 let manager = init_config_manager();
700 let mut manager_guard = manager.lock().map_err(|_| {
701 IntegrationError::ConfigMismatch("Failed to acquire config lock".to_string())
702 })?;
703 manager_guard.load_from_file(path)
704}
705
706#[allow(dead_code)]
708pub fn export_config_to_file<P: AsRef<Path>>(path: P) -> Result<(), IntegrationError> {
709 let manager = init_config_manager();
710 let manager_guard = manager.lock().map_err(|_| {
711 IntegrationError::ConfigMismatch("Failed to acquire config lock".to_string())
712 })?;
713 manager_guard.export_to_file(path)
714}
715
716#[allow(dead_code)]
718pub fn get_config_summary() -> Result<ConfigSummary, IntegrationError> {
719 let manager = init_config_manager();
720 let manager_guard = manager.lock().map_err(|_| {
721 IntegrationError::ConfigMismatch("Failed to acquire config lock".to_string())
722 })?;
723 Ok(manager_guard.summary())
724}
725
726#[cfg(test)]
727mod tests {
728 use super::*;
729 use std::collections::HashMap;
730
731 #[test]
732 fn test_config_manager_creation() {
733 let manager = ConfigManager::new();
734 assert_eq!(manager.config.modules.len(), 3); }
736
737 #[test]
738 fn test_config_value_types() {
739 let bool_val = ConfigValue::Bool(true);
740 let int_val = ConfigValue::Int(42);
741 let float_val = ConfigValue::Float(std::f64::consts::PI);
742 let string_val = ConfigValue::String("test".to_string());
743
744 assert!(matches!(bool_val, ConfigValue::Bool(true)));
745 assert!(matches!(int_val, ConfigValue::Int(42)));
746 assert!(
747 matches!(float_val, ConfigValue::Float(f) if (f - std::f64::consts::PI).abs() < 1e-10)
748 );
749 assert!(matches!(string_val, ConfigValue::String(ref s) if s == "test"));
750 }
751
752 #[test]
753 fn test_module_config() {
754 let neural_config = ModuleConfig::default_neural();
755 assert!(neural_config.enabled);
756 assert!(neural_config
757 .features
758 .contains(&"automatic_differentiation".to_string()));
759 assert!(neural_config.settings.contains_key("batch_size"));
760 }
761
762 #[test]
763 fn test_config_value_conversions() {
764 let bool_val: ConfigValue = true.into();
765 let int_val: ConfigValue = 42i64.into();
766 let float_val: ConfigValue = std::f64::consts::PI.into();
767 let string_val: ConfigValue = "test".into();
768
769 assert!(matches!(bool_val, ConfigValue::Bool(true)));
770 assert!(matches!(int_val, ConfigValue::Int(42)));
771 assert!(
772 matches!(float_val, ConfigValue::Float(f) if (f - std::f64::consts::PI).abs() < 1e-10)
773 );
774 assert!(matches!(string_val, ConfigValue::String(ref s) if s == "test"));
775 }
776
777 #[test]
778 fn test_resource_limits() {
779 let limits = ResourceLimits {
780 max_memory: Some(1024 * 1024 * 1024), max_threads: Some(8),
782 max_compute_time: Some(60.0), max_gpu_memory: Some(512 * 1024 * 1024), };
785
786 assert_eq!(limits.max_memory, Some(1024 * 1024 * 1024));
787 assert_eq!(limits.max_threads, Some(8));
788 assert_eq!(limits.max_compute_time, Some(60.0));
789 assert_eq!(limits.max_gpu_memory, Some(512 * 1024 * 1024));
790 }
791
792 #[test]
793 fn test_global_config_default() {
794 let config = GlobalConfig::default();
795 assert!(config.modules.contains_key("scirs2-neural"));
796 assert!(config.modules.contains_key("scirs2-optim"));
797 assert!(config.modules.contains_key("scirs2-linalg"));
798 assert!(config.performance.enable_simd);
799 }
800
801 #[test]
802 fn test_config_manager_set_get() {
803 let mut manager = ConfigManager::new();
804
805 manager.set("test_key", ConfigValue::String("test_value".to_string()));
806 let retrieved = manager.get("test_key");
807
808 assert!(retrieved.is_some());
809 if let Some(ConfigValue::String(val)) = retrieved {
810 assert_eq!(val, "test_value");
811 } else {
812 panic!("Expected string value");
813 }
814 }
815
816 #[test]
817 fn test_env_config_merge() {
818 let mut manager = ConfigManager::new();
819 let mut env_vars = HashMap::new();
820 env_vars.insert("auto_convert_tensors".to_string(), "false".to_string());
821 env_vars.insert("strict_compatibility".to_string(), "true".to_string());
822
823 manager
824 .merge_env_config(env_vars)
825 .expect("Operation failed");
826
827 assert!(!manager.config.integration.auto_convert_tensors);
828 assert!(manager.config.integration.strict_compatibility);
829 }
830
831 #[test]
832 fn test_config_validation() {
833 let manager = ConfigManager::new();
834
835 assert!(manager.validate().is_ok());
837 }
838}