1use nalgebra::{DMatrix, DVector, SymmetricEigen};
11use serde::{Deserialize, Serialize};
12
13use super::params::{analyze_eht_support, EhtSupport};
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct EhtResult {
18 pub energies: Vec<f64>,
20 pub coefficients: Vec<Vec<f64>>,
23 pub n_electrons: usize,
25 pub homo_index: usize,
27 pub lumo_index: usize,
29 pub homo_energy: f64,
31 pub lumo_energy: f64,
33 pub gap: f64,
35 pub support: EhtSupport,
37}
38
39pub fn solve_generalized_eigenproblem(
43 h: &DMatrix<f64>,
44 s: &DMatrix<f64>,
45) -> (DVector<f64>, DMatrix<f64>) {
46 let n = h.nrows();
47
48 let s_eigen = SymmetricEigen::new(s.clone());
50 let s_vals = &s_eigen.eigenvalues;
51 let s_vecs = &s_eigen.eigenvectors;
52
53 let mut s_inv_sqrt_diag = DMatrix::zeros(n, n);
55 for i in 0..n {
56 let val = s_vals[i];
57 if val > 1e-10 {
58 s_inv_sqrt_diag[(i, i)] = 1.0 / val.sqrt();
59 }
60 }
61 let s_inv_sqrt = s_vecs * &s_inv_sqrt_diag * s_vecs.transpose();
62
63 let h_prime = &s_inv_sqrt * h * &s_inv_sqrt;
65
66 let h_eigen = SymmetricEigen::new(h_prime);
68 let energies = h_eigen.eigenvalues.clone();
69 let c_prime = h_eigen.eigenvectors.clone();
70
71 let c = &s_inv_sqrt * c_prime;
73
74 let mut indices: Vec<usize> = (0..n).collect();
76 indices.sort_by(|&a, &b| energies[a].partial_cmp(&energies[b]).unwrap());
77
78 let mut sorted_energies = DVector::zeros(n);
79 let mut sorted_c = DMatrix::zeros(n, n);
80 for (new_idx, &old_idx) in indices.iter().enumerate() {
81 sorted_energies[new_idx] = energies[old_idx];
82 for row in 0..n {
83 sorted_c[(row, new_idx)] = c[(row, old_idx)];
84 }
85 }
86
87 (sorted_energies, sorted_c)
88}
89
90fn count_valence_electrons(elements: &[u8]) -> usize {
92 elements
93 .iter()
94 .map(|&z| match z {
95 1 => 1, 5 => 3, 6 => 4, 7 => 5, 8 => 6, 9 => 7, 14 => 4, 15 => 5, 16 => 6, 17 => 7, 35 => 7, 53 => 7, 21 | 39 => 3, 22 | 40 | 72 => 4, 23 | 41 | 73 => 5, 24 | 42 | 74 => 6, 25 | 43 | 75 => 7, 26 | 44 | 76 => 8, 27 | 45 | 77 => 9, 28 | 46 | 78 => 10, 29 | 47 | 79 => 11, 30 | 48 | 80 => 12, _ => 0,
118 })
119 .sum()
120}
121
122pub fn solve_eht(
128 elements: &[u8],
129 positions: &[[f64; 3]],
130 k: Option<f64>,
131) -> Result<EhtResult, String> {
132 use super::basis::build_basis;
133 use super::hamiltonian::build_hamiltonian;
134 use super::overlap::build_overlap_matrix;
135
136 if elements.len() != positions.len() {
137 return Err("Element and position arrays must have equal length".to_string());
138 }
139
140 let support = analyze_eht_support(elements);
141 if !support.unsupported_elements.is_empty() {
142 return Err(support.warnings.join(" "));
143 }
144
145 let basis = build_basis(elements, positions);
146 if basis.is_empty() {
147 return Err("No valence orbitals found for given elements".to_string());
148 }
149
150 let s = build_overlap_matrix(&basis);
151 let h = build_hamiltonian(&basis, &s, k);
152 let (energies, c) = solve_generalized_eigenproblem(&h, &s);
153
154 let n_electrons = count_valence_electrons(elements);
155 let n_orbitals = basis.len();
156
157 let n_occupied = n_electrons.div_ceil(2); let homo_idx = if n_occupied > 0 && n_occupied <= n_orbitals {
160 n_occupied - 1
161 } else if n_orbitals > 0 {
162 0
163 } else {
164 return Err("No orbitals in EHT basis".to_string());
165 };
166 let lumo_idx = if n_occupied < n_orbitals {
167 n_occupied
168 } else {
169 homo_idx
170 };
171
172 let homo_energy = energies[homo_idx];
173 let lumo_energy = energies[lumo_idx];
174
175 let coefficients: Vec<Vec<f64>> = (0..n_orbitals)
177 .map(|row| (0..n_orbitals).map(|col| c[(row, col)]).collect())
178 .collect();
179
180 Ok(EhtResult {
181 energies: energies.iter().copied().collect(),
182 coefficients,
183 n_electrons,
184 homo_index: homo_idx,
185 lumo_index: lumo_idx,
186 homo_energy,
187 lumo_energy,
188 gap: lumo_energy - homo_energy,
189 support,
190 })
191}
192
193#[cfg(test)]
194mod tests {
195 use super::*;
196
197 #[test]
198 fn test_h2_two_eigenvalues() {
199 let elements = [1u8, 1];
200 let positions = [[0.0, 0.0, 0.0], [0.74, 0.0, 0.0]];
201 let result = solve_eht(&elements, &positions, None).unwrap();
202 assert_eq!(result.energies.len(), 2);
203 assert!(result.energies[0] < result.energies[1]);
205 assert_eq!(result.homo_index, 0);
207 assert_eq!(result.lumo_index, 1);
208 assert!(result.gap > 0.0, "H2 HOMO-LUMO gap should be positive");
209 }
210
211 #[test]
212 fn test_h2_energies_sorted() {
213 let elements = [1u8, 1];
214 let positions = [[0.0, 0.0, 0.0], [0.74, 0.0, 0.0]];
215 let result = solve_eht(&elements, &positions, None).unwrap();
216 for i in 1..result.energies.len() {
217 assert!(
218 result.energies[i] >= result.energies[i - 1],
219 "Energies not sorted: E[{}]={} < E[{}]={}",
220 i,
221 result.energies[i],
222 i - 1,
223 result.energies[i - 1]
224 );
225 }
226 }
227
228 #[test]
229 fn test_h2_coefficients_shape() {
230 let elements = [1u8, 1];
231 let positions = [[0.0, 0.0, 0.0], [0.74, 0.0, 0.0]];
232 let result = solve_eht(&elements, &positions, None).unwrap();
233 assert_eq!(result.coefficients.len(), 2);
234 assert_eq!(result.coefficients[0].len(), 2);
235 }
236
237 #[test]
238 fn test_h2o_six_orbitals() {
239 let elements = [8u8, 1, 1];
240 let positions = [[0.0, 0.0, 0.0], [0.757, 0.586, 0.0], [-0.757, 0.586, 0.0]];
241 let result = solve_eht(&elements, &positions, None).unwrap();
242 assert_eq!(result.energies.len(), 6);
244 assert_eq!(result.n_electrons, 8);
246 assert_eq!(result.homo_index, 3);
247 assert_eq!(result.lumo_index, 4);
248 }
249
250 #[test]
251 fn test_h2o_gap_positive() {
252 let elements = [8u8, 1, 1];
253 let positions = [[0.0, 0.0, 0.0], [0.757, 0.586, 0.0], [-0.757, 0.586, 0.0]];
254 let result = solve_eht(&elements, &positions, None).unwrap();
255 assert!(
256 result.gap > 0.0,
257 "H2O HOMO-LUMO gap = {} should be > 0",
258 result.gap
259 );
260 }
261
262 #[test]
263 fn test_lowdin_preserves_orthogonality() {
264 use super::super::basis::build_basis;
266 use super::super::hamiltonian::build_hamiltonian;
267 use super::super::overlap::build_overlap_matrix;
268
269 let elements = [8u8, 1, 1];
270 let positions = [[0.0, 0.0, 0.0], [0.757, 0.586, 0.0], [-0.757, 0.586, 0.0]];
271 let basis = build_basis(&elements, &positions);
272 let s = build_overlap_matrix(&basis);
273 let h = build_hamiltonian(&basis, &s, None);
274 let (_, c) = solve_generalized_eigenproblem(&h, &s);
275
276 let ct_s_c = c.transpose() * &s * &c;
278 let n = ct_s_c.nrows();
279 for i in 0..n {
280 for j in 0..n {
281 let expected = if i == j { 1.0 } else { 0.0 };
282 assert!(
283 (ct_s_c[(i, j)] - expected).abs() < 1e-8,
284 "C^T S C[{},{}] = {}, expected {}",
285 i,
286 j,
287 ct_s_c[(i, j)],
288 expected,
289 );
290 }
291 }
292 }
293
294 #[test]
295 fn test_error_mismatched_arrays() {
296 let elements = [1u8, 1];
297 let positions = [[0.0, 0.0, 0.0]]; assert!(solve_eht(&elements, &positions, None).is_err());
299 }
300
301 #[test]
302 fn test_valence_electron_count() {
303 assert_eq!(count_valence_electrons(&[1, 1]), 2); assert_eq!(count_valence_electrons(&[8, 1, 1]), 8); assert_eq!(count_valence_electrons(&[6, 1, 1, 1, 1]), 8); assert_eq!(count_valence_electrons(&[7, 1, 1, 1]), 8); }
308
309 #[test]
310 fn test_valence_electron_count_transition_metals() {
311 assert_eq!(count_valence_electrons(&[21]), 3); assert_eq!(count_valence_electrons(&[22]), 4); assert_eq!(count_valence_electrons(&[26]), 8); assert_eq!(count_valence_electrons(&[28]), 10); assert_eq!(count_valence_electrons(&[29]), 11); assert_eq!(count_valence_electrons(&[30]), 12); assert_eq!(count_valence_electrons(&[39]), 3); assert_eq!(count_valence_electrons(&[46]), 10); assert_eq!(count_valence_electrons(&[47]), 11); assert_eq!(count_valence_electrons(&[48]), 12); assert_eq!(count_valence_electrons(&[72]), 4); assert_eq!(count_valence_electrons(&[73]), 5); assert_eq!(count_valence_electrons(&[74]), 6); assert_eq!(count_valence_electrons(&[76]), 8); assert_eq!(count_valence_electrons(&[77]), 9); assert_eq!(count_valence_electrons(&[78]), 10); assert_eq!(count_valence_electrons(&[79]), 11); assert_eq!(count_valence_electrons(&[80]), 12); }
333
334 #[test]
335 fn test_cisplatin_has_even_electron_count() {
336 let elements = [78u8, 17, 17, 7, 7, 1, 1, 1, 1, 1, 1];
338 assert_eq!(count_valence_electrons(&elements), 40);
339 }
340
341 #[test]
342 fn test_transition_metal_support_metadata() {
343 let elements = [26u8];
344 let positions = [[0.0, 0.0, 0.0]];
345 let result = solve_eht(&elements, &positions, None).unwrap();
346 assert!(result.support.has_transition_metals);
347 assert_eq!(result.support.provisional_elements, vec![26]);
348 assert!(!result.support.warnings.is_empty());
349 }
350
351 #[test]
352 fn test_unsupported_element_reports_capability_error() {
353 let elements = [118u8];
354 let positions = [[0.0, 0.0, 0.0]];
355 let error = solve_eht(&elements, &positions, None).unwrap_err();
356 assert!(error.contains("No EHT parameters are available"));
357 }
358}