Skip to main content

scirs2_core/validation/
array_checks.rs

1//! # Array Validation Utilities
2//!
3//! Comprehensive validation functions for ndarray arrays, covering finiteness,
4//! sign constraints, matrix properties (symmetry, orthogonality, positive-definiteness,
5//! stochasticity), shape matching, and diagnostic summaries.
6//!
7//! All functions return `CoreResult` and avoid panics.
8
9use crate::error::{CoreError, CoreResult, ErrorContext, ErrorLocation};
10use ::ndarray::{Array2, ArrayBase, ArrayView2, Axis, Dimension};
11use num_traits::Float;
12use std::fmt::{Debug, Display};
13
14// ---------------------------------------------------------------------------
15// Element-wise assertions
16// ---------------------------------------------------------------------------
17
18/// Assert that every element of the array is finite (not NaN, not Inf).
19pub fn assert_finite<S, D, F>(array: &ArrayBase<S, D>, name: &str) -> CoreResult<()>
20where
21    S: ::ndarray::Data<Elem = F>,
22    D: Dimension,
23    F: Float + Display,
24{
25    for (idx, &val) in array.indexed_iter() {
26        if !val.is_finite() {
27            return Err(CoreError::ValueError(
28                ErrorContext::new(format!(
29                    "{name} contains non-finite value {val} at index {idx:?}"
30                ))
31                .with_location(ErrorLocation::new(file!(), line!())),
32            ));
33        }
34    }
35    Ok(())
36}
37
38/// Assert that every element is strictly positive (> 0).
39pub fn assert_positive<S, D, F>(array: &ArrayBase<S, D>, name: &str) -> CoreResult<()>
40where
41    S: ::ndarray::Data<Elem = F>,
42    D: Dimension,
43    F: Float + Display,
44{
45    for (idx, &val) in array.indexed_iter() {
46        if val.partial_cmp(&F::zero()) != Some(std::cmp::Ordering::Greater) {
47            return Err(CoreError::ValueError(
48                ErrorContext::new(format!(
49                    "{name} contains non-positive value {val} at index {idx:?}"
50                ))
51                .with_location(ErrorLocation::new(file!(), line!())),
52            ));
53        }
54    }
55    Ok(())
56}
57
58/// Assert that every element is non-negative (>= 0).
59pub fn assert_non_negative<S, D, F>(array: &ArrayBase<S, D>, name: &str) -> CoreResult<()>
60where
61    S: ::ndarray::Data<Elem = F>,
62    D: Dimension,
63    F: Float + Display,
64{
65    for (idx, &val) in array.indexed_iter() {
66        if val < F::zero() {
67            return Err(CoreError::ValueError(
68                ErrorContext::new(format!(
69                    "{name} contains negative value {val} at index {idx:?}"
70                ))
71                .with_location(ErrorLocation::new(file!(), line!())),
72            ));
73        }
74    }
75    Ok(())
76}
77
78// ---------------------------------------------------------------------------
79// Matrix property assertions
80// ---------------------------------------------------------------------------
81
82/// Assert that a 2-D matrix is symmetric within a given tolerance.
83///
84/// Checks |A\[i,j\] - A\[j,i\]| <= tolerance for all i,j.
85pub fn assert_symmetric<F>(matrix: &ArrayView2<F>, name: &str, tolerance: F) -> CoreResult<()>
86where
87    F: Float + Display,
88{
89    let shape = matrix.shape();
90    if shape[0] != shape[1] {
91        return Err(CoreError::ShapeError(
92            ErrorContext::new(format!(
93                "{name} is not square ({} x {}), cannot be symmetric",
94                shape[0], shape[1]
95            ))
96            .with_location(ErrorLocation::new(file!(), line!())),
97        ));
98    }
99    let n = shape[0];
100    for i in 0..n {
101        for j in (i + 1)..n {
102            let diff = (matrix[[i, j]] - matrix[[j, i]]).abs();
103            if diff > tolerance {
104                return Err(CoreError::ValueError(
105                    ErrorContext::new(format!(
106                        "{name} is not symmetric: |A[{i},{j}] - A[{j},{i}]| = {diff} > {tolerance}"
107                    ))
108                    .with_location(ErrorLocation::new(file!(), line!())),
109                ));
110            }
111        }
112    }
113    Ok(())
114}
115
116/// Assert that a 2-D matrix is orthogonal within tolerance.
117///
118/// Checks that A^T * A is close to the identity matrix.
119pub fn assert_orthogonal<F>(matrix: &ArrayView2<F>, name: &str, tolerance: F) -> CoreResult<()>
120where
121    F: Float + Display + std::ops::AddAssign + Debug,
122{
123    let shape = matrix.shape();
124    if shape[0] != shape[1] {
125        return Err(CoreError::ShapeError(
126            ErrorContext::new(format!(
127                "{name} is not square ({} x {}), cannot check orthogonality",
128                shape[0], shape[1]
129            ))
130            .with_location(ErrorLocation::new(file!(), line!())),
131        ));
132    }
133    let n = shape[0];
134
135    // Compute A^T * A element-by-element without external BLAS
136    for i in 0..n {
137        for j in 0..n {
138            let mut dot = F::zero();
139            for k in 0..n {
140                dot += matrix[[k, i]] * matrix[[k, j]];
141            }
142            let expected = if i == j { F::one() } else { F::zero() };
143            let diff = (dot - expected).abs();
144            if diff > tolerance {
145                return Err(CoreError::ValueError(
146                    ErrorContext::new(format!(
147                        "{name} is not orthogonal: (A^T A)[{i},{j}] = {dot}, expected {expected} (diff={diff})"
148                    ))
149                    .with_location(ErrorLocation::new(file!(), line!())),
150                ));
151            }
152        }
153    }
154    Ok(())
155}
156
157/// Assert that a 2-D symmetric matrix is positive definite.
158///
159/// Uses attempted Cholesky decomposition (pure Rust, no BLAS) to check.
160pub fn assert_positive_definite<F>(matrix: &ArrayView2<F>, name: &str) -> CoreResult<()>
161where
162    F: Float + Display + Debug,
163{
164    let shape = matrix.shape();
165    if shape[0] != shape[1] {
166        return Err(CoreError::ShapeError(
167            ErrorContext::new(format!(
168                "{name} is not square ({} x {}), cannot check positive definiteness",
169                shape[0], shape[1]
170            ))
171            .with_location(ErrorLocation::new(file!(), line!())),
172        ));
173    }
174    let n = shape[0];
175    let mut l = Array2::<F>::zeros((n, n));
176
177    for i in 0..n {
178        for j in 0..=i {
179            let mut sum = matrix[[i, j]];
180            for k in 0..j {
181                sum = sum - l[[i, k]] * l[[j, k]];
182            }
183            if i == j {
184                if sum <= F::zero() {
185                    return Err(CoreError::ValueError(
186                        ErrorContext::new(format!(
187                            "{name} is not positive definite: Cholesky failed at diagonal element [{i},{i}] with value {sum}"
188                        ))
189                        .with_location(ErrorLocation::new(file!(), line!())),
190                    ));
191                }
192                l[[i, j]] = sum.sqrt();
193            } else {
194                if l[[j, j]].is_zero() {
195                    return Err(CoreError::ValueError(
196                        ErrorContext::new(format!(
197                            "{name} is not positive definite: zero diagonal in Cholesky at [{j},{j}]"
198                        ))
199                        .with_location(ErrorLocation::new(file!(), line!())),
200                    ));
201                }
202                l[[i, j]] = sum / l[[j, j]];
203            }
204        }
205    }
206    Ok(())
207}
208
209/// Assert that a 2-D matrix is (row-)stochastic: each row sums to 1 within tolerance,
210/// and all elements are non-negative.
211pub fn assert_stochastic<F>(matrix: &ArrayView2<F>, name: &str, tolerance: F) -> CoreResult<()>
212where
213    F: Float + Display + std::iter::Sum,
214{
215    let shape = matrix.shape();
216    // Check non-negative
217    for i in 0..shape[0] {
218        for j in 0..shape[1] {
219            if matrix[[i, j]] < F::zero() {
220                return Err(CoreError::ValueError(
221                    ErrorContext::new(format!(
222                        "{name} has negative entry {val} at [{i},{j}]; not stochastic",
223                        val = matrix[[i, j]]
224                    ))
225                    .with_location(ErrorLocation::new(file!(), line!())),
226                ));
227            }
228        }
229    }
230    // Check row sums
231    for (i, row) in matrix.axis_iter(Axis(0)).enumerate() {
232        let row_sum: F = row.iter().copied().sum();
233        let diff = (row_sum - F::one()).abs();
234        if diff > tolerance {
235            return Err(CoreError::ValueError(
236                ErrorContext::new(format!(
237                    "{name} row {i} sums to {row_sum}, not 1.0 (diff={diff})"
238                ))
239                .with_location(ErrorLocation::new(file!(), line!())),
240            ));
241        }
242    }
243    Ok(())
244}
245
246// ---------------------------------------------------------------------------
247// Shape assertions
248// ---------------------------------------------------------------------------
249
250/// Assert that the array has the exact expected shape.
251pub fn assert_shape<S, D>(array: &ArrayBase<S, D>, expected: &[usize], name: &str) -> CoreResult<()>
252where
253    S: ::ndarray::Data,
254    D: Dimension,
255{
256    let actual = array.shape();
257    if actual != expected {
258        return Err(CoreError::ShapeError(
259            ErrorContext::new(format!(
260                "{name} shape mismatch: expected {expected:?}, got {actual:?}"
261            ))
262            .with_location(ErrorLocation::new(file!(), line!())),
263        ));
264    }
265    Ok(())
266}
267
268// ---------------------------------------------------------------------------
269// ArrayStats & diagnose_array
270// ---------------------------------------------------------------------------
271
272/// Summary statistics for an array, useful for diagnostics.
273#[derive(Debug, Clone)]
274pub struct ArrayStats<F: Float> {
275    /// Number of elements.
276    pub count: usize,
277    /// Minimum value.
278    pub min: F,
279    /// Maximum value.
280    pub max: F,
281    /// Arithmetic mean.
282    pub mean: F,
283    /// Standard deviation (population).
284    pub std_dev: F,
285    /// Whether any element is NaN.
286    pub has_nan: bool,
287    /// Whether any element is infinite.
288    pub has_inf: bool,
289    /// Number of zero elements.
290    pub zero_count: usize,
291    /// Number of negative elements.
292    pub negative_count: usize,
293}
294
295impl<F: Float + Display> std::fmt::Display for ArrayStats<F> {
296    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
297        write!(
298            f,
299            "ArrayStats(n={}, min={}, max={}, mean={}, std={}, nan={}, inf={}, zeros={}, neg={})",
300            self.count,
301            self.min,
302            self.max,
303            self.mean,
304            self.std_dev,
305            self.has_nan,
306            self.has_inf,
307            self.zero_count,
308            self.negative_count,
309        )
310    }
311}
312
313/// Compute summary statistics for a flat array.
314///
315/// Handles NaN and Inf gracefully: they are counted but excluded from
316/// min/max/mean/std calculations (which use only finite elements).
317pub fn compute_array_stats<S, D, F>(array: &ArrayBase<S, D>) -> CoreResult<ArrayStats<F>>
318where
319    S: ::ndarray::Data<Elem = F>,
320    D: Dimension,
321    F: Float + Display,
322{
323    let count = array.len();
324    if count == 0 {
325        return Err(CoreError::ValueError(ErrorContext::new(
326            "Cannot compute stats on empty array",
327        )));
328    }
329
330    let mut has_nan = false;
331    let mut has_inf = false;
332    let mut zero_count: usize = 0;
333    let mut negative_count: usize = 0;
334    let mut min_val = F::infinity();
335    let mut max_val = F::neg_infinity();
336    let mut sum = F::zero();
337    let mut finite_count: usize = 0;
338
339    for &val in array.iter() {
340        if val.is_nan() {
341            has_nan = true;
342            continue;
343        }
344        if val.is_infinite() {
345            has_inf = true;
346            continue;
347        }
348        if val.is_zero() {
349            zero_count += 1;
350        }
351        if val < F::zero() {
352            negative_count += 1;
353        }
354        if val < min_val {
355            min_val = val;
356        }
357        if val > max_val {
358            max_val = val;
359        }
360        sum = sum + val;
361        finite_count += 1;
362    }
363
364    let (mean, std_dev) = if finite_count > 0 {
365        let n = num_traits::cast::<usize, F>(finite_count).unwrap_or(F::one());
366        let mean = sum / n;
367        // Second pass for variance
368        let mut var_sum = F::zero();
369        for &val in array.iter() {
370            if val.is_finite() {
371                let diff = val - mean;
372                var_sum = var_sum + diff * diff;
373            }
374        }
375        let variance = var_sum / n;
376        (mean, variance.sqrt())
377    } else {
378        (F::nan(), F::nan())
379    };
380
381    // If no finite elements, set min/max to NaN
382    if finite_count == 0 {
383        min_val = F::nan();
384        max_val = F::nan();
385    }
386
387    Ok(ArrayStats {
388        count,
389        min: min_val,
390        max: max_val,
391        mean,
392        std_dev,
393        has_nan,
394        has_inf,
395        zero_count,
396        negative_count,
397    })
398}
399
400/// Comprehensive array health check, returning a human-readable diagnostic string.
401///
402/// Reports shape, stats, and any issues found (NaN, Inf, negative values, etc.).
403pub fn diagnose_array<S, D, F>(array: &ArrayBase<S, D>, name: &str) -> String
404where
405    S: ::ndarray::Data<Elem = F>,
406    D: Dimension,
407    F: Float + Display,
408{
409    let shape = array.shape();
410    let mut report = format!("=== Diagnostics for '{name}' ===\n");
411    report.push_str(&format!("  Shape: {shape:?}\n"));
412    report.push_str(&format!("  Total elements: {}\n", array.len()));
413
414    match compute_array_stats(array) {
415        Ok(stats) => {
416            report.push_str(&format!("  Min: {}\n", stats.min));
417            report.push_str(&format!("  Max: {}\n", stats.max));
418            report.push_str(&format!("  Mean: {}\n", stats.mean));
419            report.push_str(&format!("  Std Dev: {}\n", stats.std_dev));
420            report.push_str(&format!("  Has NaN: {}\n", stats.has_nan));
421            report.push_str(&format!("  Has Inf: {}\n", stats.has_inf));
422            report.push_str(&format!("  Zero count: {}\n", stats.zero_count));
423            report.push_str(&format!("  Negative count: {}\n", stats.negative_count));
424
425            // Issue summary
426            let mut issues = Vec::new();
427            if stats.has_nan {
428                issues.push("contains NaN values");
429            }
430            if stats.has_inf {
431                issues.push("contains Inf values");
432            }
433            if stats.count > 0 && stats.zero_count == stats.count {
434                issues.push("all elements are zero");
435            }
436
437            if issues.is_empty() {
438                report.push_str("  Issues: none\n");
439            } else {
440                report.push_str(&format!("  Issues: {}\n", issues.join(", ")));
441            }
442        }
443        Err(e) => {
444            report.push_str(&format!("  Stats error: {e}\n"));
445        }
446    }
447    report
448}
449
450// ---------------------------------------------------------------------------
451// Tests
452// ---------------------------------------------------------------------------
453
454#[cfg(test)]
455mod tests {
456    use super::*;
457    use ::ndarray::{array, Array1, Array2};
458
459    // -- assert_finite --
460
461    #[test]
462    fn test_assert_finite_ok() {
463        let a = array![1.0, 2.0, 3.0];
464        assert!(assert_finite(&a, "a").is_ok());
465    }
466
467    #[test]
468    fn test_assert_finite_nan() {
469        let a = array![1.0, f64::NAN, 3.0];
470        assert!(assert_finite(&a, "a").is_err());
471    }
472
473    #[test]
474    fn test_assert_finite_inf() {
475        let a = array![1.0, f64::INFINITY, 3.0];
476        assert!(assert_finite(&a, "a").is_err());
477    }
478
479    // -- assert_positive --
480
481    #[test]
482    fn test_assert_positive_ok() {
483        let a = array![0.1, 1.0, 100.0];
484        assert!(assert_positive(&a, "a").is_ok());
485    }
486
487    #[test]
488    fn test_assert_positive_zero() {
489        let a = array![0.0, 1.0];
490        assert!(assert_positive(&a, "a").is_err());
491    }
492
493    #[test]
494    fn test_assert_positive_neg() {
495        let a = array![1.0, -0.5];
496        assert!(assert_positive(&a, "a").is_err());
497    }
498
499    // -- assert_non_negative --
500
501    #[test]
502    fn test_assert_non_negative_ok() {
503        let a = array![0.0, 1.0, 100.0];
504        assert!(assert_non_negative(&a, "a").is_ok());
505    }
506
507    #[test]
508    fn test_assert_non_negative_neg() {
509        let a = array![0.0, -0.001];
510        assert!(assert_non_negative(&a, "a").is_err());
511    }
512
513    // -- assert_symmetric --
514
515    #[test]
516    fn test_assert_symmetric_ok() {
517        let m = array![[1.0, 2.0, 3.0], [2.0, 5.0, 6.0], [3.0, 6.0, 9.0]];
518        assert!(assert_symmetric(&m.view(), "m", 1e-12).is_ok());
519    }
520
521    #[test]
522    fn test_assert_symmetric_fail() {
523        let m = array![[1.0, 2.0], [3.0, 4.0]]; // not symmetric
524        assert!(assert_symmetric(&m.view(), "m", 1e-12).is_err());
525    }
526
527    #[test]
528    fn test_assert_symmetric_non_square() {
529        let m = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
530        assert!(assert_symmetric(&m.view(), "m", 1e-12).is_err());
531    }
532
533    // -- assert_orthogonal --
534
535    #[test]
536    fn test_assert_orthogonal_identity() {
537        let m = Array2::<f64>::eye(3);
538        assert!(assert_orthogonal(&m.view(), "I", 1e-10).is_ok());
539    }
540
541    #[test]
542    fn test_assert_orthogonal_fail() {
543        let m = array![[1.0, 2.0], [3.0, 4.0]];
544        assert!(assert_orthogonal(&m.view(), "m", 1e-10).is_err());
545    }
546
547    // -- assert_positive_definite --
548
549    #[test]
550    fn test_assert_positive_definite_ok() {
551        // 2x2 positive definite
552        let m = array![[4.0, 2.0], [2.0, 3.0]];
553        assert!(assert_positive_definite(&m.view(), "m").is_ok());
554    }
555
556    #[test]
557    fn test_assert_positive_definite_fail() {
558        // Not positive definite (negative eigenvalue)
559        let m = array![[1.0, 5.0], [5.0, 1.0]];
560        assert!(assert_positive_definite(&m.view(), "m").is_err());
561    }
562
563    #[test]
564    fn test_assert_positive_definite_3x3() {
565        let m = array![
566            [4.0, 12.0, -16.0],
567            [12.0, 37.0, -43.0],
568            [-16.0, -43.0, 98.0]
569        ];
570        assert!(assert_positive_definite(&m.view(), "m").is_ok());
571    }
572
573    // -- assert_stochastic --
574
575    #[test]
576    fn test_assert_stochastic_ok() {
577        let m = array![[0.2, 0.3, 0.5], [0.1, 0.8, 0.1]];
578        assert!(assert_stochastic(&m.view(), "m", 1e-10).is_ok());
579    }
580
581    #[test]
582    fn test_assert_stochastic_bad_sum() {
583        let m = array![[0.2, 0.3, 0.4], [0.1, 0.8, 0.1]]; // row 0 sums to 0.9
584        assert!(assert_stochastic(&m.view(), "m", 1e-10).is_err());
585    }
586
587    #[test]
588    fn test_assert_stochastic_negative() {
589        let m = array![[0.5, 0.5], [-0.1, 1.1]];
590        assert!(assert_stochastic(&m.view(), "m", 1e-10).is_err());
591    }
592
593    // -- assert_shape --
594
595    #[test]
596    fn test_assert_shape_ok() {
597        let a = array![[1.0, 2.0], [3.0, 4.0]];
598        assert!(assert_shape(&a, &[2, 2], "a").is_ok());
599    }
600
601    #[test]
602    fn test_assert_shape_mismatch() {
603        let a = array![[1.0, 2.0], [3.0, 4.0]];
604        assert!(assert_shape(&a, &[2, 3], "a").is_err());
605    }
606
607    // -- compute_array_stats --
608
609    #[test]
610    fn test_array_stats_basic() {
611        let a = array![1.0, 2.0, 3.0, 4.0, 5.0];
612        let stats = compute_array_stats(&a).expect("should succeed");
613        assert_eq!(stats.count, 5);
614        assert!((stats.min - 1.0).abs() < 1e-12);
615        assert!((stats.max - 5.0).abs() < 1e-12);
616        assert!((stats.mean - 3.0).abs() < 1e-12);
617        assert!(!stats.has_nan);
618        assert!(!stats.has_inf);
619        assert_eq!(stats.zero_count, 0);
620        assert_eq!(stats.negative_count, 0);
621    }
622
623    #[test]
624    fn test_array_stats_with_nan() {
625        let a = array![1.0, f64::NAN, 3.0];
626        let stats = compute_array_stats(&a).expect("should succeed");
627        assert!(stats.has_nan);
628        assert!(!stats.has_inf);
629        // finite elements: 1.0, 3.0
630        assert!((stats.min - 1.0).abs() < 1e-12);
631        assert!((stats.max - 3.0).abs() < 1e-12);
632    }
633
634    #[test]
635    fn test_array_stats_with_inf() {
636        let a = array![1.0, f64::INFINITY, -1.0];
637        let stats = compute_array_stats(&a).expect("should succeed");
638        assert!(stats.has_inf);
639        assert_eq!(stats.negative_count, 1);
640    }
641
642    #[test]
643    fn test_array_stats_empty() {
644        let a: Array1<f64> = Array1::from_vec(vec![]);
645        assert!(compute_array_stats(&a).is_err());
646    }
647
648    #[test]
649    fn test_array_stats_display() {
650        let a = array![1.0, 2.0, 3.0];
651        let stats = compute_array_stats(&a).expect("should succeed");
652        let display = format!("{stats}");
653        assert!(display.contains("ArrayStats"));
654        assert!(display.contains("n=3"));
655    }
656
657    // -- diagnose_array --
658
659    #[test]
660    fn test_diagnose_array_clean() {
661        let a = array![1.0, 2.0, 3.0];
662        let report = diagnose_array(&a, "test_array");
663        assert!(report.contains("test_array"));
664        assert!(report.contains("Issues: none"));
665    }
666
667    #[test]
668    fn test_diagnose_array_with_nan() {
669        let a = array![1.0, f64::NAN, 3.0];
670        let report = diagnose_array(&a, "nan_array");
671        assert!(report.contains("contains NaN"));
672    }
673
674    #[test]
675    fn test_diagnose_array_all_zeros() {
676        let a = array![0.0, 0.0, 0.0];
677        let report = diagnose_array(&a, "zero_array");
678        assert!(report.contains("all elements are zero"));
679    }
680
681    // -- assert_orthogonal with rotation matrix --
682
683    #[test]
684    fn test_assert_orthogonal_rotation() {
685        let theta: f64 = std::f64::consts::PI / 4.0;
686        let c = theta.cos();
687        let s = theta.sin();
688        let m = array![[c, -s], [s, c]];
689        assert!(assert_orthogonal(&m.view(), "rot", 1e-10).is_ok());
690    }
691
692    // -- assert_positive_definite non-square --
693
694    #[test]
695    fn test_assert_positive_definite_non_square() {
696        let m = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
697        assert!(assert_positive_definite(&m.view(), "m").is_err());
698    }
699
700    // -- stochastic identity rows --
701
702    #[test]
703    fn test_assert_stochastic_identity() {
704        let m = Array2::<f64>::eye(3);
705        assert!(assert_stochastic(&m.view(), "I", 1e-10).is_ok());
706    }
707}