scirs2_core/
validation.rs

1//! Validation utilities for ``SciRS2``
2//!
3//! This module provides utilities for validating data and parameters, including
4//! production-level security hardening and comprehensive input validation.
5
6use ::ndarray::{ArrayBase, Dimension, ScalarOperand};
7use num_traits::{Float, One, Zero};
8
9use crate::error::{CoreError, CoreResult, ErrorContext, ErrorLocation};
10
11/// Checks if a value is within bounds (inclusive)
12///
13/// # Arguments
14///
15/// * `value` - The value to check
16/// * `min` - The minimum allowed value (inclusive)
17/// * `max` - The maximum allowed value (inclusive)
18/// * `name` - The name of the parameter being checked
19///
20/// # Returns
21///
22/// * `Ok(value)` if the value is within bounds
23/// * `Err(CoreError::ValueError)` if the value is out of bounds
24///
25/// # Errors
26///
27/// Returns `CoreError::ValueError` if the value is outside the specified bounds.
28pub fn check_in_bounds<T, S>(value: T, min: T, max: T, name: S) -> CoreResult<T>
29where
30    T: PartialOrd + std::fmt::Display + Copy,
31    S: Into<String>,
32{
33    if value < min || value > max {
34        return Err(CoreError::ValueError(
35            ErrorContext::new(format!(
36                "{} must be between {min} and {max}, got {value}",
37                name.into()
38            ))
39            .with_location(ErrorLocation::new(file!(), line!())),
40        ));
41    }
42    Ok(value)
43}
44
45/// Checks if a value is positive
46///
47/// # Arguments
48///
49/// * `value` - The value to check
50/// * `name` - The name of the parameter being checked
51///
52/// # Returns
53///
54/// * `Ok(value)` if the value is positive
55/// * `Err(CoreError::ValueError)` if the value is not positive
56///
57/// # Errors
58///
59/// Returns `CoreError::ValueError` if the value is not positive.
60pub fn check_positive<T, S>(value: T, name: S) -> CoreResult<T>
61where
62    T: PartialOrd + std::fmt::Display + Copy + Zero,
63    S: Into<String>,
64{
65    if value <= T::zero() {
66        return Err(CoreError::ValueError(
67            ErrorContext::new({
68                let name_str = name.into();
69                format!("{name_str} must be positive, got {value}")
70            })
71            .with_location(ErrorLocation::new(file!(), line!())),
72        ));
73    }
74    Ok(value)
75}
76
77/// Checks if a value is non-negative
78///
79/// # Arguments
80///
81/// * `value` - The value to check
82/// * `name` - The name of the parameter being checked
83///
84/// # Returns
85///
86/// * `Ok(value)` if the value is non-negative
87/// * `Err(CoreError::ValueError)` if the value is negative
88///
89/// # Errors
90///
91/// Returns `CoreError::ValueError` if the value is negative.
92pub fn check_non_negative<T, S>(value: T, name: S) -> CoreResult<T>
93where
94    T: PartialOrd + std::fmt::Display + Copy + Zero,
95    S: Into<String>,
96{
97    if value < T::zero() {
98        return Err(CoreError::ValueError(
99            ErrorContext::new({
100                let name_str = name.into();
101                format!("{name_str} must be non-negative, got {value}")
102            })
103            .with_location(ErrorLocation::new(file!(), line!())),
104        ));
105    }
106    Ok(value)
107}
108
109/// Checks if a floating-point value is finite
110///
111/// # Arguments
112///
113/// * `value` - The value to check
114/// * `name` - The name of the parameter being checked
115///
116/// # Returns
117///
118/// * `Ok(value)` if the value is finite
119/// * `Err(CoreError::ValueError)` if the value is not finite
120///
121/// # Errors
122///
123/// Returns `CoreError::ValueError` if the value is not finite.
124pub fn check_finite<T, S>(value: T, name: S) -> CoreResult<T>
125where
126    T: Float + std::fmt::Display + Copy,
127    S: Into<String>,
128{
129    if !value.is_finite() {
130        return Err(CoreError::ValueError(
131            ErrorContext::new({
132                let name_str = name.into();
133                format!("{name_str} must be finite, got {value}")
134            })
135            .with_location(ErrorLocation::new(file!(), line!())),
136        ));
137    }
138    Ok(value)
139}
140
141/// Checks if all values in an array are finite
142///
143/// # Arguments
144///
145/// * `array` - The array to check
146/// * `name` - The name of the array being checked
147///
148/// # Returns
149///
150/// * `Ok(())` if all values are finite
151/// * `Err(CoreError::ValueError)` if any value is not finite
152///
153/// # Errors
154///
155/// Returns `CoreError::ValueError` if any value in the array is not finite.
156pub fn checkarray_finite<S, A, D>(array: &ArrayBase<S, D>, name: A) -> CoreResult<()>
157where
158    S: crate::ndarray::Data,
159    D: Dimension,
160    S::Elem: Float + std::fmt::Display,
161    A: Into<String>,
162{
163    let name = name.into();
164    for (idx, &value) in array.indexed_iter() {
165        if !value.is_finite() {
166            return Err(CoreError::ValueError(
167                ErrorContext::new(format!(
168                    "{name} must contain only finite values, got {value} at {idx:?}"
169                ))
170                .with_location(ErrorLocation::new(file!(), line!())),
171            ));
172        }
173    }
174    Ok(())
175}
176
177/// Checks if an array has the expected shape
178///
179/// # Arguments
180///
181/// * `array` - The array to check
182/// * `expectedshape` - The expected shape
183/// * `name` - The name of the array being checked
184///
185/// # Returns
186///
187/// * `Ok(())` if the array has the expected shape
188/// * `Err(CoreError::ShapeError)` if the array does not have the expected shape
189///
190/// # Errors
191///
192/// Returns `CoreError::ShapeError` if the array does not have the expected shape.
193pub fn checkshape<S, D, A>(
194    array: &ArrayBase<S, D>,
195    expectedshape: &[usize],
196    name: A,
197) -> CoreResult<()>
198where
199    S: crate::ndarray::Data,
200    D: Dimension,
201    A: Into<String>,
202{
203    let actualshape = array.shape();
204    if actualshape != expectedshape {
205        return Err(CoreError::ShapeError(
206            ErrorContext::new(format!(
207                "{} has incorrect shape: expected {expectedshape:?}, got {actualshape:?}",
208                name.into()
209            ))
210            .with_location(ErrorLocation::new(file!(), line!())),
211        ));
212    }
213    Ok(())
214}
215
216/// Checks if an array is 1D
217///
218/// # Arguments
219///
220/// * `array` - The array to check
221/// * `name` - The name of the array being checked
222///
223/// # Returns
224///
225/// * `Ok(())` if the array is 1D
226/// * `Err(CoreError::ShapeError)` if the array is not 1D
227///
228/// # Errors
229///
230/// Returns `CoreError::ShapeError` if the array is not 1D.
231pub fn check_1d<S, D, A>(array: &ArrayBase<S, D>, name: A) -> CoreResult<()>
232where
233    S: crate::ndarray::Data,
234    D: Dimension,
235    A: Into<String>,
236{
237    if array.ndim() != 1 {
238        return Err(CoreError::ShapeError(
239            ErrorContext::new({
240                let name_str = name.into();
241                let ndim = array.ndim();
242                format!("{name_str} must be 1D, got {ndim}D")
243            })
244            .with_location(ErrorLocation::new(file!(), line!())),
245        ));
246    }
247    Ok(())
248}
249
250/// Checks if an array is 2D
251///
252/// # Arguments
253///
254/// * `array` - The array to check
255/// * `name` - The name of the array being checked
256///
257/// # Returns
258///
259/// * `Ok(())` if the array is 2D
260/// * `Err(CoreError::ShapeError)` if the array is not 2D
261///
262/// # Errors
263///
264/// Returns `CoreError::ShapeError` if the array is not 2D.
265pub fn check_2d<S, D, A>(array: &ArrayBase<S, D>, name: A) -> CoreResult<()>
266where
267    S: crate::ndarray::Data,
268    D: Dimension,
269    A: Into<String>,
270{
271    if array.ndim() != 2 {
272        return Err(CoreError::ShapeError(
273            ErrorContext::new({
274                let name_str = name.into();
275                let ndim = array.ndim();
276                format!("{name_str} must be 2D, got {ndim}D")
277            })
278            .with_location(ErrorLocation::new(file!(), line!())),
279        ));
280    }
281    Ok(())
282}
283
284/// Checks if two arrays have the same shape
285///
286/// # Arguments
287///
288/// * `a` - The first array
289/// * `a_name` - The name of the first array
290/// * `b` - The second array
291/// * `b_name` - The name of the second array
292///
293/// # Returns
294///
295/// * `Ok(())` if the arrays have the same shape
296/// * `Err(CoreError::ShapeError)` if the arrays have different shapes
297///
298/// # Errors
299///
300/// Returns `CoreError::ShapeError` if the arrays have different shapes.
301pub fn check_sameshape<S1, S2, D1, D2, A, B>(
302    a: &ArrayBase<S1, D1>,
303    a_name: A,
304    b: &ArrayBase<S2, D2>,
305    b_name: B,
306) -> CoreResult<()>
307where
308    S1: crate::ndarray::Data,
309    S2: crate::ndarray::Data,
310    D1: Dimension,
311    D2: Dimension,
312    A: Into<String>,
313    B: Into<String>,
314{
315    let ashape = a.shape();
316    let bshape = b.shape();
317    if ashape != bshape {
318        return Err(CoreError::ShapeError(
319            ErrorContext::new(format!(
320                "{} and {} must have the same shape, got {:?} and {:?}",
321                a_name.into(),
322                b_name.into(),
323                ashape,
324                bshape
325            ))
326            .with_location(ErrorLocation::new(file!(), line!())),
327        ));
328    }
329    Ok(())
330}
331
332/// Checks if a matrix is square
333///
334/// # Arguments
335///
336/// * `matrix` - The matrix to check
337/// * `name` - The name of the matrix being checked
338///
339/// # Returns
340///
341/// * `Ok(())` if the matrix is square
342/// * `Err(CoreError::ShapeError)` if the matrix is not square
343///
344/// # Errors
345///
346/// Returns `CoreError::ShapeError` if the matrix is not square.
347pub fn check_square<S, D, A>(matrix: &ArrayBase<S, D>, name: A) -> CoreResult<()>
348where
349    S: crate::ndarray::Data,
350    D: Dimension,
351    A: Into<String> + std::string::ToString,
352{
353    check_2d(matrix, name.to_string())?;
354    let shape = matrix.shape();
355    if shape[0] != shape[1] {
356        return Err(CoreError::ShapeError(
357            ErrorContext::new(format!(
358                "{} must be square, got shape {:?}",
359                name.into(),
360                shape
361            ))
362            .with_location(ErrorLocation::new(file!(), line!())),
363        ));
364    }
365    Ok(())
366}
367
368/// Checks if a probability value is valid (between 0 and 1, inclusive)
369///
370/// # Arguments
371///
372/// * `p` - The probability value to check
373/// * `name` - The name of the parameter being checked
374///
375/// # Returns
376///
377/// * `Ok(p)` if the probability is valid
378/// * `Err(CoreError::ValueError)` if the probability is not valid
379///
380/// # Errors
381///
382/// Returns `CoreError::ValueError` if the probability is not between 0 and 1.
383pub fn check_probability<T, S>(p: T, name: S) -> CoreResult<T>
384where
385    T: Float + std::fmt::Display + Copy,
386    S: Into<String>,
387{
388    if p < T::zero() || p > T::one() {
389        return Err(CoreError::ValueError(
390            ErrorContext::new(format!(
391                "{} must be between 0 and 1, got {}",
392                name.into(),
393                p
394            ))
395            .with_location(ErrorLocation::new(file!(), line!())),
396        ));
397    }
398    Ok(p)
399}
400
401/// Checks if an array contains only probabilities (between 0 and 1, inclusive)
402///
403/// # Arguments
404///
405/// * `probs` - The array of probabilities to check
406/// * `name` - The name of the array being checked
407///
408/// # Returns
409///
410/// * `Ok(())` if all values are valid probabilities
411/// * `Err(CoreError::ValueError)` if any value is not a valid probability
412///
413/// # Errors
414///
415/// Returns `CoreError::ValueError` if any value is not a valid probability.
416pub fn check_probabilities<S, D, A>(probs: &ArrayBase<S, D>, name: A) -> CoreResult<()>
417where
418    S: crate::ndarray::Data,
419    D: Dimension,
420    S::Elem: Float + std::fmt::Display,
421    A: Into<String>,
422{
423    let name = name.into();
424    for (idx, &p) in probs.indexed_iter() {
425        if p < S::Elem::zero() || p > S::Elem::one() {
426            return Err(CoreError::ValueError(
427                ErrorContext::new(format!(
428                    "{name} must contain only values between 0 and 1, got {p} at {idx:?}"
429                ))
430                .with_location(ErrorLocation::new(file!(), line!())),
431            ));
432        }
433    }
434    Ok(())
435}
436
437/// Checks if probability values sum to 1
438///
439/// # Arguments
440///
441/// * `probs` - The array of probabilities to check
442/// * `name` - The name of the array being checked
443/// * `tol` - Tolerance for the sum (default: 1e-10)
444///
445/// # Returns
446///
447/// * `Ok(())` if the probabilities sum to 1 (within tolerance)
448/// * `Err(CoreError::ValueError)` if the sum is not 1 (within tolerance)
449pub fn check_probabilities_sum_to_one<S, D, A>(
450    probs: &ArrayBase<S, D>,
451    name: A,
452    tol: Option<S::Elem>,
453) -> CoreResult<()>
454where
455    S: crate::ndarray::Data,
456    S::Elem: Float,
457    D: Dimension,
458    S::Elem: Float + std::fmt::Display + ScalarOperand,
459    A: Into<String> + std::string::ToString,
460{
461    let tol = tol.unwrap_or_else(|| {
462        let eps: f64 = 1e-10;
463        num_traits::cast(eps).unwrap_or_else(|| {
464            // Fallback to epsilon
465            S::Elem::epsilon()
466        })
467    });
468
469    check_probabilities(probs, name.to_string())?;
470
471    let sum = probs.sum();
472    let one = S::Elem::one();
473
474    if (sum - one).abs() > tol {
475        return Err(CoreError::ValueError(
476            ErrorContext::new({
477                let name_str = name.into();
478                format!("{name_str} must sum to 1, got sum = {sum}")
479            })
480            .with_location(ErrorLocation::new(file!(), line!())),
481        ));
482    }
483
484    Ok(())
485}
486
487/// Checks if an array is not empty
488///
489/// # Arguments
490///
491/// * `array` - The array to check
492/// * `name` - The name of the array being checked
493///
494/// # Returns
495///
496/// * `Ok(())` if the array is not empty
497/// * `Err(CoreError::ValueError)` if the array is empty
498pub fn check_not_empty<S, D, A>(array: &ArrayBase<S, D>, name: A) -> CoreResult<()>
499where
500    S: crate::ndarray::Data,
501    D: Dimension,
502    A: Into<String>,
503{
504    if array.is_empty() {
505        return Err(CoreError::ValueError(
506            ErrorContext::new({
507                let name_str = name.into();
508                format!("{name_str} cannot be empty")
509            })
510            .with_location(ErrorLocation::new(file!(), line!())),
511        ));
512    }
513    Ok(())
514}
515
516/// Checks if an array has at least the minimum number of samples
517///
518/// # Arguments
519///
520/// * `array` - The array to check
521/// * `min_samples` - The minimum required number of samples
522/// * `name` - The name of the array being checked
523///
524/// # Returns
525///
526/// * `Ok(())` if the array has sufficient samples
527/// * `Err(CoreError::ValueError)` if the array has too few samples
528pub fn check_min_samples<S, D, A>(
529    array: &ArrayBase<S, D>,
530    min_samples: usize,
531    name: A,
532) -> CoreResult<()>
533where
534    S: crate::ndarray::Data,
535    D: Dimension,
536    A: Into<String>,
537{
538    let n_samples = array.shape()[0];
539    if n_samples < min_samples {
540        return Err(CoreError::ValueError(
541            ErrorContext::new(format!(
542                "{} must have at least {} samples, got {}",
543                name.into(),
544                min_samples,
545                n_samples
546            ))
547            .with_location(ErrorLocation::new(file!(), line!())),
548        ));
549    }
550    Ok(())
551}
552
553/// Clustering-specific validation utilities
554pub mod clustering {
555    use super::*;
556
557    /// Validate number of clusters relative to data size
558    ///
559    /// # Arguments
560    ///
561    /// * `data` - Input data array
562    /// * `n_clusters` - Number of clusters
563    /// * `operation` - Name of the operation for error messages
564    ///
565    /// # Returns
566    ///
567    /// * `Ok(())` if n_clusters is valid
568    /// * `Err(CoreError::ValueError)` if n_clusters is invalid
569    pub fn check_n_clusters_bounds<S, D>(
570        data: &ArrayBase<S, D>,
571        n_clusters: usize,
572        operation: &str,
573    ) -> CoreResult<()>
574    where
575        S: crate::ndarray::Data,
576        D: Dimension,
577    {
578        let n_samples = data.shape()[0];
579
580        if n_clusters == 0 {
581            return Err(CoreError::ValueError(
582                ErrorContext::new(format!(
583                    "{operation}: number of _clusters must be > 0, got {n_clusters}"
584                ))
585                .with_location(ErrorLocation::new(file!(), line!())),
586            ));
587        }
588
589        if n_clusters > n_samples {
590            return Err(CoreError::ValueError(
591                ErrorContext::new(format!(
592                    "{operation}: number of _clusters ({n_clusters}) cannot exceed number of samples ({n_samples})"
593                ))
594                .with_location(ErrorLocation::new(file!(), line!())),
595            ));
596        }
597
598        Ok(())
599    }
600
601    /// Comprehensive data validation for clustering algorithms
602    ///
603    /// # Arguments
604    ///
605    /// * `data` - Input data array
606    /// * `operation` - Name of the operation for error messages
607    /// * `check_finite` - Whether to check for finite values
608    /// * `min_samples` - Optional minimum number of samples required
609    ///
610    /// # Returns
611    ///
612    /// * `Ok(())` if data is valid
613    /// * `Err(CoreError)` if data validation fails
614    pub fn validate_clustering_data<S, D>(
615        data: &ArrayBase<S, D>,
616        _operation: &str,
617        check_finite: bool,
618        min_samples: Option<usize>,
619    ) -> CoreResult<()>
620    where
621        S: crate::ndarray::Data,
622        D: Dimension,
623        S::Elem: Float + std::fmt::Display,
624    {
625        // Check not empty
626        check_not_empty(data, "data")?;
627
628        // Check 2D for most clustering algorithms
629        check_2d(data, "data")?;
630
631        // Check minimum _samples if specified
632        if let Some(min) = min_samples {
633            check_min_samples(data, min, "data")?;
634        }
635
636        // Check _finite if requested
637        if check_finite {
638            checkarray_finite(data, "data")?;
639        }
640
641        Ok(())
642    }
643}
644
645/// Parameter validation utilities
646pub mod parameters {
647    use super::*;
648
649    /// Validate algorithm iteration parameters
650    ///
651    /// # Arguments
652    ///
653    /// * `max_iter` - Maximum number of iterations
654    /// * `tolerance` - Convergence tolerance
655    /// * `operation` - Name of the operation for error messages
656    ///
657    /// # Returns
658    ///
659    /// * `Ok(())` if parameters are valid
660    /// * `Err(CoreError::ValueError)` if parameters are invalid
661    pub fn check_iteration_params<T>(
662        max_iter: usize,
663        tolerance: T,
664        operation: &str,
665    ) -> CoreResult<()>
666    where
667        T: Float + std::fmt::Display + Copy,
668    {
669        if max_iter == 0 {
670            return Err(CoreError::ValueError(
671                ErrorContext::new(format!("{operation}: max_iter must be > 0, got {max_iter}"))
672                    .with_location(ErrorLocation::new(file!(), line!())),
673            ));
674        }
675
676        check_positive(tolerance, format!("{operation} tolerance"))?;
677
678        Ok(())
679    }
680
681    /// Validate probability-like parameters (0 <= p <= 1)
682    ///
683    /// # Arguments
684    ///
685    /// * `value` - Value to check
686    /// * `name` - Parameter name for error messages
687    /// * `operation` - Operation name for error messages
688    ///
689    /// # Returns
690    ///
691    /// * `Ok(value)` if value is in [0, 1]
692    /// * `Err(CoreError::ValueError)` if value is out of range
693    pub fn check_unit_interval<T>(value: T, name: &str, operation: &str) -> CoreResult<T>
694    where
695        T: Float + std::fmt::Display + Copy,
696    {
697        if value < T::zero() || value > T::one() {
698            return Err(CoreError::ValueError(
699                ErrorContext::new(format!(
700                    "{operation}: {name} must be in [0, 1], got {value}"
701                ))
702                .with_location(ErrorLocation::new(file!(), line!())),
703            ));
704        }
705        Ok(value)
706    }
707
708    /// Validate bandwidth parameter for density-based clustering
709    ///
710    /// # Arguments
711    ///
712    /// * `bandwidth` - Bandwidth value
713    /// * `operation` - Operation name for error messages
714    ///
715    /// # Returns
716    ///
717    /// * `Ok(bandwidth)` if bandwidth is valid
718    /// * `Err(CoreError::ValueError)` if bandwidth is invalid
719    pub fn checkbandwidth<T>(bandwidth: T, operation: &str) -> CoreResult<T>
720    where
721        T: Float + std::fmt::Display + Copy,
722    {
723        check_positive(bandwidth, format!("{operation} bandwidth"))
724    }
725}
726
727#[cfg(test)]
728mod tests {
729    use super::*;
730    use ndarray::{arr1, arr2};
731
732    #[test]
733    fn test_check_in_bounds() {
734        assert!(check_in_bounds(5, 0, 10, "param").is_ok());
735        assert!(check_in_bounds(0, 0, 10, "param").is_ok());
736        assert!(check_in_bounds(10, 0, 10, "param").is_ok());
737        assert!(check_in_bounds(-1, 0, 10, "param").is_err());
738        assert!(check_in_bounds(11, 0, 10, "param").is_err());
739    }
740
741    #[test]
742    fn test_check_positive() {
743        assert!(check_positive(5, "param").is_ok());
744        assert!(check_positive(0.1, "param").is_ok());
745        assert!(check_positive(0, "param").is_err());
746        assert!(check_positive(-1, "param").is_err());
747    }
748
749    #[test]
750    fn test_check_non_negative() {
751        assert!(check_non_negative(5, "param").is_ok());
752        assert!(check_non_negative(0, "param").is_ok());
753        assert!(check_non_negative(-0.1, "param").is_err());
754        assert!(check_non_negative(-1, "param").is_err());
755    }
756
757    #[test]
758    fn test_check_finite() {
759        assert!(check_finite(5.0, "param").is_ok());
760        assert!(check_finite(0.0, "param").is_ok());
761        assert!(check_finite(-1.0, "param").is_ok());
762        assert!(check_finite(f64::INFINITY, "param").is_err());
763        assert!(check_finite(f64::NEG_INFINITY, "param").is_err());
764        assert!(check_finite(f64::NAN, "param").is_err());
765    }
766
767    #[test]
768    fn test_checkarray_finite() {
769        let a = arr1(&[1.0, 2.0, 3.0]);
770        assert!(checkarray_finite(&a, "array").is_ok());
771
772        let b = arr1(&[1.0, f64::INFINITY, 3.0]);
773        assert!(checkarray_finite(&b, "array").is_err());
774
775        let c = arr1(&[1.0, f64::NAN, 3.0]);
776        assert!(checkarray_finite(&c, "array").is_err());
777    }
778
779    #[test]
780    fn test_checkshape() {
781        let a = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
782        assert!(checkshape(&a, &[2, 2], "array").is_ok());
783        assert!(checkshape(&a, &[2, 3], "array").is_err());
784    }
785
786    #[test]
787    fn test_check_1d() {
788        let a = arr1(&[1.0, 2.0, 3.0]);
789        assert!(check_1d(&a, "array").is_ok());
790
791        let b = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
792        assert!(check_1d(&b, "array").is_err());
793    }
794
795    #[test]
796    fn test_check_2d() {
797        let a = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
798        assert!(check_2d(&a, "array").is_ok());
799
800        let b = arr1(&[1.0, 2.0, 3.0]);
801        assert!(check_2d(&b, "array").is_err());
802    }
803
804    #[test]
805    fn test_check_sameshape() {
806        let a = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
807        let b = arr2(&[[5.0, 6.0], [7.0, 8.0]]);
808        assert!(check_sameshape(&a, "a", &b, "b").is_ok());
809
810        let c = arr2(&[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
811        assert!(check_sameshape(&a, "a", &c, "c").is_err());
812    }
813
814    #[test]
815    fn test_check_square() {
816        let a = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
817        assert!(check_square(&a, "matrix").is_ok());
818
819        let b = arr2(&[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
820        assert!(check_square(&b, "matrix").is_err());
821
822        let c = arr1(&[1.0, 2.0, 3.0]);
823        assert!(check_square(&c, "matrix").is_err());
824    }
825
826    #[test]
827    fn test_check_probability() {
828        assert!(check_probability(0.0, "p").is_ok());
829        assert!(check_probability(0.5, "p").is_ok());
830        assert!(check_probability(1.0, "p").is_ok());
831        assert!(check_probability(-0.1, "p").is_err());
832        assert!(check_probability(1.1, "p").is_err());
833    }
834
835    #[test]
836    fn test_check_probabilities() {
837        let a = arr1(&[0.0, 0.5, 1.0]);
838        assert!(check_probabilities(&a, "probs").is_ok());
839
840        let b = arr1(&[0.0, 0.5, 1.1]);
841        assert!(check_probabilities(&b, "probs").is_err());
842
843        let c = arr1(&[-0.1, 0.5, 1.0]);
844        assert!(check_probabilities(&c, "probs").is_err());
845    }
846
847    #[test]
848    fn test_check_probabilities_sum_to_one() {
849        let a = arr1(&[0.3, 0.2, 0.5]);
850        assert!(check_probabilities_sum_to_one(&a, "probs", None).is_ok());
851
852        let b = arr1(&[0.3, 0.2, 0.6]);
853        assert!(check_probabilities_sum_to_one(&b, "probs", None).is_err());
854
855        // Test with custom tolerance
856        let c = arr1(&[0.3, 0.2, 0.501]);
857        assert!(check_probabilities_sum_to_one(&c, "probs", Some(0.01)).is_ok());
858        assert!(check_probabilities_sum_to_one(&c, "probs", Some(0.0001)).is_err());
859    }
860
861    #[test]
862    fn test_check_not_empty() {
863        let a = arr1(&[1.0, 2.0, 3.0]);
864        assert!(check_not_empty(&a, "array").is_ok());
865
866        let b = arr1(&[] as &[f64]);
867        assert!(check_not_empty(&b, "array").is_err());
868    }
869
870    #[test]
871    fn test_check_min_samples() {
872        let a = arr2(&[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]);
873        assert!(check_min_samples(&a, 2, "array").is_ok());
874        assert!(check_min_samples(&a, 3, "array").is_ok());
875        assert!(check_min_samples(&a, 4, "array").is_err());
876    }
877
878    mod clustering_tests {
879        use super::*;
880        use crate::validation::clustering::*;
881
882        #[test]
883        fn test_check_n_clusters_bounds() {
884            let data = arr2(&[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]);
885
886            assert!(check_n_clusters_bounds(&data, 1, "test").is_ok());
887            assert!(check_n_clusters_bounds(&data, 2, "test").is_ok());
888            assert!(check_n_clusters_bounds(&data, 3, "test").is_ok());
889            assert!(check_n_clusters_bounds(&data, 0, "test").is_err());
890            assert!(check_n_clusters_bounds(&data, 4, "test").is_err());
891        }
892
893        #[test]
894        fn test_validate_clustering_data() {
895            let data = arr2(&[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]);
896            assert!(validate_clustering_data(&data, "test", true, Some(2)).is_ok());
897            assert!(validate_clustering_data(&data, "test", true, Some(4)).is_err());
898
899            let empty_data = arr2(&[] as &[[f64; 2]; 0]);
900            assert!(validate_clustering_data(&empty_data, "test", true, None).is_err());
901
902            let inf_data = arr2(&[[1.0, f64::INFINITY], [3.0, 4.0]]);
903            assert!(validate_clustering_data(&inf_data, "test", true, None).is_err());
904            assert!(validate_clustering_data(&inf_data, "test", false, None).is_ok());
905        }
906    }
907
908    mod parameters_tests {
909        use crate::validation::parameters::*;
910
911        #[test]
912        fn test_check_iteration_params() {
913            assert!(check_iteration_params(100, 1e-6, "test").is_ok());
914            assert!(check_iteration_params(0, 1e-6, "test").is_err());
915            assert!(check_iteration_params(100, 0.0, "test").is_err());
916            assert!(check_iteration_params(100, -1e-6, "test").is_err());
917        }
918
919        #[test]
920        fn test_check_unit_interval() {
921            assert!(check_unit_interval(0.0, "param", "test").is_ok());
922            assert!(check_unit_interval(0.5, "param", "test").is_ok());
923            assert!(check_unit_interval(1.0, "param", "test").is_ok());
924            assert!(check_unit_interval(-0.1, "param", "test").is_err());
925            assert!(check_unit_interval(1.1, "param", "test").is_err());
926        }
927
928        #[test]
929        fn test_checkbandwidth() {
930            assert!(checkbandwidth(1.0, "test").is_ok());
931            assert!(checkbandwidth(0.1, "test").is_ok());
932            assert!(checkbandwidth(0.0, "test").is_err());
933            assert!(checkbandwidth(-1.0, "test").is_err());
934        }
935    }
936}
937
938/// Custom validator implementations for flexible validation logic
939pub mod custom {
940    use super::*;
941    use std::fmt;
942    use std::marker::PhantomData;
943
944    /// Trait for implementing custom validators
945    pub trait Validator<T> {
946        /// Validate the value and return Ok(()) if valid, or an error if invalid
947        fn validate(&self, value: &T, name: &str) -> CoreResult<()>;
948
949        /// Get a description of what this validator checks
950        fn description(&self) -> String;
951
952        /// Chain this validator with another validator
953        fn and<V: Validator<T>>(self, other: V) -> CompositeValidator<T, Self, V>
954        where
955            Self: Sized,
956        {
957            CompositeValidator::new(self, other)
958        }
959
960        /// Create a conditional validator that only applies when a condition is met
961        fn when<F>(self, condition: F) -> ConditionalValidator<T, Self, F>
962        where
963            Self: Sized,
964            F: Fn(&T) -> bool,
965        {
966            ConditionalValidator::new(self, condition)
967        }
968    }
969
970    /// A validator that combines two validators with AND logic
971    pub struct CompositeValidator<T, V1, V2> {
972        validator1: V1,
973        validator2: V2,
974        _phantom: PhantomData<T>,
975    }
976
977    impl<T, V1, V2> CompositeValidator<T, V1, V2> {
978        pub fn new(validator1: V1, validator2: V2) -> Self {
979            Self {
980                validator1,
981                validator2,
982                _phantom: PhantomData,
983            }
984        }
985    }
986
987    impl<T, V1, V2> Validator<T> for CompositeValidator<T, V1, V2>
988    where
989        V1: Validator<T>,
990        V2: Validator<T>,
991    {
992        fn validate(&self, value: &T, name: &str) -> CoreResult<()> {
993            self.validator1.validate(value, name)?;
994            self.validator2.validate(value, name)?;
995            Ok(())
996        }
997
998        fn description(&self) -> String {
999            format!(
1000                "{} AND {}",
1001                self.validator1.description(),
1002                self.validator2.description()
1003            )
1004        }
1005    }
1006
1007    /// A validator that only applies when a condition is met
1008    pub struct ConditionalValidator<T, V, F> {
1009        validator: V,
1010        condition: F,
1011        phantom: PhantomData<T>,
1012    }
1013
1014    impl<T, V, F> ConditionalValidator<T, V, F> {
1015        pub fn new(validator: V, condition: F) -> Self {
1016            Self {
1017                validator,
1018                condition,
1019                phantom: PhantomData,
1020            }
1021        }
1022    }
1023
1024    impl<T, V, F> Validator<T> for ConditionalValidator<T, V, F>
1025    where
1026        V: Validator<T>,
1027        F: Fn(&T) -> bool,
1028    {
1029        fn validate(&self, value: &T, name: &str) -> CoreResult<()> {
1030            if (self.condition)(value) {
1031                self.validator.validate(value, name)
1032            } else {
1033                Ok(())
1034            }
1035        }
1036
1037        fn description(&self) -> String {
1038            {
1039                let desc = self.validator.description();
1040                format!("IF condition THEN {desc}")
1041            }
1042        }
1043    }
1044
1045    /// Custom validator for ranges with inclusive/exclusive bounds
1046    pub struct RangeValidator<T> {
1047        min: Option<T>,
1048        max: Option<T>,
1049        min_inclusive: bool,
1050        max_inclusive: bool,
1051    }
1052
1053    impl<T> RangeValidator<T>
1054    where
1055        T: PartialOrd + Copy + fmt::Display,
1056    {
1057        pub fn new() -> Self {
1058            Self {
1059                min: None,
1060                max: None,
1061                min_inclusive: true,
1062                max_inclusive: true,
1063            }
1064        }
1065
1066        pub fn min(mut self, min: T) -> Self {
1067            self.min = Some(min);
1068            self
1069        }
1070
1071        pub fn max(mut self, max: T) -> Self {
1072            self.max = Some(max);
1073            self
1074        }
1075
1076        pub fn min_exclusive(mut self, min: T) -> Self {
1077            self.min = Some(min);
1078            self.min_inclusive = false;
1079            self
1080        }
1081
1082        pub fn max_exclusive(mut self, max: T) -> Self {
1083            self.max = Some(max);
1084            self.max_inclusive = false;
1085            self
1086        }
1087
1088        pub fn in_range(min: T, max: T) -> Self {
1089            Self::new().min(min).max(max)
1090        }
1091
1092        pub fn in_range_exclusive(min: T, max: T) -> Self {
1093            Self::new().min_exclusive(min).max_exclusive(max)
1094        }
1095    }
1096
1097    impl<T> Default for RangeValidator<T>
1098    where
1099        T: PartialOrd + Copy + fmt::Display,
1100    {
1101        fn default() -> Self {
1102            Self::new()
1103        }
1104    }
1105
1106    impl<T> Validator<T> for RangeValidator<T>
1107    where
1108        T: PartialOrd + Copy + fmt::Display,
1109    {
1110        fn validate(&self, value: &T, name: &str) -> CoreResult<()> {
1111            if let Some(min) = self.min {
1112                let valid = if self.min_inclusive {
1113                    *value >= min
1114                } else {
1115                    *value > min
1116                };
1117                if !valid {
1118                    let op = if self.min_inclusive { ">=" } else { ">" };
1119                    return Err(CoreError::ValueError(
1120                        ErrorContext::new(format!("{name} must be {op} {min}, got {value}"))
1121                            .with_location(ErrorLocation::new(file!(), line!())),
1122                    ));
1123                }
1124            }
1125
1126            if let Some(max) = self.max {
1127                let valid = if self.max_inclusive {
1128                    *value <= max
1129                } else {
1130                    *value < max
1131                };
1132                if !valid {
1133                    let op = if self.max_inclusive { "<=" } else { "<" };
1134                    return Err(CoreError::ValueError(
1135                        ErrorContext::new(format!("{name} must be {op} {max}, got {value}"))
1136                            .with_location(ErrorLocation::new(file!(), line!())),
1137                    ));
1138                }
1139            }
1140
1141            Ok(())
1142        }
1143
1144        fn description(&self) -> String {
1145            match (self.min, self.max) {
1146                (Some(min), Some(max)) => {
1147                    let min_op = if self.min_inclusive { ">=" } else { ">" };
1148                    let max_op = if self.max_inclusive { "<=" } else { "<" };
1149                    format!("value {min_op} {min} and {max_op} {max}")
1150                }
1151                (Some(min), None) => {
1152                    let op = if self.min_inclusive { ">=" } else { ">" };
1153                    format!("value {op} {min}")
1154                }
1155                (None, Some(max)) => {
1156                    let op = if self.max_inclusive { "<=" } else { "<" };
1157                    format!("value {op} {max}")
1158                }
1159                (None, None) => "no range constraints".to_string(),
1160            }
1161        }
1162    }
1163
1164    /// Type alias for shape validation function to reduce complexity
1165    type ShapeValidatorFn = Box<dyn Fn(&[usize]) -> CoreResult<()>>;
1166
1167    /// Custom validator for array properties
1168    pub struct ArrayValidator<T, D>
1169    where
1170        D: Dimension,
1171    {
1172        shape_validator: Option<ShapeValidatorFn>,
1173        element_validator: Option<Box<dyn Validator<T>>>,
1174        size_validator: Option<RangeValidator<usize>>,
1175        phantom: PhantomData<D>,
1176    }
1177
1178    impl<T, D> ArrayValidator<T, D>
1179    where
1180        D: Dimension,
1181    {
1182        pub fn new() -> Self {
1183            Self {
1184                shape_validator: None,
1185                element_validator: None,
1186                size_validator: None,
1187                phantom: PhantomData,
1188            }
1189        }
1190
1191        pub fn withshape<F>(mut self, validator: F) -> Self
1192        where
1193            F: Fn(&[usize]) -> CoreResult<()> + 'static,
1194        {
1195            self.shape_validator = Some(Box::new(validator));
1196            self
1197        }
1198
1199        pub fn with_elements<V>(mut self, validator: V) -> Self
1200        where
1201            V: Validator<T> + 'static,
1202        {
1203            self.element_validator = Some(Box::new(validator));
1204            self
1205        }
1206
1207        pub fn with_size(mut self, validator: RangeValidator<usize>) -> Self {
1208            self.size_validator = Some(validator);
1209            self
1210        }
1211
1212        pub fn minsize(self, minsize: usize) -> Self {
1213            self.with_size(RangeValidator::new().min(minsize))
1214        }
1215
1216        pub fn maxsize(self, maxsize: usize) -> Self {
1217            self.with_size(RangeValidator::new().max(maxsize))
1218        }
1219
1220        pub fn exact_size(self, size: usize) -> Self {
1221            self.with_size(RangeValidator::new().min(size).max(size))
1222        }
1223    }
1224
1225    impl<T, D> Default for ArrayValidator<T, D>
1226    where
1227        D: Dimension,
1228    {
1229        fn default() -> Self {
1230            Self::new()
1231        }
1232    }
1233
1234    impl<S, T, D> Validator<ArrayBase<S, D>> for ArrayValidator<T, D>
1235    where
1236        S: crate::ndarray::Data<Elem = T>,
1237        T: Clone,
1238        D: Dimension,
1239    {
1240        fn validate(&self, array: &ArrayBase<S, D>, name: &str) -> CoreResult<()> {
1241            // Validate shape
1242            if let Some(ref shape_validator) = self.shape_validator {
1243                shape_validator(array.shape())?;
1244            }
1245
1246            // Validate size
1247            if let Some(ref size_validator) = self.size_validator {
1248                size_validator.validate(&array.len(), &format!("{name} size"))?;
1249            }
1250
1251            // Validate elements
1252            if let Some(ref element_validator) = self.element_validator {
1253                for (idx, element) in array.indexed_iter() {
1254                    element_validator.validate(element, &format!("{name} element at {idx:?}"))?;
1255                }
1256            }
1257
1258            Ok(())
1259        }
1260
1261        fn description(&self) -> String {
1262            let mut parts = Vec::new();
1263
1264            if self.shape_validator.is_some() {
1265                parts.push("shape validation".to_string());
1266            }
1267
1268            if let Some(ref size_validator) = self.size_validator {
1269                {
1270                    let desc = size_validator.description();
1271                    parts.push(format!("size {desc}"));
1272                }
1273            }
1274
1275            if let Some(ref element_validator) = self.element_validator {
1276                {
1277                    let desc = element_validator.description();
1278                    parts.push(format!("elements {desc}"));
1279                }
1280            }
1281
1282            if parts.is_empty() {
1283                "no array constraints".to_string()
1284            } else {
1285                parts.join(" AND ")
1286            }
1287        }
1288    }
1289
1290    /// Custom validator for function-based validation
1291    pub struct FunctionValidator<T, F> {
1292        func: F,
1293        description: String,
1294        phantom: PhantomData<T>,
1295    }
1296
1297    impl<T, F> FunctionValidator<T, F>
1298    where
1299        F: Fn(&T, &str) -> CoreResult<()>,
1300    {
1301        pub fn new(func: F, description: impl Into<String>) -> Self {
1302            Self {
1303                func,
1304                description: description.into(),
1305                phantom: PhantomData,
1306            }
1307        }
1308    }
1309
1310    impl<T, F> Validator<T> for FunctionValidator<T, F>
1311    where
1312        F: Fn(&T, &str) -> CoreResult<()>,
1313    {
1314        fn validate(&self, value: &T, name: &str) -> CoreResult<()> {
1315            (self.func)(value, name)
1316        }
1317
1318        fn description(&self) -> String {
1319            self.description.clone()
1320        }
1321    }
1322
1323    /// Builder for creating complex validators
1324    pub struct ValidatorBuilder<T> {
1325        validators: Vec<Box<dyn Validator<T>>>,
1326    }
1327
1328    impl<T: 'static> ValidatorBuilder<T> {
1329        pub fn new() -> Self {
1330            Self {
1331                validators: Vec::new(),
1332            }
1333        }
1334
1335        pub fn with_validator<V: Validator<T> + 'static>(mut self, validator: V) -> Self {
1336            self.validators.push(Box::new(validator));
1337            self
1338        }
1339
1340        pub fn with_function<F>(self, func: F, description: impl Into<String>) -> Self
1341        where
1342            F: Fn(&T, &str) -> CoreResult<()> + 'static,
1343        {
1344            self.with_validator(FunctionValidator::new(func, description))
1345        }
1346
1347        pub fn build(self) -> MultiValidator<T> {
1348            MultiValidator {
1349                validators: self.validators,
1350            }
1351        }
1352    }
1353
1354    impl<T: 'static> Default for ValidatorBuilder<T> {
1355        fn default() -> Self {
1356            Self::new()
1357        }
1358    }
1359
1360    /// Validator that runs multiple validators
1361    pub struct MultiValidator<T> {
1362        validators: Vec<Box<dyn Validator<T>>>,
1363    }
1364
1365    impl<T: 'static> Validator<T> for MultiValidator<T> {
1366        fn validate(&self, value: &T, name: &str) -> CoreResult<()> {
1367            for validator in &self.validators {
1368                validator.validate(value, name)?;
1369            }
1370            Ok(())
1371        }
1372
1373        fn description(&self) -> String {
1374            if self.validators.is_empty() {
1375                "no validators".to_string()
1376            } else {
1377                self.validators
1378                    .iter()
1379                    .map(|v| v.description())
1380                    .collect::<Vec<_>>()
1381                    .join(" AND ")
1382            }
1383        }
1384    }
1385
1386    /// Convenience function to validate with a custom validator
1387    pub fn validate_with<T, V: Validator<T>>(
1388        value: &T,
1389        validator: &V,
1390        name: impl Into<String>,
1391    ) -> CoreResult<()> {
1392        validator.validate(value, &name.into())
1393    }
1394
1395    #[cfg(test)]
1396    mod tests {
1397        use super::*;
1398        use ::ndarray::arr1;
1399
1400        #[test]
1401        fn test_range_validator() {
1402            let validator = RangeValidator::in_range(0.0, 1.0);
1403
1404            assert!(validator.validate(&0.5, "value").is_ok());
1405            assert!(validator.validate(&0.0, "value").is_ok());
1406            assert!(validator.validate(&1.0, "value").is_ok());
1407            assert!(validator.validate(&-0.1, "value").is_err());
1408            assert!(validator.validate(&1.1, "value").is_err());
1409        }
1410
1411        #[test]
1412        fn test_range_validator_exclusive() {
1413            let validator = RangeValidator::in_range_exclusive(0.0, 1.0);
1414
1415            assert!(validator.validate(&0.5, "value").is_ok());
1416            assert!(validator.validate(&0.0, "value").is_err());
1417            assert!(validator.validate(&1.0, "value").is_err());
1418        }
1419
1420        #[test]
1421        fn test_composite_validator() {
1422            let positive = RangeValidator::new().min(0.0);
1423            let max_one = RangeValidator::new().max(1.0);
1424            let validator = positive.and(max_one);
1425
1426            assert!(validator.validate(&0.5, "value").is_ok());
1427            assert!(validator.validate(&-0.1, "value").is_err());
1428            assert!(validator.validate(&1.1, "value").is_err());
1429        }
1430
1431        #[test]
1432        fn test_conditional_validator() {
1433            let validator = RangeValidator::new().min(0.0).when(|x: &f64| *x > 0.0);
1434
1435            assert!(validator.validate(&0.5, "value").is_ok());
1436            assert!(validator.validate(&-0.5, "value").is_ok()); // Condition not met
1437            assert!(validator.validate(&0.0, "value").is_ok()); // Condition not met
1438        }
1439
1440        #[test]
1441        fn testarray_validator() {
1442            let element_validator = RangeValidator::in_range(0.0, 1.0);
1443            let array_validator = ArrayValidator::new()
1444                .with_elements(element_validator)
1445                .minsize(2);
1446
1447            let validarray = arr1(&[0.2, 0.8]);
1448            assert!(array_validator.validate(&validarray, "array").is_ok());
1449
1450            let invalidarray = arr1(&[0.2, 1.5]);
1451            assert!(array_validator.validate(&invalidarray, "array").is_err());
1452
1453            let too_smallarray = arr1(&[0.5]);
1454            assert!(array_validator.validate(&too_smallarray, "array").is_err());
1455        }
1456
1457        #[test]
1458        fn test_function_validator() {
1459            let validator = FunctionValidator::new(
1460                |value: &i32, name: &str| {
1461                    if *value % 2 == 0 {
1462                        Ok(())
1463                    } else {
1464                        Err(CoreError::ValueError(
1465                            ErrorContext::new(format!("{name} must be even, got {value}"))
1466                                .with_location(ErrorLocation::new(file!(), line!())),
1467                        ))
1468                    }
1469                },
1470                "value must be even",
1471            );
1472
1473            assert!(validator.validate(&4, "number").is_ok());
1474            assert!(validator.validate(&3, "number").is_err());
1475        }
1476
1477        #[test]
1478        fn test_validator_builder() {
1479            let validator = ValidatorBuilder::new()
1480                .with_validator(RangeValidator::new().min(0.0))
1481                .with_validator(RangeValidator::new().max(1.0))
1482                .with_function(
1483                    |value: &f64, name: &str| {
1484                        if *value != 0.5 {
1485                            Ok(())
1486                        } else {
1487                            Err(CoreError::ValueError(
1488                                ErrorContext::new(format!("{name} cannot be 0.5"))
1489                                    .with_location(ErrorLocation::new(file!(), line!())),
1490                            ))
1491                        }
1492                    },
1493                    "value cannot be 0.5",
1494                )
1495                .build();
1496
1497            assert!(validator.validate(&0.3, "value").is_ok());
1498            assert!(validator.validate(&0.5, "value").is_err());
1499            assert!(validator.validate(&-0.1, "value").is_err());
1500            assert!(validator.validate(&1.1, "value").is_err());
1501        }
1502    }
1503}
1504
1505// Production-level validation with comprehensive security and performance features
1506pub mod production;
1507
1508/// Cross-platform validation utilities for consistent behavior across operating systems and architectures
1509pub mod cross_platform;
1510
1511/// Comprehensive data validation system with schema validation and constraint enforcement
1512#[cfg(feature = "data_validation")]
1513pub mod data;