1use crate::error::{Result, SklearsError};
6use crate::types::{Array1, Array2, FloatBounds, Numeric};
7use scirs2_core::numeric::{Float, NumCast};
8use std::fmt::Debug;
9
10pub type ValidationGuardFn = Box<dyn Fn(&dyn std::any::Any) -> Result<bool> + Send + Sync>;
12
13pub type ValidationDestructureFn =
15 Box<dyn Fn(&dyn std::any::Any) -> Result<ValidationResult> + Send + Sync>;
16
17pub trait Validate {
19 fn validate(&self) -> Result<()>;
21
22 fn validate_with_context(&self, context: &str) -> Result<()> {
24 self.validate()
25 .map_err(|e| SklearsError::Other(format!("{context}: {e}")))
26 }
27}
28
29#[derive(Debug, Clone)]
31pub enum ValidationRule {
32 Positive,
34 NonNegative,
36 Finite,
38 Range { min: f64, max: f64 },
40 OneOf(Vec<String>),
42 MinLength(usize),
44 MaxLength(usize),
46 UniqueElements,
48 Custom(fn(&dyn std::any::Any) -> Result<()>),
50 PatternGuard(PatternGuardRule),
52}
53
54pub struct PatternGuardRule {
56 pub pattern_name: String,
58 pub guard_fn: ValidationGuardFn,
60 pub error_message: String,
62 pub destructure_fn: Option<ValidationDestructureFn>,
64}
65
66impl std::fmt::Debug for PatternGuardRule {
67 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68 f.debug_struct("PatternGuardRule")
69 .field("pattern_name", &self.pattern_name)
70 .field("guard_fn", &"<function>")
71 .field("error_message", &self.error_message)
72 .field(
73 "destructure_fn",
74 &self.destructure_fn.as_ref().map(|_| "<function>"),
75 )
76 .finish()
77 }
78}
79
80impl Clone for PatternGuardRule {
81 fn clone(&self) -> Self {
82 Self {
85 pattern_name: self.pattern_name.clone(),
86 guard_fn: Box::new(|_| Ok(true)), error_message: self.error_message.clone(),
88 destructure_fn: None,
89 }
90 }
91}
92
93#[derive(Debug, Clone)]
95pub struct ValidationResult {
96 pub matched: bool,
98 pub context: std::collections::HashMap<String, String>,
100 pub warnings: Vec<String>,
102}
103
104#[macro_export]
106macro_rules! pattern_guard {
107 (numeric_range, $min:expr, $max:expr) => {
109 $crate::validation::PatternGuardRule {
110 pattern_name: "numeric_range".to_string(),
111 guard_fn: Box::new(move |value| {
112 if let Some(val) = value.downcast_ref::<f64>() {
113 Ok(*val >= $min && *val <= $max)
114 } else if let Some(val) = value.downcast_ref::<f32>() {
115 Ok(*val >= $min as f32 && *val <= $max as f32)
116 } else if let Some(val) = value.downcast_ref::<i32>() {
117 Ok(*val >= $min as i32 && *val <= $max as i32)
118 } else if let Some(val) = value.downcast_ref::<usize>() {
119 Ok(*val >= $min as usize && *val <= $max as usize)
120 } else {
121 Ok(false)
122 }
123 }),
124 error_message: format!("Value must be in range [{}, {}]", $min, $max),
125 destructure_fn: None,
126 }
127 };
128
129 (array_shape, $expected_shape:expr) => {
131 $crate::validation::PatternGuardRule {
132 pattern_name: "array_shape".to_string(),
133 guard_fn: Box::new(move |value| {
134 Ok(true)
137 }),
138 error_message: format!("Array shape must match {:?}", $expected_shape),
139 destructure_fn: None,
140 }
141 };
142
143 (string_enum, $valid_options:expr) => {
145 $crate::validation::PatternGuardRule {
146 pattern_name: "string_enum".to_string(),
147 guard_fn: Box::new(move |value| {
148 if let Some(val) = value.downcast_ref::<String>() {
149 Ok($valid_options.contains(&val.as_str()))
150 } else if let Some(val) = value.downcast_ref::<&str>() {
151 Ok($valid_options.contains(val))
152 } else {
153 Ok(false)
154 }
155 }),
156 error_message: format!("Value must be one of {:?}", $valid_options),
157 destructure_fn: None,
158 }
159 };
160
161 ($pattern_name:literal, $guard:expr, $error_msg:literal) => {
163 $crate::validation::PatternGuardRule {
164 pattern_name: $pattern_name.to_string(),
165 guard_fn: Box::new($guard),
166 error_message: $error_msg.to_string(),
167 destructure_fn: None,
168 }
169 };
170
171 ($pattern_name:literal, $guard_fn:expr, $destructure_fn:expr) => {
173 $crate::validation::PatternGuardRule {
174 pattern_name: $pattern_name.to_string(),
175 guard_fn: Box::new($guard_fn),
176 error_message: format!("Pattern '{}' validation failed", $pattern_name),
177 destructure_fn: Some(Box::new($destructure_fn)),
178 }
179 };
180}
181
182pub trait PatternValidate {
184 fn validate_with_pattern(&self, guard: &PatternGuardRule) -> Result<ValidationResult>;
186
187 fn matches_pattern(&self, pattern_name: &str) -> bool;
189
190 fn destructure(&self, pattern: &str) -> Result<std::collections::HashMap<String, String>>;
192}
193
194impl PatternValidate for f64 {
196 fn validate_with_pattern(&self, guard: &PatternGuardRule) -> Result<ValidationResult> {
197 let value_any = self as &dyn std::any::Any;
198 let matched = (guard.guard_fn)(value_any)?;
199
200 let mut context = std::collections::HashMap::new();
201 context.insert("value".to_string(), self.to_string());
202 context.insert("type".to_string(), "f64".to_string());
203
204 if let Some(destructure_fn) = &guard.destructure_fn {
205 let destructure_result = destructure_fn(value_any)?;
206 Ok(ValidationResult {
207 matched: matched && destructure_result.matched,
208 context: destructure_result.context,
209 warnings: destructure_result.warnings,
210 })
211 } else {
212 Ok(ValidationResult {
213 matched,
214 context,
215 warnings: Vec::new(),
216 })
217 }
218 }
219
220 fn matches_pattern(&self, pattern_name: &str) -> bool {
221 match pattern_name {
222 "finite" => self.is_finite(),
223 "positive" => *self > 0.0,
224 "non_negative" => *self >= 0.0,
225 "probability" => *self >= 0.0 && *self <= 1.0,
226 _ => false,
227 }
228 }
229
230 fn destructure(&self, pattern: &str) -> Result<std::collections::HashMap<String, String>> {
231 let mut result = std::collections::HashMap::new();
232 match pattern {
233 "range_info" => {
234 result.insert("value".to_string(), self.to_string());
235 result.insert("is_finite".to_string(), self.is_finite().to_string());
236 result.insert("is_positive".to_string(), (*self > 0.0).to_string());
237 result.insert(
238 "sign".to_string(),
239 if *self >= 0.0 {
240 "positive".to_string()
241 } else {
242 "negative".to_string()
243 },
244 );
245 }
246 _ => {
247 result.insert("value".to_string(), self.to_string());
248 }
249 }
250 Ok(result)
251 }
252}
253
254impl PatternValidate for usize {
256 fn validate_with_pattern(&self, guard: &PatternGuardRule) -> Result<ValidationResult> {
257 let value_any = self as &dyn std::any::Any;
258 let matched = (guard.guard_fn)(value_any)?;
259
260 let mut context = std::collections::HashMap::new();
261 context.insert("value".to_string(), self.to_string());
262 context.insert("type".to_string(), "usize".to_string());
263
264 Ok(ValidationResult {
265 matched,
266 context,
267 warnings: Vec::new(),
268 })
269 }
270
271 fn matches_pattern(&self, pattern_name: &str) -> bool {
272 match pattern_name {
273 "positive" => *self > 0,
274 "non_negative" => true, "power_of_two" => self.is_power_of_two(),
276 _ => false,
277 }
278 }
279
280 fn destructure(&self, pattern: &str) -> Result<std::collections::HashMap<String, String>> {
281 let mut result = std::collections::HashMap::new();
282 match pattern {
283 "number_info" => {
284 result.insert("value".to_string(), self.to_string());
285 result.insert("is_positive".to_string(), (*self > 0).to_string());
286 result.insert(
287 "is_power_of_two".to_string(),
288 self.is_power_of_two().to_string(),
289 );
290 }
291 _ => {
292 result.insert("value".to_string(), self.to_string());
293 }
294 }
295 Ok(result)
296 }
297}
298
299impl PatternValidate for String {
301 fn validate_with_pattern(&self, guard: &PatternGuardRule) -> Result<ValidationResult> {
302 let value_any = self as &dyn std::any::Any;
303 let matched = (guard.guard_fn)(value_any)?;
304
305 let mut context = std::collections::HashMap::new();
306 context.insert("value".to_string(), self.clone());
307 context.insert("type".to_string(), "String".to_string());
308 context.insert("length".to_string(), self.len().to_string());
309
310 Ok(ValidationResult {
311 matched,
312 context,
313 warnings: Vec::new(),
314 })
315 }
316
317 fn matches_pattern(&self, pattern_name: &str) -> bool {
318 match pattern_name {
319 "non_empty" => !self.is_empty(),
320 "alphanumeric" => self.chars().all(|c| c.is_alphanumeric()),
321 "lowercase" => self.chars().all(|c| !c.is_alphabetic() || c.is_lowercase()),
322 "uppercase" => self.chars().all(|c| !c.is_alphabetic() || c.is_uppercase()),
323 _ => false,
324 }
325 }
326
327 fn destructure(&self, pattern: &str) -> Result<std::collections::HashMap<String, String>> {
328 let mut result = std::collections::HashMap::new();
329 match pattern {
330 "string_info" => {
331 result.insert("value".to_string(), self.clone());
332 result.insert("length".to_string(), self.len().to_string());
333 result.insert("is_empty".to_string(), self.is_empty().to_string());
334 result.insert(
335 "is_alphanumeric".to_string(),
336 self.chars().all(|c| c.is_alphanumeric()).to_string(),
337 );
338 }
339 _ => {
340 result.insert("value".to_string(), self.clone());
341 }
342 }
343 Ok(result)
344 }
345}
346
347pub mod pattern_guards {
349 use super::*;
350
351 pub fn hyperparameter_pattern<T: FloatBounds + std::fmt::Debug>(
353 min_val: T,
354 max_val: T,
355 finite_required: bool,
356 ) -> PatternGuardRule {
357 PatternGuardRule {
358 pattern_name: "hyperparameter_bounds".to_string(),
359 guard_fn: Box::new(|_value| {
360 Ok(true) }),
363 error_message: format!(
364 "Hyperparameter must be in range [{}, {}]{}",
365 min_val,
366 max_val,
367 if finite_required { " and finite" } else { "" }
368 ),
369 destructure_fn: None,
370 }
371 }
372
373 pub fn array_shape_pattern(expected_dims: &[usize]) -> PatternGuardRule {
375 let dims_str = expected_dims
376 .iter()
377 .map(|d| d.to_string())
378 .collect::<Vec<_>>()
379 .join(", ");
380
381 PatternGuardRule {
382 pattern_name: "array_shape".to_string(),
383 guard_fn: Box::new(|_value| {
384 Ok(true)
386 }),
387 error_message: format!("Array shape must match [{dims_str}]"),
388 destructure_fn: None, }
390 }
391
392 pub fn algorithm_config_pattern(required_fields: &[&str]) -> PatternGuardRule {
394 let fields_str = required_fields.join(", ");
395
396 PatternGuardRule {
397 pattern_name: "algorithm_config".to_string(),
398 guard_fn: Box::new(|_value| {
399 Ok(true)
401 }),
402 error_message: format!("Configuration must contain fields: {fields_str}"),
403 destructure_fn: None, }
405 }
406
407 pub fn data_type_pattern(expected_types: &[&str]) -> PatternGuardRule {
409 let types_str = expected_types.join(" | ");
410
411 PatternGuardRule {
412 pattern_name: "data_type_consistency".to_string(),
413 guard_fn: Box::new(|_value| {
414 Ok(true)
416 }),
417 error_message: format!("Data type must be one of: {types_str}"),
418 destructure_fn: None,
419 }
420 }
421}
422
423#[derive(Debug, Clone)]
425pub struct ValidationRules {
426 pub rules: Vec<ValidationRule>,
427 pub field_name: String,
428}
429
430impl ValidationRules {
431 pub fn new(field_name: &str) -> Self {
433 Self {
434 rules: Vec::new(),
435 field_name: field_name.to_string(),
436 }
437 }
438
439 pub fn add_rule(mut self, rule: ValidationRule) -> Self {
441 self.rules.push(rule);
442 self
443 }
444
445 pub fn validate_numeric<T>(&self, value: &T) -> Result<()>
447 where
448 T: Numeric + PartialOrd + Debug + Copy + NumCast,
449 {
450 for rule in &self.rules {
451 match rule {
452 ValidationRule::Positive if *value <= T::zero() => {
453 return Err(SklearsError::InvalidParameter {
454 name: self.field_name.clone(),
455 reason: "must be positive".to_string(),
456 });
457 }
458 ValidationRule::Positive => {}
459 ValidationRule::NonNegative if *value < T::zero() => {
460 return Err(SklearsError::InvalidParameter {
461 name: self.field_name.clone(),
462 reason: "must be non-negative".to_string(),
463 });
464 }
465 ValidationRule::NonNegative => {}
466 ValidationRule::Finite => {
467 if let Some(float_val) = NumCast::from(*value) {
468 let f: f64 = float_val;
469 if !f.is_finite() {
470 return Err(SklearsError::InvalidParameter {
471 name: self.field_name.clone(),
472 reason: "must be finite".to_string(),
473 });
474 }
475 }
476 }
477 ValidationRule::Range { min, max } => {
478 if let Some(float_val) = NumCast::from(*value) {
479 let f: f64 = float_val;
480 if f < *min || f > *max {
481 return Err(SklearsError::InvalidParameter {
482 name: self.field_name.clone(),
483 reason: format!("must be in range [{min}, {max}]"),
484 });
485 }
486 }
487 }
488 ValidationRule::PatternGuard(pattern_guard) => {
489 let value_any: &dyn std::any::Any = value;
494 let passes = (pattern_guard.guard_fn)(value_any)?;
495 if !passes {
496 return Err(SklearsError::InvalidParameter {
497 name: self.field_name.clone(),
498 reason: pattern_guard.error_message.clone(),
499 });
500 }
501 }
502 _ => {
503 }
505 }
506 }
507 Ok(())
508 }
509
510 pub fn validate_string(&self, value: &str) -> Result<()> {
512 for rule in &self.rules {
513 match rule {
514 ValidationRule::OneOf(options) if !options.contains(&value.to_string()) => {
515 return Err(SklearsError::InvalidParameter {
516 name: self.field_name.clone(),
517 reason: format!("must be one of {options:?}"),
518 });
519 }
520 ValidationRule::OneOf(_) => {}
521 ValidationRule::PatternGuard(pattern_guard) => {
522 let owned: String = value.to_owned();
524 let value_any: &dyn std::any::Any = &owned;
525 let passes = (pattern_guard.guard_fn)(value_any)?;
526 if !passes {
527 return Err(SklearsError::InvalidParameter {
528 name: self.field_name.clone(),
529 reason: pattern_guard.error_message.clone(),
530 });
531 }
532 }
533 _ => {
534 }
536 }
537 }
538 Ok(())
539 }
540
541 pub fn validate_array<T>(&self, value: &[T]) -> Result<()> {
543 for rule in &self.rules {
544 match rule {
545 ValidationRule::MinLength(min_len) if value.len() < *min_len => {
546 return Err(SklearsError::InvalidParameter {
547 name: self.field_name.clone(),
548 reason: format!("must have at least {min_len} elements"),
549 });
550 }
551 ValidationRule::MinLength(_) => {}
552 ValidationRule::MaxLength(max_len) if value.len() > *max_len => {
553 return Err(SklearsError::InvalidParameter {
554 name: self.field_name.clone(),
555 reason: format!("must have at most {max_len} elements"),
556 });
557 }
558 ValidationRule::MaxLength(_) => {}
559 ValidationRule::PatternGuard(pattern_guard) => {
560 let len: usize = value.len();
562 let value_any: &dyn std::any::Any = &len;
563 let passes = (pattern_guard.guard_fn)(value_any)?;
564 if !passes {
565 return Err(SklearsError::InvalidParameter {
566 name: self.field_name.clone(),
567 reason: pattern_guard.error_message.clone(),
568 });
569 }
570 }
571 _ => {
572 }
574 }
575 }
576 Ok(())
577 }
578
579 pub fn validate_usize(&self, value: &usize) -> Result<()> {
581 for rule in &self.rules {
582 match rule {
583 ValidationRule::Positive if *value == 0 => {
584 return Err(SklearsError::InvalidParameter {
585 name: self.field_name.clone(),
586 reason: "must be positive".to_string(),
587 });
588 }
589 ValidationRule::Positive => {}
590 ValidationRule::NonNegative => {
591 }
593 ValidationRule::Range { min, max } => {
594 let val = *value as f64;
595 if val < *min || val > *max {
596 return Err(SklearsError::InvalidParameter {
597 name: self.field_name.clone(),
598 reason: format!("must be in range [{min}, {max}]"),
599 });
600 }
601 }
602 _ => {
603 }
605 }
606 }
607 Ok(())
608 }
609}
610
611pub mod ml {
613 use super::*;
614
615 pub fn validate_learning_rate<T: FloatBounds>(lr: T) -> Result<()> {
617 if lr <= T::zero() {
618 return Err(SklearsError::InvalidParameter {
619 name: "learning_rate".to_string(),
620 reason: "must be positive".to_string(),
621 });
622 }
623
624 if !Float::is_finite(lr) {
625 return Err(SklearsError::InvalidParameter {
626 name: "learning_rate".to_string(),
627 reason: "must be finite".to_string(),
628 });
629 }
630
631 if lr > T::one() {
633 log::warn!("Learning rate {lr} is unusually high, consider using a smaller value");
634 }
635
636 Ok(())
637 }
638
639 pub fn validate_regularization<T: FloatBounds>(reg: T) -> Result<()> {
641 if reg < T::zero() {
642 return Err(SklearsError::InvalidParameter {
643 name: "regularization".to_string(),
644 reason: "must be non-negative".to_string(),
645 });
646 }
647
648 if !Float::is_finite(reg) {
649 return Err(SklearsError::InvalidParameter {
650 name: "regularization".to_string(),
651 reason: "must be finite".to_string(),
652 });
653 }
654
655 Ok(())
656 }
657
658 pub fn validate_n_clusters(n_clusters: usize, n_samples: usize) -> Result<()> {
660 if n_clusters == 0 {
661 return Err(SklearsError::InvalidParameter {
662 name: "n_clusters".to_string(),
663 reason: "must be positive".to_string(),
664 });
665 }
666
667 if n_clusters > n_samples {
668 return Err(SklearsError::InvalidParameter {
669 name: "n_clusters".to_string(),
670 reason: format!("cannot exceed number of samples ({n_samples})"),
671 });
672 }
673
674 Ok(())
675 }
676
677 pub fn validate_n_neighbors(n_neighbors: usize, n_samples: usize) -> Result<()> {
679 if n_neighbors == 0 {
680 return Err(SklearsError::InvalidParameter {
681 name: "n_neighbors".to_string(),
682 reason: "must be positive".to_string(),
683 });
684 }
685
686 if n_neighbors > n_samples {
687 return Err(SklearsError::InvalidParameter {
688 name: "n_neighbors".to_string(),
689 reason: format!("cannot exceed number of samples ({n_samples})"),
690 });
691 }
692
693 Ok(())
694 }
695
696 pub fn validate_tolerance<T: FloatBounds>(tol: T) -> Result<()> {
698 if tol <= T::zero() {
699 return Err(SklearsError::InvalidParameter {
700 name: "tolerance".to_string(),
701 reason: "must be positive".to_string(),
702 });
703 }
704
705 if !Float::is_finite(tol) {
706 return Err(SklearsError::InvalidParameter {
707 name: "tolerance".to_string(),
708 reason: "must be finite".to_string(),
709 });
710 }
711
712 if tol > T::from(0.1).unwrap_or(T::one()) {
714 log::warn!("Tolerance {tol} is very large, algorithm may converge prematurely");
715 }
716
717 Ok(())
718 }
719
720 pub fn validate_max_iter(max_iter: usize) -> Result<()> {
722 if max_iter == 0 {
723 return Err(SklearsError::InvalidParameter {
724 name: "max_iter".to_string(),
725 reason: "must be positive".to_string(),
726 });
727 }
728
729 Ok(())
730 }
731
732 pub fn validate_probability<T: FloatBounds>(prob: T) -> Result<()> {
734 if prob < T::zero() || prob > T::one() {
735 return Err(SklearsError::InvalidParameter {
736 name: "probability".to_string(),
737 reason: "must be in range [0, 1]".to_string(),
738 });
739 }
740
741 if !Float::is_finite(prob) {
742 return Err(SklearsError::InvalidParameter {
743 name: "probability".to_string(),
744 reason: "must be finite".to_string(),
745 });
746 }
747
748 Ok(())
749 }
750
751 pub fn validate_supervised_data<T, U>(x: &Array2<T>, y: &Array1<U>) -> Result<()> {
753 if x.is_empty() {
754 return Err(SklearsError::InvalidData {
755 reason: "X cannot be empty".to_string(),
756 });
757 }
758
759 if y.is_empty() {
760 return Err(SklearsError::InvalidData {
761 reason: "y cannot be empty".to_string(),
762 });
763 }
764
765 if x.nrows() != y.len() {
766 return Err(SklearsError::ShapeMismatch {
767 expected: "X.shape[0] == y.shape[0]".to_string(),
768 actual: format!("X.shape[0]={}, y.shape[0]={}", x.nrows(), y.len()),
769 });
770 }
771
772 Ok(())
773 }
774
775 pub fn validate_unsupervised_data<T>(x: &Array2<T>) -> Result<()> {
777 if x.is_empty() {
778 return Err(SklearsError::InvalidData {
779 reason: "X cannot be empty".to_string(),
780 });
781 }
782
783 if x.nrows() == 0 || x.ncols() == 0 {
784 return Err(SklearsError::InvalidData {
785 reason: "X must have positive dimensions".to_string(),
786 });
787 }
788
789 Ok(())
790 }
791
792 pub fn validate_feature_consistency<T, U>(
794 x_train: &Array2<T>,
795 x_test: &Array2<U>,
796 _model_name: &str,
797 ) -> Result<()> {
798 if x_train.ncols() != x_test.ncols() {
799 return Err(SklearsError::FeatureMismatch {
800 expected: x_train.ncols(),
801 actual: x_test.ncols(),
802 });
803 }
804
805 Ok(())
806 }
807}
808
809pub mod derive_helpers {
811 pub fn generate_field_validation(
813 field_name: &str,
814 _field_type: &str,
815 validation_attrs: &[String],
816 ) -> String {
817 let mut validations = Vec::new();
818
819 for attr in validation_attrs {
820 match attr.as_str() {
821 "positive" => {
822 validations.push(format!(
823 "ValidationRules::new(\"{field_name}\").add_rule(ValidationRule::Positive).validate_numeric(&self.{field_name})?;"
824 ));
825 }
826 "non_negative" => {
827 validations.push(format!(
828 "ValidationRules::new(\"{field_name}\").add_rule(ValidationRule::NonNegative).validate_numeric(&self.{field_name})?;"
829 ));
830 }
831 "finite" => {
832 validations.push(format!(
833 "ValidationRules::new(\"{field_name}\").add_rule(ValidationRule::Finite).validate_numeric(&self.{field_name})?;"
834 ));
835 }
836 _ if attr.starts_with("range(") => {
837 let range_str = attr
839 .strip_prefix("range(")
840 .expect("expected valid value")
841 .strip_suffix(")")
842 .expect("expected valid value");
843 let parts: Vec<&str> = range_str.split(',').map(|s| s.trim()).collect();
844 if parts.len() == 2 {
845 let min_val = parts[0];
846 let max_val = parts[1];
847 validations.push(format!(
848 "ValidationRules::new(\"{field_name}\").add_rule(ValidationRule::Range {{ min: {min_val}, max: {max_val} }}).validate_numeric(&self.{field_name})?;"
849 ));
850 }
851 }
852 _ => {}
853 }
854 }
855
856 validations.join("\n")
857 }
858}
859
860pub trait ConfigValidation {
862 fn validate_config(&self) -> Result<()>;
864
865 fn get_warnings(&self) -> Vec<String> {
867 Vec::new()
868 }
869}
870
871#[derive(Debug, Clone)]
873pub struct ValidationContext {
874 pub algorithm: String,
875 pub operation: String,
876 pub data_info: Option<DataInfo>,
877}
878
879#[derive(Debug, Clone)]
881pub struct DataInfo {
882 pub n_samples: usize,
883 pub n_features: usize,
884 pub data_type: String,
885}
886
887impl ValidationContext {
888 pub fn new(algorithm: &str, operation: &str) -> Self {
890 Self {
891 algorithm: algorithm.to_string(),
892 operation: operation.to_string(),
893 data_info: None,
894 }
895 }
896
897 pub fn with_data_info(mut self, n_samples: usize, n_features: usize, data_type: &str) -> Self {
899 self.data_info = Some(DataInfo {
900 n_samples,
901 n_features,
902 data_type: data_type.to_string(),
903 });
904 self
905 }
906
907 pub fn format_error(&self, error: &SklearsError) -> String {
909 let mut msg = format!(
910 "Error in {} during {}: {error}",
911 self.algorithm, self.operation
912 );
913
914 if let Some(data_info) = &self.data_info {
915 msg.push_str(&format!(
916 " (data: {} samples, {} features, type: {})",
917 data_info.n_samples, data_info.n_features, data_info.data_type
918 ));
919 }
920
921 msg
922 }
923}
924
925pub mod structured_destructuring {
927 use super::*;
928
929 pub trait StructuredDestructure {
931 fn destructure_into_components(
933 &self,
934 ) -> Result<std::collections::HashMap<String, Box<dyn std::any::Any>>>;
935
936 fn extract_field(&self, field_path: &str) -> Result<Box<dyn std::any::Any>>;
938
939 fn validate_structure(&self, schema: &StructuralSchema) -> Result<ValidationResult>;
941 }
942
943 #[derive(Debug, Clone, Default)]
945 pub struct StructuralSchema {
946 pub required_fields: Vec<String>,
947 pub optional_fields: Vec<String>,
948 pub field_types: std::collections::HashMap<String, String>,
949 pub nested_schemas: std::collections::HashMap<String, StructuralSchema>,
950 }
951
952 impl StructuralSchema {
953 pub fn new() -> Self {
954 Self::default()
955 }
956
957 pub fn require_field(mut self, field_name: &str, field_type: &str) -> Self {
958 self.required_fields.push(field_name.to_string());
959 self.field_types
960 .insert(field_name.to_string(), field_type.to_string());
961 self
962 }
963
964 pub fn optional_field(mut self, field_name: &str, field_type: &str) -> Self {
965 self.optional_fields.push(field_name.to_string());
966 self.field_types
967 .insert(field_name.to_string(), field_type.to_string());
968 self
969 }
970
971 pub fn nested_schema(mut self, field_name: &str, schema: StructuralSchema) -> Self {
972 self.nested_schemas.insert(field_name.to_string(), schema);
973 self
974 }
975 }
976
977 #[derive(Debug, Clone)]
979 pub struct AlgorithmConfig {
980 pub algorithm_name: String,
981 pub hyperparameters: std::collections::HashMap<String, ConfigValue>,
982 pub metadata: std::collections::HashMap<String, String>,
983 }
984
985 #[derive(Debug, Clone)]
987 pub enum ConfigValue {
988 Float(f64),
989 Integer(i64),
990 String(String),
991 Boolean(bool),
992 Array(Vec<ConfigValue>),
993 Object(std::collections::HashMap<String, ConfigValue>),
994 }
995
996 impl StructuredDestructure for AlgorithmConfig {
997 fn destructure_into_components(
998 &self,
999 ) -> Result<std::collections::HashMap<String, Box<dyn std::any::Any>>> {
1000 let mut components = std::collections::HashMap::new();
1001
1002 components.insert(
1003 "algorithm_name".to_string(),
1004 Box::new(self.algorithm_name.clone()) as Box<dyn std::any::Any>,
1005 );
1006 components.insert(
1007 "hyperparameters".to_string(),
1008 Box::new(self.hyperparameters.clone()) as Box<dyn std::any::Any>,
1009 );
1010 components.insert(
1011 "metadata".to_string(),
1012 Box::new(self.metadata.clone()) as Box<dyn std::any::Any>,
1013 );
1014
1015 Ok(components)
1016 }
1017
1018 fn extract_field(&self, field_path: &str) -> Result<Box<dyn std::any::Any>> {
1019 let parts: Vec<&str> = field_path.split('.').collect();
1020
1021 match parts.first() {
1022 Some(&"algorithm_name") => Ok(Box::new(self.algorithm_name.clone())),
1023 Some(&"hyperparameters") => {
1024 if parts.len() > 1 {
1025 if let Some(param_value) = self.hyperparameters.get(parts[1]) {
1026 Ok(Box::new(param_value.clone()))
1027 } else {
1028 Err(SklearsError::InvalidParameter {
1029 name: field_path.to_string(),
1030 reason: format!("Hyperparameter '{}' not found", parts[1]),
1031 })
1032 }
1033 } else {
1034 Ok(Box::new(self.hyperparameters.clone()))
1035 }
1036 }
1037 Some(&"metadata") => {
1038 if parts.len() > 1 {
1039 if let Some(meta_value) = self.metadata.get(parts[1]) {
1040 Ok(Box::new(meta_value.clone()))
1041 } else {
1042 Err(SklearsError::InvalidParameter {
1043 name: field_path.to_string(),
1044 reason: format!("Metadata '{}' not found", parts[1]),
1045 })
1046 }
1047 } else {
1048 Ok(Box::new(self.metadata.clone()))
1049 }
1050 }
1051 _ => Err(SklearsError::InvalidParameter {
1052 name: field_path.to_string(),
1053 reason: "Invalid field path".to_string(),
1054 }),
1055 }
1056 }
1057
1058 fn validate_structure(&self, schema: &StructuralSchema) -> Result<ValidationResult> {
1059 let mut warnings = Vec::new();
1060 let mut context = std::collections::HashMap::new();
1061
1062 for required_field in &schema.required_fields {
1064 match required_field.as_str() {
1065 "algorithm_name" => {
1066 if self.algorithm_name.is_empty() {
1067 return Err(SklearsError::InvalidParameter {
1068 name: "algorithm_name".to_string(),
1069 reason: "Required field cannot be empty".to_string(),
1070 });
1071 }
1072 context.insert("algorithm_name".to_string(), "present".to_string());
1073 }
1074 "hyperparameters" => {
1075 context.insert(
1076 "hyperparameters_count".to_string(),
1077 self.hyperparameters.len().to_string(),
1078 );
1079 }
1080 _ => {
1081 warnings.push(format!("Unknown required field: {required_field}"));
1082 }
1083 }
1084 }
1085
1086 Ok(ValidationResult {
1087 matched: true,
1088 context,
1089 warnings,
1090 })
1091 }
1092 }
1093
1094 pub fn create_complex_pattern_guard<T>(
1096 pattern_name: &str,
1097 validator: impl Fn(&T) -> Result<bool> + Send + Sync + 'static,
1098 error_message: &str,
1099 ) -> PatternGuardRule
1100 where
1101 T: 'static,
1102 {
1103 PatternGuardRule {
1104 pattern_name: pattern_name.to_string(),
1105 guard_fn: Box::new(move |value| {
1106 if let Some(typed_value) = value.downcast_ref::<T>() {
1107 validator(typed_value)
1108 } else {
1109 Ok(false)
1110 }
1111 }),
1112 error_message: error_message.to_string(),
1113 destructure_fn: None,
1114 }
1115 }
1116}
1117
1118#[macro_export]
1120macro_rules! destructure {
1121 ($obj:expr, $field:literal) => {
1123 $obj.extract_field($field)
1124 };
1125
1126 ($obj:expr, { $($field:literal),* }) => {
1128 {
1129 let mut results = std::collections::HashMap::new();
1130 $(
1131 if let Ok(value) = $obj.extract_field($field) {
1132 results.insert($field.to_string(), value);
1133 }
1134 )*
1135 results
1136 }
1137 };
1138
1139 ($obj:expr, validate: $schema:expr) => {
1141 $obj.validate_structure(&$schema)
1142 };
1143}
1144
1145#[allow(non_snake_case)]
1146#[cfg(test)]
1147mod tests {
1148 use super::*;
1149
1150 #[test]
1151 fn test_validation_rules_numeric() {
1152 let rules = ValidationRules::new("test_param")
1153 .add_rule(ValidationRule::Positive)
1154 .add_rule(ValidationRule::Finite);
1155
1156 assert!(rules.validate_numeric(&1.5f64).is_ok());
1158
1159 assert!(rules.validate_numeric(&0.0f64).is_err());
1161 assert!(rules.validate_numeric(&-1.0f64).is_err());
1162
1163 assert!(rules.validate_numeric(&f64::NAN).is_err());
1165 assert!(rules.validate_numeric(&f64::INFINITY).is_err());
1166 }
1167
1168 #[test]
1169 fn test_validation_rules_range() {
1170 let rules = ValidationRules::new("test_param")
1171 .add_rule(ValidationRule::Range { min: 0.0, max: 1.0 });
1172
1173 assert!(rules.validate_numeric(&0.5f64).is_ok());
1175 assert!(rules.validate_numeric(&0.0f64).is_ok());
1176 assert!(rules.validate_numeric(&1.0f64).is_ok());
1177
1178 assert!(rules.validate_numeric(&-0.1f64).is_err());
1180 assert!(rules.validate_numeric(&1.1f64).is_err());
1181 }
1182
1183 #[test]
1184 fn test_validation_rules_string() {
1185 let rules = ValidationRules::new("test_param").add_rule(ValidationRule::OneOf(vec![
1186 "option1".to_string(),
1187 "option2".to_string(),
1188 ]));
1189
1190 assert!(rules.validate_string("option1").is_ok());
1192 assert!(rules.validate_string("option2").is_ok());
1193
1194 assert!(rules.validate_string("option3").is_err());
1196 }
1197
1198 #[test]
1199 fn test_validation_rules_array() {
1200 let rules = ValidationRules::new("test_param")
1201 .add_rule(ValidationRule::MinLength(2))
1202 .add_rule(ValidationRule::MaxLength(5));
1203
1204 assert!(rules.validate_array(&[1, 2]).is_ok());
1206 assert!(rules.validate_array(&[1, 2, 3, 4, 5]).is_ok());
1207
1208 assert!(rules.validate_array(&[1]).is_err());
1210
1211 assert!(rules.validate_array(&[1, 2, 3, 4, 5, 6]).is_err());
1213 }
1214
1215 #[test]
1216 fn test_ml_validation_learning_rate() {
1217 assert!(ml::validate_learning_rate(0.01f64).is_ok());
1219 assert!(ml::validate_learning_rate(0.1f64).is_ok());
1220
1221 assert!(ml::validate_learning_rate(0.0f64).is_err());
1223 assert!(ml::validate_learning_rate(-0.1f64).is_err());
1224
1225 assert!(ml::validate_learning_rate(f64::NAN).is_err());
1227 }
1228
1229 #[test]
1230 fn test_ml_validation_n_clusters() {
1231 assert!(ml::validate_n_clusters(3, 10).is_ok());
1233 assert!(ml::validate_n_clusters(10, 10).is_ok());
1234
1235 assert!(ml::validate_n_clusters(0, 10).is_err());
1237
1238 assert!(ml::validate_n_clusters(15, 10).is_err());
1240 }
1241
1242 #[test]
1243 fn test_ml_validation_probability() {
1244 assert!(ml::validate_probability(0.0f64).is_ok());
1246 assert!(ml::validate_probability(0.5f64).is_ok());
1247 assert!(ml::validate_probability(1.0f64).is_ok());
1248
1249 assert!(ml::validate_probability(-0.1f64).is_err());
1251 assert!(ml::validate_probability(1.1f64).is_err());
1252
1253 assert!(ml::validate_probability(f64::NAN).is_err());
1255 }
1256
1257 #[test]
1258 fn test_validation_context() {
1259 let context = ValidationContext::new("KMeans", "fit").with_data_info(100, 5, "float64");
1260
1261 let error = SklearsError::InvalidParameter {
1262 name: "n_clusters".to_string(),
1263 reason: "must be positive".to_string(),
1264 };
1265
1266 let formatted = context.format_error(&error);
1267 assert!(formatted.contains("KMeans"));
1268 assert!(formatted.contains("fit"));
1269 assert!(formatted.contains("100 samples"));
1270 assert!(formatted.contains("5 features"));
1271 }
1272
1273 #[test]
1278 fn test_pattern_guard_numeric_passes() {
1279 let guard = PatternGuardRule {
1281 pattern_name: "even_number".to_string(),
1282 guard_fn: Box::new(|value: &dyn std::any::Any| {
1283 if let Some(v) = value.downcast_ref::<f64>() {
1284 Ok(*v as i64 % 2 == 0)
1285 } else {
1286 Ok(false)
1287 }
1288 }),
1289 error_message: "must be an even number".to_string(),
1290 destructure_fn: None,
1291 };
1292
1293 let rules =
1294 ValidationRules::new("even_param").add_rule(ValidationRule::PatternGuard(guard));
1295
1296 assert!(rules.validate_numeric(&4.0f64).is_ok());
1298 assert!(rules.validate_numeric(&3.0f64).is_err());
1300 }
1301
1302 #[test]
1303 fn test_pattern_guard_string_passes() {
1304 let guard = PatternGuardRule {
1306 pattern_name: "no_leading_digit".to_string(),
1307 guard_fn: Box::new(|value: &dyn std::any::Any| {
1308 if let Some(s) = value.downcast_ref::<String>() {
1309 Ok(!s.starts_with(char::is_numeric))
1310 } else {
1311 Ok(false)
1312 }
1313 }),
1314 error_message: "must not start with a digit".to_string(),
1315 destructure_fn: None,
1316 };
1317
1318 let rules =
1319 ValidationRules::new("identifier").add_rule(ValidationRule::PatternGuard(guard));
1320
1321 assert!(rules.validate_string("alpha_param").is_ok());
1323 assert!(rules.validate_string("1_bad").is_err());
1325 }
1326
1327 #[test]
1328 fn test_pattern_guard_array_length() {
1329 let guard = PatternGuardRule {
1331 pattern_name: "odd_length".to_string(),
1332 guard_fn: Box::new(|value: &dyn std::any::Any| {
1333 if let Some(len) = value.downcast_ref::<usize>() {
1335 Ok(len % 2 == 1)
1336 } else {
1337 Ok(false)
1338 }
1339 }),
1340 error_message: "array must have an odd number of elements".to_string(),
1341 destructure_fn: None,
1342 };
1343
1344 let rules = ValidationRules::new("odd_array").add_rule(ValidationRule::PatternGuard(guard));
1345
1346 assert!(rules.validate_array(&[1, 2, 3]).is_ok());
1348 assert!(rules.validate_array(&[1, 2, 3, 4]).is_err());
1350 }
1351
1352 #[test]
1353 fn test_pattern_guard_error_message_propagated() {
1354 let expected_reason = "value must be the answer to everything";
1355 let guard = PatternGuardRule {
1356 pattern_name: "answer".to_string(),
1357 guard_fn: Box::new(|value: &dyn std::any::Any| {
1358 if let Some(v) = value.downcast_ref::<f64>() {
1359 Ok((*v - 42.0).abs() < f64::EPSILON)
1360 } else {
1361 Ok(false)
1362 }
1363 }),
1364 error_message: expected_reason.to_string(),
1365 destructure_fn: None,
1366 };
1367
1368 let rules =
1369 ValidationRules::new("cosmic_number").add_rule(ValidationRule::PatternGuard(guard));
1370
1371 assert!(rules.validate_numeric(&42.0f64).is_ok());
1373
1374 let err = rules.validate_numeric(&7.0f64).expect_err("7 is not 42");
1376 assert!(err.to_string().contains(expected_reason));
1377 }
1378}