Skip to main content

sci_form/scf/
validation.rs

1//! Validation utilities for comparing CPU and GPU results.
2//!
3//! Provides tools to verify that GPU-computed matrices match their
4//! CPU reference implementations within acceptable tolerances.
5
6use nalgebra::DMatrix;
7
8/// Result of a matrix comparison.
9#[derive(Debug, Clone)]
10pub struct ComparisonResult {
11    /// Maximum absolute difference between elements.
12    pub max_abs_error: f64,
13    /// Mean absolute error.
14    pub mean_abs_error: f64,
15    /// Root mean square error.
16    pub rms_error: f64,
17    /// Whether all elements match within the tolerance.
18    pub passed: bool,
19    /// Number of elements that exceed the tolerance.
20    pub n_failures: usize,
21    /// Total number of elements compared.
22    pub n_elements: usize,
23}
24
25/// Compare two matrices element-wise with configurable tolerance.
26pub fn compare_matrices(a: &DMatrix<f64>, b: &DMatrix<f64>, tolerance: f64) -> ComparisonResult {
27    assert_eq!(a.nrows(), b.nrows(), "Row count mismatch");
28    assert_eq!(a.ncols(), b.ncols(), "Column count mismatch");
29
30    let n = a.nrows() * a.ncols();
31    let mut max_err = 0.0f64;
32    let mut sum_err = 0.0;
33    let mut sum_sq_err = 0.0;
34    let mut n_fail = 0;
35
36    for i in 0..a.nrows() {
37        for j in 0..a.ncols() {
38            let diff = (a[(i, j)] - b[(i, j)]).abs();
39            max_err = max_err.max(diff);
40            sum_err += diff;
41            sum_sq_err += diff * diff;
42            if diff > tolerance {
43                n_fail += 1;
44            }
45        }
46    }
47
48    ComparisonResult {
49        max_abs_error: max_err,
50        mean_abs_error: sum_err / n as f64,
51        rms_error: (sum_sq_err / n as f64).sqrt(),
52        passed: n_fail == 0,
53        n_failures: n_fail,
54        n_elements: n,
55    }
56}
57
58/// Check that a matrix is symmetric within some tolerance.
59pub fn check_symmetry(m: &DMatrix<f64>, tolerance: f64) -> bool {
60    if m.nrows() != m.ncols() {
61        return false;
62    }
63    for i in 0..m.nrows() {
64        for j in (i + 1)..m.ncols() {
65            if (m[(i, j)] - m[(j, i)]).abs() > tolerance {
66                return false;
67            }
68        }
69    }
70    true
71}
72
73/// Check that a matrix is positive definite (all eigenvalues > 0).
74pub fn check_positive_definite(m: &DMatrix<f64>) -> bool {
75    let eigen = m.clone().symmetric_eigen();
76    eigen.eigenvalues.iter().all(|&e| e > -1e-10)
77}
78
79/// Verify overlap matrix properties: symmetric, positive definite, diagonal > 0.
80pub fn validate_overlap_matrix(s: &DMatrix<f64>) -> Vec<String> {
81    let mut issues = Vec::new();
82
83    if !check_symmetry(s, 1e-12) {
84        issues.push("Overlap matrix is not symmetric".to_string());
85    }
86
87    if !check_positive_definite(s) {
88        issues.push("Overlap matrix is not positive definite".to_string());
89    }
90
91    for i in 0..s.nrows() {
92        if s[(i, i)] <= 0.0 {
93            issues.push(format!("S[{},{}] = {} is not positive", i, i, s[(i, i)]));
94        }
95    }
96
97    issues
98}
99
100#[cfg(test)]
101mod tests {
102    use super::*;
103
104    #[test]
105    fn test_compare_identical_matrices() {
106        let a = DMatrix::identity(3, 3);
107        let b = DMatrix::identity(3, 3);
108        let result = compare_matrices(&a, &b, 1e-10);
109        assert!(result.passed);
110        assert_eq!(result.n_failures, 0);
111    }
112
113    #[test]
114    fn test_compare_different_matrices() {
115        let a = DMatrix::identity(3, 3);
116        let mut b = DMatrix::identity(3, 3);
117        b[(0, 0)] = 1.1;
118        let result = compare_matrices(&a, &b, 0.05);
119        assert!(!result.passed);
120        assert_eq!(result.n_failures, 1);
121    }
122
123    #[test]
124    fn test_symmetry_check() {
125        let mut m = DMatrix::zeros(3, 3);
126        m[(0, 1)] = 1.0;
127        m[(1, 0)] = 1.0;
128        assert!(check_symmetry(&m, 1e-10));
129    }
130
131    #[test]
132    fn test_positive_definite() {
133        let m = DMatrix::identity(3, 3);
134        assert!(check_positive_definite(&m));
135    }
136}