1use crate::error::{CoreError, CoreResult, ErrorContext, ErrorLocation};
10use ::ndarray::{Array2, ArrayBase, ArrayView2, Axis, Dimension};
11use num_traits::Float;
12use std::fmt::{Debug, Display};
13
14pub fn assert_finite<S, D, F>(array: &ArrayBase<S, D>, name: &str) -> CoreResult<()>
20where
21 S: ::ndarray::Data<Elem = F>,
22 D: Dimension,
23 F: Float + Display,
24{
25 for (idx, &val) in array.indexed_iter() {
26 if !val.is_finite() {
27 return Err(CoreError::ValueError(
28 ErrorContext::new(format!(
29 "{name} contains non-finite value {val} at index {idx:?}"
30 ))
31 .with_location(ErrorLocation::new(file!(), line!())),
32 ));
33 }
34 }
35 Ok(())
36}
37
38pub fn assert_positive<S, D, F>(array: &ArrayBase<S, D>, name: &str) -> CoreResult<()>
40where
41 S: ::ndarray::Data<Elem = F>,
42 D: Dimension,
43 F: Float + Display,
44{
45 for (idx, &val) in array.indexed_iter() {
46 if val.partial_cmp(&F::zero()) != Some(std::cmp::Ordering::Greater) {
47 return Err(CoreError::ValueError(
48 ErrorContext::new(format!(
49 "{name} contains non-positive value {val} at index {idx:?}"
50 ))
51 .with_location(ErrorLocation::new(file!(), line!())),
52 ));
53 }
54 }
55 Ok(())
56}
57
58pub fn assert_non_negative<S, D, F>(array: &ArrayBase<S, D>, name: &str) -> CoreResult<()>
60where
61 S: ::ndarray::Data<Elem = F>,
62 D: Dimension,
63 F: Float + Display,
64{
65 for (idx, &val) in array.indexed_iter() {
66 if val < F::zero() {
67 return Err(CoreError::ValueError(
68 ErrorContext::new(format!(
69 "{name} contains negative value {val} at index {idx:?}"
70 ))
71 .with_location(ErrorLocation::new(file!(), line!())),
72 ));
73 }
74 }
75 Ok(())
76}
77
78pub fn assert_symmetric<F>(matrix: &ArrayView2<F>, name: &str, tolerance: F) -> CoreResult<()>
86where
87 F: Float + Display,
88{
89 let shape = matrix.shape();
90 if shape[0] != shape[1] {
91 return Err(CoreError::ShapeError(
92 ErrorContext::new(format!(
93 "{name} is not square ({} x {}), cannot be symmetric",
94 shape[0], shape[1]
95 ))
96 .with_location(ErrorLocation::new(file!(), line!())),
97 ));
98 }
99 let n = shape[0];
100 for i in 0..n {
101 for j in (i + 1)..n {
102 let diff = (matrix[[i, j]] - matrix[[j, i]]).abs();
103 if diff > tolerance {
104 return Err(CoreError::ValueError(
105 ErrorContext::new(format!(
106 "{name} is not symmetric: |A[{i},{j}] - A[{j},{i}]| = {diff} > {tolerance}"
107 ))
108 .with_location(ErrorLocation::new(file!(), line!())),
109 ));
110 }
111 }
112 }
113 Ok(())
114}
115
116pub fn assert_orthogonal<F>(matrix: &ArrayView2<F>, name: &str, tolerance: F) -> CoreResult<()>
120where
121 F: Float + Display + std::ops::AddAssign + Debug,
122{
123 let shape = matrix.shape();
124 if shape[0] != shape[1] {
125 return Err(CoreError::ShapeError(
126 ErrorContext::new(format!(
127 "{name} is not square ({} x {}), cannot check orthogonality",
128 shape[0], shape[1]
129 ))
130 .with_location(ErrorLocation::new(file!(), line!())),
131 ));
132 }
133 let n = shape[0];
134
135 for i in 0..n {
137 for j in 0..n {
138 let mut dot = F::zero();
139 for k in 0..n {
140 dot += matrix[[k, i]] * matrix[[k, j]];
141 }
142 let expected = if i == j { F::one() } else { F::zero() };
143 let diff = (dot - expected).abs();
144 if diff > tolerance {
145 return Err(CoreError::ValueError(
146 ErrorContext::new(format!(
147 "{name} is not orthogonal: (A^T A)[{i},{j}] = {dot}, expected {expected} (diff={diff})"
148 ))
149 .with_location(ErrorLocation::new(file!(), line!())),
150 ));
151 }
152 }
153 }
154 Ok(())
155}
156
157pub fn assert_positive_definite<F>(matrix: &ArrayView2<F>, name: &str) -> CoreResult<()>
161where
162 F: Float + Display + Debug,
163{
164 let shape = matrix.shape();
165 if shape[0] != shape[1] {
166 return Err(CoreError::ShapeError(
167 ErrorContext::new(format!(
168 "{name} is not square ({} x {}), cannot check positive definiteness",
169 shape[0], shape[1]
170 ))
171 .with_location(ErrorLocation::new(file!(), line!())),
172 ));
173 }
174 let n = shape[0];
175 let mut l = Array2::<F>::zeros((n, n));
176
177 for i in 0..n {
178 for j in 0..=i {
179 let mut sum = matrix[[i, j]];
180 for k in 0..j {
181 sum = sum - l[[i, k]] * l[[j, k]];
182 }
183 if i == j {
184 if sum <= F::zero() {
185 return Err(CoreError::ValueError(
186 ErrorContext::new(format!(
187 "{name} is not positive definite: Cholesky failed at diagonal element [{i},{i}] with value {sum}"
188 ))
189 .with_location(ErrorLocation::new(file!(), line!())),
190 ));
191 }
192 l[[i, j]] = sum.sqrt();
193 } else {
194 if l[[j, j]].is_zero() {
195 return Err(CoreError::ValueError(
196 ErrorContext::new(format!(
197 "{name} is not positive definite: zero diagonal in Cholesky at [{j},{j}]"
198 ))
199 .with_location(ErrorLocation::new(file!(), line!())),
200 ));
201 }
202 l[[i, j]] = sum / l[[j, j]];
203 }
204 }
205 }
206 Ok(())
207}
208
209pub fn assert_stochastic<F>(matrix: &ArrayView2<F>, name: &str, tolerance: F) -> CoreResult<()>
212where
213 F: Float + Display + std::iter::Sum,
214{
215 let shape = matrix.shape();
216 for i in 0..shape[0] {
218 for j in 0..shape[1] {
219 if matrix[[i, j]] < F::zero() {
220 return Err(CoreError::ValueError(
221 ErrorContext::new(format!(
222 "{name} has negative entry {val} at [{i},{j}]; not stochastic",
223 val = matrix[[i, j]]
224 ))
225 .with_location(ErrorLocation::new(file!(), line!())),
226 ));
227 }
228 }
229 }
230 for (i, row) in matrix.axis_iter(Axis(0)).enumerate() {
232 let row_sum: F = row.iter().copied().sum();
233 let diff = (row_sum - F::one()).abs();
234 if diff > tolerance {
235 return Err(CoreError::ValueError(
236 ErrorContext::new(format!(
237 "{name} row {i} sums to {row_sum}, not 1.0 (diff={diff})"
238 ))
239 .with_location(ErrorLocation::new(file!(), line!())),
240 ));
241 }
242 }
243 Ok(())
244}
245
246pub fn assert_shape<S, D>(array: &ArrayBase<S, D>, expected: &[usize], name: &str) -> CoreResult<()>
252where
253 S: ::ndarray::Data,
254 D: Dimension,
255{
256 let actual = array.shape();
257 if actual != expected {
258 return Err(CoreError::ShapeError(
259 ErrorContext::new(format!(
260 "{name} shape mismatch: expected {expected:?}, got {actual:?}"
261 ))
262 .with_location(ErrorLocation::new(file!(), line!())),
263 ));
264 }
265 Ok(())
266}
267
268#[derive(Debug, Clone)]
274pub struct ArrayStats<F: Float> {
275 pub count: usize,
277 pub min: F,
279 pub max: F,
281 pub mean: F,
283 pub std_dev: F,
285 pub has_nan: bool,
287 pub has_inf: bool,
289 pub zero_count: usize,
291 pub negative_count: usize,
293}
294
295impl<F: Float + Display> std::fmt::Display for ArrayStats<F> {
296 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
297 write!(
298 f,
299 "ArrayStats(n={}, min={}, max={}, mean={}, std={}, nan={}, inf={}, zeros={}, neg={})",
300 self.count,
301 self.min,
302 self.max,
303 self.mean,
304 self.std_dev,
305 self.has_nan,
306 self.has_inf,
307 self.zero_count,
308 self.negative_count,
309 )
310 }
311}
312
313pub fn compute_array_stats<S, D, F>(array: &ArrayBase<S, D>) -> CoreResult<ArrayStats<F>>
318where
319 S: ::ndarray::Data<Elem = F>,
320 D: Dimension,
321 F: Float + Display,
322{
323 let count = array.len();
324 if count == 0 {
325 return Err(CoreError::ValueError(ErrorContext::new(
326 "Cannot compute stats on empty array",
327 )));
328 }
329
330 let mut has_nan = false;
331 let mut has_inf = false;
332 let mut zero_count: usize = 0;
333 let mut negative_count: usize = 0;
334 let mut min_val = F::infinity();
335 let mut max_val = F::neg_infinity();
336 let mut sum = F::zero();
337 let mut finite_count: usize = 0;
338
339 for &val in array.iter() {
340 if val.is_nan() {
341 has_nan = true;
342 continue;
343 }
344 if val.is_infinite() {
345 has_inf = true;
346 continue;
347 }
348 if val.is_zero() {
349 zero_count += 1;
350 }
351 if val < F::zero() {
352 negative_count += 1;
353 }
354 if val < min_val {
355 min_val = val;
356 }
357 if val > max_val {
358 max_val = val;
359 }
360 sum = sum + val;
361 finite_count += 1;
362 }
363
364 let (mean, std_dev) = if finite_count > 0 {
365 let n = num_traits::cast::<usize, F>(finite_count).unwrap_or(F::one());
366 let mean = sum / n;
367 let mut var_sum = F::zero();
369 for &val in array.iter() {
370 if val.is_finite() {
371 let diff = val - mean;
372 var_sum = var_sum + diff * diff;
373 }
374 }
375 let variance = var_sum / n;
376 (mean, variance.sqrt())
377 } else {
378 (F::nan(), F::nan())
379 };
380
381 if finite_count == 0 {
383 min_val = F::nan();
384 max_val = F::nan();
385 }
386
387 Ok(ArrayStats {
388 count,
389 min: min_val,
390 max: max_val,
391 mean,
392 std_dev,
393 has_nan,
394 has_inf,
395 zero_count,
396 negative_count,
397 })
398}
399
400pub fn diagnose_array<S, D, F>(array: &ArrayBase<S, D>, name: &str) -> String
404where
405 S: ::ndarray::Data<Elem = F>,
406 D: Dimension,
407 F: Float + Display,
408{
409 let shape = array.shape();
410 let mut report = format!("=== Diagnostics for '{name}' ===\n");
411 report.push_str(&format!(" Shape: {shape:?}\n"));
412 report.push_str(&format!(" Total elements: {}\n", array.len()));
413
414 match compute_array_stats(array) {
415 Ok(stats) => {
416 report.push_str(&format!(" Min: {}\n", stats.min));
417 report.push_str(&format!(" Max: {}\n", stats.max));
418 report.push_str(&format!(" Mean: {}\n", stats.mean));
419 report.push_str(&format!(" Std Dev: {}\n", stats.std_dev));
420 report.push_str(&format!(" Has NaN: {}\n", stats.has_nan));
421 report.push_str(&format!(" Has Inf: {}\n", stats.has_inf));
422 report.push_str(&format!(" Zero count: {}\n", stats.zero_count));
423 report.push_str(&format!(" Negative count: {}\n", stats.negative_count));
424
425 let mut issues = Vec::new();
427 if stats.has_nan {
428 issues.push("contains NaN values");
429 }
430 if stats.has_inf {
431 issues.push("contains Inf values");
432 }
433 if stats.count > 0 && stats.zero_count == stats.count {
434 issues.push("all elements are zero");
435 }
436
437 if issues.is_empty() {
438 report.push_str(" Issues: none\n");
439 } else {
440 report.push_str(&format!(" Issues: {}\n", issues.join(", ")));
441 }
442 }
443 Err(e) => {
444 report.push_str(&format!(" Stats error: {e}\n"));
445 }
446 }
447 report
448}
449
450#[cfg(test)]
455mod tests {
456 use super::*;
457 use ::ndarray::{array, Array1, Array2};
458
459 #[test]
462 fn test_assert_finite_ok() {
463 let a = array![1.0, 2.0, 3.0];
464 assert!(assert_finite(&a, "a").is_ok());
465 }
466
467 #[test]
468 fn test_assert_finite_nan() {
469 let a = array![1.0, f64::NAN, 3.0];
470 assert!(assert_finite(&a, "a").is_err());
471 }
472
473 #[test]
474 fn test_assert_finite_inf() {
475 let a = array![1.0, f64::INFINITY, 3.0];
476 assert!(assert_finite(&a, "a").is_err());
477 }
478
479 #[test]
482 fn test_assert_positive_ok() {
483 let a = array![0.1, 1.0, 100.0];
484 assert!(assert_positive(&a, "a").is_ok());
485 }
486
487 #[test]
488 fn test_assert_positive_zero() {
489 let a = array![0.0, 1.0];
490 assert!(assert_positive(&a, "a").is_err());
491 }
492
493 #[test]
494 fn test_assert_positive_neg() {
495 let a = array![1.0, -0.5];
496 assert!(assert_positive(&a, "a").is_err());
497 }
498
499 #[test]
502 fn test_assert_non_negative_ok() {
503 let a = array![0.0, 1.0, 100.0];
504 assert!(assert_non_negative(&a, "a").is_ok());
505 }
506
507 #[test]
508 fn test_assert_non_negative_neg() {
509 let a = array![0.0, -0.001];
510 assert!(assert_non_negative(&a, "a").is_err());
511 }
512
513 #[test]
516 fn test_assert_symmetric_ok() {
517 let m = array![[1.0, 2.0, 3.0], [2.0, 5.0, 6.0], [3.0, 6.0, 9.0]];
518 assert!(assert_symmetric(&m.view(), "m", 1e-12).is_ok());
519 }
520
521 #[test]
522 fn test_assert_symmetric_fail() {
523 let m = array![[1.0, 2.0], [3.0, 4.0]]; assert!(assert_symmetric(&m.view(), "m", 1e-12).is_err());
525 }
526
527 #[test]
528 fn test_assert_symmetric_non_square() {
529 let m = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
530 assert!(assert_symmetric(&m.view(), "m", 1e-12).is_err());
531 }
532
533 #[test]
536 fn test_assert_orthogonal_identity() {
537 let m = Array2::<f64>::eye(3);
538 assert!(assert_orthogonal(&m.view(), "I", 1e-10).is_ok());
539 }
540
541 #[test]
542 fn test_assert_orthogonal_fail() {
543 let m = array![[1.0, 2.0], [3.0, 4.0]];
544 assert!(assert_orthogonal(&m.view(), "m", 1e-10).is_err());
545 }
546
547 #[test]
550 fn test_assert_positive_definite_ok() {
551 let m = array![[4.0, 2.0], [2.0, 3.0]];
553 assert!(assert_positive_definite(&m.view(), "m").is_ok());
554 }
555
556 #[test]
557 fn test_assert_positive_definite_fail() {
558 let m = array![[1.0, 5.0], [5.0, 1.0]];
560 assert!(assert_positive_definite(&m.view(), "m").is_err());
561 }
562
563 #[test]
564 fn test_assert_positive_definite_3x3() {
565 let m = array![
566 [4.0, 12.0, -16.0],
567 [12.0, 37.0, -43.0],
568 [-16.0, -43.0, 98.0]
569 ];
570 assert!(assert_positive_definite(&m.view(), "m").is_ok());
571 }
572
573 #[test]
576 fn test_assert_stochastic_ok() {
577 let m = array![[0.2, 0.3, 0.5], [0.1, 0.8, 0.1]];
578 assert!(assert_stochastic(&m.view(), "m", 1e-10).is_ok());
579 }
580
581 #[test]
582 fn test_assert_stochastic_bad_sum() {
583 let m = array![[0.2, 0.3, 0.4], [0.1, 0.8, 0.1]]; assert!(assert_stochastic(&m.view(), "m", 1e-10).is_err());
585 }
586
587 #[test]
588 fn test_assert_stochastic_negative() {
589 let m = array![[0.5, 0.5], [-0.1, 1.1]];
590 assert!(assert_stochastic(&m.view(), "m", 1e-10).is_err());
591 }
592
593 #[test]
596 fn test_assert_shape_ok() {
597 let a = array![[1.0, 2.0], [3.0, 4.0]];
598 assert!(assert_shape(&a, &[2, 2], "a").is_ok());
599 }
600
601 #[test]
602 fn test_assert_shape_mismatch() {
603 let a = array![[1.0, 2.0], [3.0, 4.0]];
604 assert!(assert_shape(&a, &[2, 3], "a").is_err());
605 }
606
607 #[test]
610 fn test_array_stats_basic() {
611 let a = array![1.0, 2.0, 3.0, 4.0, 5.0];
612 let stats = compute_array_stats(&a).expect("should succeed");
613 assert_eq!(stats.count, 5);
614 assert!((stats.min - 1.0).abs() < 1e-12);
615 assert!((stats.max - 5.0).abs() < 1e-12);
616 assert!((stats.mean - 3.0).abs() < 1e-12);
617 assert!(!stats.has_nan);
618 assert!(!stats.has_inf);
619 assert_eq!(stats.zero_count, 0);
620 assert_eq!(stats.negative_count, 0);
621 }
622
623 #[test]
624 fn test_array_stats_with_nan() {
625 let a = array![1.0, f64::NAN, 3.0];
626 let stats = compute_array_stats(&a).expect("should succeed");
627 assert!(stats.has_nan);
628 assert!(!stats.has_inf);
629 assert!((stats.min - 1.0).abs() < 1e-12);
631 assert!((stats.max - 3.0).abs() < 1e-12);
632 }
633
634 #[test]
635 fn test_array_stats_with_inf() {
636 let a = array![1.0, f64::INFINITY, -1.0];
637 let stats = compute_array_stats(&a).expect("should succeed");
638 assert!(stats.has_inf);
639 assert_eq!(stats.negative_count, 1);
640 }
641
642 #[test]
643 fn test_array_stats_empty() {
644 let a: Array1<f64> = Array1::from_vec(vec![]);
645 assert!(compute_array_stats(&a).is_err());
646 }
647
648 #[test]
649 fn test_array_stats_display() {
650 let a = array![1.0, 2.0, 3.0];
651 let stats = compute_array_stats(&a).expect("should succeed");
652 let display = format!("{stats}");
653 assert!(display.contains("ArrayStats"));
654 assert!(display.contains("n=3"));
655 }
656
657 #[test]
660 fn test_diagnose_array_clean() {
661 let a = array![1.0, 2.0, 3.0];
662 let report = diagnose_array(&a, "test_array");
663 assert!(report.contains("test_array"));
664 assert!(report.contains("Issues: none"));
665 }
666
667 #[test]
668 fn test_diagnose_array_with_nan() {
669 let a = array![1.0, f64::NAN, 3.0];
670 let report = diagnose_array(&a, "nan_array");
671 assert!(report.contains("contains NaN"));
672 }
673
674 #[test]
675 fn test_diagnose_array_all_zeros() {
676 let a = array![0.0, 0.0, 0.0];
677 let report = diagnose_array(&a, "zero_array");
678 assert!(report.contains("all elements are zero"));
679 }
680
681 #[test]
684 fn test_assert_orthogonal_rotation() {
685 let theta: f64 = std::f64::consts::PI / 4.0;
686 let c = theta.cos();
687 let s = theta.sin();
688 let m = array![[c, -s], [s, c]];
689 assert!(assert_orthogonal(&m.view(), "rot", 1e-10).is_ok());
690 }
691
692 #[test]
695 fn test_assert_positive_definite_non_square() {
696 let m = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
697 assert!(assert_positive_definite(&m.view(), "m").is_err());
698 }
699
700 #[test]
703 fn test_assert_stochastic_identity() {
704 let m = Array2::<f64>::eye(3);
705 assert!(assert_stochastic(&m.view(), "I", 1e-10).is_ok());
706 }
707}