1use crate::error::{Result, SklearsError};
6use crate::types::{Array1, Array2, FloatBounds, Numeric};
7use scirs2_core::numeric::{Float, NumCast};
8use std::fmt::Debug;
9
10pub type ValidationGuardFn = Box<dyn Fn(&dyn std::any::Any) -> Result<bool> + Send + Sync>;
12
13pub type ValidationDestructureFn =
15 Box<dyn Fn(&dyn std::any::Any) -> Result<ValidationResult> + Send + Sync>;
16
17pub trait Validate {
19 fn validate(&self) -> Result<()>;
21
22 fn validate_with_context(&self, context: &str) -> Result<()> {
24 self.validate()
25 .map_err(|e| SklearsError::Other(format!("{context}: {e}")))
26 }
27}
28
29#[derive(Debug, Clone)]
31pub enum ValidationRule {
32 Positive,
34 NonNegative,
36 Finite,
38 Range { min: f64, max: f64 },
40 OneOf(Vec<String>),
42 MinLength(usize),
44 MaxLength(usize),
46 UniqueElements,
48 Custom(fn(&dyn std::any::Any) -> Result<()>),
50 PatternGuard(PatternGuardRule),
52}
53
54pub struct PatternGuardRule {
56 pub pattern_name: String,
58 pub guard_fn: ValidationGuardFn,
60 pub error_message: String,
62 pub destructure_fn: Option<ValidationDestructureFn>,
64}
65
66impl std::fmt::Debug for PatternGuardRule {
67 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68 f.debug_struct("PatternGuardRule")
69 .field("pattern_name", &self.pattern_name)
70 .field("guard_fn", &"<function>")
71 .field("error_message", &self.error_message)
72 .field(
73 "destructure_fn",
74 &self.destructure_fn.as_ref().map(|_| "<function>"),
75 )
76 .finish()
77 }
78}
79
80impl Clone for PatternGuardRule {
81 fn clone(&self) -> Self {
82 Self {
85 pattern_name: self.pattern_name.clone(),
86 guard_fn: Box::new(|_| Ok(true)), error_message: self.error_message.clone(),
88 destructure_fn: None,
89 }
90 }
91}
92
93#[derive(Debug, Clone)]
95pub struct ValidationResult {
96 pub matched: bool,
98 pub context: std::collections::HashMap<String, String>,
100 pub warnings: Vec<String>,
102}
103
104#[macro_export]
106macro_rules! pattern_guard {
107 (numeric_range, $min:expr, $max:expr) => {
109 $crate::validation::PatternGuardRule {
110 pattern_name: "numeric_range".to_string(),
111 guard_fn: Box::new(move |value| {
112 if let Some(val) = value.downcast_ref::<f64>() {
113 Ok(*val >= $min && *val <= $max)
114 } else if let Some(val) = value.downcast_ref::<f32>() {
115 Ok(*val >= $min as f32 && *val <= $max as f32)
116 } else if let Some(val) = value.downcast_ref::<i32>() {
117 Ok(*val >= $min as i32 && *val <= $max as i32)
118 } else if let Some(val) = value.downcast_ref::<usize>() {
119 Ok(*val >= $min as usize && *val <= $max as usize)
120 } else {
121 Ok(false)
122 }
123 }),
124 error_message: format!("Value must be in range [{}, {}]", $min, $max),
125 destructure_fn: None,
126 }
127 };
128
129 (array_shape, $expected_shape:expr) => {
131 $crate::validation::PatternGuardRule {
132 pattern_name: "array_shape".to_string(),
133 guard_fn: Box::new(move |value| {
134 Ok(true)
137 }),
138 error_message: format!("Array shape must match {:?}", $expected_shape),
139 destructure_fn: None,
140 }
141 };
142
143 (string_enum, $valid_options:expr) => {
145 $crate::validation::PatternGuardRule {
146 pattern_name: "string_enum".to_string(),
147 guard_fn: Box::new(move |value| {
148 if let Some(val) = value.downcast_ref::<String>() {
149 Ok($valid_options.contains(&val.as_str()))
150 } else if let Some(val) = value.downcast_ref::<&str>() {
151 Ok($valid_options.contains(val))
152 } else {
153 Ok(false)
154 }
155 }),
156 error_message: format!("Value must be one of {:?}", $valid_options),
157 destructure_fn: None,
158 }
159 };
160
161 ($pattern_name:literal, $guard:expr, $error_msg:literal) => {
163 $crate::validation::PatternGuardRule {
164 pattern_name: $pattern_name.to_string(),
165 guard_fn: Box::new($guard),
166 error_message: $error_msg.to_string(),
167 destructure_fn: None,
168 }
169 };
170
171 ($pattern_name:literal, $guard_fn:expr, $destructure_fn:expr) => {
173 $crate::validation::PatternGuardRule {
174 pattern_name: $pattern_name.to_string(),
175 guard_fn: Box::new($guard_fn),
176 error_message: format!("Pattern '{}' validation failed", $pattern_name),
177 destructure_fn: Some(Box::new($destructure_fn)),
178 }
179 };
180}
181
182pub trait PatternValidate {
184 fn validate_with_pattern(&self, guard: &PatternGuardRule) -> Result<ValidationResult>;
186
187 fn matches_pattern(&self, pattern_name: &str) -> bool;
189
190 fn destructure(&self, pattern: &str) -> Result<std::collections::HashMap<String, String>>;
192}
193
194impl PatternValidate for f64 {
196 fn validate_with_pattern(&self, guard: &PatternGuardRule) -> Result<ValidationResult> {
197 let value_any = self as &dyn std::any::Any;
198 let matched = (guard.guard_fn)(value_any)?;
199
200 let mut context = std::collections::HashMap::new();
201 context.insert("value".to_string(), self.to_string());
202 context.insert("type".to_string(), "f64".to_string());
203
204 if let Some(destructure_fn) = &guard.destructure_fn {
205 let destructure_result = destructure_fn(value_any)?;
206 Ok(ValidationResult {
207 matched: matched && destructure_result.matched,
208 context: destructure_result.context,
209 warnings: destructure_result.warnings,
210 })
211 } else {
212 Ok(ValidationResult {
213 matched,
214 context,
215 warnings: Vec::new(),
216 })
217 }
218 }
219
220 fn matches_pattern(&self, pattern_name: &str) -> bool {
221 match pattern_name {
222 "finite" => self.is_finite(),
223 "positive" => *self > 0.0,
224 "non_negative" => *self >= 0.0,
225 "probability" => *self >= 0.0 && *self <= 1.0,
226 _ => false,
227 }
228 }
229
230 fn destructure(&self, pattern: &str) -> Result<std::collections::HashMap<String, String>> {
231 let mut result = std::collections::HashMap::new();
232 match pattern {
233 "range_info" => {
234 result.insert("value".to_string(), self.to_string());
235 result.insert("is_finite".to_string(), self.is_finite().to_string());
236 result.insert("is_positive".to_string(), (*self > 0.0).to_string());
237 result.insert(
238 "sign".to_string(),
239 if *self >= 0.0 {
240 "positive".to_string()
241 } else {
242 "negative".to_string()
243 },
244 );
245 }
246 _ => {
247 result.insert("value".to_string(), self.to_string());
248 }
249 }
250 Ok(result)
251 }
252}
253
254impl PatternValidate for usize {
256 fn validate_with_pattern(&self, guard: &PatternGuardRule) -> Result<ValidationResult> {
257 let value_any = self as &dyn std::any::Any;
258 let matched = (guard.guard_fn)(value_any)?;
259
260 let mut context = std::collections::HashMap::new();
261 context.insert("value".to_string(), self.to_string());
262 context.insert("type".to_string(), "usize".to_string());
263
264 Ok(ValidationResult {
265 matched,
266 context,
267 warnings: Vec::new(),
268 })
269 }
270
271 fn matches_pattern(&self, pattern_name: &str) -> bool {
272 match pattern_name {
273 "positive" => *self > 0,
274 "non_negative" => true, "power_of_two" => self.is_power_of_two(),
276 _ => false,
277 }
278 }
279
280 fn destructure(&self, pattern: &str) -> Result<std::collections::HashMap<String, String>> {
281 let mut result = std::collections::HashMap::new();
282 match pattern {
283 "number_info" => {
284 result.insert("value".to_string(), self.to_string());
285 result.insert("is_positive".to_string(), (*self > 0).to_string());
286 result.insert(
287 "is_power_of_two".to_string(),
288 self.is_power_of_two().to_string(),
289 );
290 }
291 _ => {
292 result.insert("value".to_string(), self.to_string());
293 }
294 }
295 Ok(result)
296 }
297}
298
299impl PatternValidate for String {
301 fn validate_with_pattern(&self, guard: &PatternGuardRule) -> Result<ValidationResult> {
302 let value_any = self as &dyn std::any::Any;
303 let matched = (guard.guard_fn)(value_any)?;
304
305 let mut context = std::collections::HashMap::new();
306 context.insert("value".to_string(), self.clone());
307 context.insert("type".to_string(), "String".to_string());
308 context.insert("length".to_string(), self.len().to_string());
309
310 Ok(ValidationResult {
311 matched,
312 context,
313 warnings: Vec::new(),
314 })
315 }
316
317 fn matches_pattern(&self, pattern_name: &str) -> bool {
318 match pattern_name {
319 "non_empty" => !self.is_empty(),
320 "alphanumeric" => self.chars().all(|c| c.is_alphanumeric()),
321 "lowercase" => self.chars().all(|c| !c.is_alphabetic() || c.is_lowercase()),
322 "uppercase" => self.chars().all(|c| !c.is_alphabetic() || c.is_uppercase()),
323 _ => false,
324 }
325 }
326
327 fn destructure(&self, pattern: &str) -> Result<std::collections::HashMap<String, String>> {
328 let mut result = std::collections::HashMap::new();
329 match pattern {
330 "string_info" => {
331 result.insert("value".to_string(), self.clone());
332 result.insert("length".to_string(), self.len().to_string());
333 result.insert("is_empty".to_string(), self.is_empty().to_string());
334 result.insert(
335 "is_alphanumeric".to_string(),
336 self.chars().all(|c| c.is_alphanumeric()).to_string(),
337 );
338 }
339 _ => {
340 result.insert("value".to_string(), self.clone());
341 }
342 }
343 Ok(result)
344 }
345}
346
347pub mod pattern_guards {
349 use super::*;
350
351 pub fn hyperparameter_pattern<T: FloatBounds + std::fmt::Debug>(
353 min_val: T,
354 max_val: T,
355 finite_required: bool,
356 ) -> PatternGuardRule {
357 PatternGuardRule {
358 pattern_name: "hyperparameter_bounds".to_string(),
359 guard_fn: Box::new(|_value| {
360 Ok(true) }),
363 error_message: format!(
364 "Hyperparameter must be in range [{}, {}]{}",
365 min_val,
366 max_val,
367 if finite_required { " and finite" } else { "" }
368 ),
369 destructure_fn: None,
370 }
371 }
372
373 pub fn array_shape_pattern(expected_dims: &[usize]) -> PatternGuardRule {
375 let dims_str = expected_dims
376 .iter()
377 .map(|d| d.to_string())
378 .collect::<Vec<_>>()
379 .join(", ");
380
381 PatternGuardRule {
382 pattern_name: "array_shape".to_string(),
383 guard_fn: Box::new(|_value| {
384 Ok(true)
386 }),
387 error_message: format!("Array shape must match [{dims_str}]"),
388 destructure_fn: None, }
390 }
391
392 pub fn algorithm_config_pattern(required_fields: &[&str]) -> PatternGuardRule {
394 let fields_str = required_fields.join(", ");
395
396 PatternGuardRule {
397 pattern_name: "algorithm_config".to_string(),
398 guard_fn: Box::new(|_value| {
399 Ok(true)
401 }),
402 error_message: format!("Configuration must contain fields: {fields_str}"),
403 destructure_fn: None, }
405 }
406
407 pub fn data_type_pattern(expected_types: &[&str]) -> PatternGuardRule {
409 let types_str = expected_types.join(" | ");
410
411 PatternGuardRule {
412 pattern_name: "data_type_consistency".to_string(),
413 guard_fn: Box::new(|_value| {
414 Ok(true)
416 }),
417 error_message: format!("Data type must be one of: {types_str}"),
418 destructure_fn: None,
419 }
420 }
421}
422
423#[derive(Debug, Clone)]
425pub struct ValidationRules {
426 pub rules: Vec<ValidationRule>,
427 pub field_name: String,
428}
429
430impl ValidationRules {
431 pub fn new(field_name: &str) -> Self {
433 Self {
434 rules: Vec::new(),
435 field_name: field_name.to_string(),
436 }
437 }
438
439 pub fn add_rule(mut self, rule: ValidationRule) -> Self {
441 self.rules.push(rule);
442 self
443 }
444
445 pub fn validate_numeric<T>(&self, value: &T) -> Result<()>
447 where
448 T: Numeric + PartialOrd + Debug + Copy + NumCast,
449 {
450 for rule in &self.rules {
451 match rule {
452 ValidationRule::Positive => {
453 if *value <= T::zero() {
454 return Err(SklearsError::InvalidParameter {
455 name: self.field_name.clone(),
456 reason: "must be positive".to_string(),
457 });
458 }
459 }
460 ValidationRule::NonNegative => {
461 if *value < T::zero() {
462 return Err(SklearsError::InvalidParameter {
463 name: self.field_name.clone(),
464 reason: "must be non-negative".to_string(),
465 });
466 }
467 }
468 ValidationRule::Finite => {
469 if let Some(float_val) = NumCast::from(*value) {
470 let f: f64 = float_val;
471 if !f.is_finite() {
472 return Err(SklearsError::InvalidParameter {
473 name: self.field_name.clone(),
474 reason: "must be finite".to_string(),
475 });
476 }
477 }
478 }
479 ValidationRule::Range { min, max } => {
480 if let Some(float_val) = NumCast::from(*value) {
481 let f: f64 = float_val;
482 if f < *min || f > *max {
483 return Err(SklearsError::InvalidParameter {
484 name: self.field_name.clone(),
485 reason: format!("must be in range [{min}, {max}]"),
486 });
487 }
488 }
489 }
490 ValidationRule::PatternGuard(_pattern_guard) => {
491 }
501 _ => {
502 }
504 }
505 }
506 Ok(())
507 }
508
509 pub fn validate_string(&self, value: &str) -> Result<()> {
511 for rule in &self.rules {
512 match rule {
513 ValidationRule::OneOf(options) => {
514 if !options.contains(&value.to_string()) {
515 return Err(SklearsError::InvalidParameter {
516 name: self.field_name.clone(),
517 reason: format!("must be one of {options:?}"),
518 });
519 }
520 }
521 ValidationRule::PatternGuard(_pattern_guard) => {
522 }
532 _ => {
533 }
535 }
536 }
537 Ok(())
538 }
539
540 pub fn validate_array<T>(&self, value: &[T]) -> Result<()> {
542 for rule in &self.rules {
543 match rule {
544 ValidationRule::MinLength(min_len) => {
545 if value.len() < *min_len {
546 return Err(SklearsError::InvalidParameter {
547 name: self.field_name.clone(),
548 reason: format!("must have at least {min_len} elements"),
549 });
550 }
551 }
552 ValidationRule::MaxLength(max_len) => {
553 if value.len() > *max_len {
554 return Err(SklearsError::InvalidParameter {
555 name: self.field_name.clone(),
556 reason: format!("must have at most {max_len} elements"),
557 });
558 }
559 }
560 ValidationRule::PatternGuard(_pattern_guard) => {
561 }
571 _ => {
572 }
574 }
575 }
576 Ok(())
577 }
578
579 pub fn validate_usize(&self, value: &usize) -> Result<()> {
581 for rule in &self.rules {
582 match rule {
583 ValidationRule::Positive => {
584 if *value == 0 {
585 return Err(SklearsError::InvalidParameter {
586 name: self.field_name.clone(),
587 reason: "must be positive".to_string(),
588 });
589 }
590 }
591 ValidationRule::NonNegative => {
592 }
594 ValidationRule::Range { min, max } => {
595 let val = *value as f64;
596 if val < *min || val > *max {
597 return Err(SklearsError::InvalidParameter {
598 name: self.field_name.clone(),
599 reason: format!("must be in range [{min}, {max}]"),
600 });
601 }
602 }
603 _ => {
604 }
606 }
607 }
608 Ok(())
609 }
610}
611
612pub mod ml {
614 use super::*;
615
616 pub fn validate_learning_rate<T: FloatBounds>(lr: T) -> Result<()> {
618 if lr <= T::zero() {
619 return Err(SklearsError::InvalidParameter {
620 name: "learning_rate".to_string(),
621 reason: "must be positive".to_string(),
622 });
623 }
624
625 if !Float::is_finite(lr) {
626 return Err(SklearsError::InvalidParameter {
627 name: "learning_rate".to_string(),
628 reason: "must be finite".to_string(),
629 });
630 }
631
632 if lr > T::one() {
634 log::warn!("Learning rate {lr} is unusually high, consider using a smaller value");
635 }
636
637 Ok(())
638 }
639
640 pub fn validate_regularization<T: FloatBounds>(reg: T) -> Result<()> {
642 if reg < T::zero() {
643 return Err(SklearsError::InvalidParameter {
644 name: "regularization".to_string(),
645 reason: "must be non-negative".to_string(),
646 });
647 }
648
649 if !Float::is_finite(reg) {
650 return Err(SklearsError::InvalidParameter {
651 name: "regularization".to_string(),
652 reason: "must be finite".to_string(),
653 });
654 }
655
656 Ok(())
657 }
658
659 pub fn validate_n_clusters(n_clusters: usize, n_samples: usize) -> Result<()> {
661 if n_clusters == 0 {
662 return Err(SklearsError::InvalidParameter {
663 name: "n_clusters".to_string(),
664 reason: "must be positive".to_string(),
665 });
666 }
667
668 if n_clusters > n_samples {
669 return Err(SklearsError::InvalidParameter {
670 name: "n_clusters".to_string(),
671 reason: format!("cannot exceed number of samples ({n_samples})"),
672 });
673 }
674
675 Ok(())
676 }
677
678 pub fn validate_n_neighbors(n_neighbors: usize, n_samples: usize) -> Result<()> {
680 if n_neighbors == 0 {
681 return Err(SklearsError::InvalidParameter {
682 name: "n_neighbors".to_string(),
683 reason: "must be positive".to_string(),
684 });
685 }
686
687 if n_neighbors > n_samples {
688 return Err(SklearsError::InvalidParameter {
689 name: "n_neighbors".to_string(),
690 reason: format!("cannot exceed number of samples ({n_samples})"),
691 });
692 }
693
694 Ok(())
695 }
696
697 pub fn validate_tolerance<T: FloatBounds>(tol: T) -> Result<()> {
699 if tol <= T::zero() {
700 return Err(SklearsError::InvalidParameter {
701 name: "tolerance".to_string(),
702 reason: "must be positive".to_string(),
703 });
704 }
705
706 if !Float::is_finite(tol) {
707 return Err(SklearsError::InvalidParameter {
708 name: "tolerance".to_string(),
709 reason: "must be finite".to_string(),
710 });
711 }
712
713 if tol > T::from(0.1).unwrap_or(T::one()) {
715 log::warn!("Tolerance {tol} is very large, algorithm may converge prematurely");
716 }
717
718 Ok(())
719 }
720
721 pub fn validate_max_iter(max_iter: usize) -> Result<()> {
723 if max_iter == 0 {
724 return Err(SklearsError::InvalidParameter {
725 name: "max_iter".to_string(),
726 reason: "must be positive".to_string(),
727 });
728 }
729
730 Ok(())
731 }
732
733 pub fn validate_probability<T: FloatBounds>(prob: T) -> Result<()> {
735 if prob < T::zero() || prob > T::one() {
736 return Err(SklearsError::InvalidParameter {
737 name: "probability".to_string(),
738 reason: "must be in range [0, 1]".to_string(),
739 });
740 }
741
742 if !Float::is_finite(prob) {
743 return Err(SklearsError::InvalidParameter {
744 name: "probability".to_string(),
745 reason: "must be finite".to_string(),
746 });
747 }
748
749 Ok(())
750 }
751
752 pub fn validate_supervised_data<T, U>(x: &Array2<T>, y: &Array1<U>) -> Result<()> {
754 if x.is_empty() {
755 return Err(SklearsError::InvalidData {
756 reason: "X cannot be empty".to_string(),
757 });
758 }
759
760 if y.is_empty() {
761 return Err(SklearsError::InvalidData {
762 reason: "y cannot be empty".to_string(),
763 });
764 }
765
766 if x.nrows() != y.len() {
767 return Err(SklearsError::ShapeMismatch {
768 expected: "X.shape[0] == y.shape[0]".to_string(),
769 actual: format!("X.shape[0]={}, y.shape[0]={}", x.nrows(), y.len()),
770 });
771 }
772
773 Ok(())
774 }
775
776 pub fn validate_unsupervised_data<T>(x: &Array2<T>) -> Result<()> {
778 if x.is_empty() {
779 return Err(SklearsError::InvalidData {
780 reason: "X cannot be empty".to_string(),
781 });
782 }
783
784 if x.nrows() == 0 || x.ncols() == 0 {
785 return Err(SklearsError::InvalidData {
786 reason: "X must have positive dimensions".to_string(),
787 });
788 }
789
790 Ok(())
791 }
792
793 pub fn validate_feature_consistency<T, U>(
795 x_train: &Array2<T>,
796 x_test: &Array2<U>,
797 _model_name: &str,
798 ) -> Result<()> {
799 if x_train.ncols() != x_test.ncols() {
800 return Err(SklearsError::FeatureMismatch {
801 expected: x_train.ncols(),
802 actual: x_test.ncols(),
803 });
804 }
805
806 Ok(())
807 }
808}
809
810pub mod derive_helpers {
812 pub fn generate_field_validation(
814 field_name: &str,
815 _field_type: &str,
816 validation_attrs: &[String],
817 ) -> String {
818 let mut validations = Vec::new();
819
820 for attr in validation_attrs {
821 match attr.as_str() {
822 "positive" => {
823 validations.push(format!(
824 "ValidationRules::new(\"{field_name}\").add_rule(ValidationRule::Positive).validate_numeric(&self.{field_name})?;"
825 ));
826 }
827 "non_negative" => {
828 validations.push(format!(
829 "ValidationRules::new(\"{field_name}\").add_rule(ValidationRule::NonNegative).validate_numeric(&self.{field_name})?;"
830 ));
831 }
832 "finite" => {
833 validations.push(format!(
834 "ValidationRules::new(\"{field_name}\").add_rule(ValidationRule::Finite).validate_numeric(&self.{field_name})?;"
835 ));
836 }
837 _ if attr.starts_with("range(") => {
838 let range_str = attr
840 .strip_prefix("range(")
841 .unwrap()
842 .strip_suffix(")")
843 .unwrap();
844 let parts: Vec<&str> = range_str.split(',').map(|s| s.trim()).collect();
845 if parts.len() == 2 {
846 let min_val = parts[0];
847 let max_val = parts[1];
848 validations.push(format!(
849 "ValidationRules::new(\"{field_name}\").add_rule(ValidationRule::Range {{ min: {min_val}, max: {max_val} }}).validate_numeric(&self.{field_name})?;"
850 ));
851 }
852 }
853 _ => {}
854 }
855 }
856
857 validations.join("\n")
858 }
859}
860
861pub trait ConfigValidation {
863 fn validate_config(&self) -> Result<()>;
865
866 fn get_warnings(&self) -> Vec<String> {
868 Vec::new()
869 }
870}
871
872#[derive(Debug, Clone)]
874pub struct ValidationContext {
875 pub algorithm: String,
876 pub operation: String,
877 pub data_info: Option<DataInfo>,
878}
879
880#[derive(Debug, Clone)]
882pub struct DataInfo {
883 pub n_samples: usize,
884 pub n_features: usize,
885 pub data_type: String,
886}
887
888impl ValidationContext {
889 pub fn new(algorithm: &str, operation: &str) -> Self {
891 Self {
892 algorithm: algorithm.to_string(),
893 operation: operation.to_string(),
894 data_info: None,
895 }
896 }
897
898 pub fn with_data_info(mut self, n_samples: usize, n_features: usize, data_type: &str) -> Self {
900 self.data_info = Some(DataInfo {
901 n_samples,
902 n_features,
903 data_type: data_type.to_string(),
904 });
905 self
906 }
907
908 pub fn format_error(&self, error: &SklearsError) -> String {
910 let mut msg = format!(
911 "Error in {} during {}: {error}",
912 self.algorithm, self.operation
913 );
914
915 if let Some(data_info) = &self.data_info {
916 msg.push_str(&format!(
917 " (data: {} samples, {} features, type: {})",
918 data_info.n_samples, data_info.n_features, data_info.data_type
919 ));
920 }
921
922 msg
923 }
924}
925
926pub mod structured_destructuring {
928 use super::*;
929
930 pub trait StructuredDestructure {
932 fn destructure_into_components(
934 &self,
935 ) -> Result<std::collections::HashMap<String, Box<dyn std::any::Any>>>;
936
937 fn extract_field(&self, field_path: &str) -> Result<Box<dyn std::any::Any>>;
939
940 fn validate_structure(&self, schema: &StructuralSchema) -> Result<ValidationResult>;
942 }
943
944 #[derive(Debug, Clone, Default)]
946 pub struct StructuralSchema {
947 pub required_fields: Vec<String>,
948 pub optional_fields: Vec<String>,
949 pub field_types: std::collections::HashMap<String, String>,
950 pub nested_schemas: std::collections::HashMap<String, StructuralSchema>,
951 }
952
953 impl StructuralSchema {
954 pub fn new() -> Self {
955 Self::default()
956 }
957
958 pub fn require_field(mut self, field_name: &str, field_type: &str) -> Self {
959 self.required_fields.push(field_name.to_string());
960 self.field_types
961 .insert(field_name.to_string(), field_type.to_string());
962 self
963 }
964
965 pub fn optional_field(mut self, field_name: &str, field_type: &str) -> Self {
966 self.optional_fields.push(field_name.to_string());
967 self.field_types
968 .insert(field_name.to_string(), field_type.to_string());
969 self
970 }
971
972 pub fn nested_schema(mut self, field_name: &str, schema: StructuralSchema) -> Self {
973 self.nested_schemas.insert(field_name.to_string(), schema);
974 self
975 }
976 }
977
978 #[derive(Debug, Clone)]
980 pub struct AlgorithmConfig {
981 pub algorithm_name: String,
982 pub hyperparameters: std::collections::HashMap<String, ConfigValue>,
983 pub metadata: std::collections::HashMap<String, String>,
984 }
985
986 #[derive(Debug, Clone)]
988 pub enum ConfigValue {
989 Float(f64),
990 Integer(i64),
991 String(String),
992 Boolean(bool),
993 Array(Vec<ConfigValue>),
994 Object(std::collections::HashMap<String, ConfigValue>),
995 }
996
997 impl StructuredDestructure for AlgorithmConfig {
998 fn destructure_into_components(
999 &self,
1000 ) -> Result<std::collections::HashMap<String, Box<dyn std::any::Any>>> {
1001 let mut components = std::collections::HashMap::new();
1002
1003 components.insert(
1004 "algorithm_name".to_string(),
1005 Box::new(self.algorithm_name.clone()) as Box<dyn std::any::Any>,
1006 );
1007 components.insert(
1008 "hyperparameters".to_string(),
1009 Box::new(self.hyperparameters.clone()) as Box<dyn std::any::Any>,
1010 );
1011 components.insert(
1012 "metadata".to_string(),
1013 Box::new(self.metadata.clone()) as Box<dyn std::any::Any>,
1014 );
1015
1016 Ok(components)
1017 }
1018
1019 fn extract_field(&self, field_path: &str) -> Result<Box<dyn std::any::Any>> {
1020 let parts: Vec<&str> = field_path.split('.').collect();
1021
1022 match parts.first() {
1023 Some(&"algorithm_name") => Ok(Box::new(self.algorithm_name.clone())),
1024 Some(&"hyperparameters") => {
1025 if parts.len() > 1 {
1026 if let Some(param_value) = self.hyperparameters.get(parts[1]) {
1027 Ok(Box::new(param_value.clone()))
1028 } else {
1029 Err(SklearsError::InvalidParameter {
1030 name: field_path.to_string(),
1031 reason: format!("Hyperparameter '{}' not found", parts[1]),
1032 })
1033 }
1034 } else {
1035 Ok(Box::new(self.hyperparameters.clone()))
1036 }
1037 }
1038 Some(&"metadata") => {
1039 if parts.len() > 1 {
1040 if let Some(meta_value) = self.metadata.get(parts[1]) {
1041 Ok(Box::new(meta_value.clone()))
1042 } else {
1043 Err(SklearsError::InvalidParameter {
1044 name: field_path.to_string(),
1045 reason: format!("Metadata '{}' not found", parts[1]),
1046 })
1047 }
1048 } else {
1049 Ok(Box::new(self.metadata.clone()))
1050 }
1051 }
1052 _ => Err(SklearsError::InvalidParameter {
1053 name: field_path.to_string(),
1054 reason: "Invalid field path".to_string(),
1055 }),
1056 }
1057 }
1058
1059 fn validate_structure(&self, schema: &StructuralSchema) -> Result<ValidationResult> {
1060 let mut warnings = Vec::new();
1061 let mut context = std::collections::HashMap::new();
1062
1063 for required_field in &schema.required_fields {
1065 match required_field.as_str() {
1066 "algorithm_name" => {
1067 if self.algorithm_name.is_empty() {
1068 return Err(SklearsError::InvalidParameter {
1069 name: "algorithm_name".to_string(),
1070 reason: "Required field cannot be empty".to_string(),
1071 });
1072 }
1073 context.insert("algorithm_name".to_string(), "present".to_string());
1074 }
1075 "hyperparameters" => {
1076 context.insert(
1077 "hyperparameters_count".to_string(),
1078 self.hyperparameters.len().to_string(),
1079 );
1080 }
1081 _ => {
1082 warnings.push(format!("Unknown required field: {required_field}"));
1083 }
1084 }
1085 }
1086
1087 Ok(ValidationResult {
1088 matched: true,
1089 context,
1090 warnings,
1091 })
1092 }
1093 }
1094
1095 pub fn create_complex_pattern_guard<T>(
1097 pattern_name: &str,
1098 validator: impl Fn(&T) -> Result<bool> + Send + Sync + 'static,
1099 error_message: &str,
1100 ) -> PatternGuardRule
1101 where
1102 T: 'static,
1103 {
1104 PatternGuardRule {
1105 pattern_name: pattern_name.to_string(),
1106 guard_fn: Box::new(move |value| {
1107 if let Some(typed_value) = value.downcast_ref::<T>() {
1108 validator(typed_value)
1109 } else {
1110 Ok(false)
1111 }
1112 }),
1113 error_message: error_message.to_string(),
1114 destructure_fn: None,
1115 }
1116 }
1117}
1118
1119#[macro_export]
1121macro_rules! destructure {
1122 ($obj:expr, $field:literal) => {
1124 $obj.extract_field($field)
1125 };
1126
1127 ($obj:expr, { $($field:literal),* }) => {
1129 {
1130 let mut results = std::collections::HashMap::new();
1131 $(
1132 if let Ok(value) = $obj.extract_field($field) {
1133 results.insert($field.to_string(), value);
1134 }
1135 )*
1136 results
1137 }
1138 };
1139
1140 ($obj:expr, validate: $schema:expr) => {
1142 $obj.validate_structure(&$schema)
1143 };
1144}
1145
1146#[allow(non_snake_case)]
1147#[cfg(test)]
1148mod tests {
1149 use super::*;
1150
1151 #[test]
1152 fn test_validation_rules_numeric() {
1153 let rules = ValidationRules::new("test_param")
1154 .add_rule(ValidationRule::Positive)
1155 .add_rule(ValidationRule::Finite);
1156
1157 assert!(rules.validate_numeric(&1.5f64).is_ok());
1159
1160 assert!(rules.validate_numeric(&0.0f64).is_err());
1162 assert!(rules.validate_numeric(&-1.0f64).is_err());
1163
1164 assert!(rules.validate_numeric(&f64::NAN).is_err());
1166 assert!(rules.validate_numeric(&f64::INFINITY).is_err());
1167 }
1168
1169 #[test]
1170 fn test_validation_rules_range() {
1171 let rules = ValidationRules::new("test_param")
1172 .add_rule(ValidationRule::Range { min: 0.0, max: 1.0 });
1173
1174 assert!(rules.validate_numeric(&0.5f64).is_ok());
1176 assert!(rules.validate_numeric(&0.0f64).is_ok());
1177 assert!(rules.validate_numeric(&1.0f64).is_ok());
1178
1179 assert!(rules.validate_numeric(&-0.1f64).is_err());
1181 assert!(rules.validate_numeric(&1.1f64).is_err());
1182 }
1183
1184 #[test]
1185 fn test_validation_rules_string() {
1186 let rules = ValidationRules::new("test_param").add_rule(ValidationRule::OneOf(vec![
1187 "option1".to_string(),
1188 "option2".to_string(),
1189 ]));
1190
1191 assert!(rules.validate_string("option1").is_ok());
1193 assert!(rules.validate_string("option2").is_ok());
1194
1195 assert!(rules.validate_string("option3").is_err());
1197 }
1198
1199 #[test]
1200 fn test_validation_rules_array() {
1201 let rules = ValidationRules::new("test_param")
1202 .add_rule(ValidationRule::MinLength(2))
1203 .add_rule(ValidationRule::MaxLength(5));
1204
1205 assert!(rules.validate_array(&[1, 2]).is_ok());
1207 assert!(rules.validate_array(&[1, 2, 3, 4, 5]).is_ok());
1208
1209 assert!(rules.validate_array(&[1]).is_err());
1211
1212 assert!(rules.validate_array(&[1, 2, 3, 4, 5, 6]).is_err());
1214 }
1215
1216 #[test]
1217 fn test_ml_validation_learning_rate() {
1218 assert!(ml::validate_learning_rate(0.01f64).is_ok());
1220 assert!(ml::validate_learning_rate(0.1f64).is_ok());
1221
1222 assert!(ml::validate_learning_rate(0.0f64).is_err());
1224 assert!(ml::validate_learning_rate(-0.1f64).is_err());
1225
1226 assert!(ml::validate_learning_rate(f64::NAN).is_err());
1228 }
1229
1230 #[test]
1231 fn test_ml_validation_n_clusters() {
1232 assert!(ml::validate_n_clusters(3, 10).is_ok());
1234 assert!(ml::validate_n_clusters(10, 10).is_ok());
1235
1236 assert!(ml::validate_n_clusters(0, 10).is_err());
1238
1239 assert!(ml::validate_n_clusters(15, 10).is_err());
1241 }
1242
1243 #[test]
1244 fn test_ml_validation_probability() {
1245 assert!(ml::validate_probability(0.0f64).is_ok());
1247 assert!(ml::validate_probability(0.5f64).is_ok());
1248 assert!(ml::validate_probability(1.0f64).is_ok());
1249
1250 assert!(ml::validate_probability(-0.1f64).is_err());
1252 assert!(ml::validate_probability(1.1f64).is_err());
1253
1254 assert!(ml::validate_probability(f64::NAN).is_err());
1256 }
1257
1258 #[test]
1259 fn test_validation_context() {
1260 let context = ValidationContext::new("KMeans", "fit").with_data_info(100, 5, "float64");
1261
1262 let error = SklearsError::InvalidParameter {
1263 name: "n_clusters".to_string(),
1264 reason: "must be positive".to_string(),
1265 };
1266
1267 let formatted = context.format_error(&error);
1268 assert!(formatted.contains("KMeans"));
1269 assert!(formatted.contains("fit"));
1270 assert!(formatted.contains("100 samples"));
1271 assert!(formatted.contains("5 features"));
1272 }
1273}