sklears_model_selection/
config_management.rs

1//! Configuration Management System for Model Selection
2//!
3//! This module provides comprehensive configuration management for model selection
4//! operations, including YAML/JSON serialization, configuration inheritance,
5//! and template management.
6
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::path::Path;
10use thiserror::Error;
11
12/// Configuration management error types
13#[derive(Error, Debug)]
14pub enum ConfigError {
15    #[error("IO error: {0}")]
16    Io(#[from] std::io::Error),
17    #[error("Serialization error: {0}")]
18    Serialization(#[from] serde_json::Error),
19    #[error("YAML error: {0}")]
20    Yaml(String),
21    #[error("Configuration validation error: {0}")]
22    Validation(String),
23    #[error("Template error: {0}")]
24    Template(String),
25}
26
27/// Main configuration structure for model selection operations
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct ModelSelectionConfig {
30    /// Cross-validation configuration
31    pub cross_validation: CrossValidationConfig,
32    /// Hyperparameter optimization configuration
33    pub optimization: OptimizationConfig,
34    /// Scoring and evaluation configuration
35    pub scoring: ScoringConfig,
36    /// Resource and performance configuration
37    pub resources: ResourceConfig,
38    /// Custom parameters and extensions
39    #[serde(default)]
40    pub custom: HashMap<String, serde_json::Value>,
41}
42
43/// Cross-validation configuration
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct CrossValidationConfig {
46    /// Type of cross-validation (kfold, stratified, etc.)
47    pub method: String,
48    /// Number of folds
49    #[serde(default = "default_n_folds")]
50    pub n_folds: usize,
51    /// Random state for reproducibility
52    pub random_state: Option<u64>,
53    /// Shuffle data before splitting
54    #[serde(default = "default_shuffle")]
55    pub shuffle: bool,
56    /// Additional method-specific parameters
57    #[serde(default)]
58    pub parameters: HashMap<String, serde_json::Value>,
59}
60
61/// Hyperparameter optimization configuration
62#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct OptimizationConfig {
64    /// Optimization method (grid_search, bayesian, evolutionary, etc.)
65    pub method: String,
66    /// Maximum number of iterations/evaluations
67    #[serde(default = "default_max_iter")]
68    pub max_iter: usize,
69    /// Early stopping configuration
70    pub early_stopping: Option<EarlyStoppingConfig>,
71    /// Parameter space definition
72    pub parameter_space: HashMap<String, ParameterDefinition>,
73    /// Optimization-specific parameters
74    #[serde(default)]
75    pub parameters: HashMap<String, serde_json::Value>,
76}
77
78/// Early stopping configuration
79#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct EarlyStoppingConfig {
81    /// Enable early stopping
82    #[serde(default)]
83    pub enabled: bool,
84    /// Patience (number of iterations without improvement)
85    #[serde(default = "default_patience")]
86    pub patience: usize,
87    /// Minimum improvement required
88    #[serde(default = "default_min_delta")]
89    pub min_delta: f64,
90}
91
92/// Parameter definition for optimization
93#[derive(Debug, Clone, Serialize, Deserialize)]
94#[serde(tag = "type")]
95pub enum ParameterDefinition {
96    #[serde(rename = "uniform")]
97    Uniform { low: f64, high: f64 },
98    #[serde(rename = "log_uniform")]
99    LogUniform { low: f64, high: f64 },
100    #[serde(rename = "categorical")]
101    Categorical { choices: Vec<serde_json::Value> },
102    #[serde(rename = "integer")]
103    Integer { low: i64, high: i64 },
104    #[serde(rename = "choice")]
105    Choice { choices: Vec<serde_json::Value> },
106}
107
108/// Scoring configuration
109#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct ScoringConfig {
111    /// Primary scoring metric
112    pub primary: String,
113    /// Additional metrics to compute
114    #[serde(default)]
115    pub additional: Vec<String>,
116    /// Custom scoring functions
117    #[serde(default)]
118    pub custom_scorers: HashMap<String, serde_json::Value>,
119    /// Scoring parameters
120    #[serde(default)]
121    pub parameters: HashMap<String, serde_json::Value>,
122}
123
124/// Resource and performance configuration
125#[derive(Debug, Clone, Serialize, Deserialize)]
126pub struct ResourceConfig {
127    /// Number of parallel jobs (-1 for all cores)
128    #[serde(default = "default_n_jobs")]
129    pub n_jobs: i32,
130    /// Memory limit in MB
131    pub memory_limit: Option<usize>,
132    /// Enable GPU acceleration
133    #[serde(default)]
134    pub use_gpu: bool,
135    /// Batch size for parallel processing
136    #[serde(default = "default_batch_size")]
137    pub batch_size: usize,
138    /// Enable memory optimization
139    #[serde(default)]
140    pub memory_efficient: bool,
141}
142
143/// Configuration manager for loading, saving, and validating configurations
144pub struct ConfigManager {
145    base_config: Option<ModelSelectionConfig>,
146    template_registry: HashMap<String, ModelSelectionConfig>,
147}
148
149impl ConfigManager {
150    /// Create a new configuration manager
151    pub fn new() -> Self {
152        Self {
153            base_config: None,
154            template_registry: HashMap::new(),
155        }
156    }
157
158    /// Load configuration from a file (supports JSON and YAML)
159    pub fn load_from_file<P: AsRef<Path>>(
160        &mut self,
161        path: P,
162    ) -> Result<ModelSelectionConfig, ConfigError> {
163        let path = path.as_ref();
164        let content = std::fs::read_to_string(path)?;
165
166        let config = match path.extension().and_then(|s| s.to_str()) {
167            Some("json") => serde_json::from_str(&content)?,
168            Some("yaml") | Some("yml") => {
169                // For now, use JSON as a placeholder since YAML requires additional dependency
170                serde_json::from_str(&content).map_err(|e| ConfigError::Yaml(e.to_string()))?
171            }
172            _ => {
173                return Err(ConfigError::Validation(
174                    "Unsupported file format".to_string(),
175                ))
176            }
177        };
178
179        self.validate_config(&config)?;
180        Ok(config)
181    }
182
183    /// Save configuration to a file
184    pub fn save_to_file<P: AsRef<Path>>(
185        &self,
186        config: &ModelSelectionConfig,
187        path: P,
188    ) -> Result<(), ConfigError> {
189        let path = path.as_ref();
190        let content = match path.extension().and_then(|s| s.to_str()) {
191            Some("json") => serde_json::to_string_pretty(config)?,
192            Some("yaml") | Some("yml") => {
193                // For now, use JSON as a placeholder
194                serde_json::to_string_pretty(config)?
195            }
196            _ => {
197                return Err(ConfigError::Validation(
198                    "Unsupported file format".to_string(),
199                ))
200            }
201        };
202
203        std::fs::write(path, content)?;
204        Ok(())
205    }
206
207    /// Load configuration from a string
208    pub fn load_from_string(
209        &mut self,
210        content: &str,
211        format: &str,
212    ) -> Result<ModelSelectionConfig, ConfigError> {
213        let config = match format {
214            "json" => serde_json::from_str(content)?,
215            "yaml" | "yml" => {
216                serde_json::from_str(content).map_err(|e| ConfigError::Yaml(e.to_string()))?
217            }
218            _ => return Err(ConfigError::Validation("Unsupported format".to_string())),
219        };
220
221        self.validate_config(&config)?;
222        Ok(config)
223    }
224
225    /// Register a configuration template
226    pub fn register_template(
227        &mut self,
228        name: String,
229        config: ModelSelectionConfig,
230    ) -> Result<(), ConfigError> {
231        self.validate_config(&config)?;
232        self.template_registry.insert(name, config);
233        Ok(())
234    }
235
236    /// Get a configuration template
237    pub fn get_template(&self, name: &str) -> Option<&ModelSelectionConfig> {
238        self.template_registry.get(name)
239    }
240
241    /// Create configuration from template with overrides
242    pub fn from_template(
243        &self,
244        template_name: &str,
245        overrides: HashMap<String, serde_json::Value>,
246    ) -> Result<ModelSelectionConfig, ConfigError> {
247        let template = self.get_template(template_name).ok_or_else(|| {
248            ConfigError::Template(format!("Template '{}' not found", template_name))
249        })?;
250
251        let mut config = template.clone();
252        self.apply_overrides(&mut config, overrides)?;
253        self.validate_config(&config)?;
254
255        Ok(config)
256    }
257
258    /// Validate configuration
259    fn validate_config(&self, config: &ModelSelectionConfig) -> Result<(), ConfigError> {
260        // Validate cross-validation configuration
261        if config.cross_validation.n_folds < 2 {
262            return Err(ConfigError::Validation(
263                "n_folds must be at least 2".to_string(),
264            ));
265        }
266
267        // Validate optimization configuration
268        if config.optimization.max_iter == 0 {
269            return Err(ConfigError::Validation(
270                "max_iter must be greater than 0".to_string(),
271            ));
272        }
273
274        // Validate resource configuration
275        if config.resources.n_jobs == 0 {
276            return Err(ConfigError::Validation("n_jobs cannot be 0".to_string()));
277        }
278
279        if config.resources.batch_size == 0 {
280            return Err(ConfigError::Validation(
281                "batch_size must be greater than 0".to_string(),
282            ));
283        }
284
285        Ok(())
286    }
287
288    /// Apply configuration overrides
289    fn apply_overrides(
290        &self,
291        config: &mut ModelSelectionConfig,
292        overrides: HashMap<String, serde_json::Value>,
293    ) -> Result<(), ConfigError> {
294        for (key, value) in overrides {
295            self.apply_override(config, &key, value)?;
296        }
297        Ok(())
298    }
299
300    /// Apply a single override using dot notation (e.g., "cross_validation.n_folds")
301    fn apply_override(
302        &self,
303        config: &mut ModelSelectionConfig,
304        key: &str,
305        value: serde_json::Value,
306    ) -> Result<(), ConfigError> {
307        let parts: Vec<&str> = key.split('.').collect();
308        match parts.as_slice() {
309            ["cross_validation", "n_folds"] => {
310                if let Some(n) = value.as_u64() {
311                    config.cross_validation.n_folds = n as usize;
312                }
313            }
314            ["cross_validation", "random_state"] => {
315                if let Some(n) = value.as_u64() {
316                    config.cross_validation.random_state = Some(n);
317                }
318            }
319            ["optimization", "max_iter"] => {
320                if let Some(n) = value.as_u64() {
321                    config.optimization.max_iter = n as usize;
322                }
323            }
324            ["resources", "n_jobs"] => {
325                if let Some(n) = value.as_i64() {
326                    config.resources.n_jobs = n as i32;
327                }
328            }
329            _ => {
330                // Store in custom parameters
331                config.custom.insert(key.to_string(), value);
332            }
333        }
334        Ok(())
335    }
336
337    /// Get default configuration templates
338    pub fn load_default_templates(&mut self) {
339        // Grid search template
340        let grid_search_config = ModelSelectionConfig {
341            cross_validation: CrossValidationConfig {
342                method: "kfold".to_string(),
343                n_folds: 5,
344                random_state: Some(42),
345                shuffle: true,
346                parameters: HashMap::new(),
347            },
348            optimization: OptimizationConfig {
349                method: "grid_search".to_string(),
350                max_iter: 100,
351                early_stopping: None,
352                parameter_space: HashMap::new(),
353                parameters: HashMap::new(),
354            },
355            scoring: ScoringConfig {
356                primary: "accuracy".to_string(),
357                additional: vec!["precision".to_string(), "recall".to_string()],
358                custom_scorers: HashMap::new(),
359                parameters: HashMap::new(),
360            },
361            resources: ResourceConfig {
362                n_jobs: -1,
363                memory_limit: None,
364                use_gpu: false,
365                batch_size: 32,
366                memory_efficient: false,
367            },
368            custom: HashMap::new(),
369        };
370
371        // Bayesian optimization template
372        let bayesian_config = ModelSelectionConfig {
373            optimization: OptimizationConfig {
374                method: "bayesian".to_string(),
375                max_iter: 50,
376                early_stopping: Some(EarlyStoppingConfig {
377                    enabled: true,
378                    patience: 5,
379                    min_delta: 0.001,
380                }),
381                parameter_space: HashMap::new(),
382                parameters: HashMap::new(),
383            },
384            ..grid_search_config.clone()
385        };
386
387        self.template_registry
388            .insert("grid_search".to_string(), grid_search_config);
389        self.template_registry
390            .insert("bayesian".to_string(), bayesian_config);
391    }
392}
393
394impl Default for ConfigManager {
395    fn default() -> Self {
396        let mut manager = Self::new();
397        manager.load_default_templates();
398        manager
399    }
400}
401
402// Default value functions
403fn default_n_folds() -> usize {
404    5
405}
406fn default_shuffle() -> bool {
407    true
408}
409fn default_max_iter() -> usize {
410    100
411}
412fn default_patience() -> usize {
413    5
414}
415fn default_min_delta() -> f64 {
416    0.001
417}
418fn default_n_jobs() -> i32 {
419    -1
420}
421fn default_batch_size() -> usize {
422    32
423}
424
425impl Default for ModelSelectionConfig {
426    fn default() -> Self {
427        Self {
428            cross_validation: CrossValidationConfig {
429                method: "kfold".to_string(),
430                n_folds: default_n_folds(),
431                random_state: Some(42),
432                shuffle: default_shuffle(),
433                parameters: HashMap::new(),
434            },
435            optimization: OptimizationConfig {
436                method: "grid_search".to_string(),
437                max_iter: default_max_iter(),
438                early_stopping: None,
439                parameter_space: HashMap::new(),
440                parameters: HashMap::new(),
441            },
442            scoring: ScoringConfig {
443                primary: "accuracy".to_string(),
444                additional: vec!["precision".to_string(), "recall".to_string()],
445                custom_scorers: HashMap::new(),
446                parameters: HashMap::new(),
447            },
448            resources: ResourceConfig {
449                n_jobs: default_n_jobs(),
450                memory_limit: None,
451                use_gpu: false,
452                batch_size: default_batch_size(),
453                memory_efficient: false,
454            },
455            custom: HashMap::new(),
456        }
457    }
458}
459
460#[allow(non_snake_case)]
461#[cfg(test)]
462mod tests {
463    use super::*;
464
465    #[test]
466    fn test_config_manager_creation() {
467        let manager = ConfigManager::new();
468        assert!(manager.base_config.is_none());
469        assert!(manager.template_registry.is_empty());
470    }
471
472    #[test]
473    fn test_default_config() {
474        let config = ModelSelectionConfig::default();
475        assert_eq!(config.cross_validation.method, "kfold");
476        assert_eq!(config.cross_validation.n_folds, 5);
477        assert_eq!(config.optimization.method, "grid_search");
478        assert_eq!(config.optimization.max_iter, 100);
479        assert_eq!(config.scoring.primary, "accuracy");
480        assert_eq!(config.resources.n_jobs, -1);
481    }
482
483    #[test]
484    fn test_config_validation() {
485        let manager = ConfigManager::new();
486        let mut config = ModelSelectionConfig::default();
487
488        // Valid configuration should pass
489        assert!(manager.validate_config(&config).is_ok());
490
491        // Invalid n_folds should fail
492        config.cross_validation.n_folds = 1;
493        assert!(manager.validate_config(&config).is_err());
494
495        // Reset and test invalid max_iter
496        config.cross_validation.n_folds = 5;
497        config.optimization.max_iter = 0;
498        assert!(manager.validate_config(&config).is_err());
499    }
500
501    #[test]
502    fn test_template_registration() {
503        let mut manager = ConfigManager::new();
504        let config = ModelSelectionConfig::default();
505
506        assert!(manager
507            .register_template("test_template".to_string(), config)
508            .is_ok());
509        assert!(manager.get_template("test_template").is_some());
510        assert!(manager.get_template("nonexistent").is_none());
511    }
512
513    #[test]
514    fn test_json_serialization() {
515        let config = ModelSelectionConfig::default();
516        let json = serde_json::to_string(&config).unwrap();
517        let deserialized: ModelSelectionConfig = serde_json::from_str(&json).unwrap();
518
519        assert_eq!(
520            config.cross_validation.method,
521            deserialized.cross_validation.method
522        );
523        assert_eq!(
524            config.optimization.max_iter,
525            deserialized.optimization.max_iter
526        );
527    }
528
529    #[test]
530    fn test_override_application() {
531        let mut manager = ConfigManager::default();
532        let template_config = ModelSelectionConfig::default();
533        manager
534            .register_template("test".to_string(), template_config)
535            .unwrap();
536
537        let mut overrides = HashMap::new();
538        overrides.insert(
539            "cross_validation.n_folds".to_string(),
540            serde_json::Value::from(10),
541        );
542        overrides.insert(
543            "optimization.max_iter".to_string(),
544            serde_json::Value::from(200),
545        );
546
547        let config = manager.from_template("test", overrides).unwrap();
548        assert_eq!(config.cross_validation.n_folds, 10);
549        assert_eq!(config.optimization.max_iter, 200);
550    }
551
552    #[test]
553    fn test_default_templates() {
554        let manager = ConfigManager::default();
555
556        assert!(manager.get_template("grid_search").is_some());
557        assert!(manager.get_template("bayesian").is_some());
558
559        let grid_template = manager.get_template("grid_search").unwrap();
560        assert_eq!(grid_template.optimization.method, "grid_search");
561
562        let bayesian_template = manager.get_template("bayesian").unwrap();
563        assert_eq!(bayesian_template.optimization.method, "bayesian");
564        assert!(bayesian_template.optimization.early_stopping.is_some());
565    }
566}