Skip to main content

sci_form/eht/
solver.rs

1//! Generalized eigenproblem solver for EHT: HC = SCE.
2//!
3//! Uses Löwdin orthogonalization:
4//! 1. Diagonalize S → eigenvalues λ, eigenvectors U
5//! 2. Build S^{-1/2} = U diag(1/√λ) U^T
6//! 3. Transform H' = S^{-1/2} H S^{-1/2}
7//! 4. Diagonalize H' → eigenvalues E, eigenvectors C'
8//! 5. Back-transform C = S^{-1/2} C'
9
10use nalgebra::{DMatrix, DVector, SymmetricEigen};
11use serde::{Deserialize, Serialize};
12
13/// Result of an EHT calculation.
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct EhtResult {
16    /// Orbital energies (eigenvalues) in eV, sorted ascending.
17    pub energies: Vec<f64>,
18    /// MO coefficient matrix C (rows = AO index, cols = MO index).
19    /// Each column is one molecular orbital.
20    pub coefficients: Vec<Vec<f64>>,
21    /// Total number of valence electrons.
22    pub n_electrons: usize,
23    /// Index of the HOMO (0-based).
24    pub homo_index: usize,
25    /// Index of the LUMO (0-based).
26    pub lumo_index: usize,
27    /// HOMO energy in eV.
28    pub homo_energy: f64,
29    /// LUMO energy in eV.
30    pub lumo_energy: f64,
31    /// HOMO-LUMO gap in eV.
32    pub gap: f64,
33}
34
35/// Solve the generalized eigenproblem HC = SCE using Löwdin orthogonalization.
36///
37/// Returns eigenvalues (sorted ascending) and the coefficient matrix C.
38pub fn solve_generalized_eigenproblem(
39    h: &DMatrix<f64>,
40    s: &DMatrix<f64>,
41) -> (DVector<f64>, DMatrix<f64>) {
42    let n = h.nrows();
43
44    // Step 1: Diagonalize S
45    let s_eigen = SymmetricEigen::new(s.clone());
46    let s_vals = &s_eigen.eigenvalues;
47    let s_vecs = &s_eigen.eigenvectors;
48
49    // Step 2: Build S^{-1/2}
50    let mut s_inv_sqrt_diag = DMatrix::zeros(n, n);
51    for i in 0..n {
52        let val = s_vals[i];
53        if val > 1e-10 {
54            s_inv_sqrt_diag[(i, i)] = 1.0 / val.sqrt();
55        }
56    }
57    let s_inv_sqrt = s_vecs * &s_inv_sqrt_diag * s_vecs.transpose();
58
59    // Step 3: Transform Hamiltonian: H' = S^{-1/2} H S^{-1/2}
60    let h_prime = &s_inv_sqrt * h * &s_inv_sqrt;
61
62    // Step 4: Diagonalize H'
63    let h_eigen = SymmetricEigen::new(h_prime);
64    let energies = h_eigen.eigenvalues.clone();
65    let c_prime = h_eigen.eigenvectors.clone();
66
67    // Step 5: Back-transform C = S^{-1/2} C'
68    let c = &s_inv_sqrt * c_prime;
69
70    // Sort by energy (ascending)
71    let mut indices: Vec<usize> = (0..n).collect();
72    indices.sort_by(|&a, &b| energies[a].partial_cmp(&energies[b]).unwrap());
73
74    let mut sorted_energies = DVector::zeros(n);
75    let mut sorted_c = DMatrix::zeros(n, n);
76    for (new_idx, &old_idx) in indices.iter().enumerate() {
77        sorted_energies[new_idx] = energies[old_idx];
78        for row in 0..n {
79            sorted_c[(row, new_idx)] = c[(row, old_idx)];
80        }
81    }
82
83    (sorted_energies, sorted_c)
84}
85
86/// Count valence electrons for a set of atomic numbers.
87fn count_valence_electrons(elements: &[u8]) -> usize {
88    elements
89        .iter()
90        .map(|&z| match z {
91            1 => 1,  // H
92            5 => 3,  // B
93            6 => 4,  // C
94            7 => 5,  // N
95            8 => 6,  // O
96            9 => 7,  // F
97            14 => 4, // Si
98            15 => 5, // P
99            16 => 6, // S
100            17 => 7, // Cl
101            35 => 7, // Br
102            53 => 7, // I
103            _ => 0,
104        })
105        .sum()
106}
107
108/// Run the full EHT calculation pipeline.
109///
110/// - `elements`: atomic numbers
111/// - `positions`: Cartesian coordinates in Ångström
112/// - `k`: Wolfsberg-Helmholtz constant (None = 1.75)
113pub fn solve_eht(
114    elements: &[u8],
115    positions: &[[f64; 3]],
116    k: Option<f64>,
117) -> Result<EhtResult, String> {
118    use super::basis::build_basis;
119    use super::hamiltonian::build_hamiltonian;
120    use super::overlap::build_overlap_matrix;
121
122    if elements.len() != positions.len() {
123        return Err("Element and position arrays must have equal length".to_string());
124    }
125
126    let basis = build_basis(elements, positions);
127    if basis.is_empty() {
128        return Err("No valence orbitals found for given elements".to_string());
129    }
130
131    let s = build_overlap_matrix(&basis);
132    let h = build_hamiltonian(&basis, &s, k);
133    let (energies, c) = solve_generalized_eigenproblem(&h, &s);
134
135    let n_electrons = count_valence_electrons(elements);
136    let n_orbitals = basis.len();
137
138    // HOMO is the last occupied orbital (electrons fill in pairs)
139    let n_occupied = n_electrons / 2;
140    let homo_idx = if n_occupied > 0 { n_occupied - 1 } else { 0 };
141    let lumo_idx = if homo_idx + 1 < n_orbitals {
142        homo_idx + 1
143    } else {
144        homo_idx
145    };
146
147    let homo_energy = energies[homo_idx];
148    let lumo_energy = energies[lumo_idx];
149
150    // Convert nalgebra matrices to Vec<Vec<f64>>
151    let coefficients: Vec<Vec<f64>> = (0..n_orbitals)
152        .map(|row| (0..n_orbitals).map(|col| c[(row, col)]).collect())
153        .collect();
154
155    Ok(EhtResult {
156        energies: energies.iter().copied().collect(),
157        coefficients,
158        n_electrons,
159        homo_index: homo_idx,
160        lumo_index: lumo_idx,
161        homo_energy,
162        lumo_energy,
163        gap: lumo_energy - homo_energy,
164    })
165}
166
167#[cfg(test)]
168mod tests {
169    use super::*;
170
171    #[test]
172    fn test_h2_two_eigenvalues() {
173        let elements = [1u8, 1];
174        let positions = [[0.0, 0.0, 0.0], [0.74, 0.0, 0.0]];
175        let result = solve_eht(&elements, &positions, None).unwrap();
176        assert_eq!(result.energies.len(), 2);
177        // Bonding orbital lower than anti-bonding
178        assert!(result.energies[0] < result.energies[1]);
179        // HOMO should be the bonding orbital (index 0)
180        assert_eq!(result.homo_index, 0);
181        assert_eq!(result.lumo_index, 1);
182        assert!(result.gap > 0.0, "H2 HOMO-LUMO gap should be positive");
183    }
184
185    #[test]
186    fn test_h2_energies_sorted() {
187        let elements = [1u8, 1];
188        let positions = [[0.0, 0.0, 0.0], [0.74, 0.0, 0.0]];
189        let result = solve_eht(&elements, &positions, None).unwrap();
190        for i in 1..result.energies.len() {
191            assert!(
192                result.energies[i] >= result.energies[i - 1],
193                "Energies not sorted: E[{}]={} < E[{}]={}",
194                i,
195                result.energies[i],
196                i - 1,
197                result.energies[i - 1]
198            );
199        }
200    }
201
202    #[test]
203    fn test_h2_coefficients_shape() {
204        let elements = [1u8, 1];
205        let positions = [[0.0, 0.0, 0.0], [0.74, 0.0, 0.0]];
206        let result = solve_eht(&elements, &positions, None).unwrap();
207        assert_eq!(result.coefficients.len(), 2);
208        assert_eq!(result.coefficients[0].len(), 2);
209    }
210
211    #[test]
212    fn test_h2o_six_orbitals() {
213        let elements = [8u8, 1, 1];
214        let positions = [[0.0, 0.0, 0.0], [0.757, 0.586, 0.0], [-0.757, 0.586, 0.0]];
215        let result = solve_eht(&elements, &positions, None).unwrap();
216        // O(2s,2px,2py,2pz) + 2×H(1s) = 6 basis functions
217        assert_eq!(result.energies.len(), 6);
218        // H2O: 8 valence electrons → 4 occupied orbitals → HOMO index 3
219        assert_eq!(result.n_electrons, 8);
220        assert_eq!(result.homo_index, 3);
221        assert_eq!(result.lumo_index, 4);
222    }
223
224    #[test]
225    fn test_h2o_gap_positive() {
226        let elements = [8u8, 1, 1];
227        let positions = [[0.0, 0.0, 0.0], [0.757, 0.586, 0.0], [-0.757, 0.586, 0.0]];
228        let result = solve_eht(&elements, &positions, None).unwrap();
229        assert!(
230            result.gap > 0.0,
231            "H2O HOMO-LUMO gap = {} should be > 0",
232            result.gap
233        );
234    }
235
236    #[test]
237    fn test_lowdin_preserves_orthogonality() {
238        // After Löwdin: C^T S C should be identity
239        use super::super::basis::build_basis;
240        use super::super::hamiltonian::build_hamiltonian;
241        use super::super::overlap::build_overlap_matrix;
242
243        let elements = [8u8, 1, 1];
244        let positions = [[0.0, 0.0, 0.0], [0.757, 0.586, 0.0], [-0.757, 0.586, 0.0]];
245        let basis = build_basis(&elements, &positions);
246        let s = build_overlap_matrix(&basis);
247        let h = build_hamiltonian(&basis, &s, None);
248        let (_, c) = solve_generalized_eigenproblem(&h, &s);
249
250        // C^T S C should be approximately identity
251        let ct_s_c = c.transpose() * &s * &c;
252        let n = ct_s_c.nrows();
253        for i in 0..n {
254            for j in 0..n {
255                let expected = if i == j { 1.0 } else { 0.0 };
256                assert!(
257                    (ct_s_c[(i, j)] - expected).abs() < 1e-8,
258                    "C^T S C[{},{}] = {}, expected {}",
259                    i,
260                    j,
261                    ct_s_c[(i, j)],
262                    expected,
263                );
264            }
265        }
266    }
267
268    #[test]
269    fn test_error_mismatched_arrays() {
270        let elements = [1u8, 1];
271        let positions = [[0.0, 0.0, 0.0]]; // Only 1 position for 2 elements
272        assert!(solve_eht(&elements, &positions, None).is_err());
273    }
274
275    #[test]
276    fn test_valence_electron_count() {
277        assert_eq!(count_valence_electrons(&[1, 1]), 2); // H2
278        assert_eq!(count_valence_electrons(&[8, 1, 1]), 8); // H2O
279        assert_eq!(count_valence_electrons(&[6, 1, 1, 1, 1]), 8); // CH4
280        assert_eq!(count_valence_electrons(&[7, 1, 1, 1]), 8); // NH3
281    }
282}