1use nalgebra::{DMatrix, DVector, SymmetricEigen};
11use serde::{Deserialize, Serialize};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct EhtResult {
16 pub energies: Vec<f64>,
18 pub coefficients: Vec<Vec<f64>>,
21 pub n_electrons: usize,
23 pub homo_index: usize,
25 pub lumo_index: usize,
27 pub homo_energy: f64,
29 pub lumo_energy: f64,
31 pub gap: f64,
33}
34
35pub fn solve_generalized_eigenproblem(
39 h: &DMatrix<f64>,
40 s: &DMatrix<f64>,
41) -> (DVector<f64>, DMatrix<f64>) {
42 let n = h.nrows();
43
44 let s_eigen = SymmetricEigen::new(s.clone());
46 let s_vals = &s_eigen.eigenvalues;
47 let s_vecs = &s_eigen.eigenvectors;
48
49 let mut s_inv_sqrt_diag = DMatrix::zeros(n, n);
51 for i in 0..n {
52 let val = s_vals[i];
53 if val > 1e-10 {
54 s_inv_sqrt_diag[(i, i)] = 1.0 / val.sqrt();
55 }
56 }
57 let s_inv_sqrt = s_vecs * &s_inv_sqrt_diag * s_vecs.transpose();
58
59 let h_prime = &s_inv_sqrt * h * &s_inv_sqrt;
61
62 let h_eigen = SymmetricEigen::new(h_prime);
64 let energies = h_eigen.eigenvalues.clone();
65 let c_prime = h_eigen.eigenvectors.clone();
66
67 let c = &s_inv_sqrt * c_prime;
69
70 let mut indices: Vec<usize> = (0..n).collect();
72 indices.sort_by(|&a, &b| energies[a].partial_cmp(&energies[b]).unwrap());
73
74 let mut sorted_energies = DVector::zeros(n);
75 let mut sorted_c = DMatrix::zeros(n, n);
76 for (new_idx, &old_idx) in indices.iter().enumerate() {
77 sorted_energies[new_idx] = energies[old_idx];
78 for row in 0..n {
79 sorted_c[(row, new_idx)] = c[(row, old_idx)];
80 }
81 }
82
83 (sorted_energies, sorted_c)
84}
85
86fn count_valence_electrons(elements: &[u8]) -> usize {
88 elements
89 .iter()
90 .map(|&z| match z {
91 1 => 1, 5 => 3, 6 => 4, 7 => 5, 8 => 6, 9 => 7, 14 => 4, 15 => 5, 16 => 6, 17 => 7, 35 => 7, 53 => 7, _ => 0,
104 })
105 .sum()
106}
107
108pub fn solve_eht(
114 elements: &[u8],
115 positions: &[[f64; 3]],
116 k: Option<f64>,
117) -> Result<EhtResult, String> {
118 use super::basis::build_basis;
119 use super::hamiltonian::build_hamiltonian;
120 use super::overlap::build_overlap_matrix;
121
122 if elements.len() != positions.len() {
123 return Err("Element and position arrays must have equal length".to_string());
124 }
125
126 let basis = build_basis(elements, positions);
127 if basis.is_empty() {
128 return Err("No valence orbitals found for given elements".to_string());
129 }
130
131 let s = build_overlap_matrix(&basis);
132 let h = build_hamiltonian(&basis, &s, k);
133 let (energies, c) = solve_generalized_eigenproblem(&h, &s);
134
135 let n_electrons = count_valence_electrons(elements);
136 let n_orbitals = basis.len();
137
138 let n_occupied = n_electrons / 2;
140 let homo_idx = if n_occupied > 0 { n_occupied - 1 } else { 0 };
141 let lumo_idx = if homo_idx + 1 < n_orbitals {
142 homo_idx + 1
143 } else {
144 homo_idx
145 };
146
147 let homo_energy = energies[homo_idx];
148 let lumo_energy = energies[lumo_idx];
149
150 let coefficients: Vec<Vec<f64>> = (0..n_orbitals)
152 .map(|row| (0..n_orbitals).map(|col| c[(row, col)]).collect())
153 .collect();
154
155 Ok(EhtResult {
156 energies: energies.iter().copied().collect(),
157 coefficients,
158 n_electrons,
159 homo_index: homo_idx,
160 lumo_index: lumo_idx,
161 homo_energy,
162 lumo_energy,
163 gap: lumo_energy - homo_energy,
164 })
165}
166
167#[cfg(test)]
168mod tests {
169 use super::*;
170
171 #[test]
172 fn test_h2_two_eigenvalues() {
173 let elements = [1u8, 1];
174 let positions = [[0.0, 0.0, 0.0], [0.74, 0.0, 0.0]];
175 let result = solve_eht(&elements, &positions, None).unwrap();
176 assert_eq!(result.energies.len(), 2);
177 assert!(result.energies[0] < result.energies[1]);
179 assert_eq!(result.homo_index, 0);
181 assert_eq!(result.lumo_index, 1);
182 assert!(result.gap > 0.0, "H2 HOMO-LUMO gap should be positive");
183 }
184
185 #[test]
186 fn test_h2_energies_sorted() {
187 let elements = [1u8, 1];
188 let positions = [[0.0, 0.0, 0.0], [0.74, 0.0, 0.0]];
189 let result = solve_eht(&elements, &positions, None).unwrap();
190 for i in 1..result.energies.len() {
191 assert!(
192 result.energies[i] >= result.energies[i - 1],
193 "Energies not sorted: E[{}]={} < E[{}]={}",
194 i,
195 result.energies[i],
196 i - 1,
197 result.energies[i - 1]
198 );
199 }
200 }
201
202 #[test]
203 fn test_h2_coefficients_shape() {
204 let elements = [1u8, 1];
205 let positions = [[0.0, 0.0, 0.0], [0.74, 0.0, 0.0]];
206 let result = solve_eht(&elements, &positions, None).unwrap();
207 assert_eq!(result.coefficients.len(), 2);
208 assert_eq!(result.coefficients[0].len(), 2);
209 }
210
211 #[test]
212 fn test_h2o_six_orbitals() {
213 let elements = [8u8, 1, 1];
214 let positions = [[0.0, 0.0, 0.0], [0.757, 0.586, 0.0], [-0.757, 0.586, 0.0]];
215 let result = solve_eht(&elements, &positions, None).unwrap();
216 assert_eq!(result.energies.len(), 6);
218 assert_eq!(result.n_electrons, 8);
220 assert_eq!(result.homo_index, 3);
221 assert_eq!(result.lumo_index, 4);
222 }
223
224 #[test]
225 fn test_h2o_gap_positive() {
226 let elements = [8u8, 1, 1];
227 let positions = [[0.0, 0.0, 0.0], [0.757, 0.586, 0.0], [-0.757, 0.586, 0.0]];
228 let result = solve_eht(&elements, &positions, None).unwrap();
229 assert!(
230 result.gap > 0.0,
231 "H2O HOMO-LUMO gap = {} should be > 0",
232 result.gap
233 );
234 }
235
236 #[test]
237 fn test_lowdin_preserves_orthogonality() {
238 use super::super::basis::build_basis;
240 use super::super::hamiltonian::build_hamiltonian;
241 use super::super::overlap::build_overlap_matrix;
242
243 let elements = [8u8, 1, 1];
244 let positions = [[0.0, 0.0, 0.0], [0.757, 0.586, 0.0], [-0.757, 0.586, 0.0]];
245 let basis = build_basis(&elements, &positions);
246 let s = build_overlap_matrix(&basis);
247 let h = build_hamiltonian(&basis, &s, None);
248 let (_, c) = solve_generalized_eigenproblem(&h, &s);
249
250 let ct_s_c = c.transpose() * &s * &c;
252 let n = ct_s_c.nrows();
253 for i in 0..n {
254 for j in 0..n {
255 let expected = if i == j { 1.0 } else { 0.0 };
256 assert!(
257 (ct_s_c[(i, j)] - expected).abs() < 1e-8,
258 "C^T S C[{},{}] = {}, expected {}",
259 i,
260 j,
261 ct_s_c[(i, j)],
262 expected,
263 );
264 }
265 }
266 }
267
268 #[test]
269 fn test_error_mismatched_arrays() {
270 let elements = [1u8, 1];
271 let positions = [[0.0, 0.0, 0.0]]; assert!(solve_eht(&elements, &positions, None).is_err());
273 }
274
275 #[test]
276 fn test_valence_electron_count() {
277 assert_eq!(count_valence_electrons(&[1, 1]), 2); assert_eq!(count_valence_electrons(&[8, 1, 1]), 8); assert_eq!(count_valence_electrons(&[6, 1, 1, 1, 1]), 8); assert_eq!(count_valence_electrons(&[7, 1, 1, 1]), 8); }
282}