scirs2_core/
validation.rs

1//! Validation utilities for ``SciRS2``
2//!
3//! This module provides utilities for validating data and parameters, including
4//! production-level security hardening and comprehensive input validation.
5
6use ndarray::{ArrayBase, Dimension, ScalarOperand};
7use num_traits::{Float, One, Zero};
8
9use crate::error::{CoreError, CoreResult, ErrorContext, ErrorLocation};
10
11/// Checks if a value is within bounds (inclusive)
12///
13/// # Arguments
14///
15/// * `value` - The value to check
16/// * `min` - The minimum allowed value (inclusive)
17/// * `max` - The maximum allowed value (inclusive)
18/// * `name` - The name of the parameter being checked
19///
20/// # Returns
21///
22/// * `Ok(value)` if the value is within bounds
23/// * `Err(CoreError::ValueError)` if the value is out of bounds
24///
25/// # Errors
26///
27/// Returns `CoreError::ValueError` if the value is outside the specified bounds.
28pub fn check_in_bounds<T, S>(value: T, min: T, max: T, name: S) -> CoreResult<T>
29where
30    T: PartialOrd + std::fmt::Display + Copy,
31    S: Into<String>,
32{
33    if value < min || value > max {
34        return Err(CoreError::ValueError(
35            ErrorContext::new(format!(
36                "{} must be between {min} and {max}, got {value}",
37                name.into()
38            ))
39            .with_location(ErrorLocation::new(file!(), line!())),
40        ));
41    }
42    Ok(value)
43}
44
45/// Checks if a value is positive
46///
47/// # Arguments
48///
49/// * `value` - The value to check
50/// * `name` - The name of the parameter being checked
51///
52/// # Returns
53///
54/// * `Ok(value)` if the value is positive
55/// * `Err(CoreError::ValueError)` if the value is not positive
56///
57/// # Errors
58///
59/// Returns `CoreError::ValueError` if the value is not positive.
60pub fn check_positive<T, S>(value: T, name: S) -> CoreResult<T>
61where
62    T: PartialOrd + std::fmt::Display + Copy + Zero,
63    S: Into<String>,
64{
65    if value <= T::zero() {
66        return Err(CoreError::ValueError(
67            ErrorContext::new({
68                let name_str = name.into();
69                format!("{name_str} must be positive, got {value}")
70            })
71            .with_location(ErrorLocation::new(file!(), line!())),
72        ));
73    }
74    Ok(value)
75}
76
77/// Checks if a value is non-negative
78///
79/// # Arguments
80///
81/// * `value` - The value to check
82/// * `name` - The name of the parameter being checked
83///
84/// # Returns
85///
86/// * `Ok(value)` if the value is non-negative
87/// * `Err(CoreError::ValueError)` if the value is negative
88///
89/// # Errors
90///
91/// Returns `CoreError::ValueError` if the value is negative.
92pub fn check_non_negative<T, S>(value: T, name: S) -> CoreResult<T>
93where
94    T: PartialOrd + std::fmt::Display + Copy + Zero,
95    S: Into<String>,
96{
97    if value < T::zero() {
98        return Err(CoreError::ValueError(
99            ErrorContext::new({
100                let name_str = name.into();
101                format!("{name_str} must be non-negative, got {value}")
102            })
103            .with_location(ErrorLocation::new(file!(), line!())),
104        ));
105    }
106    Ok(value)
107}
108
109/// Checks if a floating-point value is finite
110///
111/// # Arguments
112///
113/// * `value` - The value to check
114/// * `name` - The name of the parameter being checked
115///
116/// # Returns
117///
118/// * `Ok(value)` if the value is finite
119/// * `Err(CoreError::ValueError)` if the value is not finite
120///
121/// # Errors
122///
123/// Returns `CoreError::ValueError` if the value is not finite.
124pub fn check_finite<T, S>(value: T, name: S) -> CoreResult<T>
125where
126    T: Float + std::fmt::Display + Copy,
127    S: Into<String>,
128{
129    if !value.is_finite() {
130        return Err(CoreError::ValueError(
131            ErrorContext::new({
132                let name_str = name.into();
133                format!("{name_str} must be finite, got {value}")
134            })
135            .with_location(ErrorLocation::new(file!(), line!())),
136        ));
137    }
138    Ok(value)
139}
140
141/// Checks if all values in an array are finite
142///
143/// # Arguments
144///
145/// * `array` - The array to check
146/// * `name` - The name of the array being checked
147///
148/// # Returns
149///
150/// * `Ok(())` if all values are finite
151/// * `Err(CoreError::ValueError)` if any value is not finite
152///
153/// # Errors
154///
155/// Returns `CoreError::ValueError` if any value in the array is not finite.
156pub fn checkarray_finite<S, A, D>(array: &ArrayBase<S, D>, name: A) -> CoreResult<()>
157where
158    S: ndarray::Data,
159    D: Dimension,
160    S::Elem: Float + std::fmt::Display,
161    A: Into<String>,
162{
163    let name = name.into();
164    for (idx, &value) in array.indexed_iter() {
165        if !value.is_finite() {
166            return Err(CoreError::ValueError(
167                ErrorContext::new(format!(
168                    "{name} must contain only finite values, got {value} at {idx:?}"
169                ))
170                .with_location(ErrorLocation::new(file!(), line!())),
171            ));
172        }
173    }
174    Ok(())
175}
176
177/// Checks if an array has the expected shape
178///
179/// # Arguments
180///
181/// * `array` - The array to check
182/// * `expectedshape` - The expected shape
183/// * `name` - The name of the array being checked
184///
185/// # Returns
186///
187/// * `Ok(())` if the array has the expected shape
188/// * `Err(CoreError::ShapeError)` if the array does not have the expected shape
189///
190/// # Errors
191///
192/// Returns `CoreError::ShapeError` if the array does not have the expected shape.
193pub fn checkshape<S, D, A>(
194    array: &ArrayBase<S, D>,
195    expectedshape: &[usize],
196    name: A,
197) -> CoreResult<()>
198where
199    S: ndarray::Data,
200    D: Dimension,
201    A: Into<String>,
202{
203    let actualshape = array.shape();
204    if actualshape != expectedshape {
205        return Err(CoreError::ShapeError(
206            ErrorContext::new(format!(
207                "{} has incorrect shape: expected {expectedshape:?}, got {actualshape:?}",
208                name.into()
209            ))
210            .with_location(ErrorLocation::new(file!(), line!())),
211        ));
212    }
213    Ok(())
214}
215
216/// Checks if an array is 1D
217///
218/// # Arguments
219///
220/// * `array` - The array to check
221/// * `name` - The name of the array being checked
222///
223/// # Returns
224///
225/// * `Ok(())` if the array is 1D
226/// * `Err(CoreError::ShapeError)` if the array is not 1D
227///
228/// # Errors
229///
230/// Returns `CoreError::ShapeError` if the array is not 1D.
231pub fn check_1d<S, D, A>(array: &ArrayBase<S, D>, name: A) -> CoreResult<()>
232where
233    S: ndarray::Data,
234    D: Dimension,
235    A: Into<String>,
236{
237    if array.ndim() != 1 {
238        return Err(CoreError::ShapeError(
239            ErrorContext::new({
240                let name_str = name.into();
241                let ndim = array.ndim();
242                format!("{name_str} must be 1D, got {ndim}D")
243            })
244            .with_location(ErrorLocation::new(file!(), line!())),
245        ));
246    }
247    Ok(())
248}
249
250/// Checks if an array is 2D
251///
252/// # Arguments
253///
254/// * `array` - The array to check
255/// * `name` - The name of the array being checked
256///
257/// # Returns
258///
259/// * `Ok(())` if the array is 2D
260/// * `Err(CoreError::ShapeError)` if the array is not 2D
261///
262/// # Errors
263///
264/// Returns `CoreError::ShapeError` if the array is not 2D.
265pub fn check_2d<S, D, A>(array: &ArrayBase<S, D>, name: A) -> CoreResult<()>
266where
267    S: ndarray::Data,
268    D: Dimension,
269    A: Into<String>,
270{
271    if array.ndim() != 2 {
272        return Err(CoreError::ShapeError(
273            ErrorContext::new({
274                let name_str = name.into();
275                let ndim = array.ndim();
276                format!("{name_str} must be 2D, got {ndim}D")
277            })
278            .with_location(ErrorLocation::new(file!(), line!())),
279        ));
280    }
281    Ok(())
282}
283
284/// Checks if two arrays have the same shape
285///
286/// # Arguments
287///
288/// * `a` - The first array
289/// * `a_name` - The name of the first array
290/// * `b` - The second array
291/// * `b_name` - The name of the second array
292///
293/// # Returns
294///
295/// * `Ok(())` if the arrays have the same shape
296/// * `Err(CoreError::ShapeError)` if the arrays have different shapes
297///
298/// # Errors
299///
300/// Returns `CoreError::ShapeError` if the arrays have different shapes.
301pub fn check_sameshape<S1, S2, D1, D2, A, B>(
302    a: &ArrayBase<S1, D1>,
303    a_name: A,
304    b: &ArrayBase<S2, D2>,
305    b_name: B,
306) -> CoreResult<()>
307where
308    S1: ndarray::Data,
309    S2: ndarray::Data,
310    D1: Dimension,
311    D2: Dimension,
312    A: Into<String>,
313    B: Into<String>,
314{
315    let ashape = a.shape();
316    let bshape = b.shape();
317    if ashape != bshape {
318        return Err(CoreError::ShapeError(
319            ErrorContext::new(format!(
320                "{} and {} must have the same shape, got {:?} and {:?}",
321                a_name.into(),
322                b_name.into(),
323                ashape,
324                bshape
325            ))
326            .with_location(ErrorLocation::new(file!(), line!())),
327        ));
328    }
329    Ok(())
330}
331
332/// Checks if a matrix is square
333///
334/// # Arguments
335///
336/// * `matrix` - The matrix to check
337/// * `name` - The name of the matrix being checked
338///
339/// # Returns
340///
341/// * `Ok(())` if the matrix is square
342/// * `Err(CoreError::ShapeError)` if the matrix is not square
343///
344/// # Errors
345///
346/// Returns `CoreError::ShapeError` if the matrix is not square.
347pub fn check_square<S, D, A>(matrix: &ArrayBase<S, D>, name: A) -> CoreResult<()>
348where
349    S: ndarray::Data,
350    D: Dimension,
351    A: Into<String> + std::string::ToString,
352{
353    check_2d(matrix, name.to_string())?;
354    let shape = matrix.shape();
355    if shape[0] != shape[1] {
356        return Err(CoreError::ShapeError(
357            ErrorContext::new(format!(
358                "{} must be square, got shape {:?}",
359                name.into(),
360                shape
361            ))
362            .with_location(ErrorLocation::new(file!(), line!())),
363        ));
364    }
365    Ok(())
366}
367
368/// Checks if a probability value is valid (between 0 and 1, inclusive)
369///
370/// # Arguments
371///
372/// * `p` - The probability value to check
373/// * `name` - The name of the parameter being checked
374///
375/// # Returns
376///
377/// * `Ok(p)` if the probability is valid
378/// * `Err(CoreError::ValueError)` if the probability is not valid
379///
380/// # Errors
381///
382/// Returns `CoreError::ValueError` if the probability is not between 0 and 1.
383pub fn check_probability<T, S>(p: T, name: S) -> CoreResult<T>
384where
385    T: Float + std::fmt::Display + Copy,
386    S: Into<String>,
387{
388    if p < T::zero() || p > T::one() {
389        return Err(CoreError::ValueError(
390            ErrorContext::new(format!(
391                "{} must be between 0 and 1, got {}",
392                name.into(),
393                p
394            ))
395            .with_location(ErrorLocation::new(file!(), line!())),
396        ));
397    }
398    Ok(p)
399}
400
401/// Checks if an array contains only probabilities (between 0 and 1, inclusive)
402///
403/// # Arguments
404///
405/// * `probs` - The array of probabilities to check
406/// * `name` - The name of the array being checked
407///
408/// # Returns
409///
410/// * `Ok(())` if all values are valid probabilities
411/// * `Err(CoreError::ValueError)` if any value is not a valid probability
412///
413/// # Errors
414///
415/// Returns `CoreError::ValueError` if any value is not a valid probability.
416pub fn check_probabilities<S, D, A>(probs: &ArrayBase<S, D>, name: A) -> CoreResult<()>
417where
418    S: ndarray::Data,
419    D: Dimension,
420    S::Elem: Float + std::fmt::Display,
421    A: Into<String>,
422{
423    let name = name.into();
424    for (idx, &p) in probs.indexed_iter() {
425        if p < S::Elem::zero() || p > S::Elem::one() {
426            return Err(CoreError::ValueError(
427                ErrorContext::new(format!(
428                    "{name} must contain only values between 0 and 1, got {p} at {idx:?}"
429                ))
430                .with_location(ErrorLocation::new(file!(), line!())),
431            ));
432        }
433    }
434    Ok(())
435}
436
437/// Checks if probability values sum to 1
438///
439/// # Arguments
440///
441/// * `probs` - The array of probabilities to check
442/// * `name` - The name of the array being checked
443/// * `tol` - Tolerance for the sum (default: 1e-10)
444///
445/// # Returns
446///
447/// * `Ok(())` if the probabilities sum to 1 (within tolerance)
448/// * `Err(CoreError::ValueError)` if the sum is not 1 (within tolerance)
449pub fn check_probabilities_sum_to_one<S, D, A>(
450    probs: &ArrayBase<S, D>,
451    name: A,
452    tol: Option<S::Elem>,
453) -> CoreResult<()>
454where
455    S: ndarray::Data,
456    D: Dimension,
457    S::Elem: Float + std::fmt::Display + ScalarOperand,
458    A: Into<String> + std::string::ToString,
459{
460    let tol = tol.unwrap_or_else(|| {
461        let eps: f64 = 1e-10;
462        num_traits::cast(eps).unwrap_or_else(|| {
463            // Fallback to epsilon
464            S::Elem::epsilon()
465        })
466    });
467
468    check_probabilities(probs, name.to_string())?;
469
470    let sum = probs.sum();
471    let one = S::Elem::one();
472
473    if (sum - one).abs() > tol {
474        return Err(CoreError::ValueError(
475            ErrorContext::new({
476                let name_str = name.into();
477                format!("{name_str} must sum to 1, got sum = {sum}")
478            })
479            .with_location(ErrorLocation::new(file!(), line!())),
480        ));
481    }
482
483    Ok(())
484}
485
486/// Checks if an array is not empty
487///
488/// # Arguments
489///
490/// * `array` - The array to check
491/// * `name` - The name of the array being checked
492///
493/// # Returns
494///
495/// * `Ok(())` if the array is not empty
496/// * `Err(CoreError::ValueError)` if the array is empty
497pub fn check_not_empty<S, D, A>(array: &ArrayBase<S, D>, name: A) -> CoreResult<()>
498where
499    S: ndarray::Data,
500    D: Dimension,
501    A: Into<String>,
502{
503    if array.is_empty() {
504        return Err(CoreError::ValueError(
505            ErrorContext::new({
506                let name_str = name.into();
507                format!("{name_str} cannot be empty")
508            })
509            .with_location(ErrorLocation::new(file!(), line!())),
510        ));
511    }
512    Ok(())
513}
514
515/// Checks if an array has at least the minimum number of samples
516///
517/// # Arguments
518///
519/// * `array` - The array to check
520/// * `min_samples` - The minimum required number of samples
521/// * `name` - The name of the array being checked
522///
523/// # Returns
524///
525/// * `Ok(())` if the array has sufficient samples
526/// * `Err(CoreError::ValueError)` if the array has too few samples
527pub fn check_min_samples<S, D, A>(
528    array: &ArrayBase<S, D>,
529    min_samples: usize,
530    name: A,
531) -> CoreResult<()>
532where
533    S: ndarray::Data,
534    D: Dimension,
535    A: Into<String>,
536{
537    let n_samples = array.shape()[0];
538    if n_samples < min_samples {
539        return Err(CoreError::ValueError(
540            ErrorContext::new(format!(
541                "{} must have at least {} samples, got {}",
542                name.into(),
543                min_samples,
544                n_samples
545            ))
546            .with_location(ErrorLocation::new(file!(), line!())),
547        ));
548    }
549    Ok(())
550}
551
552/// Clustering-specific validation utilities
553pub mod clustering {
554    use super::*;
555
556    /// Validate number of clusters relative to data size
557    ///
558    /// # Arguments
559    ///
560    /// * `data` - Input data array
561    /// * `n_clusters` - Number of clusters
562    /// * `operation` - Name of the operation for error messages
563    ///
564    /// # Returns
565    ///
566    /// * `Ok(())` if n_clusters is valid
567    /// * `Err(CoreError::ValueError)` if n_clusters is invalid
568    pub fn check_n_clusters_bounds<S, D>(
569        data: &ArrayBase<S, D>,
570        n_clusters: usize,
571        operation: &str,
572    ) -> CoreResult<()>
573    where
574        S: ndarray::Data,
575        D: Dimension,
576    {
577        let n_samples = data.shape()[0];
578
579        if n_clusters == 0 {
580            return Err(CoreError::ValueError(
581                ErrorContext::new(format!(
582                    "{operation}: number of _clusters must be > 0, got {n_clusters}"
583                ))
584                .with_location(ErrorLocation::new(file!(), line!())),
585            ));
586        }
587
588        if n_clusters > n_samples {
589            return Err(CoreError::ValueError(
590                ErrorContext::new(format!(
591                    "{operation}: number of _clusters ({n_clusters}) cannot exceed number of samples ({n_samples})"
592                ))
593                .with_location(ErrorLocation::new(file!(), line!())),
594            ));
595        }
596
597        Ok(())
598    }
599
600    /// Comprehensive data validation for clustering algorithms
601    ///
602    /// # Arguments
603    ///
604    /// * `data` - Input data array
605    /// * `operation` - Name of the operation for error messages
606    /// * `check_finite` - Whether to check for finite values
607    /// * `min_samples` - Optional minimum number of samples required
608    ///
609    /// # Returns
610    ///
611    /// * `Ok(())` if data is valid
612    /// * `Err(CoreError)` if data validation fails
613    pub fn validate_clustering_data<S, D>(
614        data: &ArrayBase<S, D>,
615        _operation: &str,
616        check_finite: bool,
617        min_samples: Option<usize>,
618    ) -> CoreResult<()>
619    where
620        S: ndarray::Data,
621        D: Dimension,
622        S::Elem: Float + std::fmt::Display,
623    {
624        // Check not empty
625        check_not_empty(data, "data")?;
626
627        // Check 2D for most clustering algorithms
628        check_2d(data, "data")?;
629
630        // Check minimum _samples if specified
631        if let Some(min) = min_samples {
632            check_min_samples(data, min, "data")?;
633        }
634
635        // Check _finite if requested
636        if check_finite {
637            checkarray_finite(data, "data")?;
638        }
639
640        Ok(())
641    }
642}
643
644/// Parameter validation utilities
645pub mod parameters {
646    use super::*;
647
648    /// Validate algorithm iteration parameters
649    ///
650    /// # Arguments
651    ///
652    /// * `max_iter` - Maximum number of iterations
653    /// * `tolerance` - Convergence tolerance
654    /// * `operation` - Name of the operation for error messages
655    ///
656    /// # Returns
657    ///
658    /// * `Ok(())` if parameters are valid
659    /// * `Err(CoreError::ValueError)` if parameters are invalid
660    pub fn check_iteration_params<T>(
661        max_iter: usize,
662        tolerance: T,
663        operation: &str,
664    ) -> CoreResult<()>
665    where
666        T: Float + std::fmt::Display + Copy,
667    {
668        if max_iter == 0 {
669            return Err(CoreError::ValueError(
670                ErrorContext::new(format!("{operation}: max_iter must be > 0, got {max_iter}"))
671                    .with_location(ErrorLocation::new(file!(), line!())),
672            ));
673        }
674
675        check_positive(tolerance, format!("{operation} tolerance"))?;
676
677        Ok(())
678    }
679
680    /// Validate probability-like parameters (0 <= p <= 1)
681    ///
682    /// # Arguments
683    ///
684    /// * `value` - Value to check
685    /// * `name` - Parameter name for error messages
686    /// * `operation` - Operation name for error messages
687    ///
688    /// # Returns
689    ///
690    /// * `Ok(value)` if value is in [0, 1]
691    /// * `Err(CoreError::ValueError)` if value is out of range
692    pub fn check_unit_interval<T>(value: T, name: &str, operation: &str) -> CoreResult<T>
693    where
694        T: Float + std::fmt::Display + Copy,
695    {
696        if value < T::zero() || value > T::one() {
697            return Err(CoreError::ValueError(
698                ErrorContext::new(format!(
699                    "{operation}: {name} must be in [0, 1], got {value}"
700                ))
701                .with_location(ErrorLocation::new(file!(), line!())),
702            ));
703        }
704        Ok(value)
705    }
706
707    /// Validate bandwidth parameter for density-based clustering
708    ///
709    /// # Arguments
710    ///
711    /// * `bandwidth` - Bandwidth value
712    /// * `operation` - Operation name for error messages
713    ///
714    /// # Returns
715    ///
716    /// * `Ok(bandwidth)` if bandwidth is valid
717    /// * `Err(CoreError::ValueError)` if bandwidth is invalid
718    pub fn checkbandwidth<T>(bandwidth: T, operation: &str) -> CoreResult<T>
719    where
720        T: Float + std::fmt::Display + Copy,
721    {
722        check_positive(bandwidth, format!("{operation} bandwidth"))
723    }
724}
725
726#[cfg(test)]
727mod tests {
728    use super::*;
729    use ndarray::{arr1, arr2};
730
731    #[test]
732    fn test_check_in_bounds() {
733        assert!(check_in_bounds(5, 0, 10, "param").is_ok());
734        assert!(check_in_bounds(0, 0, 10, "param").is_ok());
735        assert!(check_in_bounds(10, 0, 10, "param").is_ok());
736        assert!(check_in_bounds(-1, 0, 10, "param").is_err());
737        assert!(check_in_bounds(11, 0, 10, "param").is_err());
738    }
739
740    #[test]
741    fn test_check_positive() {
742        assert!(check_positive(5, "param").is_ok());
743        assert!(check_positive(0.1, "param").is_ok());
744        assert!(check_positive(0, "param").is_err());
745        assert!(check_positive(-1, "param").is_err());
746    }
747
748    #[test]
749    fn test_check_non_negative() {
750        assert!(check_non_negative(5, "param").is_ok());
751        assert!(check_non_negative(0, "param").is_ok());
752        assert!(check_non_negative(-0.1, "param").is_err());
753        assert!(check_non_negative(-1, "param").is_err());
754    }
755
756    #[test]
757    fn test_check_finite() {
758        assert!(check_finite(5.0, "param").is_ok());
759        assert!(check_finite(0.0, "param").is_ok());
760        assert!(check_finite(-1.0, "param").is_ok());
761        assert!(check_finite(f64::INFINITY, "param").is_err());
762        assert!(check_finite(f64::NEG_INFINITY, "param").is_err());
763        assert!(check_finite(f64::NAN, "param").is_err());
764    }
765
766    #[test]
767    fn test_checkarray_finite() {
768        let a = arr1(&[1.0, 2.0, 3.0]);
769        assert!(checkarray_finite(&a, "array").is_ok());
770
771        let b = arr1(&[1.0, f64::INFINITY, 3.0]);
772        assert!(checkarray_finite(&b, "array").is_err());
773
774        let c = arr1(&[1.0, f64::NAN, 3.0]);
775        assert!(checkarray_finite(&c, "array").is_err());
776    }
777
778    #[test]
779    fn test_checkshape() {
780        let a = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
781        assert!(checkshape(&a, &[2, 2], "array").is_ok());
782        assert!(checkshape(&a, &[2, 3], "array").is_err());
783    }
784
785    #[test]
786    fn test_check_1d() {
787        let a = arr1(&[1.0, 2.0, 3.0]);
788        assert!(check_1d(&a, "array").is_ok());
789
790        let b = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
791        assert!(check_1d(&b, "array").is_err());
792    }
793
794    #[test]
795    fn test_check_2d() {
796        let a = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
797        assert!(check_2d(&a, "array").is_ok());
798
799        let b = arr1(&[1.0, 2.0, 3.0]);
800        assert!(check_2d(&b, "array").is_err());
801    }
802
803    #[test]
804    fn test_check_sameshape() {
805        let a = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
806        let b = arr2(&[[5.0, 6.0], [7.0, 8.0]]);
807        assert!(check_sameshape(&a, "a", &b, "b").is_ok());
808
809        let c = arr2(&[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
810        assert!(check_sameshape(&a, "a", &c, "c").is_err());
811    }
812
813    #[test]
814    fn test_check_square() {
815        let a = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
816        assert!(check_square(&a, "matrix").is_ok());
817
818        let b = arr2(&[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
819        assert!(check_square(&b, "matrix").is_err());
820
821        let c = arr1(&[1.0, 2.0, 3.0]);
822        assert!(check_square(&c, "matrix").is_err());
823    }
824
825    #[test]
826    fn test_check_probability() {
827        assert!(check_probability(0.0, "p").is_ok());
828        assert!(check_probability(0.5, "p").is_ok());
829        assert!(check_probability(1.0, "p").is_ok());
830        assert!(check_probability(-0.1, "p").is_err());
831        assert!(check_probability(1.1, "p").is_err());
832    }
833
834    #[test]
835    fn test_check_probabilities() {
836        let a = arr1(&[0.0, 0.5, 1.0]);
837        assert!(check_probabilities(&a, "probs").is_ok());
838
839        let b = arr1(&[0.0, 0.5, 1.1]);
840        assert!(check_probabilities(&b, "probs").is_err());
841
842        let c = arr1(&[-0.1, 0.5, 1.0]);
843        assert!(check_probabilities(&c, "probs").is_err());
844    }
845
846    #[test]
847    fn test_check_probabilities_sum_to_one() {
848        let a = arr1(&[0.3, 0.2, 0.5]);
849        assert!(check_probabilities_sum_to_one(&a, "probs", None).is_ok());
850
851        let b = arr1(&[0.3, 0.2, 0.6]);
852        assert!(check_probabilities_sum_to_one(&b, "probs", None).is_err());
853
854        // Test with custom tolerance
855        let c = arr1(&[0.3, 0.2, 0.501]);
856        assert!(check_probabilities_sum_to_one(&c, "probs", Some(0.01)).is_ok());
857        assert!(check_probabilities_sum_to_one(&c, "probs", Some(0.0001)).is_err());
858    }
859
860    #[test]
861    fn test_check_not_empty() {
862        let a = arr1(&[1.0, 2.0, 3.0]);
863        assert!(check_not_empty(&a, "array").is_ok());
864
865        let b = arr1(&[] as &[f64]);
866        assert!(check_not_empty(&b, "array").is_err());
867    }
868
869    #[test]
870    fn test_check_min_samples() {
871        let a = arr2(&[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]);
872        assert!(check_min_samples(&a, 2, "array").is_ok());
873        assert!(check_min_samples(&a, 3, "array").is_ok());
874        assert!(check_min_samples(&a, 4, "array").is_err());
875    }
876
877    mod clustering_tests {
878        use super::*;
879        use crate::validation::clustering::*;
880
881        #[test]
882        fn test_check_n_clusters_bounds() {
883            let data = arr2(&[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]);
884
885            assert!(check_n_clusters_bounds(&data, 1, "test").is_ok());
886            assert!(check_n_clusters_bounds(&data, 2, "test").is_ok());
887            assert!(check_n_clusters_bounds(&data, 3, "test").is_ok());
888            assert!(check_n_clusters_bounds(&data, 0, "test").is_err());
889            assert!(check_n_clusters_bounds(&data, 4, "test").is_err());
890        }
891
892        #[test]
893        fn test_validate_clustering_data() {
894            let data = arr2(&[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]);
895            assert!(validate_clustering_data(&data, "test", true, Some(2)).is_ok());
896            assert!(validate_clustering_data(&data, "test", true, Some(4)).is_err());
897
898            let empty_data = arr2(&[] as &[[f64; 2]; 0]);
899            assert!(validate_clustering_data(&empty_data, "test", true, None).is_err());
900
901            let inf_data = arr2(&[[1.0, f64::INFINITY], [3.0, 4.0]]);
902            assert!(validate_clustering_data(&inf_data, "test", true, None).is_err());
903            assert!(validate_clustering_data(&inf_data, "test", false, None).is_ok());
904        }
905    }
906
907    mod parameters_tests {
908        use crate::validation::parameters::*;
909
910        #[test]
911        fn test_check_iteration_params() {
912            assert!(check_iteration_params(100, 1e-6, "test").is_ok());
913            assert!(check_iteration_params(0, 1e-6, "test").is_err());
914            assert!(check_iteration_params(100, 0.0, "test").is_err());
915            assert!(check_iteration_params(100, -1e-6, "test").is_err());
916        }
917
918        #[test]
919        fn test_check_unit_interval() {
920            assert!(check_unit_interval(0.0, "param", "test").is_ok());
921            assert!(check_unit_interval(0.5, "param", "test").is_ok());
922            assert!(check_unit_interval(1.0, "param", "test").is_ok());
923            assert!(check_unit_interval(-0.1, "param", "test").is_err());
924            assert!(check_unit_interval(1.1, "param", "test").is_err());
925        }
926
927        #[test]
928        fn test_checkbandwidth() {
929            assert!(checkbandwidth(1.0, "test").is_ok());
930            assert!(checkbandwidth(0.1, "test").is_ok());
931            assert!(checkbandwidth(0.0, "test").is_err());
932            assert!(checkbandwidth(-1.0, "test").is_err());
933        }
934    }
935}
936
937/// Custom validator implementations for flexible validation logic
938pub mod custom {
939    use super::*;
940    use std::fmt;
941    use std::marker::PhantomData;
942
943    /// Trait for implementing custom validators
944    pub trait Validator<T> {
945        /// Validate the value and return Ok(()) if valid, or an error if invalid
946        fn validate(&self, value: &T, name: &str) -> CoreResult<()>;
947
948        /// Get a description of what this validator checks
949        fn description(&self) -> String;
950
951        /// Chain this validator with another validator
952        fn and<V: Validator<T>>(self, other: V) -> CompositeValidator<T, Self, V>
953        where
954            Self: Sized,
955        {
956            CompositeValidator::new(self, other)
957        }
958
959        /// Create a conditional validator that only applies when a condition is met
960        fn when<F>(self, condition: F) -> ConditionalValidator<T, Self, F>
961        where
962            Self: Sized,
963            F: Fn(&T) -> bool,
964        {
965            ConditionalValidator::new(self, condition)
966        }
967    }
968
969    /// A validator that combines two validators with AND logic
970    pub struct CompositeValidator<T, V1, V2> {
971        validator1: V1,
972        validator2: V2,
973        _phantom: PhantomData<T>,
974    }
975
976    impl<T, V1, V2> CompositeValidator<T, V1, V2> {
977        pub fn new(validator1: V1, validator2: V2) -> Self {
978            Self {
979                validator1,
980                validator2,
981                _phantom: PhantomData,
982            }
983        }
984    }
985
986    impl<T, V1, V2> Validator<T> for CompositeValidator<T, V1, V2>
987    where
988        V1: Validator<T>,
989        V2: Validator<T>,
990    {
991        fn validate(&self, value: &T, name: &str) -> CoreResult<()> {
992            self.validator1.validate(value, name)?;
993            self.validator2.validate(value, name)?;
994            Ok(())
995        }
996
997        fn description(&self) -> String {
998            format!(
999                "{} AND {}",
1000                self.validator1.description(),
1001                self.validator2.description()
1002            )
1003        }
1004    }
1005
1006    /// A validator that only applies when a condition is met
1007    pub struct ConditionalValidator<T, V, F> {
1008        validator: V,
1009        condition: F,
1010        phantom: PhantomData<T>,
1011    }
1012
1013    impl<T, V, F> ConditionalValidator<T, V, F> {
1014        pub fn new(validator: V, condition: F) -> Self {
1015            Self {
1016                validator,
1017                condition,
1018                phantom: PhantomData,
1019            }
1020        }
1021    }
1022
1023    impl<T, V, F> Validator<T> for ConditionalValidator<T, V, F>
1024    where
1025        V: Validator<T>,
1026        F: Fn(&T) -> bool,
1027    {
1028        fn validate(&self, value: &T, name: &str) -> CoreResult<()> {
1029            if (self.condition)(value) {
1030                self.validator.validate(value, name)
1031            } else {
1032                Ok(())
1033            }
1034        }
1035
1036        fn description(&self) -> String {
1037            {
1038                let desc = self.validator.description();
1039                format!("IF condition THEN {desc}")
1040            }
1041        }
1042    }
1043
1044    /// Custom validator for ranges with inclusive/exclusive bounds
1045    pub struct RangeValidator<T> {
1046        min: Option<T>,
1047        max: Option<T>,
1048        min_inclusive: bool,
1049        max_inclusive: bool,
1050    }
1051
1052    impl<T> RangeValidator<T>
1053    where
1054        T: PartialOrd + Copy + fmt::Display,
1055    {
1056        pub fn new() -> Self {
1057            Self {
1058                min: None,
1059                max: None,
1060                min_inclusive: true,
1061                max_inclusive: true,
1062            }
1063        }
1064
1065        pub fn min(mut self, min: T) -> Self {
1066            self.min = Some(min);
1067            self
1068        }
1069
1070        pub fn max(mut self, max: T) -> Self {
1071            self.max = Some(max);
1072            self
1073        }
1074
1075        pub fn min_exclusive(mut self, min: T) -> Self {
1076            self.min = Some(min);
1077            self.min_inclusive = false;
1078            self
1079        }
1080
1081        pub fn max_exclusive(mut self, max: T) -> Self {
1082            self.max = Some(max);
1083            self.max_inclusive = false;
1084            self
1085        }
1086
1087        pub fn in_range(min: T, max: T) -> Self {
1088            Self::new().min(min).max(max)
1089        }
1090
1091        pub fn in_range_exclusive(min: T, max: T) -> Self {
1092            Self::new().min_exclusive(min).max_exclusive(max)
1093        }
1094    }
1095
1096    impl<T> Default for RangeValidator<T>
1097    where
1098        T: PartialOrd + Copy + fmt::Display,
1099    {
1100        fn default() -> Self {
1101            Self::new()
1102        }
1103    }
1104
1105    impl<T> Validator<T> for RangeValidator<T>
1106    where
1107        T: PartialOrd + Copy + fmt::Display,
1108    {
1109        fn validate(&self, value: &T, name: &str) -> CoreResult<()> {
1110            if let Some(min) = self.min {
1111                let valid = if self.min_inclusive {
1112                    *value >= min
1113                } else {
1114                    *value > min
1115                };
1116                if !valid {
1117                    let op = if self.min_inclusive { ">=" } else { ">" };
1118                    return Err(CoreError::ValueError(
1119                        ErrorContext::new(format!("{name} must be {op} {min}, got {value}"))
1120                            .with_location(ErrorLocation::new(file!(), line!())),
1121                    ));
1122                }
1123            }
1124
1125            if let Some(max) = self.max {
1126                let valid = if self.max_inclusive {
1127                    *value <= max
1128                } else {
1129                    *value < max
1130                };
1131                if !valid {
1132                    let op = if self.max_inclusive { "<=" } else { "<" };
1133                    return Err(CoreError::ValueError(
1134                        ErrorContext::new(format!("{name} must be {op} {max}, got {value}"))
1135                            .with_location(ErrorLocation::new(file!(), line!())),
1136                    ));
1137                }
1138            }
1139
1140            Ok(())
1141        }
1142
1143        fn description(&self) -> String {
1144            match (self.min, self.max) {
1145                (Some(min), Some(max)) => {
1146                    let min_op = if self.min_inclusive { ">=" } else { ">" };
1147                    let max_op = if self.max_inclusive { "<=" } else { "<" };
1148                    format!("value {min_op} {min} and {max_op} {max}")
1149                }
1150                (Some(min), None) => {
1151                    let op = if self.min_inclusive { ">=" } else { ">" };
1152                    format!("value {op} {min}")
1153                }
1154                (None, Some(max)) => {
1155                    let op = if self.max_inclusive { "<=" } else { "<" };
1156                    format!("value {op} {max}")
1157                }
1158                (None, None) => "no range constraints".to_string(),
1159            }
1160        }
1161    }
1162
1163    /// Type alias for shape validation function to reduce complexity
1164    type ShapeValidatorFn = Box<dyn Fn(&[usize]) -> CoreResult<()>>;
1165
1166    /// Custom validator for array properties
1167    pub struct ArrayValidator<T, D>
1168    where
1169        D: Dimension,
1170    {
1171        shape_validator: Option<ShapeValidatorFn>,
1172        element_validator: Option<Box<dyn Validator<T>>>,
1173        size_validator: Option<RangeValidator<usize>>,
1174        phantom: PhantomData<D>,
1175    }
1176
1177    impl<T, D> ArrayValidator<T, D>
1178    where
1179        D: Dimension,
1180    {
1181        pub fn new() -> Self {
1182            Self {
1183                shape_validator: None,
1184                element_validator: None,
1185                size_validator: None,
1186                phantom: PhantomData,
1187            }
1188        }
1189
1190        pub fn withshape<F>(mut self, validator: F) -> Self
1191        where
1192            F: Fn(&[usize]) -> CoreResult<()> + 'static,
1193        {
1194            self.shape_validator = Some(Box::new(validator));
1195            self
1196        }
1197
1198        pub fn with_elements<V>(mut self, validator: V) -> Self
1199        where
1200            V: Validator<T> + 'static,
1201        {
1202            self.element_validator = Some(Box::new(validator));
1203            self
1204        }
1205
1206        pub fn with_size(mut self, validator: RangeValidator<usize>) -> Self {
1207            self.size_validator = Some(validator);
1208            self
1209        }
1210
1211        pub fn minsize(self, minsize: usize) -> Self {
1212            self.with_size(RangeValidator::new().min(minsize))
1213        }
1214
1215        pub fn maxsize(self, maxsize: usize) -> Self {
1216            self.with_size(RangeValidator::new().max(maxsize))
1217        }
1218
1219        pub fn exact_size(self, size: usize) -> Self {
1220            self.with_size(RangeValidator::new().min(size).max(size))
1221        }
1222    }
1223
1224    impl<T, D> Default for ArrayValidator<T, D>
1225    where
1226        D: Dimension,
1227    {
1228        fn default() -> Self {
1229            Self::new()
1230        }
1231    }
1232
1233    impl<S, T, D> Validator<ArrayBase<S, D>> for ArrayValidator<T, D>
1234    where
1235        S: ndarray::Data<Elem = T>,
1236        T: Clone,
1237        D: Dimension,
1238    {
1239        fn validate(&self, array: &ArrayBase<S, D>, name: &str) -> CoreResult<()> {
1240            // Validate shape
1241            if let Some(ref shape_validator) = self.shape_validator {
1242                shape_validator(array.shape())?;
1243            }
1244
1245            // Validate size
1246            if let Some(ref size_validator) = self.size_validator {
1247                size_validator.validate(&array.len(), &format!("{name} size"))?;
1248            }
1249
1250            // Validate elements
1251            if let Some(ref element_validator) = self.element_validator {
1252                for (idx, element) in array.indexed_iter() {
1253                    element_validator.validate(element, &format!("{name} element at {idx:?}"))?;
1254                }
1255            }
1256
1257            Ok(())
1258        }
1259
1260        fn description(&self) -> String {
1261            let mut parts = Vec::new();
1262
1263            if self.shape_validator.is_some() {
1264                parts.push("shape validation".to_string());
1265            }
1266
1267            if let Some(ref size_validator) = self.size_validator {
1268                {
1269                    let desc = size_validator.description();
1270                    parts.push(format!("size {desc}"));
1271                }
1272            }
1273
1274            if let Some(ref element_validator) = self.element_validator {
1275                {
1276                    let desc = element_validator.description();
1277                    parts.push(format!("elements {desc}"));
1278                }
1279            }
1280
1281            if parts.is_empty() {
1282                "no array constraints".to_string()
1283            } else {
1284                parts.join(" AND ")
1285            }
1286        }
1287    }
1288
1289    /// Custom validator for function-based validation
1290    pub struct FunctionValidator<T, F> {
1291        func: F,
1292        description: String,
1293        phantom: PhantomData<T>,
1294    }
1295
1296    impl<T, F> FunctionValidator<T, F>
1297    where
1298        F: Fn(&T, &str) -> CoreResult<()>,
1299    {
1300        pub fn new(func: F, description: impl Into<String>) -> Self {
1301            Self {
1302                func,
1303                description: description.into(),
1304                phantom: PhantomData,
1305            }
1306        }
1307    }
1308
1309    impl<T, F> Validator<T> for FunctionValidator<T, F>
1310    where
1311        F: Fn(&T, &str) -> CoreResult<()>,
1312    {
1313        fn validate(&self, value: &T, name: &str) -> CoreResult<()> {
1314            (self.func)(value, name)
1315        }
1316
1317        fn description(&self) -> String {
1318            self.description.clone()
1319        }
1320    }
1321
1322    /// Builder for creating complex validators
1323    pub struct ValidatorBuilder<T> {
1324        validators: Vec<Box<dyn Validator<T>>>,
1325    }
1326
1327    impl<T: 'static> ValidatorBuilder<T> {
1328        pub fn new() -> Self {
1329            Self {
1330                validators: Vec::new(),
1331            }
1332        }
1333
1334        pub fn with_validator<V: Validator<T> + 'static>(mut self, validator: V) -> Self {
1335            self.validators.push(Box::new(validator));
1336            self
1337        }
1338
1339        pub fn with_function<F>(self, func: F, description: impl Into<String>) -> Self
1340        where
1341            F: Fn(&T, &str) -> CoreResult<()> + 'static,
1342        {
1343            self.with_validator(FunctionValidator::new(func, description))
1344        }
1345
1346        pub fn build(self) -> MultiValidator<T> {
1347            MultiValidator {
1348                validators: self.validators,
1349            }
1350        }
1351    }
1352
1353    impl<T: 'static> Default for ValidatorBuilder<T> {
1354        fn default() -> Self {
1355            Self::new()
1356        }
1357    }
1358
1359    /// Validator that runs multiple validators
1360    pub struct MultiValidator<T> {
1361        validators: Vec<Box<dyn Validator<T>>>,
1362    }
1363
1364    impl<T: 'static> Validator<T> for MultiValidator<T> {
1365        fn validate(&self, value: &T, name: &str) -> CoreResult<()> {
1366            for validator in &self.validators {
1367                validator.validate(value, name)?;
1368            }
1369            Ok(())
1370        }
1371
1372        fn description(&self) -> String {
1373            if self.validators.is_empty() {
1374                "no validators".to_string()
1375            } else {
1376                self.validators
1377                    .iter()
1378                    .map(|v| v.description())
1379                    .collect::<Vec<_>>()
1380                    .join(" AND ")
1381            }
1382        }
1383    }
1384
1385    /// Convenience function to validate with a custom validator
1386    pub fn validate_with<T, V: Validator<T>>(
1387        value: &T,
1388        validator: &V,
1389        name: impl Into<String>,
1390    ) -> CoreResult<()> {
1391        validator.validate(value, &name.into())
1392    }
1393
1394    #[cfg(test)]
1395    mod tests {
1396        use super::*;
1397        use ndarray::arr1;
1398
1399        #[test]
1400        fn test_range_validator() {
1401            let validator = RangeValidator::in_range(0.0, 1.0);
1402
1403            assert!(validator.validate(&0.5, "value").is_ok());
1404            assert!(validator.validate(&0.0, "value").is_ok());
1405            assert!(validator.validate(&1.0, "value").is_ok());
1406            assert!(validator.validate(&-0.1, "value").is_err());
1407            assert!(validator.validate(&1.1, "value").is_err());
1408        }
1409
1410        #[test]
1411        fn test_range_validator_exclusive() {
1412            let validator = RangeValidator::in_range_exclusive(0.0, 1.0);
1413
1414            assert!(validator.validate(&0.5, "value").is_ok());
1415            assert!(validator.validate(&0.0, "value").is_err());
1416            assert!(validator.validate(&1.0, "value").is_err());
1417        }
1418
1419        #[test]
1420        fn test_composite_validator() {
1421            let positive = RangeValidator::new().min(0.0);
1422            let max_one = RangeValidator::new().max(1.0);
1423            let validator = positive.and(max_one);
1424
1425            assert!(validator.validate(&0.5, "value").is_ok());
1426            assert!(validator.validate(&-0.1, "value").is_err());
1427            assert!(validator.validate(&1.1, "value").is_err());
1428        }
1429
1430        #[test]
1431        fn test_conditional_validator() {
1432            let validator = RangeValidator::new().min(0.0).when(|x: &f64| *x > 0.0);
1433
1434            assert!(validator.validate(&0.5, "value").is_ok());
1435            assert!(validator.validate(&-0.5, "value").is_ok()); // Condition not met
1436            assert!(validator.validate(&0.0, "value").is_ok()); // Condition not met
1437        }
1438
1439        #[test]
1440        fn testarray_validator() {
1441            let element_validator = RangeValidator::in_range(0.0, 1.0);
1442            let array_validator = ArrayValidator::new()
1443                .with_elements(element_validator)
1444                .minsize(2);
1445
1446            let validarray = arr1(&[0.2, 0.8]);
1447            assert!(array_validator.validate(&validarray, "array").is_ok());
1448
1449            let invalidarray = arr1(&[0.2, 1.5]);
1450            assert!(array_validator.validate(&invalidarray, "array").is_err());
1451
1452            let too_smallarray = arr1(&[0.5]);
1453            assert!(array_validator.validate(&too_smallarray, "array").is_err());
1454        }
1455
1456        #[test]
1457        fn test_function_validator() {
1458            let validator = FunctionValidator::new(
1459                |value: &i32, name: &str| {
1460                    if *value % 2 == 0 {
1461                        Ok(())
1462                    } else {
1463                        Err(CoreError::ValueError(
1464                            ErrorContext::new(format!("{name} must be even, got {value}"))
1465                                .with_location(ErrorLocation::new(file!(), line!())),
1466                        ))
1467                    }
1468                },
1469                "value must be even",
1470            );
1471
1472            assert!(validator.validate(&4, "number").is_ok());
1473            assert!(validator.validate(&3, "number").is_err());
1474        }
1475
1476        #[test]
1477        fn test_validator_builder() {
1478            let validator = ValidatorBuilder::new()
1479                .with_validator(RangeValidator::new().min(0.0))
1480                .with_validator(RangeValidator::new().max(1.0))
1481                .with_function(
1482                    |value: &f64, name: &str| {
1483                        if *value != 0.5 {
1484                            Ok(())
1485                        } else {
1486                            Err(CoreError::ValueError(
1487                                ErrorContext::new(format!("{name} cannot be 0.5"))
1488                                    .with_location(ErrorLocation::new(file!(), line!())),
1489                            ))
1490                        }
1491                    },
1492                    "value cannot be 0.5",
1493                )
1494                .build();
1495
1496            assert!(validator.validate(&0.3, "value").is_ok());
1497            assert!(validator.validate(&0.5, "value").is_err());
1498            assert!(validator.validate(&-0.1, "value").is_err());
1499            assert!(validator.validate(&1.1, "value").is_err());
1500        }
1501    }
1502}
1503
1504// Production-level validation with comprehensive security and performance features
1505pub mod production;
1506
1507/// Cross-platform validation utilities for consistent behavior across operating systems and architectures
1508pub mod cross_platform;
1509
1510/// Comprehensive data validation system with schema validation and constraint enforcement
1511#[cfg(feature = "data_validation")]
1512pub mod data;