Skip to main content

sklears_core/
validation.rs

1/// Comprehensive validation framework for machine learning parameters and data
2///
3/// This module provides a robust validation system with custom derive macros
4/// for automatic parameter validation in ML algorithms.
5use crate::error::{Result, SklearsError};
6use crate::types::{Array1, Array2, FloatBounds, Numeric};
7use scirs2_core::numeric::{Float, NumCast};
8use std::fmt::Debug;
9
10/// Type alias for validation guard function
11pub type ValidationGuardFn = Box<dyn Fn(&dyn std::any::Any) -> Result<bool> + Send + Sync>;
12
13/// Type alias for validation destructuring function
14pub type ValidationDestructureFn =
15    Box<dyn Fn(&dyn std::any::Any) -> Result<ValidationResult> + Send + Sync>;
16
17/// Core validation trait that can be derived for automatic parameter validation
18pub trait Validate {
19    /// Validate all parameters and return an error if any validation fails
20    fn validate(&self) -> Result<()>;
21
22    /// Validate and provide detailed error information  
23    fn validate_with_context(&self, context: &str) -> Result<()> {
24        self.validate()
25            .map_err(|e| SklearsError::Other(format!("{context}: {e}")))
26    }
27}
28
29/// Validation attributes for ML parameter constraints
30#[derive(Debug, Clone)]
31pub enum ValidationRule {
32    /// Value must be positive (> 0)
33    Positive,
34    /// Value must be non-negative (>= 0)
35    NonNegative,
36    /// Value must be finite (not NaN or infinity)
37    Finite,
38    /// Value must be in a specific range [min, max]
39    Range { min: f64, max: f64 },
40    /// Value must be one of the specified options
41    OneOf(Vec<String>),
42    /// Array must have minimum number of elements
43    MinLength(usize),
44    /// Array must have maximum number of elements
45    MaxLength(usize),
46    /// Array elements must be unique
47    UniqueElements,
48    /// Custom validation function
49    Custom(fn(&dyn std::any::Any) -> Result<()>),
50    /// Pattern guard validation with custom matching
51    PatternGuard(PatternGuardRule),
52}
53
54/// Pattern guard rule for advanced validation with custom matching
55pub struct PatternGuardRule {
56    /// Name of the pattern for error reporting
57    pub pattern_name: String,
58    /// Function that performs the pattern matching and validation
59    pub guard_fn: ValidationGuardFn,
60    /// Error message when the pattern doesn't match
61    pub error_message: String,
62    /// Optional structured destructuring validator
63    pub destructure_fn: Option<ValidationDestructureFn>,
64}
65
66impl std::fmt::Debug for PatternGuardRule {
67    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68        f.debug_struct("PatternGuardRule")
69            .field("pattern_name", &self.pattern_name)
70            .field("guard_fn", &"<function>")
71            .field("error_message", &self.error_message)
72            .field(
73                "destructure_fn",
74                &self.destructure_fn.as_ref().map(|_| "<function>"),
75            )
76            .finish()
77    }
78}
79
80impl Clone for PatternGuardRule {
81    fn clone(&self) -> Self {
82        // Create a simple clone that preserves the pattern name and error message
83        // but uses a default guard function
84        Self {
85            pattern_name: self.pattern_name.clone(),
86            guard_fn: Box::new(|_| Ok(true)), // Default to always pass
87            error_message: self.error_message.clone(),
88            destructure_fn: None,
89        }
90    }
91}
92
93/// Result of pattern matching and validation
94#[derive(Debug, Clone)]
95pub struct ValidationResult {
96    /// Whether the pattern matched
97    pub matched: bool,
98    /// Additional context from pattern matching
99    pub context: std::collections::HashMap<String, String>,
100    /// Any warnings generated during validation
101    pub warnings: Vec<String>,
102}
103
104/// Macro for creating pattern guards with custom validation logic
105#[macro_export]
106macro_rules! pattern_guard {
107    // Pattern guard for numeric types with range validation
108    (numeric_range, $min:expr, $max:expr) => {
109        $crate::validation::PatternGuardRule {
110            pattern_name: "numeric_range".to_string(),
111            guard_fn: Box::new(move |value| {
112                if let Some(val) = value.downcast_ref::<f64>() {
113                    Ok(*val >= $min && *val <= $max)
114                } else if let Some(val) = value.downcast_ref::<f32>() {
115                    Ok(*val >= $min as f32 && *val <= $max as f32)
116                } else if let Some(val) = value.downcast_ref::<i32>() {
117                    Ok(*val >= $min as i32 && *val <= $max as i32)
118                } else if let Some(val) = value.downcast_ref::<usize>() {
119                    Ok(*val >= $min as usize && *val <= $max as usize)
120                } else {
121                    Ok(false)
122                }
123            }),
124            error_message: format!("Value must be in range [{}, {}]", $min, $max),
125            destructure_fn: None,
126        }
127    };
128
129    // Pattern guard for array shape validation
130    (array_shape, $expected_shape:expr) => {
131        $crate::validation::PatternGuardRule {
132            pattern_name: "array_shape".to_string(),
133            guard_fn: Box::new(move |value| {
134                // This would need proper array type checking in real implementation
135                // For now, just return true as placeholder
136                Ok(true)
137            }),
138            error_message: format!("Array shape must match {:?}", $expected_shape),
139            destructure_fn: None,
140        }
141    };
142
143    // Pattern guard for string enum validation
144    (string_enum, $valid_options:expr) => {
145        $crate::validation::PatternGuardRule {
146            pattern_name: "string_enum".to_string(),
147            guard_fn: Box::new(move |value| {
148                if let Some(val) = value.downcast_ref::<String>() {
149                    Ok($valid_options.contains(&val.as_str()))
150                } else if let Some(val) = value.downcast_ref::<&str>() {
151                    Ok($valid_options.contains(val))
152                } else {
153                    Ok(false)
154                }
155            }),
156            error_message: format!("Value must be one of {:?}", $valid_options),
157            destructure_fn: None,
158        }
159    };
160
161    // Pattern guard with custom function and error message
162    ($pattern_name:literal, $guard:expr, $error_msg:literal) => {
163        $crate::validation::PatternGuardRule {
164            pattern_name: $pattern_name.to_string(),
165            guard_fn: Box::new($guard),
166            error_message: $error_msg.to_string(),
167            destructure_fn: None,
168        }
169    };
170
171    // Pattern guard with destructuring validation
172    ($pattern_name:literal, $guard_fn:expr, $destructure_fn:expr) => {
173        $crate::validation::PatternGuardRule {
174            pattern_name: $pattern_name.to_string(),
175            guard_fn: Box::new($guard_fn),
176            error_message: format!("Pattern '{}' validation failed", $pattern_name),
177            destructure_fn: Some(Box::new($destructure_fn)),
178        }
179    };
180}
181
182/// Trait for types that can be pattern matched and validated
183pub trait PatternValidate {
184    /// Apply pattern guard validation
185    fn validate_with_pattern(&self, guard: &PatternGuardRule) -> Result<ValidationResult>;
186
187    /// Check if value matches a specific pattern
188    fn matches_pattern(&self, pattern_name: &str) -> bool;
189
190    /// Extract structured data using pattern destructuring
191    fn destructure(&self, pattern: &str) -> Result<std::collections::HashMap<String, String>>;
192}
193
194/// Implementation of PatternValidate for f64
195impl PatternValidate for f64 {
196    fn validate_with_pattern(&self, guard: &PatternGuardRule) -> Result<ValidationResult> {
197        let value_any = self as &dyn std::any::Any;
198        let matched = (guard.guard_fn)(value_any)?;
199
200        let mut context = std::collections::HashMap::new();
201        context.insert("value".to_string(), self.to_string());
202        context.insert("type".to_string(), "f64".to_string());
203
204        if let Some(destructure_fn) = &guard.destructure_fn {
205            let destructure_result = destructure_fn(value_any)?;
206            Ok(ValidationResult {
207                matched: matched && destructure_result.matched,
208                context: destructure_result.context,
209                warnings: destructure_result.warnings,
210            })
211        } else {
212            Ok(ValidationResult {
213                matched,
214                context,
215                warnings: Vec::new(),
216            })
217        }
218    }
219
220    fn matches_pattern(&self, pattern_name: &str) -> bool {
221        match pattern_name {
222            "finite" => self.is_finite(),
223            "positive" => *self > 0.0,
224            "non_negative" => *self >= 0.0,
225            "probability" => *self >= 0.0 && *self <= 1.0,
226            _ => false,
227        }
228    }
229
230    fn destructure(&self, pattern: &str) -> Result<std::collections::HashMap<String, String>> {
231        let mut result = std::collections::HashMap::new();
232        match pattern {
233            "range_info" => {
234                result.insert("value".to_string(), self.to_string());
235                result.insert("is_finite".to_string(), self.is_finite().to_string());
236                result.insert("is_positive".to_string(), (*self > 0.0).to_string());
237                result.insert(
238                    "sign".to_string(),
239                    if *self >= 0.0 {
240                        "positive".to_string()
241                    } else {
242                        "negative".to_string()
243                    },
244                );
245            }
246            _ => {
247                result.insert("value".to_string(), self.to_string());
248            }
249        }
250        Ok(result)
251    }
252}
253
254/// Implementation of PatternValidate for usize
255impl PatternValidate for usize {
256    fn validate_with_pattern(&self, guard: &PatternGuardRule) -> Result<ValidationResult> {
257        let value_any = self as &dyn std::any::Any;
258        let matched = (guard.guard_fn)(value_any)?;
259
260        let mut context = std::collections::HashMap::new();
261        context.insert("value".to_string(), self.to_string());
262        context.insert("type".to_string(), "usize".to_string());
263
264        Ok(ValidationResult {
265            matched,
266            context,
267            warnings: Vec::new(),
268        })
269    }
270
271    fn matches_pattern(&self, pattern_name: &str) -> bool {
272        match pattern_name {
273            "positive" => *self > 0,
274            "non_negative" => true, // usize is always non-negative
275            "power_of_two" => self.is_power_of_two(),
276            _ => false,
277        }
278    }
279
280    fn destructure(&self, pattern: &str) -> Result<std::collections::HashMap<String, String>> {
281        let mut result = std::collections::HashMap::new();
282        match pattern {
283            "number_info" => {
284                result.insert("value".to_string(), self.to_string());
285                result.insert("is_positive".to_string(), (*self > 0).to_string());
286                result.insert(
287                    "is_power_of_two".to_string(),
288                    self.is_power_of_two().to_string(),
289                );
290            }
291            _ => {
292                result.insert("value".to_string(), self.to_string());
293            }
294        }
295        Ok(result)
296    }
297}
298
299/// Implementation of PatternValidate for String
300impl PatternValidate for String {
301    fn validate_with_pattern(&self, guard: &PatternGuardRule) -> Result<ValidationResult> {
302        let value_any = self as &dyn std::any::Any;
303        let matched = (guard.guard_fn)(value_any)?;
304
305        let mut context = std::collections::HashMap::new();
306        context.insert("value".to_string(), self.clone());
307        context.insert("type".to_string(), "String".to_string());
308        context.insert("length".to_string(), self.len().to_string());
309
310        Ok(ValidationResult {
311            matched,
312            context,
313            warnings: Vec::new(),
314        })
315    }
316
317    fn matches_pattern(&self, pattern_name: &str) -> bool {
318        match pattern_name {
319            "non_empty" => !self.is_empty(),
320            "alphanumeric" => self.chars().all(|c| c.is_alphanumeric()),
321            "lowercase" => self.chars().all(|c| !c.is_alphabetic() || c.is_lowercase()),
322            "uppercase" => self.chars().all(|c| !c.is_alphabetic() || c.is_uppercase()),
323            _ => false,
324        }
325    }
326
327    fn destructure(&self, pattern: &str) -> Result<std::collections::HashMap<String, String>> {
328        let mut result = std::collections::HashMap::new();
329        match pattern {
330            "string_info" => {
331                result.insert("value".to_string(), self.clone());
332                result.insert("length".to_string(), self.len().to_string());
333                result.insert("is_empty".to_string(), self.is_empty().to_string());
334                result.insert(
335                    "is_alphanumeric".to_string(),
336                    self.chars().all(|c| c.is_alphanumeric()).to_string(),
337                );
338            }
339            _ => {
340                result.insert("value".to_string(), self.clone());
341            }
342        }
343        Ok(result)
344    }
345}
346
347/// Default implementations for common pattern validation scenarios
348pub mod pattern_guards {
349    use super::*;
350
351    /// Pattern guard for ML model hyperparameters
352    pub fn hyperparameter_pattern<T: FloatBounds + std::fmt::Debug>(
353        min_val: T,
354        max_val: T,
355        finite_required: bool,
356    ) -> PatternGuardRule {
357        PatternGuardRule {
358            pattern_name: "hyperparameter_bounds".to_string(),
359            guard_fn: Box::new(|_value| {
360                // In real implementation, would need proper type casting
361                Ok(true) // Placeholder
362            }),
363            error_message: format!(
364                "Hyperparameter must be in range [{}, {}]{}",
365                min_val,
366                max_val,
367                if finite_required { " and finite" } else { "" }
368            ),
369            destructure_fn: None,
370        }
371    }
372
373    /// Pattern guard for array shape validation
374    pub fn array_shape_pattern(expected_dims: &[usize]) -> PatternGuardRule {
375        let dims_str = expected_dims
376            .iter()
377            .map(|d| d.to_string())
378            .collect::<Vec<_>>()
379            .join(", ");
380
381        PatternGuardRule {
382            pattern_name: "array_shape".to_string(),
383            guard_fn: Box::new(|_value| {
384                // Would validate array shape in real implementation
385                Ok(true)
386            }),
387            error_message: format!("Array shape must match [{dims_str}]"),
388            destructure_fn: None, // Remove capturing closure for now
389        }
390    }
391
392    /// Pattern guard for ML algorithm configuration
393    pub fn algorithm_config_pattern(required_fields: &[&str]) -> PatternGuardRule {
394        let fields_str = required_fields.join(", ");
395
396        PatternGuardRule {
397            pattern_name: "algorithm_config".to_string(),
398            guard_fn: Box::new(|_value| {
399                // Would validate configuration completeness
400                Ok(true)
401            }),
402            error_message: format!("Configuration must contain fields: {fields_str}"),
403            destructure_fn: None, // Remove capturing closure for now
404        }
405    }
406
407    /// Pattern guard for data type consistency
408    pub fn data_type_pattern(expected_types: &[&str]) -> PatternGuardRule {
409        let types_str = expected_types.join(" | ");
410
411        PatternGuardRule {
412            pattern_name: "data_type_consistency".to_string(),
413            guard_fn: Box::new(|_value| {
414                // Would validate data type consistency
415                Ok(true)
416            }),
417            error_message: format!("Data type must be one of: {types_str}"),
418            destructure_fn: None,
419        }
420    }
421}
422
423/// Container for multiple validation rules
424#[derive(Debug, Clone)]
425pub struct ValidationRules {
426    pub rules: Vec<ValidationRule>,
427    pub field_name: String,
428}
429
430impl ValidationRules {
431    /// Create a new validation rules container
432    pub fn new(field_name: &str) -> Self {
433        Self {
434            rules: Vec::new(),
435            field_name: field_name.to_string(),
436        }
437    }
438
439    /// Add a validation rule
440    pub fn add_rule(mut self, rule: ValidationRule) -> Self {
441        self.rules.push(rule);
442        self
443    }
444
445    /// Validate a numeric value against all rules
446    pub fn validate_numeric<T>(&self, value: &T) -> Result<()>
447    where
448        T: Numeric + PartialOrd + Debug + Copy + NumCast,
449    {
450        for rule in &self.rules {
451            match rule {
452                ValidationRule::Positive if *value <= T::zero() => {
453                    return Err(SklearsError::InvalidParameter {
454                        name: self.field_name.clone(),
455                        reason: "must be positive".to_string(),
456                    });
457                }
458                ValidationRule::Positive => {}
459                ValidationRule::NonNegative if *value < T::zero() => {
460                    return Err(SklearsError::InvalidParameter {
461                        name: self.field_name.clone(),
462                        reason: "must be non-negative".to_string(),
463                    });
464                }
465                ValidationRule::NonNegative => {}
466                ValidationRule::Finite => {
467                    if let Some(float_val) = NumCast::from(*value) {
468                        let f: f64 = float_val;
469                        if !f.is_finite() {
470                            return Err(SklearsError::InvalidParameter {
471                                name: self.field_name.clone(),
472                                reason: "must be finite".to_string(),
473                            });
474                        }
475                    }
476                }
477                ValidationRule::Range { min, max } => {
478                    if let Some(float_val) = NumCast::from(*value) {
479                        let f: f64 = float_val;
480                        if f < *min || f > *max {
481                            return Err(SklearsError::InvalidParameter {
482                                name: self.field_name.clone(),
483                                reason: format!("must be in range [{min}, {max}]"),
484                            });
485                        }
486                    }
487                }
488                ValidationRule::PatternGuard(pattern_guard) => {
489                    // Cast the numeric value to Any so the guard function can inspect it.
490                    // The guard_fn is Box<dyn Fn(&dyn Any) -> Result<bool>> which takes
491                    // a shared reference; the local `value` binding is the concrete `T`,
492                    // so we take a reference to it and coerce through &dyn Any.
493                    let value_any: &dyn std::any::Any = value;
494                    let passes = (pattern_guard.guard_fn)(value_any)?;
495                    if !passes {
496                        return Err(SklearsError::InvalidParameter {
497                            name: self.field_name.clone(),
498                            reason: pattern_guard.error_message.clone(),
499                        });
500                    }
501                }
502                _ => {
503                    // Skip rules that don't apply to numeric values
504                }
505            }
506        }
507        Ok(())
508    }
509
510    /// Validate a string value against all rules
511    pub fn validate_string(&self, value: &str) -> Result<()> {
512        for rule in &self.rules {
513            match rule {
514                ValidationRule::OneOf(options) if !options.contains(&value.to_string()) => {
515                    return Err(SklearsError::InvalidParameter {
516                        name: self.field_name.clone(),
517                        reason: format!("must be one of {options:?}"),
518                    });
519                }
520                ValidationRule::OneOf(_) => {}
521                ValidationRule::PatternGuard(pattern_guard) => {
522                    // Convert to owned String so it is 'static and can be cast to &dyn Any.
523                    let owned: String = value.to_owned();
524                    let value_any: &dyn std::any::Any = &owned;
525                    let passes = (pattern_guard.guard_fn)(value_any)?;
526                    if !passes {
527                        return Err(SklearsError::InvalidParameter {
528                            name: self.field_name.clone(),
529                            reason: pattern_guard.error_message.clone(),
530                        });
531                    }
532                }
533                _ => {
534                    // Skip rules that don't apply to string values
535                }
536            }
537        }
538        Ok(())
539    }
540
541    /// Validate an array/vector against all rules
542    pub fn validate_array<T>(&self, value: &[T]) -> Result<()> {
543        for rule in &self.rules {
544            match rule {
545                ValidationRule::MinLength(min_len) if value.len() < *min_len => {
546                    return Err(SklearsError::InvalidParameter {
547                        name: self.field_name.clone(),
548                        reason: format!("must have at least {min_len} elements"),
549                    });
550                }
551                ValidationRule::MinLength(_) => {}
552                ValidationRule::MaxLength(max_len) if value.len() > *max_len => {
553                    return Err(SklearsError::InvalidParameter {
554                        name: self.field_name.clone(),
555                        reason: format!("must have at most {max_len} elements"),
556                    });
557                }
558                ValidationRule::MaxLength(_) => {}
559                ValidationRule::PatternGuard(pattern_guard) => {
560                    // Pass the length as a 'static usize so it can be cast to &dyn Any.
561                    let len: usize = value.len();
562                    let value_any: &dyn std::any::Any = &len;
563                    let passes = (pattern_guard.guard_fn)(value_any)?;
564                    if !passes {
565                        return Err(SklearsError::InvalidParameter {
566                            name: self.field_name.clone(),
567                            reason: pattern_guard.error_message.clone(),
568                        });
569                    }
570                }
571                _ => {
572                    // Skip rules that don't apply to arrays
573                }
574            }
575        }
576        Ok(())
577    }
578
579    /// Validate an unsigned integer value (usize) against all rules
580    pub fn validate_usize(&self, value: &usize) -> Result<()> {
581        for rule in &self.rules {
582            match rule {
583                ValidationRule::Positive if *value == 0 => {
584                    return Err(SklearsError::InvalidParameter {
585                        name: self.field_name.clone(),
586                        reason: "must be positive".to_string(),
587                    });
588                }
589                ValidationRule::Positive => {}
590                ValidationRule::NonNegative => {
591                    // usize is always non-negative, so this always passes
592                }
593                ValidationRule::Range { min, max } => {
594                    let val = *value as f64;
595                    if val < *min || val > *max {
596                        return Err(SklearsError::InvalidParameter {
597                            name: self.field_name.clone(),
598                            reason: format!("must be in range [{min}, {max}]"),
599                        });
600                    }
601                }
602                _ => {
603                    // Skip rules that don't apply to usize values
604                }
605            }
606        }
607        Ok(())
608    }
609}
610
611/// ML-specific validation functions
612pub mod ml {
613    use super::*;
614
615    /// Validate learning rate (must be positive and typically < 1.0)
616    pub fn validate_learning_rate<T: FloatBounds>(lr: T) -> Result<()> {
617        if lr <= T::zero() {
618            return Err(SklearsError::InvalidParameter {
619                name: "learning_rate".to_string(),
620                reason: "must be positive".to_string(),
621            });
622        }
623
624        if !Float::is_finite(lr) {
625            return Err(SklearsError::InvalidParameter {
626                name: "learning_rate".to_string(),
627                reason: "must be finite".to_string(),
628            });
629        }
630
631        // Warning for unusually high learning rates
632        if lr > T::one() {
633            log::warn!("Learning rate {lr} is unusually high, consider using a smaller value");
634        }
635
636        Ok(())
637    }
638
639    /// Validate regularization parameter (must be non-negative)
640    pub fn validate_regularization<T: FloatBounds>(reg: T) -> Result<()> {
641        if reg < T::zero() {
642            return Err(SklearsError::InvalidParameter {
643                name: "regularization".to_string(),
644                reason: "must be non-negative".to_string(),
645            });
646        }
647
648        if !Float::is_finite(reg) {
649            return Err(SklearsError::InvalidParameter {
650                name: "regularization".to_string(),
651                reason: "must be finite".to_string(),
652            });
653        }
654
655        Ok(())
656    }
657
658    /// Validate number of clusters (must be positive integer)
659    pub fn validate_n_clusters(n_clusters: usize, n_samples: usize) -> Result<()> {
660        if n_clusters == 0 {
661            return Err(SklearsError::InvalidParameter {
662                name: "n_clusters".to_string(),
663                reason: "must be positive".to_string(),
664            });
665        }
666
667        if n_clusters > n_samples {
668            return Err(SklearsError::InvalidParameter {
669                name: "n_clusters".to_string(),
670                reason: format!("cannot exceed number of samples ({n_samples})"),
671            });
672        }
673
674        Ok(())
675    }
676
677    /// Validate number of neighbors for KNN (must be positive and <= n_samples)
678    pub fn validate_n_neighbors(n_neighbors: usize, n_samples: usize) -> Result<()> {
679        if n_neighbors == 0 {
680            return Err(SklearsError::InvalidParameter {
681                name: "n_neighbors".to_string(),
682                reason: "must be positive".to_string(),
683            });
684        }
685
686        if n_neighbors > n_samples {
687            return Err(SklearsError::InvalidParameter {
688                name: "n_neighbors".to_string(),
689                reason: format!("cannot exceed number of samples ({n_samples})"),
690            });
691        }
692
693        Ok(())
694    }
695
696    /// Validate tolerance parameter (must be positive and small)
697    pub fn validate_tolerance<T: FloatBounds>(tol: T) -> Result<()> {
698        if tol <= T::zero() {
699            return Err(SklearsError::InvalidParameter {
700                name: "tolerance".to_string(),
701                reason: "must be positive".to_string(),
702            });
703        }
704
705        if !Float::is_finite(tol) {
706            return Err(SklearsError::InvalidParameter {
707                name: "tolerance".to_string(),
708                reason: "must be finite".to_string(),
709            });
710        }
711
712        // Warning for very large tolerances
713        if tol > T::from(0.1).unwrap_or(T::one()) {
714            log::warn!("Tolerance {tol} is very large, algorithm may converge prematurely");
715        }
716
717        Ok(())
718    }
719
720    /// Validate max iterations (must be positive)
721    pub fn validate_max_iter(max_iter: usize) -> Result<()> {
722        if max_iter == 0 {
723            return Err(SklearsError::InvalidParameter {
724                name: "max_iter".to_string(),
725                reason: "must be positive".to_string(),
726            });
727        }
728
729        Ok(())
730    }
731
732    /// Validate probability values (must be in [0, 1])
733    pub fn validate_probability<T: FloatBounds>(prob: T) -> Result<()> {
734        if prob < T::zero() || prob > T::one() {
735            return Err(SklearsError::InvalidParameter {
736                name: "probability".to_string(),
737                reason: "must be in range [0, 1]".to_string(),
738            });
739        }
740
741        if !Float::is_finite(prob) {
742            return Err(SklearsError::InvalidParameter {
743                name: "probability".to_string(),
744                reason: "must be finite".to_string(),
745            });
746        }
747
748        Ok(())
749    }
750
751    /// Validate data shapes for supervised learning
752    pub fn validate_supervised_data<T, U>(x: &Array2<T>, y: &Array1<U>) -> Result<()> {
753        if x.is_empty() {
754            return Err(SklearsError::InvalidData {
755                reason: "X cannot be empty".to_string(),
756            });
757        }
758
759        if y.is_empty() {
760            return Err(SklearsError::InvalidData {
761                reason: "y cannot be empty".to_string(),
762            });
763        }
764
765        if x.nrows() != y.len() {
766            return Err(SklearsError::ShapeMismatch {
767                expected: "X.shape[0] == y.shape[0]".to_string(),
768                actual: format!("X.shape[0]={}, y.shape[0]={}", x.nrows(), y.len()),
769            });
770        }
771
772        Ok(())
773    }
774
775    /// Validate data for unsupervised learning
776    pub fn validate_unsupervised_data<T>(x: &Array2<T>) -> Result<()> {
777        if x.is_empty() {
778            return Err(SklearsError::InvalidData {
779                reason: "X cannot be empty".to_string(),
780            });
781        }
782
783        if x.nrows() == 0 || x.ncols() == 0 {
784            return Err(SklearsError::InvalidData {
785                reason: "X must have positive dimensions".to_string(),
786            });
787        }
788
789        Ok(())
790    }
791
792    /// Validate feature consistency between training and prediction
793    pub fn validate_feature_consistency<T, U>(
794        x_train: &Array2<T>,
795        x_test: &Array2<U>,
796        _model_name: &str,
797    ) -> Result<()> {
798        if x_train.ncols() != x_test.ncols() {
799            return Err(SklearsError::FeatureMismatch {
800                expected: x_train.ncols(),
801                actual: x_test.ncols(),
802            });
803        }
804
805        Ok(())
806    }
807}
808
809/// Proc macro helper functions for derive implementation
810pub mod derive_helpers {
811    /// Generate validation code for a field with validation attributes
812    pub fn generate_field_validation(
813        field_name: &str,
814        _field_type: &str,
815        validation_attrs: &[String],
816    ) -> String {
817        let mut validations = Vec::new();
818
819        for attr in validation_attrs {
820            match attr.as_str() {
821                "positive" => {
822                    validations.push(format!(
823                        "ValidationRules::new(\"{field_name}\").add_rule(ValidationRule::Positive).validate_numeric(&self.{field_name})?;"
824                    ));
825                }
826                "non_negative" => {
827                    validations.push(format!(
828                        "ValidationRules::new(\"{field_name}\").add_rule(ValidationRule::NonNegative).validate_numeric(&self.{field_name})?;"
829                    ));
830                }
831                "finite" => {
832                    validations.push(format!(
833                        "ValidationRules::new(\"{field_name}\").add_rule(ValidationRule::Finite).validate_numeric(&self.{field_name})?;"
834                    ));
835                }
836                _ if attr.starts_with("range(") => {
837                    // Parse range(min, max) format
838                    let range_str = attr
839                        .strip_prefix("range(")
840                        .expect("expected valid value")
841                        .strip_suffix(")")
842                        .expect("expected valid value");
843                    let parts: Vec<&str> = range_str.split(',').map(|s| s.trim()).collect();
844                    if parts.len() == 2 {
845                        let min_val = parts[0];
846                        let max_val = parts[1];
847                        validations.push(format!(
848                            "ValidationRules::new(\"{field_name}\").add_rule(ValidationRule::Range {{ min: {min_val}, max: {max_val} }}).validate_numeric(&self.{field_name})?;"
849                        ));
850                    }
851                }
852                _ => {}
853            }
854        }
855
856        validations.join("\n")
857    }
858}
859
860/// Configuration validation for complete ML algorithms
861pub trait ConfigValidation {
862    /// Validate the entire configuration
863    fn validate_config(&self) -> Result<()>;
864
865    /// Get validation warnings (non-fatal issues)
866    fn get_warnings(&self) -> Vec<String> {
867        Vec::new()
868    }
869}
870
871/// Validation context for providing better error messages
872#[derive(Debug, Clone)]
873pub struct ValidationContext {
874    pub algorithm: String,
875    pub operation: String,
876    pub data_info: Option<DataInfo>,
877}
878
879/// Information about the data being validated
880#[derive(Debug, Clone)]
881pub struct DataInfo {
882    pub n_samples: usize,
883    pub n_features: usize,
884    pub data_type: String,
885}
886
887impl ValidationContext {
888    /// Create a new validation context
889    pub fn new(algorithm: &str, operation: &str) -> Self {
890        Self {
891            algorithm: algorithm.to_string(),
892            operation: operation.to_string(),
893            data_info: None,
894        }
895    }
896
897    /// Add data information to the context
898    pub fn with_data_info(mut self, n_samples: usize, n_features: usize, data_type: &str) -> Self {
899        self.data_info = Some(DataInfo {
900            n_samples,
901            n_features,
902            data_type: data_type.to_string(),
903        });
904        self
905    }
906
907    /// Format error with context information
908    pub fn format_error(&self, error: &SklearsError) -> String {
909        let mut msg = format!(
910            "Error in {} during {}: {error}",
911            self.algorithm, self.operation
912        );
913
914        if let Some(data_info) = &self.data_info {
915            msg.push_str(&format!(
916                " (data: {} samples, {} features, type: {})",
917                data_info.n_samples, data_info.n_features, data_info.data_type
918            ));
919        }
920
921        msg
922    }
923}
924
925/// Structured destructuring for complex data types
926pub mod structured_destructuring {
927    use super::*;
928
929    /// Trait for types that support structured destructuring
930    pub trait StructuredDestructure {
931        /// Destructure into named components
932        fn destructure_into_components(
933            &self,
934        ) -> Result<std::collections::HashMap<String, Box<dyn std::any::Any>>>;
935
936        /// Extract specific fields by path (e.g., "user.address.city")
937        fn extract_field(&self, field_path: &str) -> Result<Box<dyn std::any::Any>>;
938
939        /// Validate structure matches expected schema
940        fn validate_structure(&self, schema: &StructuralSchema) -> Result<ValidationResult>;
941    }
942
943    /// Schema for validating complex data structures
944    #[derive(Debug, Clone, Default)]
945    pub struct StructuralSchema {
946        pub required_fields: Vec<String>,
947        pub optional_fields: Vec<String>,
948        pub field_types: std::collections::HashMap<String, String>,
949        pub nested_schemas: std::collections::HashMap<String, StructuralSchema>,
950    }
951
952    impl StructuralSchema {
953        pub fn new() -> Self {
954            Self::default()
955        }
956
957        pub fn require_field(mut self, field_name: &str, field_type: &str) -> Self {
958            self.required_fields.push(field_name.to_string());
959            self.field_types
960                .insert(field_name.to_string(), field_type.to_string());
961            self
962        }
963
964        pub fn optional_field(mut self, field_name: &str, field_type: &str) -> Self {
965            self.optional_fields.push(field_name.to_string());
966            self.field_types
967                .insert(field_name.to_string(), field_type.to_string());
968            self
969        }
970
971        pub fn nested_schema(mut self, field_name: &str, schema: StructuralSchema) -> Self {
972            self.nested_schemas.insert(field_name.to_string(), schema);
973            self
974        }
975    }
976
977    /// Configuration for ML algorithms with structured validation
978    #[derive(Debug, Clone)]
979    pub struct AlgorithmConfig {
980        pub algorithm_name: String,
981        pub hyperparameters: std::collections::HashMap<String, ConfigValue>,
982        pub metadata: std::collections::HashMap<String, String>,
983    }
984
985    /// Values that can be stored in configuration
986    #[derive(Debug, Clone)]
987    pub enum ConfigValue {
988        Float(f64),
989        Integer(i64),
990        String(String),
991        Boolean(bool),
992        Array(Vec<ConfigValue>),
993        Object(std::collections::HashMap<String, ConfigValue>),
994    }
995
996    impl StructuredDestructure for AlgorithmConfig {
997        fn destructure_into_components(
998            &self,
999        ) -> Result<std::collections::HashMap<String, Box<dyn std::any::Any>>> {
1000            let mut components = std::collections::HashMap::new();
1001
1002            components.insert(
1003                "algorithm_name".to_string(),
1004                Box::new(self.algorithm_name.clone()) as Box<dyn std::any::Any>,
1005            );
1006            components.insert(
1007                "hyperparameters".to_string(),
1008                Box::new(self.hyperparameters.clone()) as Box<dyn std::any::Any>,
1009            );
1010            components.insert(
1011                "metadata".to_string(),
1012                Box::new(self.metadata.clone()) as Box<dyn std::any::Any>,
1013            );
1014
1015            Ok(components)
1016        }
1017
1018        fn extract_field(&self, field_path: &str) -> Result<Box<dyn std::any::Any>> {
1019            let parts: Vec<&str> = field_path.split('.').collect();
1020
1021            match parts.first() {
1022                Some(&"algorithm_name") => Ok(Box::new(self.algorithm_name.clone())),
1023                Some(&"hyperparameters") => {
1024                    if parts.len() > 1 {
1025                        if let Some(param_value) = self.hyperparameters.get(parts[1]) {
1026                            Ok(Box::new(param_value.clone()))
1027                        } else {
1028                            Err(SklearsError::InvalidParameter {
1029                                name: field_path.to_string(),
1030                                reason: format!("Hyperparameter '{}' not found", parts[1]),
1031                            })
1032                        }
1033                    } else {
1034                        Ok(Box::new(self.hyperparameters.clone()))
1035                    }
1036                }
1037                Some(&"metadata") => {
1038                    if parts.len() > 1 {
1039                        if let Some(meta_value) = self.metadata.get(parts[1]) {
1040                            Ok(Box::new(meta_value.clone()))
1041                        } else {
1042                            Err(SklearsError::InvalidParameter {
1043                                name: field_path.to_string(),
1044                                reason: format!("Metadata '{}' not found", parts[1]),
1045                            })
1046                        }
1047                    } else {
1048                        Ok(Box::new(self.metadata.clone()))
1049                    }
1050                }
1051                _ => Err(SklearsError::InvalidParameter {
1052                    name: field_path.to_string(),
1053                    reason: "Invalid field path".to_string(),
1054                }),
1055            }
1056        }
1057
1058        fn validate_structure(&self, schema: &StructuralSchema) -> Result<ValidationResult> {
1059            let mut warnings = Vec::new();
1060            let mut context = std::collections::HashMap::new();
1061
1062            // Check required fields
1063            for required_field in &schema.required_fields {
1064                match required_field.as_str() {
1065                    "algorithm_name" => {
1066                        if self.algorithm_name.is_empty() {
1067                            return Err(SklearsError::InvalidParameter {
1068                                name: "algorithm_name".to_string(),
1069                                reason: "Required field cannot be empty".to_string(),
1070                            });
1071                        }
1072                        context.insert("algorithm_name".to_string(), "present".to_string());
1073                    }
1074                    "hyperparameters" => {
1075                        context.insert(
1076                            "hyperparameters_count".to_string(),
1077                            self.hyperparameters.len().to_string(),
1078                        );
1079                    }
1080                    _ => {
1081                        warnings.push(format!("Unknown required field: {required_field}"));
1082                    }
1083                }
1084            }
1085
1086            Ok(ValidationResult {
1087                matched: true,
1088                context,
1089                warnings,
1090            })
1091        }
1092    }
1093
1094    /// Pattern matching for complex validation scenarios
1095    pub fn create_complex_pattern_guard<T>(
1096        pattern_name: &str,
1097        validator: impl Fn(&T) -> Result<bool> + Send + Sync + 'static,
1098        error_message: &str,
1099    ) -> PatternGuardRule
1100    where
1101        T: 'static,
1102    {
1103        PatternGuardRule {
1104            pattern_name: pattern_name.to_string(),
1105            guard_fn: Box::new(move |value| {
1106                if let Some(typed_value) = value.downcast_ref::<T>() {
1107                    validator(typed_value)
1108                } else {
1109                    Ok(false)
1110                }
1111            }),
1112            error_message: error_message.to_string(),
1113            destructure_fn: None,
1114        }
1115    }
1116}
1117
1118/// Macro for easy destructuring of complex types
1119#[macro_export]
1120macro_rules! destructure {
1121    // Basic field extraction
1122    ($obj:expr, $field:literal) => {
1123        $obj.extract_field($field)
1124    };
1125
1126    // Multiple field extraction
1127    ($obj:expr, { $($field:literal),* }) => {
1128        {
1129            let mut results = std::collections::HashMap::new();
1130            $(
1131                if let Ok(value) = $obj.extract_field($field) {
1132                    results.insert($field.to_string(), value);
1133                }
1134            )*
1135            results
1136        }
1137    };
1138
1139    // Destructuring with validation
1140    ($obj:expr, validate: $schema:expr) => {
1141        $obj.validate_structure(&$schema)
1142    };
1143}
1144
1145#[allow(non_snake_case)]
1146#[cfg(test)]
1147mod tests {
1148    use super::*;
1149
1150    #[test]
1151    fn test_validation_rules_numeric() {
1152        let rules = ValidationRules::new("test_param")
1153            .add_rule(ValidationRule::Positive)
1154            .add_rule(ValidationRule::Finite);
1155
1156        // Valid value
1157        assert!(rules.validate_numeric(&1.5f64).is_ok());
1158
1159        // Invalid: non-positive
1160        assert!(rules.validate_numeric(&0.0f64).is_err());
1161        assert!(rules.validate_numeric(&-1.0f64).is_err());
1162
1163        // Invalid: non-finite
1164        assert!(rules.validate_numeric(&f64::NAN).is_err());
1165        assert!(rules.validate_numeric(&f64::INFINITY).is_err());
1166    }
1167
1168    #[test]
1169    fn test_validation_rules_range() {
1170        let rules = ValidationRules::new("test_param")
1171            .add_rule(ValidationRule::Range { min: 0.0, max: 1.0 });
1172
1173        // Valid values
1174        assert!(rules.validate_numeric(&0.5f64).is_ok());
1175        assert!(rules.validate_numeric(&0.0f64).is_ok());
1176        assert!(rules.validate_numeric(&1.0f64).is_ok());
1177
1178        // Invalid values
1179        assert!(rules.validate_numeric(&-0.1f64).is_err());
1180        assert!(rules.validate_numeric(&1.1f64).is_err());
1181    }
1182
1183    #[test]
1184    fn test_validation_rules_string() {
1185        let rules = ValidationRules::new("test_param").add_rule(ValidationRule::OneOf(vec![
1186            "option1".to_string(),
1187            "option2".to_string(),
1188        ]));
1189
1190        // Valid values
1191        assert!(rules.validate_string("option1").is_ok());
1192        assert!(rules.validate_string("option2").is_ok());
1193
1194        // Invalid value
1195        assert!(rules.validate_string("option3").is_err());
1196    }
1197
1198    #[test]
1199    fn test_validation_rules_array() {
1200        let rules = ValidationRules::new("test_param")
1201            .add_rule(ValidationRule::MinLength(2))
1202            .add_rule(ValidationRule::MaxLength(5));
1203
1204        // Valid arrays
1205        assert!(rules.validate_array(&[1, 2]).is_ok());
1206        assert!(rules.validate_array(&[1, 2, 3, 4, 5]).is_ok());
1207
1208        // Invalid: too short
1209        assert!(rules.validate_array(&[1]).is_err());
1210
1211        // Invalid: too long
1212        assert!(rules.validate_array(&[1, 2, 3, 4, 5, 6]).is_err());
1213    }
1214
1215    #[test]
1216    fn test_ml_validation_learning_rate() {
1217        // Valid learning rates
1218        assert!(ml::validate_learning_rate(0.01f64).is_ok());
1219        assert!(ml::validate_learning_rate(0.1f64).is_ok());
1220
1221        // Invalid: non-positive
1222        assert!(ml::validate_learning_rate(0.0f64).is_err());
1223        assert!(ml::validate_learning_rate(-0.1f64).is_err());
1224
1225        // Invalid: non-finite
1226        assert!(ml::validate_learning_rate(f64::NAN).is_err());
1227    }
1228
1229    #[test]
1230    fn test_ml_validation_n_clusters() {
1231        // Valid
1232        assert!(ml::validate_n_clusters(3, 10).is_ok());
1233        assert!(ml::validate_n_clusters(10, 10).is_ok());
1234
1235        // Invalid: zero clusters
1236        assert!(ml::validate_n_clusters(0, 10).is_err());
1237
1238        // Invalid: more clusters than samples
1239        assert!(ml::validate_n_clusters(15, 10).is_err());
1240    }
1241
1242    #[test]
1243    fn test_ml_validation_probability() {
1244        // Valid probabilities
1245        assert!(ml::validate_probability(0.0f64).is_ok());
1246        assert!(ml::validate_probability(0.5f64).is_ok());
1247        assert!(ml::validate_probability(1.0f64).is_ok());
1248
1249        // Invalid: out of range
1250        assert!(ml::validate_probability(-0.1f64).is_err());
1251        assert!(ml::validate_probability(1.1f64).is_err());
1252
1253        // Invalid: non-finite
1254        assert!(ml::validate_probability(f64::NAN).is_err());
1255    }
1256
1257    #[test]
1258    fn test_validation_context() {
1259        let context = ValidationContext::new("KMeans", "fit").with_data_info(100, 5, "float64");
1260
1261        let error = SklearsError::InvalidParameter {
1262            name: "n_clusters".to_string(),
1263            reason: "must be positive".to_string(),
1264        };
1265
1266        let formatted = context.format_error(&error);
1267        assert!(formatted.contains("KMeans"));
1268        assert!(formatted.contains("fit"));
1269        assert!(formatted.contains("100 samples"));
1270        assert!(formatted.contains("5 features"));
1271    }
1272
1273    // -----------------------------------------------------------------
1274    // PatternGuard validation tests (previously the commented-out TODO)
1275    // -----------------------------------------------------------------
1276
1277    #[test]
1278    fn test_pattern_guard_numeric_passes() {
1279        // Guard that only accepts even numbers (using i64 casting inside Any).
1280        let guard = PatternGuardRule {
1281            pattern_name: "even_number".to_string(),
1282            guard_fn: Box::new(|value: &dyn std::any::Any| {
1283                if let Some(v) = value.downcast_ref::<f64>() {
1284                    Ok(*v as i64 % 2 == 0)
1285                } else {
1286                    Ok(false)
1287                }
1288            }),
1289            error_message: "must be an even number".to_string(),
1290            destructure_fn: None,
1291        };
1292
1293        let rules =
1294            ValidationRules::new("even_param").add_rule(ValidationRule::PatternGuard(guard));
1295
1296        // Even number passes.
1297        assert!(rules.validate_numeric(&4.0f64).is_ok());
1298        // Odd number fails.
1299        assert!(rules.validate_numeric(&3.0f64).is_err());
1300    }
1301
1302    #[test]
1303    fn test_pattern_guard_string_passes() {
1304        // Guard that rejects strings starting with digits.
1305        let guard = PatternGuardRule {
1306            pattern_name: "no_leading_digit".to_string(),
1307            guard_fn: Box::new(|value: &dyn std::any::Any| {
1308                if let Some(s) = value.downcast_ref::<String>() {
1309                    Ok(!s.starts_with(char::is_numeric))
1310                } else {
1311                    Ok(false)
1312                }
1313            }),
1314            error_message: "must not start with a digit".to_string(),
1315            destructure_fn: None,
1316        };
1317
1318        let rules =
1319            ValidationRules::new("identifier").add_rule(ValidationRule::PatternGuard(guard));
1320
1321        // Valid — no leading digit.
1322        assert!(rules.validate_string("alpha_param").is_ok());
1323        // Invalid — starts with digit.
1324        assert!(rules.validate_string("1_bad").is_err());
1325    }
1326
1327    #[test]
1328    fn test_pattern_guard_array_length() {
1329        // Guard that enforces the array has an odd length.
1330        let guard = PatternGuardRule {
1331            pattern_name: "odd_length".to_string(),
1332            guard_fn: Box::new(|value: &dyn std::any::Any| {
1333                // validate_array passes value.len() as a usize.
1334                if let Some(len) = value.downcast_ref::<usize>() {
1335                    Ok(len % 2 == 1)
1336                } else {
1337                    Ok(false)
1338                }
1339            }),
1340            error_message: "array must have an odd number of elements".to_string(),
1341            destructure_fn: None,
1342        };
1343
1344        let rules = ValidationRules::new("odd_array").add_rule(ValidationRule::PatternGuard(guard));
1345
1346        // Length 3 → odd → passes.
1347        assert!(rules.validate_array(&[1, 2, 3]).is_ok());
1348        // Length 4 → even → fails.
1349        assert!(rules.validate_array(&[1, 2, 3, 4]).is_err());
1350    }
1351
1352    #[test]
1353    fn test_pattern_guard_error_message_propagated() {
1354        let expected_reason = "value must be the answer to everything";
1355        let guard = PatternGuardRule {
1356            pattern_name: "answer".to_string(),
1357            guard_fn: Box::new(|value: &dyn std::any::Any| {
1358                if let Some(v) = value.downcast_ref::<f64>() {
1359                    Ok((*v - 42.0).abs() < f64::EPSILON)
1360                } else {
1361                    Ok(false)
1362                }
1363            }),
1364            error_message: expected_reason.to_string(),
1365            destructure_fn: None,
1366        };
1367
1368        let rules =
1369            ValidationRules::new("cosmic_number").add_rule(ValidationRule::PatternGuard(guard));
1370
1371        // 42 passes.
1372        assert!(rules.validate_numeric(&42.0f64).is_ok());
1373
1374        // Any other value fails and propagates the error message.
1375        let err = rules.validate_numeric(&7.0f64).expect_err("7 is not 42");
1376        assert!(err.to_string().contains(expected_reason));
1377    }
1378}