Skip to main content

sci_form/eht/
solver.rs

1//! Generalized eigenproblem solver for EHT: HC = SCE.
2//!
3//! Uses Löwdin orthogonalization:
4//! 1. Diagonalize S → eigenvalues λ, eigenvectors U
5//! 2. Build S^{-1/2} = U diag(1/√λ) U^T
6//! 3. Transform H' = S^{-1/2} H S^{-1/2}
7//! 4. Diagonalize H' → eigenvalues E, eigenvectors C'
8//! 5. Back-transform C = S^{-1/2} C'
9
10use nalgebra::{DMatrix, DVector, SymmetricEigen};
11use serde::{Deserialize, Serialize};
12
13use super::params::{analyze_eht_support, EhtSupport};
14
15/// Result of an EHT calculation.
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct EhtResult {
18    /// Orbital energies (eigenvalues) in eV, sorted ascending.
19    pub energies: Vec<f64>,
20    /// MO coefficient matrix C (rows = AO index, cols = MO index).
21    /// Each column is one molecular orbital.
22    pub coefficients: Vec<Vec<f64>>,
23    /// Total number of valence electrons.
24    pub n_electrons: usize,
25    /// Index of the HOMO (0-based).
26    pub homo_index: usize,
27    /// Index of the LUMO (0-based).
28    pub lumo_index: usize,
29    /// HOMO energy in eV.
30    pub homo_energy: f64,
31    /// LUMO energy in eV.
32    pub lumo_energy: f64,
33    /// HOMO-LUMO gap in eV.
34    pub gap: f64,
35    /// Capability and confidence metadata for this element set.
36    pub support: EhtSupport,
37}
38
39/// Solve the generalized eigenproblem HC = SCE using Löwdin orthogonalization.
40///
41/// Returns eigenvalues (sorted ascending) and the coefficient matrix C.
42pub fn solve_generalized_eigenproblem(
43    h: &DMatrix<f64>,
44    s: &DMatrix<f64>,
45) -> (DVector<f64>, DMatrix<f64>) {
46    let n = h.nrows();
47
48    // Step 1: Diagonalize S
49    let s_eigen = SymmetricEigen::new(s.clone());
50    let s_vals = &s_eigen.eigenvalues;
51    let s_vecs = &s_eigen.eigenvectors;
52
53    // Step 2: Build S^{-1/2}
54    let mut s_inv_sqrt_diag = DMatrix::zeros(n, n);
55    for i in 0..n {
56        let val = s_vals[i];
57        if val > 1e-10 {
58            s_inv_sqrt_diag[(i, i)] = 1.0 / val.sqrt();
59        }
60    }
61    let s_inv_sqrt = s_vecs * &s_inv_sqrt_diag * s_vecs.transpose();
62
63    // Step 3: Transform Hamiltonian: H' = S^{-1/2} H S^{-1/2}
64    let h_prime = &s_inv_sqrt * h * &s_inv_sqrt;
65
66    // Step 4: Diagonalize H'
67    let h_eigen = SymmetricEigen::new(h_prime);
68    let energies = h_eigen.eigenvalues.clone();
69    let c_prime = h_eigen.eigenvectors.clone();
70
71    // Step 5: Back-transform C = S^{-1/2} C'
72    let c = &s_inv_sqrt * c_prime;
73
74    // Sort by energy (ascending)
75    let mut indices: Vec<usize> = (0..n).collect();
76    indices.sort_by(|&a, &b| energies[a].partial_cmp(&energies[b]).unwrap());
77
78    let mut sorted_energies = DVector::zeros(n);
79    let mut sorted_c = DMatrix::zeros(n, n);
80    for (new_idx, &old_idx) in indices.iter().enumerate() {
81        sorted_energies[new_idx] = energies[old_idx];
82        for row in 0..n {
83            sorted_c[(row, new_idx)] = c[(row, old_idx)];
84        }
85    }
86
87    (sorted_energies, sorted_c)
88}
89
90/// Count valence electrons for a set of atomic numbers.
91fn count_valence_electrons(elements: &[u8]) -> usize {
92    elements
93        .iter()
94        .map(|&z| match z {
95            1 => 1,             // H
96            5 => 3,             // B
97            6 => 4,             // C
98            7 => 5,             // N
99            8 => 6,             // O
100            9 => 7,             // F
101            14 => 4,            // Si
102            15 => 5,            // P
103            16 => 6,            // S
104            17 => 7,            // Cl
105            35 => 7,            // Br
106            53 => 7,            // I
107            21 | 39 => 3,       // group 3: Sc, Y
108            22 | 40 | 72 => 4,  // group 4: Ti, Zr, Hf
109            23 | 41 | 73 => 5,  // group 5: V, Nb, Ta
110            24 | 42 | 74 => 6,  // group 6: Cr, Mo, W
111            25 | 43 | 75 => 7,  // group 7: Mn, Tc, Re
112            26 | 44 | 76 => 8,  // group 8: Fe, Ru, Os
113            27 | 45 | 77 => 9,  // group 9: Co, Rh, Ir
114            28 | 46 | 78 => 10, // group 10: Ni, Pd, Pt
115            29 | 47 | 79 => 11, // group 11: Cu, Ag, Au
116            30 | 48 | 80 => 12, // group 12: Zn, Cd, Hg
117            _ => 0,
118        })
119        .sum()
120}
121
122/// Run the full EHT calculation pipeline.
123///
124/// - `elements`: atomic numbers
125/// - `positions`: Cartesian coordinates in Ångström
126/// - `k`: Wolfsberg-Helmholtz constant (None = 1.75)
127pub fn solve_eht(
128    elements: &[u8],
129    positions: &[[f64; 3]],
130    k: Option<f64>,
131) -> Result<EhtResult, String> {
132    use super::basis::build_basis;
133    use super::hamiltonian::build_hamiltonian;
134    use super::overlap::build_overlap_matrix;
135
136    if elements.len() != positions.len() {
137        return Err("Element and position arrays must have equal length".to_string());
138    }
139
140    let support = analyze_eht_support(elements);
141    if !support.unsupported_elements.is_empty() {
142        return Err(support.warnings.join(" "));
143    }
144
145    let basis = build_basis(elements, positions);
146    if basis.is_empty() {
147        return Err("No valence orbitals found for given elements".to_string());
148    }
149
150    let s = build_overlap_matrix(&basis);
151    let h = build_hamiltonian(&basis, &s, k);
152    let (energies, c) = solve_generalized_eigenproblem(&h, &s);
153
154    let n_electrons = count_valence_electrons(elements);
155    let n_orbitals = basis.len();
156
157    // HOMO is the last occupied orbital (electrons fill in pairs)
158    let n_occupied = n_electrons / 2;
159    let homo_idx = if n_occupied > 0 { n_occupied - 1 } else { 0 };
160    let lumo_idx = if homo_idx + 1 < n_orbitals {
161        homo_idx + 1
162    } else {
163        homo_idx
164    };
165
166    let homo_energy = energies[homo_idx];
167    let lumo_energy = energies[lumo_idx];
168
169    // Convert nalgebra matrices to Vec<Vec<f64>>
170    let coefficients: Vec<Vec<f64>> = (0..n_orbitals)
171        .map(|row| (0..n_orbitals).map(|col| c[(row, col)]).collect())
172        .collect();
173
174    Ok(EhtResult {
175        energies: energies.iter().copied().collect(),
176        coefficients,
177        n_electrons,
178        homo_index: homo_idx,
179        lumo_index: lumo_idx,
180        homo_energy,
181        lumo_energy,
182        gap: lumo_energy - homo_energy,
183        support,
184    })
185}
186
187#[cfg(test)]
188mod tests {
189    use super::*;
190
191    #[test]
192    fn test_h2_two_eigenvalues() {
193        let elements = [1u8, 1];
194        let positions = [[0.0, 0.0, 0.0], [0.74, 0.0, 0.0]];
195        let result = solve_eht(&elements, &positions, None).unwrap();
196        assert_eq!(result.energies.len(), 2);
197        // Bonding orbital lower than anti-bonding
198        assert!(result.energies[0] < result.energies[1]);
199        // HOMO should be the bonding orbital (index 0)
200        assert_eq!(result.homo_index, 0);
201        assert_eq!(result.lumo_index, 1);
202        assert!(result.gap > 0.0, "H2 HOMO-LUMO gap should be positive");
203    }
204
205    #[test]
206    fn test_h2_energies_sorted() {
207        let elements = [1u8, 1];
208        let positions = [[0.0, 0.0, 0.0], [0.74, 0.0, 0.0]];
209        let result = solve_eht(&elements, &positions, None).unwrap();
210        for i in 1..result.energies.len() {
211            assert!(
212                result.energies[i] >= result.energies[i - 1],
213                "Energies not sorted: E[{}]={} < E[{}]={}",
214                i,
215                result.energies[i],
216                i - 1,
217                result.energies[i - 1]
218            );
219        }
220    }
221
222    #[test]
223    fn test_h2_coefficients_shape() {
224        let elements = [1u8, 1];
225        let positions = [[0.0, 0.0, 0.0], [0.74, 0.0, 0.0]];
226        let result = solve_eht(&elements, &positions, None).unwrap();
227        assert_eq!(result.coefficients.len(), 2);
228        assert_eq!(result.coefficients[0].len(), 2);
229    }
230
231    #[test]
232    fn test_h2o_six_orbitals() {
233        let elements = [8u8, 1, 1];
234        let positions = [[0.0, 0.0, 0.0], [0.757, 0.586, 0.0], [-0.757, 0.586, 0.0]];
235        let result = solve_eht(&elements, &positions, None).unwrap();
236        // O(2s,2px,2py,2pz) + 2×H(1s) = 6 basis functions
237        assert_eq!(result.energies.len(), 6);
238        // H2O: 8 valence electrons → 4 occupied orbitals → HOMO index 3
239        assert_eq!(result.n_electrons, 8);
240        assert_eq!(result.homo_index, 3);
241        assert_eq!(result.lumo_index, 4);
242    }
243
244    #[test]
245    fn test_h2o_gap_positive() {
246        let elements = [8u8, 1, 1];
247        let positions = [[0.0, 0.0, 0.0], [0.757, 0.586, 0.0], [-0.757, 0.586, 0.0]];
248        let result = solve_eht(&elements, &positions, None).unwrap();
249        assert!(
250            result.gap > 0.0,
251            "H2O HOMO-LUMO gap = {} should be > 0",
252            result.gap
253        );
254    }
255
256    #[test]
257    fn test_lowdin_preserves_orthogonality() {
258        // After Löwdin: C^T S C should be identity
259        use super::super::basis::build_basis;
260        use super::super::hamiltonian::build_hamiltonian;
261        use super::super::overlap::build_overlap_matrix;
262
263        let elements = [8u8, 1, 1];
264        let positions = [[0.0, 0.0, 0.0], [0.757, 0.586, 0.0], [-0.757, 0.586, 0.0]];
265        let basis = build_basis(&elements, &positions);
266        let s = build_overlap_matrix(&basis);
267        let h = build_hamiltonian(&basis, &s, None);
268        let (_, c) = solve_generalized_eigenproblem(&h, &s);
269
270        // C^T S C should be approximately identity
271        let ct_s_c = c.transpose() * &s * &c;
272        let n = ct_s_c.nrows();
273        for i in 0..n {
274            for j in 0..n {
275                let expected = if i == j { 1.0 } else { 0.0 };
276                assert!(
277                    (ct_s_c[(i, j)] - expected).abs() < 1e-8,
278                    "C^T S C[{},{}] = {}, expected {}",
279                    i,
280                    j,
281                    ct_s_c[(i, j)],
282                    expected,
283                );
284            }
285        }
286    }
287
288    #[test]
289    fn test_error_mismatched_arrays() {
290        let elements = [1u8, 1];
291        let positions = [[0.0, 0.0, 0.0]]; // Only 1 position for 2 elements
292        assert!(solve_eht(&elements, &positions, None).is_err());
293    }
294
295    #[test]
296    fn test_valence_electron_count() {
297        assert_eq!(count_valence_electrons(&[1, 1]), 2); // H2
298        assert_eq!(count_valence_electrons(&[8, 1, 1]), 8); // H2O
299        assert_eq!(count_valence_electrons(&[6, 1, 1, 1, 1]), 8); // CH4
300        assert_eq!(count_valence_electrons(&[7, 1, 1, 1]), 8); // NH3
301    }
302
303    #[test]
304    fn test_valence_electron_count_transition_metals() {
305        // 3d series
306        assert_eq!(count_valence_electrons(&[21]), 3); // Sc group 3
307        assert_eq!(count_valence_electrons(&[22]), 4); // Ti group 4
308        assert_eq!(count_valence_electrons(&[26]), 8); // Fe group 8
309        assert_eq!(count_valence_electrons(&[28]), 10); // Ni group 10
310        assert_eq!(count_valence_electrons(&[29]), 11); // Cu group 11
311        assert_eq!(count_valence_electrons(&[30]), 12); // Zn group 12
312                                                        // 4d series
313        assert_eq!(count_valence_electrons(&[39]), 3); // Y group 3
314        assert_eq!(count_valence_electrons(&[46]), 10); // Pd group 10
315        assert_eq!(count_valence_electrons(&[47]), 11); // Ag group 11
316        assert_eq!(count_valence_electrons(&[48]), 12); // Cd group 12
317                                                        // 5d series (Hf=72 is group 4, NOT group 3)
318        assert_eq!(count_valence_electrons(&[72]), 4); // Hf group 4
319        assert_eq!(count_valence_electrons(&[73]), 5); // Ta group 5
320        assert_eq!(count_valence_electrons(&[74]), 6); // W group 6
321        assert_eq!(count_valence_electrons(&[76]), 8); // Os group 8
322        assert_eq!(count_valence_electrons(&[77]), 9); // Ir group 9
323        assert_eq!(count_valence_electrons(&[78]), 10); // Pt group 10
324        assert_eq!(count_valence_electrons(&[79]), 11); // Au group 11
325        assert_eq!(count_valence_electrons(&[80]), 12); // Hg group 12
326    }
327
328    #[test]
329    fn test_cisplatin_has_even_electron_count() {
330        // Pt(10) + 2×Cl(7) + 2×N(5) + 6×H(1) = 40
331        let elements = [78u8, 17, 17, 7, 7, 1, 1, 1, 1, 1, 1];
332        assert_eq!(count_valence_electrons(&elements), 40);
333    }
334
335    #[test]
336    fn test_transition_metal_support_metadata() {
337        let elements = [26u8];
338        let positions = [[0.0, 0.0, 0.0]];
339        let result = solve_eht(&elements, &positions, None).unwrap();
340        assert!(result.support.has_transition_metals);
341        assert_eq!(result.support.provisional_elements, vec![26]);
342        assert!(!result.support.warnings.is_empty());
343    }
344
345    #[test]
346    fn test_unsupported_element_reports_capability_error() {
347        let elements = [118u8];
348        let positions = [[0.0, 0.0, 0.0]];
349        let error = solve_eht(&elements, &positions, None).unwrap_err();
350        assert!(error.contains("No EHT parameters are available"));
351    }
352}