sklears_neural/
validation.rs

1//! Hyperparameter validation and configuration management for neural networks.
2//!
3//! This module provides comprehensive validation of neural network hyperparameters,
4//! configuration templates, and automatic parameter tuning support.
5
6use crate::NeuralResult;
7use sklears_core::error::SklearsError;
8use std::collections::HashMap;
9
10#[cfg(feature = "serde")]
11use serde_json;
12
13/// Range constraint for numeric parameters
14#[derive(Debug, Clone, PartialEq)]
15pub enum RangeConstraint<T> {
16    /// Parameter must be greater than value
17    GreaterThan(T),
18    /// Parameter must be greater than or equal to value
19    GreaterEqualThan(T),
20    /// Parameter must be less than value
21    LessThan(T),
22    /// Parameter must be less than or equal to value
23    LessEqualThan(T),
24    /// Parameter must be within inclusive range
25    Range(T, T),
26    /// Parameter must be one of specific values
27    OneOf(Vec<T>),
28    /// Parameter must be positive
29    Positive,
30    /// Parameter must be non-negative
31    NonNegative,
32    /// No constraint
33    Any,
34}
35
36impl RangeConstraint<f64> {
37    /// Validate that a value satisfies the constraint (f64 version)
38    pub fn validate_f64(&self, value: f64, param_name: &str) -> NeuralResult<()> {
39        match self {
40            RangeConstraint::GreaterThan(threshold) => {
41                if value <= *threshold {
42                    return Err(SklearsError::InvalidParameter {
43                        name: param_name.to_string(),
44                        reason: format!("Value {} must be greater than {}", value, threshold),
45                    });
46                }
47            }
48            RangeConstraint::GreaterEqualThan(threshold) => {
49                if value < *threshold {
50                    return Err(SklearsError::InvalidParameter {
51                        name: param_name.to_string(),
52                        reason: format!(
53                            "Value {} must be greater than or equal to {}",
54                            value, threshold
55                        ),
56                    });
57                }
58            }
59            RangeConstraint::LessThan(threshold) => {
60                if value >= *threshold {
61                    return Err(SklearsError::InvalidParameter {
62                        name: param_name.to_string(),
63                        reason: format!("Value {} must be less than {}", value, threshold),
64                    });
65                }
66            }
67            RangeConstraint::LessEqualThan(threshold) => {
68                if value > *threshold {
69                    return Err(SklearsError::InvalidParameter {
70                        name: param_name.to_string(),
71                        reason: format!(
72                            "Value {} must be less than or equal to {}",
73                            value, threshold
74                        ),
75                    });
76                }
77            }
78            RangeConstraint::Range(min_val, max_val) => {
79                if value < *min_val || value > *max_val {
80                    return Err(SklearsError::InvalidParameter {
81                        name: param_name.to_string(),
82                        reason: format!(
83                            "Value {} must be between {} and {}",
84                            value, min_val, max_val
85                        ),
86                    });
87                }
88            }
89            RangeConstraint::OneOf(valid_values) => {
90                if !valid_values.contains(&value) {
91                    return Err(SklearsError::InvalidParameter {
92                        name: param_name.to_string(),
93                        reason: format!("Value {} must be one of: {:?}", value, valid_values),
94                    });
95                }
96            }
97            RangeConstraint::Positive => {
98                if value <= 0.0 {
99                    return Err(SklearsError::InvalidParameter {
100                        name: param_name.to_string(),
101                        reason: format!("Value {} must be positive", value),
102                    });
103                }
104            }
105            RangeConstraint::NonNegative => {
106                if value < 0.0 {
107                    return Err(SklearsError::InvalidParameter {
108                        name: param_name.to_string(),
109                        reason: format!("Value {} must be non-negative", value),
110                    });
111                }
112            }
113            RangeConstraint::Any => {
114                // No constraint
115            }
116        }
117        Ok(())
118    }
119}
120
121impl RangeConstraint<i64> {
122    /// Validate that a value satisfies the constraint (i64 version)
123    pub fn validate_i64(&self, value: i64, param_name: &str) -> NeuralResult<()> {
124        match self {
125            RangeConstraint::GreaterThan(threshold) => {
126                if value <= *threshold {
127                    return Err(SklearsError::InvalidParameter {
128                        name: param_name.to_string(),
129                        reason: format!("Value {} must be greater than {}", value, threshold),
130                    });
131                }
132            }
133            RangeConstraint::GreaterEqualThan(threshold) => {
134                if value < *threshold {
135                    return Err(SklearsError::InvalidParameter {
136                        name: param_name.to_string(),
137                        reason: format!(
138                            "Value {} must be greater than or equal to {}",
139                            value, threshold
140                        ),
141                    });
142                }
143            }
144            RangeConstraint::LessThan(threshold) => {
145                if value >= *threshold {
146                    return Err(SklearsError::InvalidParameter {
147                        name: param_name.to_string(),
148                        reason: format!("Value {} must be less than {}", value, threshold),
149                    });
150                }
151            }
152            RangeConstraint::LessEqualThan(threshold) => {
153                if value > *threshold {
154                    return Err(SklearsError::InvalidParameter {
155                        name: param_name.to_string(),
156                        reason: format!(
157                            "Value {} must be less than or equal to {}",
158                            value, threshold
159                        ),
160                    });
161                }
162            }
163            RangeConstraint::Range(min_val, max_val) => {
164                if value < *min_val || value > *max_val {
165                    return Err(SklearsError::InvalidParameter {
166                        name: param_name.to_string(),
167                        reason: format!(
168                            "Value {} must be between {} and {}",
169                            value, min_val, max_val
170                        ),
171                    });
172                }
173            }
174            RangeConstraint::OneOf(valid_values) => {
175                if !valid_values.contains(&value) {
176                    return Err(SklearsError::InvalidParameter {
177                        name: param_name.to_string(),
178                        reason: format!("Value {} must be one of: {:?}", value, valid_values),
179                    });
180                }
181            }
182            RangeConstraint::Positive => {
183                if value <= 0 {
184                    return Err(SklearsError::InvalidParameter {
185                        name: param_name.to_string(),
186                        reason: format!("Value {} must be positive", value),
187                    });
188                }
189            }
190            RangeConstraint::NonNegative => {
191                if value < 0 {
192                    return Err(SklearsError::InvalidParameter {
193                        name: param_name.to_string(),
194                        reason: format!("Value {} must be non-negative", value),
195                    });
196                }
197            }
198            RangeConstraint::Any => {
199                // No constraint
200            }
201        }
202        Ok(())
203    }
204}
205
206/// Parameter validation rule
207#[derive(Debug, Clone)]
208pub struct ValidationRule {
209    /// Parameter name
210    pub name: String,
211    /// Description of the parameter
212    pub description: String,
213    /// Whether the parameter is required
214    pub required: bool,
215    /// Numeric constraints (for numeric parameters)
216    pub numeric_constraint: Option<RangeConstraint<f64>>,
217    /// Integer constraints (for integer parameters)
218    pub integer_constraint: Option<RangeConstraint<i64>>,
219    /// String constraints (for string parameters)
220    pub string_constraint: Option<Vec<String>>,
221    /// Custom validation function
222    #[cfg(feature = "serde")]
223    pub custom_validator: Option<fn(&serde_json::Value) -> NeuralResult<()>>,
224    /// Default value (if not required)
225    #[cfg(feature = "serde")]
226    pub default_value: Option<serde_json::Value>,
227}
228
229impl ValidationRule {
230    /// Create a new validation rule (basic version)
231    pub fn new(name: String, description: String) -> Self {
232        Self {
233            name,
234            description,
235            required: false,
236            numeric_constraint: None,
237            integer_constraint: None,
238            string_constraint: None,
239            #[cfg(feature = "serde")]
240            custom_validator: None,
241            #[cfg(feature = "serde")]
242            default_value: None,
243        }
244    }
245
246    /// Mark parameter as required
247    pub fn required(mut self) -> Self {
248        self.required = true;
249        self
250    }
251
252    /// Add numeric constraint
253    pub fn with_numeric_constraint(mut self, constraint: RangeConstraint<f64>) -> Self {
254        self.numeric_constraint = Some(constraint);
255        self
256    }
257
258    /// Add integer constraint
259    pub fn with_integer_constraint(mut self, constraint: RangeConstraint<i64>) -> Self {
260        self.integer_constraint = Some(constraint);
261        self
262    }
263
264    /// Add string constraint (allowed values)
265    pub fn with_string_constraint(mut self, allowed_values: Vec<String>) -> Self {
266        self.string_constraint = Some(allowed_values);
267        self
268    }
269}
270
271#[cfg(feature = "serde")]
272impl ValidationRule {
273    /// Add custom validator
274    pub fn with_custom_validator(
275        mut self,
276        validator: fn(&serde_json::Value) -> NeuralResult<()>,
277    ) -> Self {
278        self.custom_validator = Some(validator);
279        self
280    }
281
282    /// Set default value
283    pub fn with_default(mut self, default_value: serde_json::Value) -> Self {
284        self.default_value = Some(default_value);
285        self.required = false; // Can't be required if has default
286        self
287    }
288
289    /// Validate a parameter value
290    pub fn validate(&self, value: Option<&serde_json::Value>) -> NeuralResult<()> {
291        match value {
292            Some(val) => {
293                // Validate numeric constraints
294                if let Some(ref constraint) = self.numeric_constraint {
295                    if val.is_null() {
296                        // Allow null values for optional parameters
297                        if self.required {
298                            return Err(SklearsError::InvalidParameter {
299                                name: self.name.clone(),
300                                reason: "Required parameter cannot be null".to_string(),
301                            });
302                        }
303                    } else if let Some(num_val) = val.as_f64() {
304                        constraint.validate_f64(num_val, &self.name)?;
305                    } else {
306                        return Err(SklearsError::InvalidParameter {
307                            name: self.name.clone(),
308                            reason: "Expected numeric value".to_string(),
309                        });
310                    }
311                }
312
313                // Validate integer constraints
314                if let Some(ref constraint) = self.integer_constraint {
315                    if val.is_null() {
316                        // Allow null values for optional parameters
317                        if self.required {
318                            return Err(SklearsError::InvalidParameter {
319                                name: self.name.clone(),
320                                reason: "Required parameter cannot be null".to_string(),
321                            });
322                        }
323                    } else if let Some(int_val) = val.as_i64() {
324                        constraint.validate_i64(int_val, &self.name)?;
325                    } else {
326                        return Err(SklearsError::InvalidParameter {
327                            name: self.name.clone(),
328                            reason: "Expected integer value".to_string(),
329                        });
330                    }
331                }
332
333                // Validate string constraints
334                if let Some(ref allowed_values) = self.string_constraint {
335                    if let Some(str_val) = val.as_str() {
336                        if !allowed_values.contains(&str_val.to_string()) {
337                            return Err(SklearsError::InvalidParameter {
338                                name: self.name.clone(),
339                                reason: format!(
340                                    "Value '{}' must be one of: {:?}",
341                                    str_val, allowed_values
342                                ),
343                            });
344                        }
345                    } else {
346                        return Err(SklearsError::InvalidParameter {
347                            name: self.name.clone(),
348                            reason: "Expected string value".to_string(),
349                        });
350                    }
351                }
352
353                // Run custom validator
354                if let Some(validator) = self.custom_validator {
355                    validator(val)?;
356                }
357            }
358            None => {
359                if self.required {
360                    return Err(SklearsError::InvalidParameter {
361                        name: self.name.clone(),
362                        reason: "Required parameter is missing".to_string(),
363                    });
364                }
365            }
366        }
367        Ok(())
368    }
369}
370
371/// Hyperparameter validator for neural networks
372pub struct HyperparameterValidator {
373    /// Validation rules for each parameter
374    rules: HashMap<String, ValidationRule>,
375    /// Model type this validator is for
376    model_type: String,
377}
378
379impl HyperparameterValidator {
380    /// Create a new hyperparameter validator
381    pub fn new(model_type: String) -> Self {
382        Self {
383            rules: HashMap::new(),
384            model_type,
385        }
386    }
387
388    /// Add a validation rule
389    pub fn add_rule(mut self, rule: ValidationRule) -> Self {
390        self.rules.insert(rule.name.clone(), rule);
391        self
392    }
393
394    /// Add multiple validation rules
395    pub fn add_rules(mut self, rules: Vec<ValidationRule>) -> Self {
396        for rule in rules {
397            self.rules.insert(rule.name.clone(), rule);
398        }
399        self
400    }
401}
402
403#[cfg(feature = "serde")]
404impl HyperparameterValidator {
405    /// Validate hyperparameters
406    pub fn validate(&self, params: &HashMap<String, serde_json::Value>) -> NeuralResult<()> {
407        // Check all rules
408        for rule in self.rules.values() {
409            let param_value = params.get(&rule.name);
410            rule.validate(param_value)?;
411        }
412
413        // Check for unknown parameters
414        for param_name in params.keys() {
415            if !self.rules.contains_key(param_name) {
416                log::warn!(
417                    "Unknown parameter '{}' for model type '{}'",
418                    param_name,
419                    self.model_type
420                );
421            }
422        }
423
424        Ok(())
425    }
426
427    /// Get parameter with default value if missing
428    pub fn get_parameter_with_default(
429        &self,
430        params: &HashMap<String, serde_json::Value>,
431        param_name: &str,
432    ) -> NeuralResult<Option<serde_json::Value>> {
433        if let Some(value) = params.get(param_name) {
434            Ok(Some(value.clone()))
435        } else if let Some(rule) = self.rules.get(param_name) {
436            Ok(rule.default_value.clone())
437        } else {
438            Ok(None)
439        }
440    }
441
442    /// Fill in missing parameters with default values
443    pub fn apply_defaults(
444        &self,
445        params: &mut HashMap<String, serde_json::Value>,
446    ) -> NeuralResult<()> {
447        for rule in self.rules.values() {
448            if !params.contains_key(&rule.name) {
449                if let Some(ref default_value) = rule.default_value {
450                    params.insert(rule.name.clone(), default_value.clone());
451                }
452            }
453        }
454        Ok(())
455    }
456
457    /// Get validation summary
458    pub fn get_validation_summary(&self) -> ValidationSummary {
459        let mut required_params = Vec::new();
460        let mut optional_params = Vec::new();
461
462        for rule in self.rules.values() {
463            let param_info = ParameterInfo {
464                name: rule.name.clone(),
465                description: rule.description.clone(),
466                required: rule.required,
467                default_value: rule.default_value.clone(),
468                constraints: self.get_constraint_description(rule),
469            };
470
471            if rule.required {
472                required_params.push(param_info);
473            } else {
474                optional_params.push(param_info);
475            }
476        }
477
478        ValidationSummary {
479            model_type: self.model_type.clone(),
480            required_params,
481            optional_params,
482        }
483    }
484
485    fn get_constraint_description(&self, rule: &ValidationRule) -> Vec<String> {
486        let mut constraints = Vec::new();
487
488        if let Some(ref numeric_constraint) = rule.numeric_constraint {
489            constraints.push(format!("Numeric: {:?}", numeric_constraint));
490        }
491
492        if let Some(ref integer_constraint) = rule.integer_constraint {
493            constraints.push(format!("Integer: {:?}", integer_constraint));
494        }
495
496        if let Some(ref string_constraint) = rule.string_constraint {
497            constraints.push(format!("String options: {:?}", string_constraint));
498        }
499
500        if rule.custom_validator.is_some() {
501            constraints.push("Custom validation".to_string());
502        }
503
504        constraints
505    }
506}
507
508/// Parameter information for documentation
509#[derive(Debug, Clone)]
510pub struct ParameterInfo {
511    pub name: String,
512    pub description: String,
513    pub required: bool,
514    #[cfg(feature = "serde")]
515    pub default_value: Option<serde_json::Value>,
516    pub constraints: Vec<String>,
517}
518
519/// Validation summary for documentation
520#[derive(Debug, Clone)]
521pub struct ValidationSummary {
522    pub model_type: String,
523    pub required_params: Vec<ParameterInfo>,
524    pub optional_params: Vec<ParameterInfo>,
525}
526
527/// Configuration templates for common neural network architectures
528pub struct ConfigurationTemplates;
529
530#[cfg(feature = "serde")]
531impl ConfigurationTemplates {
532    /// Get MLP classifier template validator
533    pub fn mlp_classifier() -> HyperparameterValidator {
534        HyperparameterValidator::new("MLPClassifier".to_string()).add_rules(vec![
535            ValidationRule::new(
536                "hidden_layer_sizes".to_string(),
537                "Number of neurons in each hidden layer".to_string(),
538            )
539            .with_default(serde_json::json!([100])),
540            ValidationRule::new(
541                "activation".to_string(),
542                "Activation function for hidden layers".to_string(),
543            )
544            .with_string_constraint(vec![
545                "relu".to_string(),
546                "tanh".to_string(),
547                "sigmoid".to_string(),
548                "elu".to_string(),
549                "gelu".to_string(),
550                "swish".to_string(),
551                "leaky_relu".to_string(),
552                "mish".to_string(),
553            ])
554            .with_default(serde_json::json!("relu")),
555            ValidationRule::new(
556                "learning_rate".to_string(),
557                "Initial learning rate".to_string(),
558            )
559            .with_numeric_constraint(RangeConstraint::Range(1e-6, 1.0))
560            .with_default(serde_json::json!(0.001)),
561            ValidationRule::new(
562                "max_iter".to_string(),
563                "Maximum number of training iterations".to_string(),
564            )
565            .with_integer_constraint(RangeConstraint::Positive)
566            .with_default(serde_json::json!(200)),
567            ValidationRule::new(
568                "batch_size".to_string(),
569                "Size of minibatches for training".to_string(),
570            )
571            .with_integer_constraint(RangeConstraint::Positive)
572            .with_default(serde_json::json!(32)),
573            ValidationRule::new("solver".to_string(), "Optimization algorithm".to_string())
574                .with_string_constraint(vec![
575                    "sgd".to_string(),
576                    "adam".to_string(),
577                    "adamw".to_string(),
578                    "rmsprop".to_string(),
579                    "nadam".to_string(),
580                    "lamb".to_string(),
581                    "lars".to_string(),
582                ])
583                .with_default(serde_json::json!("adam")),
584            ValidationRule::new(
585                "alpha".to_string(),
586                "L2 regularization parameter".to_string(),
587            )
588            .with_numeric_constraint(RangeConstraint::NonNegative)
589            .with_default(serde_json::json!(0.0001)),
590            ValidationRule::new(
591                "random_state".to_string(),
592                "Random seed for reproducibility".to_string(),
593            )
594            .with_integer_constraint(RangeConstraint::NonNegative)
595            .with_default(serde_json::json!(null)),
596            ValidationRule::new(
597                "tol".to_string(),
598                "Tolerance for optimization convergence".to_string(),
599            )
600            .with_numeric_constraint(RangeConstraint::Positive)
601            .with_default(serde_json::json!(1e-4)),
602            ValidationRule::new(
603                "momentum".to_string(),
604                "Momentum for SGD optimizer".to_string(),
605            )
606            .with_numeric_constraint(RangeConstraint::Range(0.0, 1.0))
607            .with_default(serde_json::json!(0.9)),
608            ValidationRule::new(
609                "beta_1".to_string(),
610                "Beta1 parameter for Adam optimizer".to_string(),
611            )
612            .with_numeric_constraint(RangeConstraint::Range(0.0, 1.0))
613            .with_default(serde_json::json!(0.9)),
614            ValidationRule::new(
615                "beta_2".to_string(),
616                "Beta2 parameter for Adam optimizer".to_string(),
617            )
618            .with_numeric_constraint(RangeConstraint::Range(0.0, 1.0))
619            .with_default(serde_json::json!(0.999)),
620            ValidationRule::new(
621                "epsilon".to_string(),
622                "Epsilon parameter for Adam optimizer".to_string(),
623            )
624            .with_numeric_constraint(RangeConstraint::Positive)
625            .with_default(serde_json::json!(1e-8)),
626            ValidationRule::new(
627                "early_stopping".to_string(),
628                "Whether to use early stopping".to_string(),
629            )
630            .with_default(serde_json::json!(false)),
631            ValidationRule::new(
632                "validation_fraction".to_string(),
633                "Fraction of training data to use for validation".to_string(),
634            )
635            .with_numeric_constraint(RangeConstraint::Range(0.0, 1.0))
636            .with_default(serde_json::json!(0.1)),
637            ValidationRule::new(
638                "n_iter_no_change".to_string(),
639                "Maximum number of epochs without improvement for early stopping".to_string(),
640            )
641            .with_integer_constraint(RangeConstraint::Positive)
642            .with_default(serde_json::json!(10)),
643        ])
644    }
645
646    /// Get MLP regressor template validator
647    pub fn mlp_regressor() -> HyperparameterValidator {
648        let mut validator = Self::mlp_classifier();
649        validator.model_type = "MLPRegressor".to_string();
650        validator
651    }
652
653    /// Get CNN classifier template validator
654    pub fn cnn_classifier() -> HyperparameterValidator {
655        HyperparameterValidator::new("CNNClassifier".to_string()).add_rules(vec![
656            ValidationRule::new(
657                "conv_layers".to_string(),
658                "Configuration for convolutional layers".to_string(),
659            )
660            .required(),
661            ValidationRule::new(
662                "pool_size".to_string(),
663                "Pooling layer kernel size".to_string(),
664            )
665            .with_integer_constraint(RangeConstraint::Positive)
666            .with_default(serde_json::json!(2)),
667            ValidationRule::new(
668                "kernel_size".to_string(),
669                "Convolutional kernel size".to_string(),
670            )
671            .with_integer_constraint(RangeConstraint::Positive)
672            .with_default(serde_json::json!(3)),
673            ValidationRule::new("stride".to_string(), "Convolutional stride".to_string())
674                .with_integer_constraint(RangeConstraint::Positive)
675                .with_default(serde_json::json!(1)),
676            ValidationRule::new(
677                "padding".to_string(),
678                "Padding type for convolution".to_string(),
679            )
680            .with_string_constraint(vec!["valid".to_string(), "same".to_string()])
681            .with_default(serde_json::json!("valid")),
682            ValidationRule::new(
683                "dropout_rate".to_string(),
684                "Dropout rate for regularization".to_string(),
685            )
686            .with_numeric_constraint(RangeConstraint::Range(0.0, 1.0))
687            .with_default(serde_json::json!(0.0)),
688        ])
689    }
690
691    /// Get LSTM classifier template validator
692    pub fn lstm_classifier() -> HyperparameterValidator {
693        HyperparameterValidator::new("LSTMClassifier".to_string()).add_rules(vec![
694            ValidationRule::new(
695                "hidden_size".to_string(),
696                "Number of features in hidden state".to_string(),
697            )
698            .with_integer_constraint(RangeConstraint::Positive)
699            .with_default(serde_json::json!(128)),
700            ValidationRule::new(
701                "num_layers".to_string(),
702                "Number of recurrent layers".to_string(),
703            )
704            .with_integer_constraint(RangeConstraint::Positive)
705            .with_default(serde_json::json!(1)),
706            ValidationRule::new(
707                "bidirectional".to_string(),
708                "Whether to use bidirectional LSTM".to_string(),
709            )
710            .with_default(serde_json::json!(false)),
711            ValidationRule::new(
712                "sequence_length".to_string(),
713                "Input sequence length".to_string(),
714            )
715            .with_integer_constraint(RangeConstraint::Positive)
716            .required(),
717            ValidationRule::new(
718                "dropout_rate".to_string(),
719                "Dropout rate between LSTM layers".to_string(),
720            )
721            .with_numeric_constraint(RangeConstraint::Range(0.0, 1.0))
722            .with_default(serde_json::json!(0.0)),
723        ])
724    }
725}
726
727/// Parameter tuning suggestions based on validation results
728pub struct ParameterTuner;
729
730#[cfg(feature = "serde")]
731impl ParameterTuner {
732    /// Suggest parameter ranges for hyperparameter optimization
733    pub fn suggest_ranges(
734        validator: &HyperparameterValidator,
735        base_params: &HashMap<String, serde_json::Value>,
736    ) -> HashMap<String, ParameterRange> {
737        let mut suggestions = HashMap::new();
738
739        for rule in validator.rules.values() {
740            if let Some(range) = Self::suggest_range_for_rule(rule, base_params.get(&rule.name)) {
741                suggestions.insert(rule.name.clone(), range);
742            }
743        }
744
745        suggestions
746    }
747
748    fn suggest_range_for_rule(
749        rule: &ValidationRule,
750        current_value: Option<&serde_json::Value>,
751    ) -> Option<ParameterRange> {
752        match rule.name.as_str() {
753            "learning_rate" => Some(ParameterRange::LogUniform(1e-6, 1e-1)),
754            "batch_size" => Some(ParameterRange::Choice(vec![
755                serde_json::json!(16),
756                serde_json::json!(32),
757                serde_json::json!(64),
758                serde_json::json!(128),
759                serde_json::json!(256),
760            ])),
761            "hidden_layer_sizes" => Some(ParameterRange::Choice(vec![
762                serde_json::json!([50]),
763                serde_json::json!([100]),
764                serde_json::json!([100, 50]),
765                serde_json::json!([200, 100]),
766                serde_json::json!([300, 200, 100]),
767            ])),
768            "alpha" => Some(ParameterRange::LogUniform(1e-6, 1e-1)),
769            "momentum" => Some(ParameterRange::Uniform(0.0, 1.0)),
770            "beta_1" => Some(ParameterRange::Uniform(0.8, 0.999)),
771            "beta_2" => Some(ParameterRange::Uniform(0.9, 0.9999)),
772            "dropout_rate" => Some(ParameterRange::Uniform(0.0, 0.5)),
773            _ => None,
774        }
775    }
776}
777
778/// Parameter range for hyperparameter optimization
779#[derive(Debug, Clone)]
780pub enum ParameterRange {
781    /// Uniform distribution over range
782    Uniform(f64, f64),
783    /// Log-uniform distribution over range
784    LogUniform(f64, f64),
785    /// Discrete choices
786    #[cfg(feature = "serde")]
787    Choice(Vec<serde_json::Value>),
788    /// Integer range
789    IntRange(i64, i64),
790}
791
792#[cfg(all(test, feature = "serde"))]
793mod tests {
794    use super::*;
795    use serde_json::json;
796
797    #[test]
798    fn test_range_constraint_validation() {
799        let constraint = RangeConstraint::Range(0.0, 1.0);
800        assert!(constraint.validate_f64(0.5, "test_param").is_ok());
801        assert!(constraint.validate_f64(-0.1, "test_param").is_err());
802        assert!(constraint.validate_f64(1.1, "test_param").is_err());
803
804        let positive_constraint = RangeConstraint::Positive;
805        assert!(positive_constraint.validate_f64(1.0, "test_param").is_ok());
806        assert!(positive_constraint.validate_f64(0.0, "test_param").is_err());
807        assert!(positive_constraint
808            .validate_f64(-1.0, "test_param")
809            .is_err());
810    }
811
812    #[test]
813    fn test_validation_rule() {
814        let rule = ValidationRule::new(
815            "learning_rate".to_string(),
816            "Learning rate parameter".to_string(),
817        )
818        .with_numeric_constraint(RangeConstraint::Range(1e-6, 1.0))
819        .with_default(json!(0.001));
820
821        // Valid value
822        assert!(rule.validate(Some(&json!(0.01))).is_ok());
823
824        // Invalid value (too high)
825        assert!(rule.validate(Some(&json!(2.0))).is_err());
826
827        // Missing value with default
828        assert!(rule.validate(None).is_ok());
829
830        // Non-numeric value
831        assert!(rule.validate(Some(&json!("invalid"))).is_err());
832    }
833
834    #[test]
835    fn test_hyperparameter_validator() {
836        let validator = HyperparameterValidator::new("TestModel".to_string())
837            .add_rule(
838                ValidationRule::new("learning_rate".to_string(), "Learning rate".to_string())
839                    .with_numeric_constraint(RangeConstraint::Positive)
840                    .required(),
841            )
842            .add_rule(
843                ValidationRule::new("batch_size".to_string(), "Batch size".to_string())
844                    .with_integer_constraint(RangeConstraint::Positive)
845                    .with_default(json!(32)),
846            );
847
848        let mut valid_params = HashMap::new();
849        valid_params.insert("learning_rate".to_string(), json!(0.01));
850        assert!(validator.validate(&valid_params).is_ok());
851
852        let invalid_params = HashMap::new(); // Missing required parameter
853        assert!(validator.validate(&invalid_params).is_err());
854
855        let mut params_with_defaults = HashMap::new();
856        params_with_defaults.insert("learning_rate".to_string(), json!(0.01));
857        let mut params_with_defaults_applied = params_with_defaults.clone();
858        validator
859            .apply_defaults(&mut params_with_defaults_applied)
860            .unwrap();
861        assert!(params_with_defaults_applied.contains_key("batch_size"));
862    }
863
864    #[test]
865    fn test_mlp_classifier_template() {
866        let validator = ConfigurationTemplates::mlp_classifier();
867
868        let mut params = HashMap::new();
869        validator.apply_defaults(&mut params).unwrap();
870
871        // Should have all default values
872        assert!(params.contains_key("hidden_layer_sizes"));
873        assert!(params.contains_key("activation"));
874        assert!(params.contains_key("learning_rate"));
875
876        // Should validate successfully
877        assert!(validator.validate(&params).is_ok());
878
879        // Test invalid activation
880        params.insert("activation".to_string(), json!("invalid_activation"));
881        assert!(validator.validate(&params).is_err());
882    }
883
884    #[test]
885    fn test_parameter_tuner() {
886        let validator = ConfigurationTemplates::mlp_classifier();
887        let params = HashMap::new();
888
889        let suggestions = ParameterTuner::suggest_ranges(&validator, &params);
890
891        assert!(suggestions.contains_key("learning_rate"));
892        assert!(suggestions.contains_key("batch_size"));
893        assert!(suggestions.contains_key("hidden_layer_sizes"));
894
895        if let Some(ParameterRange::LogUniform(min, max)) = suggestions.get("learning_rate") {
896            assert!(min < max);
897            assert!(*min > 0.0);
898        } else {
899            panic!("Expected LogUniform range for learning_rate");
900        }
901    }
902
903    #[test]
904    fn test_validation_summary() {
905        let validator = ValidationRule::new("test_param".to_string(), "Test parameter".to_string())
906            .required()
907            .with_numeric_constraint(RangeConstraint::Positive);
908
909        let validator = HyperparameterValidator::new("TestModel".to_string()).add_rule(validator);
910
911        let summary = validator.get_validation_summary();
912        assert_eq!(summary.model_type, "TestModel");
913        assert_eq!(summary.required_params.len(), 1);
914        assert_eq!(summary.optional_params.len(), 0);
915        assert_eq!(summary.required_params[0].name, "test_param");
916    }
917}