1use super::fock::{build_fock, electronic_energy};
10use nalgebra::{DMatrix, DVector, SymmetricEigen};
11
12#[derive(Debug, Clone)]
14pub struct ScfResult {
15 pub energy: f64,
17 pub orbital_energies: Vec<f64>,
19 pub coefficients: DMatrix<f64>,
21 pub density: DMatrix<f64>,
23 pub iterations: usize,
25 pub converged: bool,
27}
28
29pub struct ScfConfig {
31 pub max_iter: usize,
32 pub energy_threshold: f64,
33 pub density_threshold: f64,
34 pub diis_size: usize,
35 pub level_shift: f64,
38}
39
40impl Default for ScfConfig {
41 fn default() -> Self {
42 ScfConfig {
43 max_iter: 100,
44 energy_threshold: 1e-8,
45 density_threshold: 1e-6,
46 diis_size: 6,
47 level_shift: 0.0,
48 }
49 }
50}
51
52pub fn solve_scf(
54 h_core: &DMatrix<f64>,
55 s_mat: &DMatrix<f64>,
56 eris: &[f64],
57 n_electrons: usize,
58 config: &ScfConfig,
59) -> ScfResult {
60 let n = h_core.nrows();
61 let n_occ = n_electrons / 2;
62
63 let s_half_inv = lowdin_orthogonalization(s_mat);
65
66 let (mut energies, mut coeffs) = diagonalize_fock(h_core, &s_half_inv);
68 let mut density = build_density(&coeffs, n_occ);
69
70 let mut prev_energy = 0.0;
71 let mut converged = false;
72 let mut iterations = 0;
73
74 let mut diis_focks: Vec<DMatrix<f64>> = Vec::new();
76 let mut diis_errors: Vec<DMatrix<f64>> = Vec::new();
77
78 let mut prev_error_norm = f64::MAX;
79
80 for iter in 0..config.max_iter {
81 iterations = iter + 1;
82
83 let fock = build_fock(h_core, &density, eris, n);
84 let energy = electronic_energy(&density, h_core, &fock);
85
86 let error = &fock * &density * s_mat - s_mat * &density * &fock;
88 let error_norm = error.iter().map(|v| v * v).sum::<f64>().sqrt();
89
90 if error_norm > prev_error_norm * 10.0 && diis_focks.len() > 2 {
93 diis_focks.clear();
94 diis_errors.clear();
95 }
96 prev_error_norm = error_norm;
97
98 diis_focks.push(fock.clone());
100 diis_errors.push(error);
101 if diis_focks.len() > config.diis_size {
102 diis_focks.remove(0);
103 diis_errors.remove(0);
104 }
105
106 let fock_diis = if diis_focks.len() >= 2 {
107 diis_extrapolate(&diis_focks, &diis_errors)
108 } else {
109 fock
110 };
111
112 let fock_shifted = if config.level_shift > 0.0 && n_occ < n {
114 let f_orth = s_half_inv.transpose() * &fock_diis * &s_half_inv;
117 let eigen = SymmetricEigen::new(f_orth);
118 let mut shifted_evals = eigen.eigenvalues.clone();
119 let mut idx_sorted: Vec<usize> = (0..n).collect();
120 idx_sorted.sort_by(|&a, &b| shifted_evals[a].partial_cmp(&shifted_evals[b]).unwrap());
121 for &idx in idx_sorted.iter().skip(n_occ) {
122 shifted_evals[idx] += config.level_shift;
123 }
124 let v = &eigen.eigenvectors;
125 let d = DMatrix::from_diagonal(&shifted_evals);
126 let f_shifted = v * d * v.transpose();
127 &s_half_inv * f_shifted * s_half_inv.transpose()
130 } else {
131 fock_diis
132 };
133
134 let (new_energies, new_coeffs) = diagonalize_fock(&fock_shifted, &s_half_inv);
135 let new_density = build_density(&new_coeffs, n_occ);
136
137 let de = (energy - prev_energy).abs();
138
139 if de < config.energy_threshold && error_norm < config.density_threshold {
140 converged = true;
141 energies = new_energies;
142 coeffs = new_coeffs;
143 density = new_density;
144 break;
145 }
146
147 prev_energy = energy;
148 energies = new_energies;
149 coeffs = new_coeffs;
150 density = new_density;
151 }
152
153 let final_energy = electronic_energy(&density, h_core, &build_fock(h_core, &density, eris, n));
154
155 ScfResult {
156 energy: final_energy,
157 orbital_energies: energies.as_slice().to_vec(),
158 coefficients: coeffs,
159 density,
160 iterations,
161 converged,
162 }
163}
164
165fn lowdin_orthogonalization(s: &DMatrix<f64>) -> DMatrix<f64> {
166 let eigen = SymmetricEigen::new(s.clone());
167 let n = s.nrows();
168 let mut s_inv_half = DMatrix::zeros(n, n);
169
170 for i in 0..n {
171 let val = eigen.eigenvalues[i];
172 if val > 1e-10 {
173 let factor = 1.0 / val.sqrt();
174 let col = eigen.eigenvectors.column(i);
175 s_inv_half += factor * col * col.transpose();
176 }
177 }
178 s_inv_half
179}
180
181fn diagonalize_fock(
182 fock: &DMatrix<f64>,
183 s_half_inv: &DMatrix<f64>,
184) -> (DVector<f64>, DMatrix<f64>) {
185 let f_prime = s_half_inv.transpose() * fock * s_half_inv;
186 let eigen = SymmetricEigen::new(f_prime);
187
188 let n = eigen.eigenvalues.len();
190 let mut indices: Vec<usize> = (0..n).collect();
191 indices.sort_by(|&a, &b| {
192 eigen.eigenvalues[a]
193 .partial_cmp(&eigen.eigenvalues[b])
194 .unwrap()
195 });
196
197 let sorted_energies = DVector::from_fn(n, |i, _| eigen.eigenvalues[indices[i]]);
198 let sorted_vecs = DMatrix::from_fn(n, n, |r, c| eigen.eigenvectors[(r, indices[c])]);
199
200 let coeffs = s_half_inv * sorted_vecs;
202 (sorted_energies, coeffs)
203}
204
205fn build_density(coeffs: &DMatrix<f64>, n_occ: usize) -> DMatrix<f64> {
206 let n = coeffs.nrows();
207 let mut density = DMatrix::zeros(n, n);
208 for i in 0..n_occ {
209 let col = coeffs.column(i);
210 density += 2.0 * &col * col.transpose();
211 }
212 density
213}
214
215fn diis_extrapolate(focks: &[DMatrix<f64>], errors: &[DMatrix<f64>]) -> DMatrix<f64> {
216 let m = errors.len();
217 let mut b = DMatrix::zeros(m + 1, m + 1);
218
219 for i in 0..m {
220 for j in 0..=i {
221 let bij: f64 = errors[i]
222 .iter()
223 .zip(errors[j].iter())
224 .map(|(a, b)| a * b)
225 .sum();
226 b[(i, j)] = bij;
227 b[(j, i)] = bij;
228 }
229 }
230 for i in 0..m {
231 b[(m, i)] = -1.0;
232 b[(i, m)] = -1.0;
233 }
234
235 let mut rhs = DVector::zeros(m + 1);
236 rhs[m] = -1.0;
237
238 let svd = b.svd(true, true);
240 let c = match svd.solve(&rhs, 1e-10) {
241 Ok(c) => c,
242 Err(_) => {
243 return focks.last().unwrap().clone();
245 }
246 };
247
248 let mut f_diis = DMatrix::zeros(focks[0].nrows(), focks[0].ncols());
249 for i in 0..m {
250 f_diis += c[i] * &focks[i];
251 }
252 f_diis
253}
254
255#[cfg(test)]
256mod tests {
257 use super::*;
258
259 #[test]
260 fn test_lowdin_identity() {
261 let s = DMatrix::identity(3, 3);
262 let s_inv = lowdin_orthogonalization(&s);
263 for i in 0..3 {
264 for j in 0..3 {
265 let expected = if i == j { 1.0 } else { 0.0 };
266 assert!(
267 (s_inv[(i, j)] - expected).abs() < 1e-10,
268 "S^{{-1/2}} of identity should be identity"
269 );
270 }
271 }
272 }
273}