1use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::path::Path;
10use thiserror::Error;
11
12#[derive(Error, Debug)]
14pub enum ConfigError {
15 #[error("IO error: {0}")]
16 Io(#[from] std::io::Error),
17 #[error("Serialization error: {0}")]
18 Serialization(#[from] serde_json::Error),
19 #[error("YAML error: {0}")]
20 Yaml(String),
21 #[error("Configuration validation error: {0}")]
22 Validation(String),
23 #[error("Template error: {0}")]
24 Template(String),
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct ModelSelectionConfig {
30 pub cross_validation: CrossValidationConfig,
32 pub optimization: OptimizationConfig,
34 pub scoring: ScoringConfig,
36 pub resources: ResourceConfig,
38 #[serde(default)]
40 pub custom: HashMap<String, serde_json::Value>,
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct CrossValidationConfig {
46 pub method: String,
48 #[serde(default = "default_n_folds")]
50 pub n_folds: usize,
51 pub random_state: Option<u64>,
53 #[serde(default = "default_shuffle")]
55 pub shuffle: bool,
56 #[serde(default)]
58 pub parameters: HashMap<String, serde_json::Value>,
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct OptimizationConfig {
64 pub method: String,
66 #[serde(default = "default_max_iter")]
68 pub max_iter: usize,
69 pub early_stopping: Option<EarlyStoppingConfig>,
71 pub parameter_space: HashMap<String, ParameterDefinition>,
73 #[serde(default)]
75 pub parameters: HashMap<String, serde_json::Value>,
76}
77
78#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct EarlyStoppingConfig {
81 #[serde(default)]
83 pub enabled: bool,
84 #[serde(default = "default_patience")]
86 pub patience: usize,
87 #[serde(default = "default_min_delta")]
89 pub min_delta: f64,
90}
91
92#[derive(Debug, Clone, Serialize, Deserialize)]
94#[serde(tag = "type")]
95pub enum ParameterDefinition {
96 #[serde(rename = "uniform")]
97 Uniform { low: f64, high: f64 },
98 #[serde(rename = "log_uniform")]
99 LogUniform { low: f64, high: f64 },
100 #[serde(rename = "categorical")]
101 Categorical { choices: Vec<serde_json::Value> },
102 #[serde(rename = "integer")]
103 Integer { low: i64, high: i64 },
104 #[serde(rename = "choice")]
105 Choice { choices: Vec<serde_json::Value> },
106}
107
108#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct ScoringConfig {
111 pub primary: String,
113 #[serde(default)]
115 pub additional: Vec<String>,
116 #[serde(default)]
118 pub custom_scorers: HashMap<String, serde_json::Value>,
119 #[serde(default)]
121 pub parameters: HashMap<String, serde_json::Value>,
122}
123
124#[derive(Debug, Clone, Serialize, Deserialize)]
126pub struct ResourceConfig {
127 #[serde(default = "default_n_jobs")]
129 pub n_jobs: i32,
130 pub memory_limit: Option<usize>,
132 #[serde(default)]
134 pub use_gpu: bool,
135 #[serde(default = "default_batch_size")]
137 pub batch_size: usize,
138 #[serde(default)]
140 pub memory_efficient: bool,
141}
142
143pub struct ConfigManager {
145 base_config: Option<ModelSelectionConfig>,
146 template_registry: HashMap<String, ModelSelectionConfig>,
147}
148
149impl ConfigManager {
150 pub fn new() -> Self {
152 Self {
153 base_config: None,
154 template_registry: HashMap::new(),
155 }
156 }
157
158 pub fn load_from_file<P: AsRef<Path>>(
160 &mut self,
161 path: P,
162 ) -> Result<ModelSelectionConfig, ConfigError> {
163 let path = path.as_ref();
164 let content = std::fs::read_to_string(path)?;
165
166 let config = match path.extension().and_then(|s| s.to_str()) {
167 Some("json") => serde_json::from_str(&content)?,
168 Some("yaml") | Some("yml") => {
169 serde_json::from_str(&content).map_err(|e| ConfigError::Yaml(e.to_string()))?
171 }
172 _ => {
173 return Err(ConfigError::Validation(
174 "Unsupported file format".to_string(),
175 ))
176 }
177 };
178
179 self.validate_config(&config)?;
180 Ok(config)
181 }
182
183 pub fn save_to_file<P: AsRef<Path>>(
185 &self,
186 config: &ModelSelectionConfig,
187 path: P,
188 ) -> Result<(), ConfigError> {
189 let path = path.as_ref();
190 let content = match path.extension().and_then(|s| s.to_str()) {
191 Some("json") => serde_json::to_string_pretty(config)?,
192 Some("yaml") | Some("yml") => {
193 serde_json::to_string_pretty(config)?
195 }
196 _ => {
197 return Err(ConfigError::Validation(
198 "Unsupported file format".to_string(),
199 ))
200 }
201 };
202
203 std::fs::write(path, content)?;
204 Ok(())
205 }
206
207 pub fn load_from_string(
209 &mut self,
210 content: &str,
211 format: &str,
212 ) -> Result<ModelSelectionConfig, ConfigError> {
213 let config = match format {
214 "json" => serde_json::from_str(content)?,
215 "yaml" | "yml" => {
216 serde_json::from_str(content).map_err(|e| ConfigError::Yaml(e.to_string()))?
217 }
218 _ => return Err(ConfigError::Validation("Unsupported format".to_string())),
219 };
220
221 self.validate_config(&config)?;
222 Ok(config)
223 }
224
225 pub fn register_template(
227 &mut self,
228 name: String,
229 config: ModelSelectionConfig,
230 ) -> Result<(), ConfigError> {
231 self.validate_config(&config)?;
232 self.template_registry.insert(name, config);
233 Ok(())
234 }
235
236 pub fn get_template(&self, name: &str) -> Option<&ModelSelectionConfig> {
238 self.template_registry.get(name)
239 }
240
241 pub fn from_template(
243 &self,
244 template_name: &str,
245 overrides: HashMap<String, serde_json::Value>,
246 ) -> Result<ModelSelectionConfig, ConfigError> {
247 let template = self.get_template(template_name).ok_or_else(|| {
248 ConfigError::Template(format!("Template '{}' not found", template_name))
249 })?;
250
251 let mut config = template.clone();
252 self.apply_overrides(&mut config, overrides)?;
253 self.validate_config(&config)?;
254
255 Ok(config)
256 }
257
258 fn validate_config(&self, config: &ModelSelectionConfig) -> Result<(), ConfigError> {
260 if config.cross_validation.n_folds < 2 {
262 return Err(ConfigError::Validation(
263 "n_folds must be at least 2".to_string(),
264 ));
265 }
266
267 if config.optimization.max_iter == 0 {
269 return Err(ConfigError::Validation(
270 "max_iter must be greater than 0".to_string(),
271 ));
272 }
273
274 if config.resources.n_jobs == 0 {
276 return Err(ConfigError::Validation("n_jobs cannot be 0".to_string()));
277 }
278
279 if config.resources.batch_size == 0 {
280 return Err(ConfigError::Validation(
281 "batch_size must be greater than 0".to_string(),
282 ));
283 }
284
285 Ok(())
286 }
287
288 fn apply_overrides(
290 &self,
291 config: &mut ModelSelectionConfig,
292 overrides: HashMap<String, serde_json::Value>,
293 ) -> Result<(), ConfigError> {
294 for (key, value) in overrides {
295 self.apply_override(config, &key, value)?;
296 }
297 Ok(())
298 }
299
300 fn apply_override(
302 &self,
303 config: &mut ModelSelectionConfig,
304 key: &str,
305 value: serde_json::Value,
306 ) -> Result<(), ConfigError> {
307 let parts: Vec<&str> = key.split('.').collect();
308 match parts.as_slice() {
309 ["cross_validation", "n_folds"] => {
310 if let Some(n) = value.as_u64() {
311 config.cross_validation.n_folds = n as usize;
312 }
313 }
314 ["cross_validation", "random_state"] => {
315 if let Some(n) = value.as_u64() {
316 config.cross_validation.random_state = Some(n);
317 }
318 }
319 ["optimization", "max_iter"] => {
320 if let Some(n) = value.as_u64() {
321 config.optimization.max_iter = n as usize;
322 }
323 }
324 ["resources", "n_jobs"] => {
325 if let Some(n) = value.as_i64() {
326 config.resources.n_jobs = n as i32;
327 }
328 }
329 _ => {
330 config.custom.insert(key.to_string(), value);
332 }
333 }
334 Ok(())
335 }
336
337 pub fn load_default_templates(&mut self) {
339 let grid_search_config = ModelSelectionConfig {
341 cross_validation: CrossValidationConfig {
342 method: "kfold".to_string(),
343 n_folds: 5,
344 random_state: Some(42),
345 shuffle: true,
346 parameters: HashMap::new(),
347 },
348 optimization: OptimizationConfig {
349 method: "grid_search".to_string(),
350 max_iter: 100,
351 early_stopping: None,
352 parameter_space: HashMap::new(),
353 parameters: HashMap::new(),
354 },
355 scoring: ScoringConfig {
356 primary: "accuracy".to_string(),
357 additional: vec!["precision".to_string(), "recall".to_string()],
358 custom_scorers: HashMap::new(),
359 parameters: HashMap::new(),
360 },
361 resources: ResourceConfig {
362 n_jobs: -1,
363 memory_limit: None,
364 use_gpu: false,
365 batch_size: 32,
366 memory_efficient: false,
367 },
368 custom: HashMap::new(),
369 };
370
371 let bayesian_config = ModelSelectionConfig {
373 optimization: OptimizationConfig {
374 method: "bayesian".to_string(),
375 max_iter: 50,
376 early_stopping: Some(EarlyStoppingConfig {
377 enabled: true,
378 patience: 5,
379 min_delta: 0.001,
380 }),
381 parameter_space: HashMap::new(),
382 parameters: HashMap::new(),
383 },
384 ..grid_search_config.clone()
385 };
386
387 self.template_registry
388 .insert("grid_search".to_string(), grid_search_config);
389 self.template_registry
390 .insert("bayesian".to_string(), bayesian_config);
391 }
392}
393
394impl Default for ConfigManager {
395 fn default() -> Self {
396 let mut manager = Self::new();
397 manager.load_default_templates();
398 manager
399 }
400}
401
402fn default_n_folds() -> usize {
404 5
405}
406fn default_shuffle() -> bool {
407 true
408}
409fn default_max_iter() -> usize {
410 100
411}
412fn default_patience() -> usize {
413 5
414}
415fn default_min_delta() -> f64 {
416 0.001
417}
418fn default_n_jobs() -> i32 {
419 -1
420}
421fn default_batch_size() -> usize {
422 32
423}
424
425impl Default for ModelSelectionConfig {
426 fn default() -> Self {
427 Self {
428 cross_validation: CrossValidationConfig {
429 method: "kfold".to_string(),
430 n_folds: default_n_folds(),
431 random_state: Some(42),
432 shuffle: default_shuffle(),
433 parameters: HashMap::new(),
434 },
435 optimization: OptimizationConfig {
436 method: "grid_search".to_string(),
437 max_iter: default_max_iter(),
438 early_stopping: None,
439 parameter_space: HashMap::new(),
440 parameters: HashMap::new(),
441 },
442 scoring: ScoringConfig {
443 primary: "accuracy".to_string(),
444 additional: vec!["precision".to_string(), "recall".to_string()],
445 custom_scorers: HashMap::new(),
446 parameters: HashMap::new(),
447 },
448 resources: ResourceConfig {
449 n_jobs: default_n_jobs(),
450 memory_limit: None,
451 use_gpu: false,
452 batch_size: default_batch_size(),
453 memory_efficient: false,
454 },
455 custom: HashMap::new(),
456 }
457 }
458}
459
460#[allow(non_snake_case)]
461#[cfg(test)]
462mod tests {
463 use super::*;
464
465 #[test]
466 fn test_config_manager_creation() {
467 let manager = ConfigManager::new();
468 assert!(manager.base_config.is_none());
469 assert!(manager.template_registry.is_empty());
470 }
471
472 #[test]
473 fn test_default_config() {
474 let config = ModelSelectionConfig::default();
475 assert_eq!(config.cross_validation.method, "kfold");
476 assert_eq!(config.cross_validation.n_folds, 5);
477 assert_eq!(config.optimization.method, "grid_search");
478 assert_eq!(config.optimization.max_iter, 100);
479 assert_eq!(config.scoring.primary, "accuracy");
480 assert_eq!(config.resources.n_jobs, -1);
481 }
482
483 #[test]
484 fn test_config_validation() {
485 let manager = ConfigManager::new();
486 let mut config = ModelSelectionConfig::default();
487
488 assert!(manager.validate_config(&config).is_ok());
490
491 config.cross_validation.n_folds = 1;
493 assert!(manager.validate_config(&config).is_err());
494
495 config.cross_validation.n_folds = 5;
497 config.optimization.max_iter = 0;
498 assert!(manager.validate_config(&config).is_err());
499 }
500
501 #[test]
502 fn test_template_registration() {
503 let mut manager = ConfigManager::new();
504 let config = ModelSelectionConfig::default();
505
506 assert!(manager
507 .register_template("test_template".to_string(), config)
508 .is_ok());
509 assert!(manager.get_template("test_template").is_some());
510 assert!(manager.get_template("nonexistent").is_none());
511 }
512
513 #[test]
514 fn test_json_serialization() {
515 let config = ModelSelectionConfig::default();
516 let json = serde_json::to_string(&config).unwrap();
517 let deserialized: ModelSelectionConfig = serde_json::from_str(&json).unwrap();
518
519 assert_eq!(
520 config.cross_validation.method,
521 deserialized.cross_validation.method
522 );
523 assert_eq!(
524 config.optimization.max_iter,
525 deserialized.optimization.max_iter
526 );
527 }
528
529 #[test]
530 fn test_override_application() {
531 let mut manager = ConfigManager::default();
532 let template_config = ModelSelectionConfig::default();
533 manager
534 .register_template("test".to_string(), template_config)
535 .unwrap();
536
537 let mut overrides = HashMap::new();
538 overrides.insert(
539 "cross_validation.n_folds".to_string(),
540 serde_json::Value::from(10),
541 );
542 overrides.insert(
543 "optimization.max_iter".to_string(),
544 serde_json::Value::from(200),
545 );
546
547 let config = manager.from_template("test", overrides).unwrap();
548 assert_eq!(config.cross_validation.n_folds, 10);
549 assert_eq!(config.optimization.max_iter, 200);
550 }
551
552 #[test]
553 fn test_default_templates() {
554 let manager = ConfigManager::default();
555
556 assert!(manager.get_template("grid_search").is_some());
557 assert!(manager.get_template("bayesian").is_some());
558
559 let grid_template = manager.get_template("grid_search").unwrap();
560 assert_eq!(grid_template.optimization.method, "grid_search");
561
562 let bayesian_template = manager.get_template("bayesian").unwrap();
563 assert_eq!(bayesian_template.optimization.method, "bayesian");
564 assert!(bayesian_template.optimization.early_stopping.is_some());
565 }
566}