sklears_core/
validation.rs

1/// Comprehensive validation framework for machine learning parameters and data
2///
3/// This module provides a robust validation system with custom derive macros
4/// for automatic parameter validation in ML algorithms.
5use crate::error::{Result, SklearsError};
6use crate::types::{Array1, Array2, FloatBounds, Numeric};
7use scirs2_core::numeric::{Float, NumCast};
8use std::fmt::Debug;
9
10/// Type alias for validation guard function
11pub type ValidationGuardFn = Box<dyn Fn(&dyn std::any::Any) -> Result<bool> + Send + Sync>;
12
13/// Type alias for validation destructuring function
14pub type ValidationDestructureFn =
15    Box<dyn Fn(&dyn std::any::Any) -> Result<ValidationResult> + Send + Sync>;
16
17/// Core validation trait that can be derived for automatic parameter validation
18pub trait Validate {
19    /// Validate all parameters and return an error if any validation fails
20    fn validate(&self) -> Result<()>;
21
22    /// Validate and provide detailed error information  
23    fn validate_with_context(&self, context: &str) -> Result<()> {
24        self.validate()
25            .map_err(|e| SklearsError::Other(format!("{context}: {e}")))
26    }
27}
28
29/// Validation attributes for ML parameter constraints
30#[derive(Debug, Clone)]
31pub enum ValidationRule {
32    /// Value must be positive (> 0)
33    Positive,
34    /// Value must be non-negative (>= 0)
35    NonNegative,
36    /// Value must be finite (not NaN or infinity)
37    Finite,
38    /// Value must be in a specific range [min, max]
39    Range { min: f64, max: f64 },
40    /// Value must be one of the specified options
41    OneOf(Vec<String>),
42    /// Array must have minimum number of elements
43    MinLength(usize),
44    /// Array must have maximum number of elements
45    MaxLength(usize),
46    /// Array elements must be unique
47    UniqueElements,
48    /// Custom validation function
49    Custom(fn(&dyn std::any::Any) -> Result<()>),
50    /// Pattern guard validation with custom matching
51    PatternGuard(PatternGuardRule),
52}
53
54/// Pattern guard rule for advanced validation with custom matching
55pub struct PatternGuardRule {
56    /// Name of the pattern for error reporting
57    pub pattern_name: String,
58    /// Function that performs the pattern matching and validation
59    pub guard_fn: ValidationGuardFn,
60    /// Error message when the pattern doesn't match
61    pub error_message: String,
62    /// Optional structured destructuring validator
63    pub destructure_fn: Option<ValidationDestructureFn>,
64}
65
66impl std::fmt::Debug for PatternGuardRule {
67    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68        f.debug_struct("PatternGuardRule")
69            .field("pattern_name", &self.pattern_name)
70            .field("guard_fn", &"<function>")
71            .field("error_message", &self.error_message)
72            .field(
73                "destructure_fn",
74                &self.destructure_fn.as_ref().map(|_| "<function>"),
75            )
76            .finish()
77    }
78}
79
80impl Clone for PatternGuardRule {
81    fn clone(&self) -> Self {
82        // Create a simple clone that preserves the pattern name and error message
83        // but uses a default guard function
84        Self {
85            pattern_name: self.pattern_name.clone(),
86            guard_fn: Box::new(|_| Ok(true)), // Default to always pass
87            error_message: self.error_message.clone(),
88            destructure_fn: None,
89        }
90    }
91}
92
93/// Result of pattern matching and validation
94#[derive(Debug, Clone)]
95pub struct ValidationResult {
96    /// Whether the pattern matched
97    pub matched: bool,
98    /// Additional context from pattern matching
99    pub context: std::collections::HashMap<String, String>,
100    /// Any warnings generated during validation
101    pub warnings: Vec<String>,
102}
103
104/// Macro for creating pattern guards with custom validation logic
105#[macro_export]
106macro_rules! pattern_guard {
107    // Pattern guard for numeric types with range validation
108    (numeric_range, $min:expr, $max:expr) => {
109        $crate::validation::PatternGuardRule {
110            pattern_name: "numeric_range".to_string(),
111            guard_fn: Box::new(move |value| {
112                if let Some(val) = value.downcast_ref::<f64>() {
113                    Ok(*val >= $min && *val <= $max)
114                } else if let Some(val) = value.downcast_ref::<f32>() {
115                    Ok(*val >= $min as f32 && *val <= $max as f32)
116                } else if let Some(val) = value.downcast_ref::<i32>() {
117                    Ok(*val >= $min as i32 && *val <= $max as i32)
118                } else if let Some(val) = value.downcast_ref::<usize>() {
119                    Ok(*val >= $min as usize && *val <= $max as usize)
120                } else {
121                    Ok(false)
122                }
123            }),
124            error_message: format!("Value must be in range [{}, {}]", $min, $max),
125            destructure_fn: None,
126        }
127    };
128
129    // Pattern guard for array shape validation
130    (array_shape, $expected_shape:expr) => {
131        $crate::validation::PatternGuardRule {
132            pattern_name: "array_shape".to_string(),
133            guard_fn: Box::new(move |value| {
134                // This would need proper array type checking in real implementation
135                // For now, just return true as placeholder
136                Ok(true)
137            }),
138            error_message: format!("Array shape must match {:?}", $expected_shape),
139            destructure_fn: None,
140        }
141    };
142
143    // Pattern guard for string enum validation
144    (string_enum, $valid_options:expr) => {
145        $crate::validation::PatternGuardRule {
146            pattern_name: "string_enum".to_string(),
147            guard_fn: Box::new(move |value| {
148                if let Some(val) = value.downcast_ref::<String>() {
149                    Ok($valid_options.contains(&val.as_str()))
150                } else if let Some(val) = value.downcast_ref::<&str>() {
151                    Ok($valid_options.contains(val))
152                } else {
153                    Ok(false)
154                }
155            }),
156            error_message: format!("Value must be one of {:?}", $valid_options),
157            destructure_fn: None,
158        }
159    };
160
161    // Pattern guard with custom function and error message
162    ($pattern_name:literal, $guard:expr, $error_msg:literal) => {
163        $crate::validation::PatternGuardRule {
164            pattern_name: $pattern_name.to_string(),
165            guard_fn: Box::new($guard),
166            error_message: $error_msg.to_string(),
167            destructure_fn: None,
168        }
169    };
170
171    // Pattern guard with destructuring validation
172    ($pattern_name:literal, $guard_fn:expr, $destructure_fn:expr) => {
173        $crate::validation::PatternGuardRule {
174            pattern_name: $pattern_name.to_string(),
175            guard_fn: Box::new($guard_fn),
176            error_message: format!("Pattern '{}' validation failed", $pattern_name),
177            destructure_fn: Some(Box::new($destructure_fn)),
178        }
179    };
180}
181
182/// Trait for types that can be pattern matched and validated
183pub trait PatternValidate {
184    /// Apply pattern guard validation
185    fn validate_with_pattern(&self, guard: &PatternGuardRule) -> Result<ValidationResult>;
186
187    /// Check if value matches a specific pattern
188    fn matches_pattern(&self, pattern_name: &str) -> bool;
189
190    /// Extract structured data using pattern destructuring
191    fn destructure(&self, pattern: &str) -> Result<std::collections::HashMap<String, String>>;
192}
193
194/// Implementation of PatternValidate for f64
195impl PatternValidate for f64 {
196    fn validate_with_pattern(&self, guard: &PatternGuardRule) -> Result<ValidationResult> {
197        let value_any = self as &dyn std::any::Any;
198        let matched = (guard.guard_fn)(value_any)?;
199
200        let mut context = std::collections::HashMap::new();
201        context.insert("value".to_string(), self.to_string());
202        context.insert("type".to_string(), "f64".to_string());
203
204        if let Some(destructure_fn) = &guard.destructure_fn {
205            let destructure_result = destructure_fn(value_any)?;
206            Ok(ValidationResult {
207                matched: matched && destructure_result.matched,
208                context: destructure_result.context,
209                warnings: destructure_result.warnings,
210            })
211        } else {
212            Ok(ValidationResult {
213                matched,
214                context,
215                warnings: Vec::new(),
216            })
217        }
218    }
219
220    fn matches_pattern(&self, pattern_name: &str) -> bool {
221        match pattern_name {
222            "finite" => self.is_finite(),
223            "positive" => *self > 0.0,
224            "non_negative" => *self >= 0.0,
225            "probability" => *self >= 0.0 && *self <= 1.0,
226            _ => false,
227        }
228    }
229
230    fn destructure(&self, pattern: &str) -> Result<std::collections::HashMap<String, String>> {
231        let mut result = std::collections::HashMap::new();
232        match pattern {
233            "range_info" => {
234                result.insert("value".to_string(), self.to_string());
235                result.insert("is_finite".to_string(), self.is_finite().to_string());
236                result.insert("is_positive".to_string(), (*self > 0.0).to_string());
237                result.insert(
238                    "sign".to_string(),
239                    if *self >= 0.0 {
240                        "positive".to_string()
241                    } else {
242                        "negative".to_string()
243                    },
244                );
245            }
246            _ => {
247                result.insert("value".to_string(), self.to_string());
248            }
249        }
250        Ok(result)
251    }
252}
253
254/// Implementation of PatternValidate for usize
255impl PatternValidate for usize {
256    fn validate_with_pattern(&self, guard: &PatternGuardRule) -> Result<ValidationResult> {
257        let value_any = self as &dyn std::any::Any;
258        let matched = (guard.guard_fn)(value_any)?;
259
260        let mut context = std::collections::HashMap::new();
261        context.insert("value".to_string(), self.to_string());
262        context.insert("type".to_string(), "usize".to_string());
263
264        Ok(ValidationResult {
265            matched,
266            context,
267            warnings: Vec::new(),
268        })
269    }
270
271    fn matches_pattern(&self, pattern_name: &str) -> bool {
272        match pattern_name {
273            "positive" => *self > 0,
274            "non_negative" => true, // usize is always non-negative
275            "power_of_two" => self.is_power_of_two(),
276            _ => false,
277        }
278    }
279
280    fn destructure(&self, pattern: &str) -> Result<std::collections::HashMap<String, String>> {
281        let mut result = std::collections::HashMap::new();
282        match pattern {
283            "number_info" => {
284                result.insert("value".to_string(), self.to_string());
285                result.insert("is_positive".to_string(), (*self > 0).to_string());
286                result.insert(
287                    "is_power_of_two".to_string(),
288                    self.is_power_of_two().to_string(),
289                );
290            }
291            _ => {
292                result.insert("value".to_string(), self.to_string());
293            }
294        }
295        Ok(result)
296    }
297}
298
299/// Implementation of PatternValidate for String
300impl PatternValidate for String {
301    fn validate_with_pattern(&self, guard: &PatternGuardRule) -> Result<ValidationResult> {
302        let value_any = self as &dyn std::any::Any;
303        let matched = (guard.guard_fn)(value_any)?;
304
305        let mut context = std::collections::HashMap::new();
306        context.insert("value".to_string(), self.clone());
307        context.insert("type".to_string(), "String".to_string());
308        context.insert("length".to_string(), self.len().to_string());
309
310        Ok(ValidationResult {
311            matched,
312            context,
313            warnings: Vec::new(),
314        })
315    }
316
317    fn matches_pattern(&self, pattern_name: &str) -> bool {
318        match pattern_name {
319            "non_empty" => !self.is_empty(),
320            "alphanumeric" => self.chars().all(|c| c.is_alphanumeric()),
321            "lowercase" => self.chars().all(|c| !c.is_alphabetic() || c.is_lowercase()),
322            "uppercase" => self.chars().all(|c| !c.is_alphabetic() || c.is_uppercase()),
323            _ => false,
324        }
325    }
326
327    fn destructure(&self, pattern: &str) -> Result<std::collections::HashMap<String, String>> {
328        let mut result = std::collections::HashMap::new();
329        match pattern {
330            "string_info" => {
331                result.insert("value".to_string(), self.clone());
332                result.insert("length".to_string(), self.len().to_string());
333                result.insert("is_empty".to_string(), self.is_empty().to_string());
334                result.insert(
335                    "is_alphanumeric".to_string(),
336                    self.chars().all(|c| c.is_alphanumeric()).to_string(),
337                );
338            }
339            _ => {
340                result.insert("value".to_string(), self.clone());
341            }
342        }
343        Ok(result)
344    }
345}
346
347/// Default implementations for common pattern validation scenarios
348pub mod pattern_guards {
349    use super::*;
350
351    /// Pattern guard for ML model hyperparameters
352    pub fn hyperparameter_pattern<T: FloatBounds + std::fmt::Debug>(
353        min_val: T,
354        max_val: T,
355        finite_required: bool,
356    ) -> PatternGuardRule {
357        PatternGuardRule {
358            pattern_name: "hyperparameter_bounds".to_string(),
359            guard_fn: Box::new(|_value| {
360                // In real implementation, would need proper type casting
361                Ok(true) // Placeholder
362            }),
363            error_message: format!(
364                "Hyperparameter must be in range [{}, {}]{}",
365                min_val,
366                max_val,
367                if finite_required { " and finite" } else { "" }
368            ),
369            destructure_fn: None,
370        }
371    }
372
373    /// Pattern guard for array shape validation
374    pub fn array_shape_pattern(expected_dims: &[usize]) -> PatternGuardRule {
375        let dims_str = expected_dims
376            .iter()
377            .map(|d| d.to_string())
378            .collect::<Vec<_>>()
379            .join(", ");
380
381        PatternGuardRule {
382            pattern_name: "array_shape".to_string(),
383            guard_fn: Box::new(|_value| {
384                // Would validate array shape in real implementation
385                Ok(true)
386            }),
387            error_message: format!("Array shape must match [{dims_str}]"),
388            destructure_fn: None, // Remove capturing closure for now
389        }
390    }
391
392    /// Pattern guard for ML algorithm configuration
393    pub fn algorithm_config_pattern(required_fields: &[&str]) -> PatternGuardRule {
394        let fields_str = required_fields.join(", ");
395
396        PatternGuardRule {
397            pattern_name: "algorithm_config".to_string(),
398            guard_fn: Box::new(|_value| {
399                // Would validate configuration completeness
400                Ok(true)
401            }),
402            error_message: format!("Configuration must contain fields: {fields_str}"),
403            destructure_fn: None, // Remove capturing closure for now
404        }
405    }
406
407    /// Pattern guard for data type consistency
408    pub fn data_type_pattern(expected_types: &[&str]) -> PatternGuardRule {
409        let types_str = expected_types.join(" | ");
410
411        PatternGuardRule {
412            pattern_name: "data_type_consistency".to_string(),
413            guard_fn: Box::new(|_value| {
414                // Would validate data type consistency
415                Ok(true)
416            }),
417            error_message: format!("Data type must be one of: {types_str}"),
418            destructure_fn: None,
419        }
420    }
421}
422
423/// Container for multiple validation rules
424#[derive(Debug, Clone)]
425pub struct ValidationRules {
426    pub rules: Vec<ValidationRule>,
427    pub field_name: String,
428}
429
430impl ValidationRules {
431    /// Create a new validation rules container
432    pub fn new(field_name: &str) -> Self {
433        Self {
434            rules: Vec::new(),
435            field_name: field_name.to_string(),
436        }
437    }
438
439    /// Add a validation rule
440    pub fn add_rule(mut self, rule: ValidationRule) -> Self {
441        self.rules.push(rule);
442        self
443    }
444
445    /// Validate a numeric value against all rules
446    pub fn validate_numeric<T>(&self, value: &T) -> Result<()>
447    where
448        T: Numeric + PartialOrd + Debug + Copy + NumCast,
449    {
450        for rule in &self.rules {
451            match rule {
452                ValidationRule::Positive => {
453                    if *value <= T::zero() {
454                        return Err(SklearsError::InvalidParameter {
455                            name: self.field_name.clone(),
456                            reason: "must be positive".to_string(),
457                        });
458                    }
459                }
460                ValidationRule::NonNegative => {
461                    if *value < T::zero() {
462                        return Err(SklearsError::InvalidParameter {
463                            name: self.field_name.clone(),
464                            reason: "must be non-negative".to_string(),
465                        });
466                    }
467                }
468                ValidationRule::Finite => {
469                    if let Some(float_val) = NumCast::from(*value) {
470                        let f: f64 = float_val;
471                        if !f.is_finite() {
472                            return Err(SklearsError::InvalidParameter {
473                                name: self.field_name.clone(),
474                                reason: "must be finite".to_string(),
475                            });
476                        }
477                    }
478                }
479                ValidationRule::Range { min, max } => {
480                    if let Some(float_val) = NumCast::from(*value) {
481                        let f: f64 = float_val;
482                        if f < *min || f > *max {
483                            return Err(SklearsError::InvalidParameter {
484                                name: self.field_name.clone(),
485                                reason: format!("must be in range [{min}, {max}]"),
486                            });
487                        }
488                    }
489                }
490                ValidationRule::PatternGuard(_pattern_guard) => {
491                    // TODO: Fix lifetime issues with pattern guard validation
492                    // let value_any = &value as &dyn std::any::Any;
493                    // let result = (pattern_guard.guard_fn)(value_any)?;
494                    // if !result {
495                    //     return Err(SklearsError::InvalidParameter {
496                    //         name: self.field_name.clone(),
497                    //         reason: pattern_guard.error_message.clone(),
498                    //     });
499                    // }
500                }
501                _ => {
502                    // Skip rules that don't apply to numeric values
503                }
504            }
505        }
506        Ok(())
507    }
508
509    /// Validate a string value against all rules
510    pub fn validate_string(&self, value: &str) -> Result<()> {
511        for rule in &self.rules {
512            match rule {
513                ValidationRule::OneOf(options) => {
514                    if !options.contains(&value.to_string()) {
515                        return Err(SklearsError::InvalidParameter {
516                            name: self.field_name.clone(),
517                            reason: format!("must be one of {options:?}"),
518                        });
519                    }
520                }
521                ValidationRule::PatternGuard(_pattern_guard) => {
522                    // TODO: Fix lifetime issues with pattern guard validation
523                    // let value_any = &value as &dyn std::any::Any;
524                    // let result = (pattern_guard.guard_fn)(value_any)?;
525                    // if !result {
526                    //     return Err(SklearsError::InvalidParameter {
527                    //         name: self.field_name.clone(),
528                    //         reason: pattern_guard.error_message.clone(),
529                    //     });
530                    // }
531                }
532                _ => {
533                    // Skip rules that don't apply to string values
534                }
535            }
536        }
537        Ok(())
538    }
539
540    /// Validate an array/vector against all rules
541    pub fn validate_array<T>(&self, value: &[T]) -> Result<()> {
542        for rule in &self.rules {
543            match rule {
544                ValidationRule::MinLength(min_len) => {
545                    if value.len() < *min_len {
546                        return Err(SklearsError::InvalidParameter {
547                            name: self.field_name.clone(),
548                            reason: format!("must have at least {min_len} elements"),
549                        });
550                    }
551                }
552                ValidationRule::MaxLength(max_len) => {
553                    if value.len() > *max_len {
554                        return Err(SklearsError::InvalidParameter {
555                            name: self.field_name.clone(),
556                            reason: format!("must have at most {max_len} elements"),
557                        });
558                    }
559                }
560                ValidationRule::PatternGuard(_pattern_guard) => {
561                    // TODO: Fix lifetime issues with pattern guard validation
562                    // let value_any = &value as &dyn std::any::Any;
563                    // let result = (pattern_guard.guard_fn)(value_any)?;
564                    // if !result {
565                    //     return Err(SklearsError::InvalidParameter {
566                    //         name: self.field_name.clone(),
567                    //         reason: pattern_guard.error_message.clone(),
568                    //     });
569                    // }
570                }
571                _ => {
572                    // Skip rules that don't apply to arrays
573                }
574            }
575        }
576        Ok(())
577    }
578
579    /// Validate an unsigned integer value (usize) against all rules
580    pub fn validate_usize(&self, value: &usize) -> Result<()> {
581        for rule in &self.rules {
582            match rule {
583                ValidationRule::Positive => {
584                    if *value == 0 {
585                        return Err(SklearsError::InvalidParameter {
586                            name: self.field_name.clone(),
587                            reason: "must be positive".to_string(),
588                        });
589                    }
590                }
591                ValidationRule::NonNegative => {
592                    // usize is always non-negative, so this always passes
593                }
594                ValidationRule::Range { min, max } => {
595                    let val = *value as f64;
596                    if val < *min || val > *max {
597                        return Err(SklearsError::InvalidParameter {
598                            name: self.field_name.clone(),
599                            reason: format!("must be in range [{min}, {max}]"),
600                        });
601                    }
602                }
603                _ => {
604                    // Skip rules that don't apply to usize values
605                }
606            }
607        }
608        Ok(())
609    }
610}
611
612/// ML-specific validation functions
613pub mod ml {
614    use super::*;
615
616    /// Validate learning rate (must be positive and typically < 1.0)
617    pub fn validate_learning_rate<T: FloatBounds>(lr: T) -> Result<()> {
618        if lr <= T::zero() {
619            return Err(SklearsError::InvalidParameter {
620                name: "learning_rate".to_string(),
621                reason: "must be positive".to_string(),
622            });
623        }
624
625        if !Float::is_finite(lr) {
626            return Err(SklearsError::InvalidParameter {
627                name: "learning_rate".to_string(),
628                reason: "must be finite".to_string(),
629            });
630        }
631
632        // Warning for unusually high learning rates
633        if lr > T::one() {
634            log::warn!("Learning rate {lr} is unusually high, consider using a smaller value");
635        }
636
637        Ok(())
638    }
639
640    /// Validate regularization parameter (must be non-negative)
641    pub fn validate_regularization<T: FloatBounds>(reg: T) -> Result<()> {
642        if reg < T::zero() {
643            return Err(SklearsError::InvalidParameter {
644                name: "regularization".to_string(),
645                reason: "must be non-negative".to_string(),
646            });
647        }
648
649        if !Float::is_finite(reg) {
650            return Err(SklearsError::InvalidParameter {
651                name: "regularization".to_string(),
652                reason: "must be finite".to_string(),
653            });
654        }
655
656        Ok(())
657    }
658
659    /// Validate number of clusters (must be positive integer)
660    pub fn validate_n_clusters(n_clusters: usize, n_samples: usize) -> Result<()> {
661        if n_clusters == 0 {
662            return Err(SklearsError::InvalidParameter {
663                name: "n_clusters".to_string(),
664                reason: "must be positive".to_string(),
665            });
666        }
667
668        if n_clusters > n_samples {
669            return Err(SklearsError::InvalidParameter {
670                name: "n_clusters".to_string(),
671                reason: format!("cannot exceed number of samples ({n_samples})"),
672            });
673        }
674
675        Ok(())
676    }
677
678    /// Validate number of neighbors for KNN (must be positive and <= n_samples)
679    pub fn validate_n_neighbors(n_neighbors: usize, n_samples: usize) -> Result<()> {
680        if n_neighbors == 0 {
681            return Err(SklearsError::InvalidParameter {
682                name: "n_neighbors".to_string(),
683                reason: "must be positive".to_string(),
684            });
685        }
686
687        if n_neighbors > n_samples {
688            return Err(SklearsError::InvalidParameter {
689                name: "n_neighbors".to_string(),
690                reason: format!("cannot exceed number of samples ({n_samples})"),
691            });
692        }
693
694        Ok(())
695    }
696
697    /// Validate tolerance parameter (must be positive and small)
698    pub fn validate_tolerance<T: FloatBounds>(tol: T) -> Result<()> {
699        if tol <= T::zero() {
700            return Err(SklearsError::InvalidParameter {
701                name: "tolerance".to_string(),
702                reason: "must be positive".to_string(),
703            });
704        }
705
706        if !Float::is_finite(tol) {
707            return Err(SklearsError::InvalidParameter {
708                name: "tolerance".to_string(),
709                reason: "must be finite".to_string(),
710            });
711        }
712
713        // Warning for very large tolerances
714        if tol > T::from(0.1).unwrap_or(T::one()) {
715            log::warn!("Tolerance {tol} is very large, algorithm may converge prematurely");
716        }
717
718        Ok(())
719    }
720
721    /// Validate max iterations (must be positive)
722    pub fn validate_max_iter(max_iter: usize) -> Result<()> {
723        if max_iter == 0 {
724            return Err(SklearsError::InvalidParameter {
725                name: "max_iter".to_string(),
726                reason: "must be positive".to_string(),
727            });
728        }
729
730        Ok(())
731    }
732
733    /// Validate probability values (must be in [0, 1])
734    pub fn validate_probability<T: FloatBounds>(prob: T) -> Result<()> {
735        if prob < T::zero() || prob > T::one() {
736            return Err(SklearsError::InvalidParameter {
737                name: "probability".to_string(),
738                reason: "must be in range [0, 1]".to_string(),
739            });
740        }
741
742        if !Float::is_finite(prob) {
743            return Err(SklearsError::InvalidParameter {
744                name: "probability".to_string(),
745                reason: "must be finite".to_string(),
746            });
747        }
748
749        Ok(())
750    }
751
752    /// Validate data shapes for supervised learning
753    pub fn validate_supervised_data<T, U>(x: &Array2<T>, y: &Array1<U>) -> Result<()> {
754        if x.is_empty() {
755            return Err(SklearsError::InvalidData {
756                reason: "X cannot be empty".to_string(),
757            });
758        }
759
760        if y.is_empty() {
761            return Err(SklearsError::InvalidData {
762                reason: "y cannot be empty".to_string(),
763            });
764        }
765
766        if x.nrows() != y.len() {
767            return Err(SklearsError::ShapeMismatch {
768                expected: "X.shape[0] == y.shape[0]".to_string(),
769                actual: format!("X.shape[0]={}, y.shape[0]={}", x.nrows(), y.len()),
770            });
771        }
772
773        Ok(())
774    }
775
776    /// Validate data for unsupervised learning
777    pub fn validate_unsupervised_data<T>(x: &Array2<T>) -> Result<()> {
778        if x.is_empty() {
779            return Err(SklearsError::InvalidData {
780                reason: "X cannot be empty".to_string(),
781            });
782        }
783
784        if x.nrows() == 0 || x.ncols() == 0 {
785            return Err(SklearsError::InvalidData {
786                reason: "X must have positive dimensions".to_string(),
787            });
788        }
789
790        Ok(())
791    }
792
793    /// Validate feature consistency between training and prediction
794    pub fn validate_feature_consistency<T, U>(
795        x_train: &Array2<T>,
796        x_test: &Array2<U>,
797        _model_name: &str,
798    ) -> Result<()> {
799        if x_train.ncols() != x_test.ncols() {
800            return Err(SklearsError::FeatureMismatch {
801                expected: x_train.ncols(),
802                actual: x_test.ncols(),
803            });
804        }
805
806        Ok(())
807    }
808}
809
810/// Proc macro helper functions for derive implementation
811pub mod derive_helpers {
812    /// Generate validation code for a field with validation attributes
813    pub fn generate_field_validation(
814        field_name: &str,
815        _field_type: &str,
816        validation_attrs: &[String],
817    ) -> String {
818        let mut validations = Vec::new();
819
820        for attr in validation_attrs {
821            match attr.as_str() {
822                "positive" => {
823                    validations.push(format!(
824                        "ValidationRules::new(\"{field_name}\").add_rule(ValidationRule::Positive).validate_numeric(&self.{field_name})?;"
825                    ));
826                }
827                "non_negative" => {
828                    validations.push(format!(
829                        "ValidationRules::new(\"{field_name}\").add_rule(ValidationRule::NonNegative).validate_numeric(&self.{field_name})?;"
830                    ));
831                }
832                "finite" => {
833                    validations.push(format!(
834                        "ValidationRules::new(\"{field_name}\").add_rule(ValidationRule::Finite).validate_numeric(&self.{field_name})?;"
835                    ));
836                }
837                _ if attr.starts_with("range(") => {
838                    // Parse range(min, max) format
839                    let range_str = attr
840                        .strip_prefix("range(")
841                        .unwrap()
842                        .strip_suffix(")")
843                        .unwrap();
844                    let parts: Vec<&str> = range_str.split(',').map(|s| s.trim()).collect();
845                    if parts.len() == 2 {
846                        let min_val = parts[0];
847                        let max_val = parts[1];
848                        validations.push(format!(
849                            "ValidationRules::new(\"{field_name}\").add_rule(ValidationRule::Range {{ min: {min_val}, max: {max_val} }}).validate_numeric(&self.{field_name})?;"
850                        ));
851                    }
852                }
853                _ => {}
854            }
855        }
856
857        validations.join("\n")
858    }
859}
860
861/// Configuration validation for complete ML algorithms
862pub trait ConfigValidation {
863    /// Validate the entire configuration
864    fn validate_config(&self) -> Result<()>;
865
866    /// Get validation warnings (non-fatal issues)
867    fn get_warnings(&self) -> Vec<String> {
868        Vec::new()
869    }
870}
871
872/// Validation context for providing better error messages
873#[derive(Debug, Clone)]
874pub struct ValidationContext {
875    pub algorithm: String,
876    pub operation: String,
877    pub data_info: Option<DataInfo>,
878}
879
880/// Information about the data being validated
881#[derive(Debug, Clone)]
882pub struct DataInfo {
883    pub n_samples: usize,
884    pub n_features: usize,
885    pub data_type: String,
886}
887
888impl ValidationContext {
889    /// Create a new validation context
890    pub fn new(algorithm: &str, operation: &str) -> Self {
891        Self {
892            algorithm: algorithm.to_string(),
893            operation: operation.to_string(),
894            data_info: None,
895        }
896    }
897
898    /// Add data information to the context
899    pub fn with_data_info(mut self, n_samples: usize, n_features: usize, data_type: &str) -> Self {
900        self.data_info = Some(DataInfo {
901            n_samples,
902            n_features,
903            data_type: data_type.to_string(),
904        });
905        self
906    }
907
908    /// Format error with context information
909    pub fn format_error(&self, error: &SklearsError) -> String {
910        let mut msg = format!(
911            "Error in {} during {}: {error}",
912            self.algorithm, self.operation
913        );
914
915        if let Some(data_info) = &self.data_info {
916            msg.push_str(&format!(
917                " (data: {} samples, {} features, type: {})",
918                data_info.n_samples, data_info.n_features, data_info.data_type
919            ));
920        }
921
922        msg
923    }
924}
925
926/// Structured destructuring for complex data types
927pub mod structured_destructuring {
928    use super::*;
929
930    /// Trait for types that support structured destructuring
931    pub trait StructuredDestructure {
932        /// Destructure into named components
933        fn destructure_into_components(
934            &self,
935        ) -> Result<std::collections::HashMap<String, Box<dyn std::any::Any>>>;
936
937        /// Extract specific fields by path (e.g., "user.address.city")
938        fn extract_field(&self, field_path: &str) -> Result<Box<dyn std::any::Any>>;
939
940        /// Validate structure matches expected schema
941        fn validate_structure(&self, schema: &StructuralSchema) -> Result<ValidationResult>;
942    }
943
944    /// Schema for validating complex data structures
945    #[derive(Debug, Clone, Default)]
946    pub struct StructuralSchema {
947        pub required_fields: Vec<String>,
948        pub optional_fields: Vec<String>,
949        pub field_types: std::collections::HashMap<String, String>,
950        pub nested_schemas: std::collections::HashMap<String, StructuralSchema>,
951    }
952
953    impl StructuralSchema {
954        pub fn new() -> Self {
955            Self::default()
956        }
957
958        pub fn require_field(mut self, field_name: &str, field_type: &str) -> Self {
959            self.required_fields.push(field_name.to_string());
960            self.field_types
961                .insert(field_name.to_string(), field_type.to_string());
962            self
963        }
964
965        pub fn optional_field(mut self, field_name: &str, field_type: &str) -> Self {
966            self.optional_fields.push(field_name.to_string());
967            self.field_types
968                .insert(field_name.to_string(), field_type.to_string());
969            self
970        }
971
972        pub fn nested_schema(mut self, field_name: &str, schema: StructuralSchema) -> Self {
973            self.nested_schemas.insert(field_name.to_string(), schema);
974            self
975        }
976    }
977
978    /// Configuration for ML algorithms with structured validation
979    #[derive(Debug, Clone)]
980    pub struct AlgorithmConfig {
981        pub algorithm_name: String,
982        pub hyperparameters: std::collections::HashMap<String, ConfigValue>,
983        pub metadata: std::collections::HashMap<String, String>,
984    }
985
986    /// Values that can be stored in configuration
987    #[derive(Debug, Clone)]
988    pub enum ConfigValue {
989        Float(f64),
990        Integer(i64),
991        String(String),
992        Boolean(bool),
993        Array(Vec<ConfigValue>),
994        Object(std::collections::HashMap<String, ConfigValue>),
995    }
996
997    impl StructuredDestructure for AlgorithmConfig {
998        fn destructure_into_components(
999            &self,
1000        ) -> Result<std::collections::HashMap<String, Box<dyn std::any::Any>>> {
1001            let mut components = std::collections::HashMap::new();
1002
1003            components.insert(
1004                "algorithm_name".to_string(),
1005                Box::new(self.algorithm_name.clone()) as Box<dyn std::any::Any>,
1006            );
1007            components.insert(
1008                "hyperparameters".to_string(),
1009                Box::new(self.hyperparameters.clone()) as Box<dyn std::any::Any>,
1010            );
1011            components.insert(
1012                "metadata".to_string(),
1013                Box::new(self.metadata.clone()) as Box<dyn std::any::Any>,
1014            );
1015
1016            Ok(components)
1017        }
1018
1019        fn extract_field(&self, field_path: &str) -> Result<Box<dyn std::any::Any>> {
1020            let parts: Vec<&str> = field_path.split('.').collect();
1021
1022            match parts.first() {
1023                Some(&"algorithm_name") => Ok(Box::new(self.algorithm_name.clone())),
1024                Some(&"hyperparameters") => {
1025                    if parts.len() > 1 {
1026                        if let Some(param_value) = self.hyperparameters.get(parts[1]) {
1027                            Ok(Box::new(param_value.clone()))
1028                        } else {
1029                            Err(SklearsError::InvalidParameter {
1030                                name: field_path.to_string(),
1031                                reason: format!("Hyperparameter '{}' not found", parts[1]),
1032                            })
1033                        }
1034                    } else {
1035                        Ok(Box::new(self.hyperparameters.clone()))
1036                    }
1037                }
1038                Some(&"metadata") => {
1039                    if parts.len() > 1 {
1040                        if let Some(meta_value) = self.metadata.get(parts[1]) {
1041                            Ok(Box::new(meta_value.clone()))
1042                        } else {
1043                            Err(SklearsError::InvalidParameter {
1044                                name: field_path.to_string(),
1045                                reason: format!("Metadata '{}' not found", parts[1]),
1046                            })
1047                        }
1048                    } else {
1049                        Ok(Box::new(self.metadata.clone()))
1050                    }
1051                }
1052                _ => Err(SklearsError::InvalidParameter {
1053                    name: field_path.to_string(),
1054                    reason: "Invalid field path".to_string(),
1055                }),
1056            }
1057        }
1058
1059        fn validate_structure(&self, schema: &StructuralSchema) -> Result<ValidationResult> {
1060            let mut warnings = Vec::new();
1061            let mut context = std::collections::HashMap::new();
1062
1063            // Check required fields
1064            for required_field in &schema.required_fields {
1065                match required_field.as_str() {
1066                    "algorithm_name" => {
1067                        if self.algorithm_name.is_empty() {
1068                            return Err(SklearsError::InvalidParameter {
1069                                name: "algorithm_name".to_string(),
1070                                reason: "Required field cannot be empty".to_string(),
1071                            });
1072                        }
1073                        context.insert("algorithm_name".to_string(), "present".to_string());
1074                    }
1075                    "hyperparameters" => {
1076                        context.insert(
1077                            "hyperparameters_count".to_string(),
1078                            self.hyperparameters.len().to_string(),
1079                        );
1080                    }
1081                    _ => {
1082                        warnings.push(format!("Unknown required field: {required_field}"));
1083                    }
1084                }
1085            }
1086
1087            Ok(ValidationResult {
1088                matched: true,
1089                context,
1090                warnings,
1091            })
1092        }
1093    }
1094
1095    /// Pattern matching for complex validation scenarios
1096    pub fn create_complex_pattern_guard<T>(
1097        pattern_name: &str,
1098        validator: impl Fn(&T) -> Result<bool> + Send + Sync + 'static,
1099        error_message: &str,
1100    ) -> PatternGuardRule
1101    where
1102        T: 'static,
1103    {
1104        PatternGuardRule {
1105            pattern_name: pattern_name.to_string(),
1106            guard_fn: Box::new(move |value| {
1107                if let Some(typed_value) = value.downcast_ref::<T>() {
1108                    validator(typed_value)
1109                } else {
1110                    Ok(false)
1111                }
1112            }),
1113            error_message: error_message.to_string(),
1114            destructure_fn: None,
1115        }
1116    }
1117}
1118
1119/// Macro for easy destructuring of complex types
1120#[macro_export]
1121macro_rules! destructure {
1122    // Basic field extraction
1123    ($obj:expr, $field:literal) => {
1124        $obj.extract_field($field)
1125    };
1126
1127    // Multiple field extraction
1128    ($obj:expr, { $($field:literal),* }) => {
1129        {
1130            let mut results = std::collections::HashMap::new();
1131            $(
1132                if let Ok(value) = $obj.extract_field($field) {
1133                    results.insert($field.to_string(), value);
1134                }
1135            )*
1136            results
1137        }
1138    };
1139
1140    // Destructuring with validation
1141    ($obj:expr, validate: $schema:expr) => {
1142        $obj.validate_structure(&$schema)
1143    };
1144}
1145
1146#[allow(non_snake_case)]
1147#[cfg(test)]
1148mod tests {
1149    use super::*;
1150
1151    #[test]
1152    fn test_validation_rules_numeric() {
1153        let rules = ValidationRules::new("test_param")
1154            .add_rule(ValidationRule::Positive)
1155            .add_rule(ValidationRule::Finite);
1156
1157        // Valid value
1158        assert!(rules.validate_numeric(&1.5f64).is_ok());
1159
1160        // Invalid: non-positive
1161        assert!(rules.validate_numeric(&0.0f64).is_err());
1162        assert!(rules.validate_numeric(&-1.0f64).is_err());
1163
1164        // Invalid: non-finite
1165        assert!(rules.validate_numeric(&f64::NAN).is_err());
1166        assert!(rules.validate_numeric(&f64::INFINITY).is_err());
1167    }
1168
1169    #[test]
1170    fn test_validation_rules_range() {
1171        let rules = ValidationRules::new("test_param")
1172            .add_rule(ValidationRule::Range { min: 0.0, max: 1.0 });
1173
1174        // Valid values
1175        assert!(rules.validate_numeric(&0.5f64).is_ok());
1176        assert!(rules.validate_numeric(&0.0f64).is_ok());
1177        assert!(rules.validate_numeric(&1.0f64).is_ok());
1178
1179        // Invalid values
1180        assert!(rules.validate_numeric(&-0.1f64).is_err());
1181        assert!(rules.validate_numeric(&1.1f64).is_err());
1182    }
1183
1184    #[test]
1185    fn test_validation_rules_string() {
1186        let rules = ValidationRules::new("test_param").add_rule(ValidationRule::OneOf(vec![
1187            "option1".to_string(),
1188            "option2".to_string(),
1189        ]));
1190
1191        // Valid values
1192        assert!(rules.validate_string("option1").is_ok());
1193        assert!(rules.validate_string("option2").is_ok());
1194
1195        // Invalid value
1196        assert!(rules.validate_string("option3").is_err());
1197    }
1198
1199    #[test]
1200    fn test_validation_rules_array() {
1201        let rules = ValidationRules::new("test_param")
1202            .add_rule(ValidationRule::MinLength(2))
1203            .add_rule(ValidationRule::MaxLength(5));
1204
1205        // Valid arrays
1206        assert!(rules.validate_array(&[1, 2]).is_ok());
1207        assert!(rules.validate_array(&[1, 2, 3, 4, 5]).is_ok());
1208
1209        // Invalid: too short
1210        assert!(rules.validate_array(&[1]).is_err());
1211
1212        // Invalid: too long
1213        assert!(rules.validate_array(&[1, 2, 3, 4, 5, 6]).is_err());
1214    }
1215
1216    #[test]
1217    fn test_ml_validation_learning_rate() {
1218        // Valid learning rates
1219        assert!(ml::validate_learning_rate(0.01f64).is_ok());
1220        assert!(ml::validate_learning_rate(0.1f64).is_ok());
1221
1222        // Invalid: non-positive
1223        assert!(ml::validate_learning_rate(0.0f64).is_err());
1224        assert!(ml::validate_learning_rate(-0.1f64).is_err());
1225
1226        // Invalid: non-finite
1227        assert!(ml::validate_learning_rate(f64::NAN).is_err());
1228    }
1229
1230    #[test]
1231    fn test_ml_validation_n_clusters() {
1232        // Valid
1233        assert!(ml::validate_n_clusters(3, 10).is_ok());
1234        assert!(ml::validate_n_clusters(10, 10).is_ok());
1235
1236        // Invalid: zero clusters
1237        assert!(ml::validate_n_clusters(0, 10).is_err());
1238
1239        // Invalid: more clusters than samples
1240        assert!(ml::validate_n_clusters(15, 10).is_err());
1241    }
1242
1243    #[test]
1244    fn test_ml_validation_probability() {
1245        // Valid probabilities
1246        assert!(ml::validate_probability(0.0f64).is_ok());
1247        assert!(ml::validate_probability(0.5f64).is_ok());
1248        assert!(ml::validate_probability(1.0f64).is_ok());
1249
1250        // Invalid: out of range
1251        assert!(ml::validate_probability(-0.1f64).is_err());
1252        assert!(ml::validate_probability(1.1f64).is_err());
1253
1254        // Invalid: non-finite
1255        assert!(ml::validate_probability(f64::NAN).is_err());
1256    }
1257
1258    #[test]
1259    fn test_validation_context() {
1260        let context = ValidationContext::new("KMeans", "fit").with_data_info(100, 5, "float64");
1261
1262        let error = SklearsError::InvalidParameter {
1263            name: "n_clusters".to_string(),
1264            reason: "must be positive".to_string(),
1265        };
1266
1267        let formatted = context.format_error(&error);
1268        assert!(formatted.contains("KMeans"));
1269        assert!(formatted.contains("fit"));
1270        assert!(formatted.contains("100 samples"));
1271        assert!(formatted.contains("5 features"));
1272    }
1273}