Skip to main content

sci_form/eht/
solver.rs

1//! Generalized eigenproblem solver for EHT: HC = SCE.
2//!
3//! Uses Löwdin orthogonalization:
4//! 1. Diagonalize S → eigenvalues λ, eigenvectors U
5//! 2. Build S^{-1/2} = U diag(1/√λ) U^T
6//! 3. Transform H' = S^{-1/2} H S^{-1/2}
7//! 4. Diagonalize H' → eigenvalues E, eigenvectors C'
8//! 5. Back-transform C = S^{-1/2} C'
9
10use nalgebra::{DMatrix, DVector, SymmetricEigen};
11use serde::{Deserialize, Serialize};
12
13use super::params::{analyze_eht_support, EhtSupport};
14
15/// Result of an EHT calculation.
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct EhtResult {
18    /// Orbital energies (eigenvalues) in eV, sorted ascending.
19    pub energies: Vec<f64>,
20    /// MO coefficient matrix C (rows = AO index, cols = MO index).
21    /// Each column is one molecular orbital.
22    pub coefficients: Vec<Vec<f64>>,
23    /// Total number of valence electrons.
24    pub n_electrons: usize,
25    /// Index of the HOMO (0-based).
26    pub homo_index: usize,
27    /// Index of the LUMO (0-based).
28    pub lumo_index: usize,
29    /// HOMO energy in eV.
30    pub homo_energy: f64,
31    /// LUMO energy in eV.
32    pub lumo_energy: f64,
33    /// HOMO-LUMO gap in eV.
34    pub gap: f64,
35    /// Capability and confidence metadata for this element set.
36    pub support: EhtSupport,
37}
38
39/// Solve the generalized eigenproblem HC = SCE using Löwdin orthogonalization.
40///
41/// Returns eigenvalues (sorted ascending) and the coefficient matrix C.
42pub fn solve_generalized_eigenproblem(
43    h: &DMatrix<f64>,
44    s: &DMatrix<f64>,
45) -> (DVector<f64>, DMatrix<f64>) {
46    let n = h.nrows();
47
48    // Step 1: Diagonalize S
49    let s_eigen = SymmetricEigen::new(s.clone());
50    let s_vals = &s_eigen.eigenvalues;
51    let s_vecs = &s_eigen.eigenvectors;
52
53    // Step 2: Build S^{-1/2}
54    let mut s_inv_sqrt_diag = DMatrix::zeros(n, n);
55    for i in 0..n {
56        let val = s_vals[i];
57        if val > 1e-10 {
58            s_inv_sqrt_diag[(i, i)] = 1.0 / val.sqrt();
59        }
60    }
61    let s_inv_sqrt = s_vecs * &s_inv_sqrt_diag * s_vecs.transpose();
62
63    // Step 3: Transform Hamiltonian: H' = S^{-1/2} H S^{-1/2}
64    let h_prime = &s_inv_sqrt * h * &s_inv_sqrt;
65
66    // Step 4: Diagonalize H'
67    let h_eigen = SymmetricEigen::new(h_prime);
68    let energies = h_eigen.eigenvalues.clone();
69    let c_prime = h_eigen.eigenvectors.clone();
70
71    // Step 5: Back-transform C = S^{-1/2} C'
72    let c = &s_inv_sqrt * c_prime;
73
74    // Sort by energy (ascending)
75    let mut indices: Vec<usize> = (0..n).collect();
76    indices.sort_by(|&a, &b| energies[a].partial_cmp(&energies[b]).unwrap());
77
78    let mut sorted_energies = DVector::zeros(n);
79    let mut sorted_c = DMatrix::zeros(n, n);
80    for (new_idx, &old_idx) in indices.iter().enumerate() {
81        sorted_energies[new_idx] = energies[old_idx];
82        for row in 0..n {
83            sorted_c[(row, new_idx)] = c[(row, old_idx)];
84        }
85    }
86
87    (sorted_energies, sorted_c)
88}
89
90/// Count valence electrons for a set of atomic numbers.
91fn count_valence_electrons(elements: &[u8]) -> usize {
92    elements
93        .iter()
94        .map(|&z| match z {
95            1 => 1,             // H
96            5 => 3,             // B
97            6 => 4,             // C
98            7 => 5,             // N
99            8 => 6,             // O
100            9 => 7,             // F
101            14 => 4,            // Si
102            15 => 5,            // P
103            16 => 6,            // S
104            17 => 7,            // Cl
105            35 => 7,            // Br
106            53 => 7,            // I
107            21 | 39 => 3,       // group 3: Sc, Y
108            22 | 40 | 72 => 4,  // group 4: Ti, Zr, Hf
109            23 | 41 | 73 => 5,  // group 5: V, Nb, Ta
110            24 | 42 | 74 => 6,  // group 6: Cr, Mo, W
111            25 | 43 | 75 => 7,  // group 7: Mn, Tc, Re
112            26 | 44 | 76 => 8,  // group 8: Fe, Ru, Os
113            27 | 45 | 77 => 9,  // group 9: Co, Rh, Ir
114            28 | 46 | 78 => 10, // group 10: Ni, Pd, Pt
115            29 | 47 | 79 => 11, // group 11: Cu, Ag, Au
116            30 | 48 | 80 => 12, // group 12: Zn, Cd, Hg
117            _ => 0,
118        })
119        .sum()
120}
121
122/// Run the full EHT calculation pipeline.
123///
124/// - `elements`: atomic numbers
125/// - `positions`: Cartesian coordinates in Ångström
126/// - `k`: Wolfsberg-Helmholtz constant (None = 1.75)
127pub fn solve_eht(
128    elements: &[u8],
129    positions: &[[f64; 3]],
130    k: Option<f64>,
131) -> Result<EhtResult, String> {
132    use super::basis::build_basis;
133    use super::hamiltonian::build_hamiltonian;
134    use super::overlap::build_overlap_matrix;
135
136    if elements.len() != positions.len() {
137        return Err("Element and position arrays must have equal length".to_string());
138    }
139
140    let support = analyze_eht_support(elements);
141    if !support.unsupported_elements.is_empty() {
142        return Err(support.warnings.join(" "));
143    }
144
145    let basis = build_basis(elements, positions);
146    if basis.is_empty() {
147        return Err("No valence orbitals found for given elements".to_string());
148    }
149
150    let s = build_overlap_matrix(&basis);
151    let h = build_hamiltonian(&basis, &s, k);
152    let (energies, c) = solve_generalized_eigenproblem(&h, &s);
153
154    let n_electrons = count_valence_electrons(elements);
155    let n_orbitals = basis.len();
156
157    // HOMO is the last occupied orbital (include SOMO for odd electrons)
158    let n_occupied = n_electrons.div_ceil(2); // ceil division for odd-electron systems
159    let homo_idx = if n_occupied > 0 && n_occupied <= n_orbitals {
160        n_occupied - 1
161    } else if n_orbitals > 0 {
162        0
163    } else {
164        return Err("No orbitals in EHT basis".to_string());
165    };
166    let lumo_idx = if n_occupied < n_orbitals {
167        n_occupied
168    } else {
169        homo_idx
170    };
171
172    let homo_energy = energies[homo_idx];
173    let lumo_energy = energies[lumo_idx];
174
175    // Convert nalgebra matrices to Vec<Vec<f64>>
176    let coefficients: Vec<Vec<f64>> = (0..n_orbitals)
177        .map(|row| (0..n_orbitals).map(|col| c[(row, col)]).collect())
178        .collect();
179
180    Ok(EhtResult {
181        energies: energies.iter().copied().collect(),
182        coefficients,
183        n_electrons,
184        homo_index: homo_idx,
185        lumo_index: lumo_idx,
186        homo_energy,
187        lumo_energy,
188        gap: lumo_energy - homo_energy,
189        support,
190    })
191}
192
193#[cfg(test)]
194mod tests {
195    use super::*;
196
197    #[test]
198    fn test_h2_two_eigenvalues() {
199        let elements = [1u8, 1];
200        let positions = [[0.0, 0.0, 0.0], [0.74, 0.0, 0.0]];
201        let result = solve_eht(&elements, &positions, None).unwrap();
202        assert_eq!(result.energies.len(), 2);
203        // Bonding orbital lower than anti-bonding
204        assert!(result.energies[0] < result.energies[1]);
205        // HOMO should be the bonding orbital (index 0)
206        assert_eq!(result.homo_index, 0);
207        assert_eq!(result.lumo_index, 1);
208        assert!(result.gap > 0.0, "H2 HOMO-LUMO gap should be positive");
209    }
210
211    #[test]
212    fn test_h2_energies_sorted() {
213        let elements = [1u8, 1];
214        let positions = [[0.0, 0.0, 0.0], [0.74, 0.0, 0.0]];
215        let result = solve_eht(&elements, &positions, None).unwrap();
216        for i in 1..result.energies.len() {
217            assert!(
218                result.energies[i] >= result.energies[i - 1],
219                "Energies not sorted: E[{}]={} < E[{}]={}",
220                i,
221                result.energies[i],
222                i - 1,
223                result.energies[i - 1]
224            );
225        }
226    }
227
228    #[test]
229    fn test_h2_coefficients_shape() {
230        let elements = [1u8, 1];
231        let positions = [[0.0, 0.0, 0.0], [0.74, 0.0, 0.0]];
232        let result = solve_eht(&elements, &positions, None).unwrap();
233        assert_eq!(result.coefficients.len(), 2);
234        assert_eq!(result.coefficients[0].len(), 2);
235    }
236
237    #[test]
238    fn test_h2o_six_orbitals() {
239        let elements = [8u8, 1, 1];
240        let positions = [[0.0, 0.0, 0.0], [0.757, 0.586, 0.0], [-0.757, 0.586, 0.0]];
241        let result = solve_eht(&elements, &positions, None).unwrap();
242        // O(2s,2px,2py,2pz) + 2×H(1s) = 6 basis functions
243        assert_eq!(result.energies.len(), 6);
244        // H2O: 8 valence electrons → 4 occupied orbitals → HOMO index 3
245        assert_eq!(result.n_electrons, 8);
246        assert_eq!(result.homo_index, 3);
247        assert_eq!(result.lumo_index, 4);
248    }
249
250    #[test]
251    fn test_h2o_gap_positive() {
252        let elements = [8u8, 1, 1];
253        let positions = [[0.0, 0.0, 0.0], [0.757, 0.586, 0.0], [-0.757, 0.586, 0.0]];
254        let result = solve_eht(&elements, &positions, None).unwrap();
255        assert!(
256            result.gap > 0.0,
257            "H2O HOMO-LUMO gap = {} should be > 0",
258            result.gap
259        );
260    }
261
262    #[test]
263    fn test_lowdin_preserves_orthogonality() {
264        // After Löwdin: C^T S C should be identity
265        use super::super::basis::build_basis;
266        use super::super::hamiltonian::build_hamiltonian;
267        use super::super::overlap::build_overlap_matrix;
268
269        let elements = [8u8, 1, 1];
270        let positions = [[0.0, 0.0, 0.0], [0.757, 0.586, 0.0], [-0.757, 0.586, 0.0]];
271        let basis = build_basis(&elements, &positions);
272        let s = build_overlap_matrix(&basis);
273        let h = build_hamiltonian(&basis, &s, None);
274        let (_, c) = solve_generalized_eigenproblem(&h, &s);
275
276        // C^T S C should be approximately identity
277        let ct_s_c = c.transpose() * &s * &c;
278        let n = ct_s_c.nrows();
279        for i in 0..n {
280            for j in 0..n {
281                let expected = if i == j { 1.0 } else { 0.0 };
282                assert!(
283                    (ct_s_c[(i, j)] - expected).abs() < 1e-8,
284                    "C^T S C[{},{}] = {}, expected {}",
285                    i,
286                    j,
287                    ct_s_c[(i, j)],
288                    expected,
289                );
290            }
291        }
292    }
293
294    #[test]
295    fn test_error_mismatched_arrays() {
296        let elements = [1u8, 1];
297        let positions = [[0.0, 0.0, 0.0]]; // Only 1 position for 2 elements
298        assert!(solve_eht(&elements, &positions, None).is_err());
299    }
300
301    #[test]
302    fn test_valence_electron_count() {
303        assert_eq!(count_valence_electrons(&[1, 1]), 2); // H2
304        assert_eq!(count_valence_electrons(&[8, 1, 1]), 8); // H2O
305        assert_eq!(count_valence_electrons(&[6, 1, 1, 1, 1]), 8); // CH4
306        assert_eq!(count_valence_electrons(&[7, 1, 1, 1]), 8); // NH3
307    }
308
309    #[test]
310    fn test_valence_electron_count_transition_metals() {
311        // 3d series
312        assert_eq!(count_valence_electrons(&[21]), 3); // Sc group 3
313        assert_eq!(count_valence_electrons(&[22]), 4); // Ti group 4
314        assert_eq!(count_valence_electrons(&[26]), 8); // Fe group 8
315        assert_eq!(count_valence_electrons(&[28]), 10); // Ni group 10
316        assert_eq!(count_valence_electrons(&[29]), 11); // Cu group 11
317        assert_eq!(count_valence_electrons(&[30]), 12); // Zn group 12
318                                                        // 4d series
319        assert_eq!(count_valence_electrons(&[39]), 3); // Y group 3
320        assert_eq!(count_valence_electrons(&[46]), 10); // Pd group 10
321        assert_eq!(count_valence_electrons(&[47]), 11); // Ag group 11
322        assert_eq!(count_valence_electrons(&[48]), 12); // Cd group 12
323                                                        // 5d series (Hf=72 is group 4, NOT group 3)
324        assert_eq!(count_valence_electrons(&[72]), 4); // Hf group 4
325        assert_eq!(count_valence_electrons(&[73]), 5); // Ta group 5
326        assert_eq!(count_valence_electrons(&[74]), 6); // W group 6
327        assert_eq!(count_valence_electrons(&[76]), 8); // Os group 8
328        assert_eq!(count_valence_electrons(&[77]), 9); // Ir group 9
329        assert_eq!(count_valence_electrons(&[78]), 10); // Pt group 10
330        assert_eq!(count_valence_electrons(&[79]), 11); // Au group 11
331        assert_eq!(count_valence_electrons(&[80]), 12); // Hg group 12
332    }
333
334    #[test]
335    fn test_cisplatin_has_even_electron_count() {
336        // Pt(10) + 2×Cl(7) + 2×N(5) + 6×H(1) = 40
337        let elements = [78u8, 17, 17, 7, 7, 1, 1, 1, 1, 1, 1];
338        assert_eq!(count_valence_electrons(&elements), 40);
339    }
340
341    #[test]
342    fn test_transition_metal_support_metadata() {
343        let elements = [26u8];
344        let positions = [[0.0, 0.0, 0.0]];
345        let result = solve_eht(&elements, &positions, None).unwrap();
346        assert!(result.support.has_transition_metals);
347        assert_eq!(result.support.provisional_elements, vec![26]);
348        assert!(!result.support.warnings.is_empty());
349    }
350
351    #[test]
352    fn test_unsupported_element_reports_capability_error() {
353        let elements = [118u8];
354        let positions = [[0.0, 0.0, 0.0]];
355        let error = solve_eht(&elements, &positions, None).unwrap_err();
356        assert!(error.contains("No EHT parameters are available"));
357    }
358}