1use crate::NeuralResult;
7use sklears_core::error::SklearsError;
8use std::collections::HashMap;
9
10#[cfg(feature = "serde")]
11use serde_json;
12
13#[derive(Debug, Clone, PartialEq)]
15pub enum RangeConstraint<T> {
16 GreaterThan(T),
18 GreaterEqualThan(T),
20 LessThan(T),
22 LessEqualThan(T),
24 Range(T, T),
26 OneOf(Vec<T>),
28 Positive,
30 NonNegative,
32 Any,
34}
35
36impl RangeConstraint<f64> {
37 pub fn validate_f64(&self, value: f64, param_name: &str) -> NeuralResult<()> {
39 match self {
40 RangeConstraint::GreaterThan(threshold) => {
41 if value <= *threshold {
42 return Err(SklearsError::InvalidParameter {
43 name: param_name.to_string(),
44 reason: format!("Value {} must be greater than {}", value, threshold),
45 });
46 }
47 }
48 RangeConstraint::GreaterEqualThan(threshold) => {
49 if value < *threshold {
50 return Err(SklearsError::InvalidParameter {
51 name: param_name.to_string(),
52 reason: format!(
53 "Value {} must be greater than or equal to {}",
54 value, threshold
55 ),
56 });
57 }
58 }
59 RangeConstraint::LessThan(threshold) => {
60 if value >= *threshold {
61 return Err(SklearsError::InvalidParameter {
62 name: param_name.to_string(),
63 reason: format!("Value {} must be less than {}", value, threshold),
64 });
65 }
66 }
67 RangeConstraint::LessEqualThan(threshold) => {
68 if value > *threshold {
69 return Err(SklearsError::InvalidParameter {
70 name: param_name.to_string(),
71 reason: format!(
72 "Value {} must be less than or equal to {}",
73 value, threshold
74 ),
75 });
76 }
77 }
78 RangeConstraint::Range(min_val, max_val) => {
79 if value < *min_val || value > *max_val {
80 return Err(SklearsError::InvalidParameter {
81 name: param_name.to_string(),
82 reason: format!(
83 "Value {} must be between {} and {}",
84 value, min_val, max_val
85 ),
86 });
87 }
88 }
89 RangeConstraint::OneOf(valid_values) => {
90 if !valid_values.contains(&value) {
91 return Err(SklearsError::InvalidParameter {
92 name: param_name.to_string(),
93 reason: format!("Value {} must be one of: {:?}", value, valid_values),
94 });
95 }
96 }
97 RangeConstraint::Positive => {
98 if value <= 0.0 {
99 return Err(SklearsError::InvalidParameter {
100 name: param_name.to_string(),
101 reason: format!("Value {} must be positive", value),
102 });
103 }
104 }
105 RangeConstraint::NonNegative => {
106 if value < 0.0 {
107 return Err(SklearsError::InvalidParameter {
108 name: param_name.to_string(),
109 reason: format!("Value {} must be non-negative", value),
110 });
111 }
112 }
113 RangeConstraint::Any => {
114 }
116 }
117 Ok(())
118 }
119}
120
121impl RangeConstraint<i64> {
122 pub fn validate_i64(&self, value: i64, param_name: &str) -> NeuralResult<()> {
124 match self {
125 RangeConstraint::GreaterThan(threshold) => {
126 if value <= *threshold {
127 return Err(SklearsError::InvalidParameter {
128 name: param_name.to_string(),
129 reason: format!("Value {} must be greater than {}", value, threshold),
130 });
131 }
132 }
133 RangeConstraint::GreaterEqualThan(threshold) => {
134 if value < *threshold {
135 return Err(SklearsError::InvalidParameter {
136 name: param_name.to_string(),
137 reason: format!(
138 "Value {} must be greater than or equal to {}",
139 value, threshold
140 ),
141 });
142 }
143 }
144 RangeConstraint::LessThan(threshold) => {
145 if value >= *threshold {
146 return Err(SklearsError::InvalidParameter {
147 name: param_name.to_string(),
148 reason: format!("Value {} must be less than {}", value, threshold),
149 });
150 }
151 }
152 RangeConstraint::LessEqualThan(threshold) => {
153 if value > *threshold {
154 return Err(SklearsError::InvalidParameter {
155 name: param_name.to_string(),
156 reason: format!(
157 "Value {} must be less than or equal to {}",
158 value, threshold
159 ),
160 });
161 }
162 }
163 RangeConstraint::Range(min_val, max_val) => {
164 if value < *min_val || value > *max_val {
165 return Err(SklearsError::InvalidParameter {
166 name: param_name.to_string(),
167 reason: format!(
168 "Value {} must be between {} and {}",
169 value, min_val, max_val
170 ),
171 });
172 }
173 }
174 RangeConstraint::OneOf(valid_values) => {
175 if !valid_values.contains(&value) {
176 return Err(SklearsError::InvalidParameter {
177 name: param_name.to_string(),
178 reason: format!("Value {} must be one of: {:?}", value, valid_values),
179 });
180 }
181 }
182 RangeConstraint::Positive => {
183 if value <= 0 {
184 return Err(SklearsError::InvalidParameter {
185 name: param_name.to_string(),
186 reason: format!("Value {} must be positive", value),
187 });
188 }
189 }
190 RangeConstraint::NonNegative => {
191 if value < 0 {
192 return Err(SklearsError::InvalidParameter {
193 name: param_name.to_string(),
194 reason: format!("Value {} must be non-negative", value),
195 });
196 }
197 }
198 RangeConstraint::Any => {
199 }
201 }
202 Ok(())
203 }
204}
205
206#[derive(Debug, Clone)]
208pub struct ValidationRule {
209 pub name: String,
211 pub description: String,
213 pub required: bool,
215 pub numeric_constraint: Option<RangeConstraint<f64>>,
217 pub integer_constraint: Option<RangeConstraint<i64>>,
219 pub string_constraint: Option<Vec<String>>,
221 #[cfg(feature = "serde")]
223 pub custom_validator: Option<fn(&serde_json::Value) -> NeuralResult<()>>,
224 #[cfg(feature = "serde")]
226 pub default_value: Option<serde_json::Value>,
227}
228
229impl ValidationRule {
230 pub fn new(name: String, description: String) -> Self {
232 Self {
233 name,
234 description,
235 required: false,
236 numeric_constraint: None,
237 integer_constraint: None,
238 string_constraint: None,
239 #[cfg(feature = "serde")]
240 custom_validator: None,
241 #[cfg(feature = "serde")]
242 default_value: None,
243 }
244 }
245
246 pub fn required(mut self) -> Self {
248 self.required = true;
249 self
250 }
251
252 pub fn with_numeric_constraint(mut self, constraint: RangeConstraint<f64>) -> Self {
254 self.numeric_constraint = Some(constraint);
255 self
256 }
257
258 pub fn with_integer_constraint(mut self, constraint: RangeConstraint<i64>) -> Self {
260 self.integer_constraint = Some(constraint);
261 self
262 }
263
264 pub fn with_string_constraint(mut self, allowed_values: Vec<String>) -> Self {
266 self.string_constraint = Some(allowed_values);
267 self
268 }
269}
270
271#[cfg(feature = "serde")]
272impl ValidationRule {
273 pub fn with_custom_validator(
275 mut self,
276 validator: fn(&serde_json::Value) -> NeuralResult<()>,
277 ) -> Self {
278 self.custom_validator = Some(validator);
279 self
280 }
281
282 pub fn with_default(mut self, default_value: serde_json::Value) -> Self {
284 self.default_value = Some(default_value);
285 self.required = false; self
287 }
288
289 pub fn validate(&self, value: Option<&serde_json::Value>) -> NeuralResult<()> {
291 match value {
292 Some(val) => {
293 if let Some(ref constraint) = self.numeric_constraint {
295 if val.is_null() {
296 if self.required {
298 return Err(SklearsError::InvalidParameter {
299 name: self.name.clone(),
300 reason: "Required parameter cannot be null".to_string(),
301 });
302 }
303 } else if let Some(num_val) = val.as_f64() {
304 constraint.validate_f64(num_val, &self.name)?;
305 } else {
306 return Err(SklearsError::InvalidParameter {
307 name: self.name.clone(),
308 reason: "Expected numeric value".to_string(),
309 });
310 }
311 }
312
313 if let Some(ref constraint) = self.integer_constraint {
315 if val.is_null() {
316 if self.required {
318 return Err(SklearsError::InvalidParameter {
319 name: self.name.clone(),
320 reason: "Required parameter cannot be null".to_string(),
321 });
322 }
323 } else if let Some(int_val) = val.as_i64() {
324 constraint.validate_i64(int_val, &self.name)?;
325 } else {
326 return Err(SklearsError::InvalidParameter {
327 name: self.name.clone(),
328 reason: "Expected integer value".to_string(),
329 });
330 }
331 }
332
333 if let Some(ref allowed_values) = self.string_constraint {
335 if let Some(str_val) = val.as_str() {
336 if !allowed_values.contains(&str_val.to_string()) {
337 return Err(SklearsError::InvalidParameter {
338 name: self.name.clone(),
339 reason: format!(
340 "Value '{}' must be one of: {:?}",
341 str_val, allowed_values
342 ),
343 });
344 }
345 } else {
346 return Err(SklearsError::InvalidParameter {
347 name: self.name.clone(),
348 reason: "Expected string value".to_string(),
349 });
350 }
351 }
352
353 if let Some(validator) = self.custom_validator {
355 validator(val)?;
356 }
357 }
358 None => {
359 if self.required {
360 return Err(SklearsError::InvalidParameter {
361 name: self.name.clone(),
362 reason: "Required parameter is missing".to_string(),
363 });
364 }
365 }
366 }
367 Ok(())
368 }
369}
370
371pub struct HyperparameterValidator {
373 rules: HashMap<String, ValidationRule>,
375 model_type: String,
377}
378
379impl HyperparameterValidator {
380 pub fn new(model_type: String) -> Self {
382 Self {
383 rules: HashMap::new(),
384 model_type,
385 }
386 }
387
388 pub fn add_rule(mut self, rule: ValidationRule) -> Self {
390 self.rules.insert(rule.name.clone(), rule);
391 self
392 }
393
394 pub fn add_rules(mut self, rules: Vec<ValidationRule>) -> Self {
396 for rule in rules {
397 self.rules.insert(rule.name.clone(), rule);
398 }
399 self
400 }
401}
402
403#[cfg(feature = "serde")]
404impl HyperparameterValidator {
405 pub fn validate(&self, params: &HashMap<String, serde_json::Value>) -> NeuralResult<()> {
407 for rule in self.rules.values() {
409 let param_value = params.get(&rule.name);
410 rule.validate(param_value)?;
411 }
412
413 for param_name in params.keys() {
415 if !self.rules.contains_key(param_name) {
416 log::warn!(
417 "Unknown parameter '{}' for model type '{}'",
418 param_name,
419 self.model_type
420 );
421 }
422 }
423
424 Ok(())
425 }
426
427 pub fn get_parameter_with_default(
429 &self,
430 params: &HashMap<String, serde_json::Value>,
431 param_name: &str,
432 ) -> NeuralResult<Option<serde_json::Value>> {
433 if let Some(value) = params.get(param_name) {
434 Ok(Some(value.clone()))
435 } else if let Some(rule) = self.rules.get(param_name) {
436 Ok(rule.default_value.clone())
437 } else {
438 Ok(None)
439 }
440 }
441
442 pub fn apply_defaults(
444 &self,
445 params: &mut HashMap<String, serde_json::Value>,
446 ) -> NeuralResult<()> {
447 for rule in self.rules.values() {
448 if !params.contains_key(&rule.name) {
449 if let Some(ref default_value) = rule.default_value {
450 params.insert(rule.name.clone(), default_value.clone());
451 }
452 }
453 }
454 Ok(())
455 }
456
457 pub fn get_validation_summary(&self) -> ValidationSummary {
459 let mut required_params = Vec::new();
460 let mut optional_params = Vec::new();
461
462 for rule in self.rules.values() {
463 let param_info = ParameterInfo {
464 name: rule.name.clone(),
465 description: rule.description.clone(),
466 required: rule.required,
467 default_value: rule.default_value.clone(),
468 constraints: self.get_constraint_description(rule),
469 };
470
471 if rule.required {
472 required_params.push(param_info);
473 } else {
474 optional_params.push(param_info);
475 }
476 }
477
478 ValidationSummary {
479 model_type: self.model_type.clone(),
480 required_params,
481 optional_params,
482 }
483 }
484
485 fn get_constraint_description(&self, rule: &ValidationRule) -> Vec<String> {
486 let mut constraints = Vec::new();
487
488 if let Some(ref numeric_constraint) = rule.numeric_constraint {
489 constraints.push(format!("Numeric: {:?}", numeric_constraint));
490 }
491
492 if let Some(ref integer_constraint) = rule.integer_constraint {
493 constraints.push(format!("Integer: {:?}", integer_constraint));
494 }
495
496 if let Some(ref string_constraint) = rule.string_constraint {
497 constraints.push(format!("String options: {:?}", string_constraint));
498 }
499
500 if rule.custom_validator.is_some() {
501 constraints.push("Custom validation".to_string());
502 }
503
504 constraints
505 }
506}
507
508#[derive(Debug, Clone)]
510pub struct ParameterInfo {
511 pub name: String,
512 pub description: String,
513 pub required: bool,
514 #[cfg(feature = "serde")]
515 pub default_value: Option<serde_json::Value>,
516 pub constraints: Vec<String>,
517}
518
519#[derive(Debug, Clone)]
521pub struct ValidationSummary {
522 pub model_type: String,
523 pub required_params: Vec<ParameterInfo>,
524 pub optional_params: Vec<ParameterInfo>,
525}
526
527pub struct ConfigurationTemplates;
529
530#[cfg(feature = "serde")]
531impl ConfigurationTemplates {
532 pub fn mlp_classifier() -> HyperparameterValidator {
534 HyperparameterValidator::new("MLPClassifier".to_string()).add_rules(vec![
535 ValidationRule::new(
536 "hidden_layer_sizes".to_string(),
537 "Number of neurons in each hidden layer".to_string(),
538 )
539 .with_default(serde_json::json!([100])),
540 ValidationRule::new(
541 "activation".to_string(),
542 "Activation function for hidden layers".to_string(),
543 )
544 .with_string_constraint(vec![
545 "relu".to_string(),
546 "tanh".to_string(),
547 "sigmoid".to_string(),
548 "elu".to_string(),
549 "gelu".to_string(),
550 "swish".to_string(),
551 "leaky_relu".to_string(),
552 "mish".to_string(),
553 ])
554 .with_default(serde_json::json!("relu")),
555 ValidationRule::new(
556 "learning_rate".to_string(),
557 "Initial learning rate".to_string(),
558 )
559 .with_numeric_constraint(RangeConstraint::Range(1e-6, 1.0))
560 .with_default(serde_json::json!(0.001)),
561 ValidationRule::new(
562 "max_iter".to_string(),
563 "Maximum number of training iterations".to_string(),
564 )
565 .with_integer_constraint(RangeConstraint::Positive)
566 .with_default(serde_json::json!(200)),
567 ValidationRule::new(
568 "batch_size".to_string(),
569 "Size of minibatches for training".to_string(),
570 )
571 .with_integer_constraint(RangeConstraint::Positive)
572 .with_default(serde_json::json!(32)),
573 ValidationRule::new("solver".to_string(), "Optimization algorithm".to_string())
574 .with_string_constraint(vec![
575 "sgd".to_string(),
576 "adam".to_string(),
577 "adamw".to_string(),
578 "rmsprop".to_string(),
579 "nadam".to_string(),
580 "lamb".to_string(),
581 "lars".to_string(),
582 ])
583 .with_default(serde_json::json!("adam")),
584 ValidationRule::new(
585 "alpha".to_string(),
586 "L2 regularization parameter".to_string(),
587 )
588 .with_numeric_constraint(RangeConstraint::NonNegative)
589 .with_default(serde_json::json!(0.0001)),
590 ValidationRule::new(
591 "random_state".to_string(),
592 "Random seed for reproducibility".to_string(),
593 )
594 .with_integer_constraint(RangeConstraint::NonNegative)
595 .with_default(serde_json::json!(null)),
596 ValidationRule::new(
597 "tol".to_string(),
598 "Tolerance for optimization convergence".to_string(),
599 )
600 .with_numeric_constraint(RangeConstraint::Positive)
601 .with_default(serde_json::json!(1e-4)),
602 ValidationRule::new(
603 "momentum".to_string(),
604 "Momentum for SGD optimizer".to_string(),
605 )
606 .with_numeric_constraint(RangeConstraint::Range(0.0, 1.0))
607 .with_default(serde_json::json!(0.9)),
608 ValidationRule::new(
609 "beta_1".to_string(),
610 "Beta1 parameter for Adam optimizer".to_string(),
611 )
612 .with_numeric_constraint(RangeConstraint::Range(0.0, 1.0))
613 .with_default(serde_json::json!(0.9)),
614 ValidationRule::new(
615 "beta_2".to_string(),
616 "Beta2 parameter for Adam optimizer".to_string(),
617 )
618 .with_numeric_constraint(RangeConstraint::Range(0.0, 1.0))
619 .with_default(serde_json::json!(0.999)),
620 ValidationRule::new(
621 "epsilon".to_string(),
622 "Epsilon parameter for Adam optimizer".to_string(),
623 )
624 .with_numeric_constraint(RangeConstraint::Positive)
625 .with_default(serde_json::json!(1e-8)),
626 ValidationRule::new(
627 "early_stopping".to_string(),
628 "Whether to use early stopping".to_string(),
629 )
630 .with_default(serde_json::json!(false)),
631 ValidationRule::new(
632 "validation_fraction".to_string(),
633 "Fraction of training data to use for validation".to_string(),
634 )
635 .with_numeric_constraint(RangeConstraint::Range(0.0, 1.0))
636 .with_default(serde_json::json!(0.1)),
637 ValidationRule::new(
638 "n_iter_no_change".to_string(),
639 "Maximum number of epochs without improvement for early stopping".to_string(),
640 )
641 .with_integer_constraint(RangeConstraint::Positive)
642 .with_default(serde_json::json!(10)),
643 ])
644 }
645
646 pub fn mlp_regressor() -> HyperparameterValidator {
648 let mut validator = Self::mlp_classifier();
649 validator.model_type = "MLPRegressor".to_string();
650 validator
651 }
652
653 pub fn cnn_classifier() -> HyperparameterValidator {
655 HyperparameterValidator::new("CNNClassifier".to_string()).add_rules(vec![
656 ValidationRule::new(
657 "conv_layers".to_string(),
658 "Configuration for convolutional layers".to_string(),
659 )
660 .required(),
661 ValidationRule::new(
662 "pool_size".to_string(),
663 "Pooling layer kernel size".to_string(),
664 )
665 .with_integer_constraint(RangeConstraint::Positive)
666 .with_default(serde_json::json!(2)),
667 ValidationRule::new(
668 "kernel_size".to_string(),
669 "Convolutional kernel size".to_string(),
670 )
671 .with_integer_constraint(RangeConstraint::Positive)
672 .with_default(serde_json::json!(3)),
673 ValidationRule::new("stride".to_string(), "Convolutional stride".to_string())
674 .with_integer_constraint(RangeConstraint::Positive)
675 .with_default(serde_json::json!(1)),
676 ValidationRule::new(
677 "padding".to_string(),
678 "Padding type for convolution".to_string(),
679 )
680 .with_string_constraint(vec!["valid".to_string(), "same".to_string()])
681 .with_default(serde_json::json!("valid")),
682 ValidationRule::new(
683 "dropout_rate".to_string(),
684 "Dropout rate for regularization".to_string(),
685 )
686 .with_numeric_constraint(RangeConstraint::Range(0.0, 1.0))
687 .with_default(serde_json::json!(0.0)),
688 ])
689 }
690
691 pub fn lstm_classifier() -> HyperparameterValidator {
693 HyperparameterValidator::new("LSTMClassifier".to_string()).add_rules(vec![
694 ValidationRule::new(
695 "hidden_size".to_string(),
696 "Number of features in hidden state".to_string(),
697 )
698 .with_integer_constraint(RangeConstraint::Positive)
699 .with_default(serde_json::json!(128)),
700 ValidationRule::new(
701 "num_layers".to_string(),
702 "Number of recurrent layers".to_string(),
703 )
704 .with_integer_constraint(RangeConstraint::Positive)
705 .with_default(serde_json::json!(1)),
706 ValidationRule::new(
707 "bidirectional".to_string(),
708 "Whether to use bidirectional LSTM".to_string(),
709 )
710 .with_default(serde_json::json!(false)),
711 ValidationRule::new(
712 "sequence_length".to_string(),
713 "Input sequence length".to_string(),
714 )
715 .with_integer_constraint(RangeConstraint::Positive)
716 .required(),
717 ValidationRule::new(
718 "dropout_rate".to_string(),
719 "Dropout rate between LSTM layers".to_string(),
720 )
721 .with_numeric_constraint(RangeConstraint::Range(0.0, 1.0))
722 .with_default(serde_json::json!(0.0)),
723 ])
724 }
725}
726
727pub struct ParameterTuner;
729
730#[cfg(feature = "serde")]
731impl ParameterTuner {
732 pub fn suggest_ranges(
734 validator: &HyperparameterValidator,
735 base_params: &HashMap<String, serde_json::Value>,
736 ) -> HashMap<String, ParameterRange> {
737 let mut suggestions = HashMap::new();
738
739 for rule in validator.rules.values() {
740 if let Some(range) = Self::suggest_range_for_rule(rule, base_params.get(&rule.name)) {
741 suggestions.insert(rule.name.clone(), range);
742 }
743 }
744
745 suggestions
746 }
747
748 fn suggest_range_for_rule(
749 rule: &ValidationRule,
750 current_value: Option<&serde_json::Value>,
751 ) -> Option<ParameterRange> {
752 match rule.name.as_str() {
753 "learning_rate" => Some(ParameterRange::LogUniform(1e-6, 1e-1)),
754 "batch_size" => Some(ParameterRange::Choice(vec![
755 serde_json::json!(16),
756 serde_json::json!(32),
757 serde_json::json!(64),
758 serde_json::json!(128),
759 serde_json::json!(256),
760 ])),
761 "hidden_layer_sizes" => Some(ParameterRange::Choice(vec![
762 serde_json::json!([50]),
763 serde_json::json!([100]),
764 serde_json::json!([100, 50]),
765 serde_json::json!([200, 100]),
766 serde_json::json!([300, 200, 100]),
767 ])),
768 "alpha" => Some(ParameterRange::LogUniform(1e-6, 1e-1)),
769 "momentum" => Some(ParameterRange::Uniform(0.0, 1.0)),
770 "beta_1" => Some(ParameterRange::Uniform(0.8, 0.999)),
771 "beta_2" => Some(ParameterRange::Uniform(0.9, 0.9999)),
772 "dropout_rate" => Some(ParameterRange::Uniform(0.0, 0.5)),
773 _ => None,
774 }
775 }
776}
777
778#[derive(Debug, Clone)]
780pub enum ParameterRange {
781 Uniform(f64, f64),
783 LogUniform(f64, f64),
785 #[cfg(feature = "serde")]
787 Choice(Vec<serde_json::Value>),
788 IntRange(i64, i64),
790}
791
792#[cfg(all(test, feature = "serde"))]
793mod tests {
794 use super::*;
795 use serde_json::json;
796
797 #[test]
798 fn test_range_constraint_validation() {
799 let constraint = RangeConstraint::Range(0.0, 1.0);
800 assert!(constraint.validate_f64(0.5, "test_param").is_ok());
801 assert!(constraint.validate_f64(-0.1, "test_param").is_err());
802 assert!(constraint.validate_f64(1.1, "test_param").is_err());
803
804 let positive_constraint = RangeConstraint::Positive;
805 assert!(positive_constraint.validate_f64(1.0, "test_param").is_ok());
806 assert!(positive_constraint.validate_f64(0.0, "test_param").is_err());
807 assert!(positive_constraint
808 .validate_f64(-1.0, "test_param")
809 .is_err());
810 }
811
812 #[test]
813 fn test_validation_rule() {
814 let rule = ValidationRule::new(
815 "learning_rate".to_string(),
816 "Learning rate parameter".to_string(),
817 )
818 .with_numeric_constraint(RangeConstraint::Range(1e-6, 1.0))
819 .with_default(json!(0.001));
820
821 assert!(rule.validate(Some(&json!(0.01))).is_ok());
823
824 assert!(rule.validate(Some(&json!(2.0))).is_err());
826
827 assert!(rule.validate(None).is_ok());
829
830 assert!(rule.validate(Some(&json!("invalid"))).is_err());
832 }
833
834 #[test]
835 fn test_hyperparameter_validator() {
836 let validator = HyperparameterValidator::new("TestModel".to_string())
837 .add_rule(
838 ValidationRule::new("learning_rate".to_string(), "Learning rate".to_string())
839 .with_numeric_constraint(RangeConstraint::Positive)
840 .required(),
841 )
842 .add_rule(
843 ValidationRule::new("batch_size".to_string(), "Batch size".to_string())
844 .with_integer_constraint(RangeConstraint::Positive)
845 .with_default(json!(32)),
846 );
847
848 let mut valid_params = HashMap::new();
849 valid_params.insert("learning_rate".to_string(), json!(0.01));
850 assert!(validator.validate(&valid_params).is_ok());
851
852 let invalid_params = HashMap::new(); assert!(validator.validate(&invalid_params).is_err());
854
855 let mut params_with_defaults = HashMap::new();
856 params_with_defaults.insert("learning_rate".to_string(), json!(0.01));
857 let mut params_with_defaults_applied = params_with_defaults.clone();
858 validator
859 .apply_defaults(&mut params_with_defaults_applied)
860 .unwrap();
861 assert!(params_with_defaults_applied.contains_key("batch_size"));
862 }
863
864 #[test]
865 fn test_mlp_classifier_template() {
866 let validator = ConfigurationTemplates::mlp_classifier();
867
868 let mut params = HashMap::new();
869 validator.apply_defaults(&mut params).unwrap();
870
871 assert!(params.contains_key("hidden_layer_sizes"));
873 assert!(params.contains_key("activation"));
874 assert!(params.contains_key("learning_rate"));
875
876 assert!(validator.validate(¶ms).is_ok());
878
879 params.insert("activation".to_string(), json!("invalid_activation"));
881 assert!(validator.validate(¶ms).is_err());
882 }
883
884 #[test]
885 fn test_parameter_tuner() {
886 let validator = ConfigurationTemplates::mlp_classifier();
887 let params = HashMap::new();
888
889 let suggestions = ParameterTuner::suggest_ranges(&validator, ¶ms);
890
891 assert!(suggestions.contains_key("learning_rate"));
892 assert!(suggestions.contains_key("batch_size"));
893 assert!(suggestions.contains_key("hidden_layer_sizes"));
894
895 if let Some(ParameterRange::LogUniform(min, max)) = suggestions.get("learning_rate") {
896 assert!(min < max);
897 assert!(*min > 0.0);
898 } else {
899 panic!("Expected LogUniform range for learning_rate");
900 }
901 }
902
903 #[test]
904 fn test_validation_summary() {
905 let validator = ValidationRule::new("test_param".to_string(), "Test parameter".to_string())
906 .required()
907 .with_numeric_constraint(RangeConstraint::Positive);
908
909 let validator = HyperparameterValidator::new("TestModel".to_string()).add_rule(validator);
910
911 let summary = validator.get_validation_summary();
912 assert_eq!(summary.model_type, "TestModel");
913 assert_eq!(summary.required_params.len(), 1);
914 assert_eq!(summary.optional_params.len(), 0);
915 assert_eq!(summary.required_params[0].name, "test_param");
916 }
917}