1use nalgebra::{DMatrix, DVector, SymmetricEigen};
11use serde::{Deserialize, Serialize};
12
13use super::params::{analyze_eht_support, EhtSupport};
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct EhtResult {
18 pub energies: Vec<f64>,
20 pub coefficients: Vec<Vec<f64>>,
23 pub n_electrons: usize,
25 pub homo_index: usize,
27 pub lumo_index: usize,
29 pub homo_energy: f64,
31 pub lumo_energy: f64,
33 pub gap: f64,
35 pub support: EhtSupport,
37}
38
39pub fn solve_generalized_eigenproblem(
43 h: &DMatrix<f64>,
44 s: &DMatrix<f64>,
45) -> (DVector<f64>, DMatrix<f64>) {
46 let n = h.nrows();
47
48 let s_eigen = SymmetricEigen::new(s.clone());
50 let s_vals = &s_eigen.eigenvalues;
51 let s_vecs = &s_eigen.eigenvectors;
52
53 let mut s_inv_sqrt_diag = DMatrix::zeros(n, n);
55 for i in 0..n {
56 let val = s_vals[i];
57 if val > 1e-10 {
58 s_inv_sqrt_diag[(i, i)] = 1.0 / val.sqrt();
59 }
60 }
61 let s_inv_sqrt = s_vecs * &s_inv_sqrt_diag * s_vecs.transpose();
62
63 let h_prime = &s_inv_sqrt * h * &s_inv_sqrt;
65
66 let h_eigen = SymmetricEigen::new(h_prime);
68 let energies = h_eigen.eigenvalues.clone();
69 let c_prime = h_eigen.eigenvectors.clone();
70
71 let c = &s_inv_sqrt * c_prime;
73
74 let mut indices: Vec<usize> = (0..n).collect();
76 indices.sort_by(|&a, &b| energies[a].partial_cmp(&energies[b]).unwrap());
77
78 let mut sorted_energies = DVector::zeros(n);
79 let mut sorted_c = DMatrix::zeros(n, n);
80 for (new_idx, &old_idx) in indices.iter().enumerate() {
81 sorted_energies[new_idx] = energies[old_idx];
82 for row in 0..n {
83 sorted_c[(row, new_idx)] = c[(row, old_idx)];
84 }
85 }
86
87 (sorted_energies, sorted_c)
88}
89
90fn count_valence_electrons(elements: &[u8]) -> usize {
92 elements
93 .iter()
94 .map(|&z| match z {
95 1 => 1, 5 => 3, 6 => 4, 7 => 5, 8 => 6, 9 => 7, 14 => 4, 15 => 5, 16 => 6, 17 => 7, 35 => 7, 53 => 7, 21 | 39 => 3, 22 | 40 | 72 => 4, 23 | 41 | 73 => 5, 24 | 42 | 74 => 6, 25 | 43 | 75 => 7, 26 | 44 | 76 => 8, 27 | 45 | 77 => 9, 28 | 46 | 78 => 10, 29 | 47 | 79 => 11, 30 | 48 | 80 => 12, _ => 0,
118 })
119 .sum()
120}
121
122pub fn solve_eht(
128 elements: &[u8],
129 positions: &[[f64; 3]],
130 k: Option<f64>,
131) -> Result<EhtResult, String> {
132 use super::basis::build_basis;
133 use super::hamiltonian::build_hamiltonian;
134 use super::overlap::build_overlap_matrix;
135
136 if elements.len() != positions.len() {
137 return Err("Element and position arrays must have equal length".to_string());
138 }
139
140 let support = analyze_eht_support(elements);
141 if !support.unsupported_elements.is_empty() {
142 return Err(support.warnings.join(" "));
143 }
144
145 let basis = build_basis(elements, positions);
146 if basis.is_empty() {
147 return Err("No valence orbitals found for given elements".to_string());
148 }
149
150 let s = build_overlap_matrix(&basis);
151 let h = build_hamiltonian(&basis, &s, k);
152 let (energies, c) = solve_generalized_eigenproblem(&h, &s);
153
154 let n_electrons = count_valence_electrons(elements);
155 let n_orbitals = basis.len();
156
157 let n_occupied = n_electrons / 2;
159 let homo_idx = if n_occupied > 0 { n_occupied - 1 } else { 0 };
160 let lumo_idx = if homo_idx + 1 < n_orbitals {
161 homo_idx + 1
162 } else {
163 homo_idx
164 };
165
166 let homo_energy = energies[homo_idx];
167 let lumo_energy = energies[lumo_idx];
168
169 let coefficients: Vec<Vec<f64>> = (0..n_orbitals)
171 .map(|row| (0..n_orbitals).map(|col| c[(row, col)]).collect())
172 .collect();
173
174 Ok(EhtResult {
175 energies: energies.iter().copied().collect(),
176 coefficients,
177 n_electrons,
178 homo_index: homo_idx,
179 lumo_index: lumo_idx,
180 homo_energy,
181 lumo_energy,
182 gap: lumo_energy - homo_energy,
183 support,
184 })
185}
186
187#[cfg(test)]
188mod tests {
189 use super::*;
190
191 #[test]
192 fn test_h2_two_eigenvalues() {
193 let elements = [1u8, 1];
194 let positions = [[0.0, 0.0, 0.0], [0.74, 0.0, 0.0]];
195 let result = solve_eht(&elements, &positions, None).unwrap();
196 assert_eq!(result.energies.len(), 2);
197 assert!(result.energies[0] < result.energies[1]);
199 assert_eq!(result.homo_index, 0);
201 assert_eq!(result.lumo_index, 1);
202 assert!(result.gap > 0.0, "H2 HOMO-LUMO gap should be positive");
203 }
204
205 #[test]
206 fn test_h2_energies_sorted() {
207 let elements = [1u8, 1];
208 let positions = [[0.0, 0.0, 0.0], [0.74, 0.0, 0.0]];
209 let result = solve_eht(&elements, &positions, None).unwrap();
210 for i in 1..result.energies.len() {
211 assert!(
212 result.energies[i] >= result.energies[i - 1],
213 "Energies not sorted: E[{}]={} < E[{}]={}",
214 i,
215 result.energies[i],
216 i - 1,
217 result.energies[i - 1]
218 );
219 }
220 }
221
222 #[test]
223 fn test_h2_coefficients_shape() {
224 let elements = [1u8, 1];
225 let positions = [[0.0, 0.0, 0.0], [0.74, 0.0, 0.0]];
226 let result = solve_eht(&elements, &positions, None).unwrap();
227 assert_eq!(result.coefficients.len(), 2);
228 assert_eq!(result.coefficients[0].len(), 2);
229 }
230
231 #[test]
232 fn test_h2o_six_orbitals() {
233 let elements = [8u8, 1, 1];
234 let positions = [[0.0, 0.0, 0.0], [0.757, 0.586, 0.0], [-0.757, 0.586, 0.0]];
235 let result = solve_eht(&elements, &positions, None).unwrap();
236 assert_eq!(result.energies.len(), 6);
238 assert_eq!(result.n_electrons, 8);
240 assert_eq!(result.homo_index, 3);
241 assert_eq!(result.lumo_index, 4);
242 }
243
244 #[test]
245 fn test_h2o_gap_positive() {
246 let elements = [8u8, 1, 1];
247 let positions = [[0.0, 0.0, 0.0], [0.757, 0.586, 0.0], [-0.757, 0.586, 0.0]];
248 let result = solve_eht(&elements, &positions, None).unwrap();
249 assert!(
250 result.gap > 0.0,
251 "H2O HOMO-LUMO gap = {} should be > 0",
252 result.gap
253 );
254 }
255
256 #[test]
257 fn test_lowdin_preserves_orthogonality() {
258 use super::super::basis::build_basis;
260 use super::super::hamiltonian::build_hamiltonian;
261 use super::super::overlap::build_overlap_matrix;
262
263 let elements = [8u8, 1, 1];
264 let positions = [[0.0, 0.0, 0.0], [0.757, 0.586, 0.0], [-0.757, 0.586, 0.0]];
265 let basis = build_basis(&elements, &positions);
266 let s = build_overlap_matrix(&basis);
267 let h = build_hamiltonian(&basis, &s, None);
268 let (_, c) = solve_generalized_eigenproblem(&h, &s);
269
270 let ct_s_c = c.transpose() * &s * &c;
272 let n = ct_s_c.nrows();
273 for i in 0..n {
274 for j in 0..n {
275 let expected = if i == j { 1.0 } else { 0.0 };
276 assert!(
277 (ct_s_c[(i, j)] - expected).abs() < 1e-8,
278 "C^T S C[{},{}] = {}, expected {}",
279 i,
280 j,
281 ct_s_c[(i, j)],
282 expected,
283 );
284 }
285 }
286 }
287
288 #[test]
289 fn test_error_mismatched_arrays() {
290 let elements = [1u8, 1];
291 let positions = [[0.0, 0.0, 0.0]]; assert!(solve_eht(&elements, &positions, None).is_err());
293 }
294
295 #[test]
296 fn test_valence_electron_count() {
297 assert_eq!(count_valence_electrons(&[1, 1]), 2); assert_eq!(count_valence_electrons(&[8, 1, 1]), 8); assert_eq!(count_valence_electrons(&[6, 1, 1, 1, 1]), 8); assert_eq!(count_valence_electrons(&[7, 1, 1, 1]), 8); }
302
303 #[test]
304 fn test_valence_electron_count_transition_metals() {
305 assert_eq!(count_valence_electrons(&[21]), 3); assert_eq!(count_valence_electrons(&[22]), 4); assert_eq!(count_valence_electrons(&[26]), 8); assert_eq!(count_valence_electrons(&[28]), 10); assert_eq!(count_valence_electrons(&[29]), 11); assert_eq!(count_valence_electrons(&[30]), 12); assert_eq!(count_valence_electrons(&[39]), 3); assert_eq!(count_valence_electrons(&[46]), 10); assert_eq!(count_valence_electrons(&[47]), 11); assert_eq!(count_valence_electrons(&[48]), 12); assert_eq!(count_valence_electrons(&[72]), 4); assert_eq!(count_valence_electrons(&[73]), 5); assert_eq!(count_valence_electrons(&[74]), 6); assert_eq!(count_valence_electrons(&[76]), 8); assert_eq!(count_valence_electrons(&[77]), 9); assert_eq!(count_valence_electrons(&[78]), 10); assert_eq!(count_valence_electrons(&[79]), 11); assert_eq!(count_valence_electrons(&[80]), 12); }
327
328 #[test]
329 fn test_cisplatin_has_even_electron_count() {
330 let elements = [78u8, 17, 17, 7, 7, 1, 1, 1, 1, 1, 1];
332 assert_eq!(count_valence_electrons(&elements), 40);
333 }
334
335 #[test]
336 fn test_transition_metal_support_metadata() {
337 let elements = [26u8];
338 let positions = [[0.0, 0.0, 0.0]];
339 let result = solve_eht(&elements, &positions, None).unwrap();
340 assert!(result.support.has_transition_metals);
341 assert_eq!(result.support.provisional_elements, vec![26]);
342 assert!(!result.support.warnings.is_empty());
343 }
344
345 #[test]
346 fn test_unsupported_element_reports_capability_error() {
347 let elements = [118u8];
348 let positions = [[0.0, 0.0, 0.0]];
349 let error = solve_eht(&elements, &positions, None).unwrap_err();
350 assert!(error.contains("No EHT parameters are available"));
351 }
352}