1use nalgebra::DMatrix;
9
10use super::types::{ScfInput, SpectroscopyResult, TransitionInfo};
11
12const HARTREE_TO_EV: f64 = 27.211386245988;
13
14#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
16pub struct StdaConfig {
17 pub occ_window_ev: f64,
19 pub virt_window_ev: f64,
21 pub n_roots: usize,
23 pub ax: f64,
25 pub threshold: f64,
27}
28
29impl Default for StdaConfig {
30 fn default() -> Self {
31 Self {
32 occ_window_ev: 7.0,
33 virt_window_ev: 9.0,
34 n_roots: 20,
35 ax: 0.5,
36 threshold: 1e-6,
37 }
38 }
39}
40
41struct ActiveSpace {
42 occ_indices: Vec<usize>,
43 virt_indices: Vec<usize>,
44 n_occ: usize,
45 n_virt: usize,
46}
47
48fn select_active_space(scf: &ScfInput, config: &StdaConfig) -> ActiveSpace {
49 let n_occ = scf.n_electrons / 2;
50 let homo_e = scf.orbital_energies[n_occ - 1];
51
52 let lumo_e = if n_occ < scf.n_basis {
53 scf.orbital_energies[n_occ]
54 } else {
55 homo_e + 1.0
56 };
57
58 let occ_cutoff = homo_e - config.occ_window_ev / HARTREE_TO_EV;
59 let virt_cutoff = lumo_e + config.virt_window_ev / HARTREE_TO_EV;
60
61 let core_floor_hartree = -20.0 / HARTREE_TO_EV;
65 let effective_occ_cutoff = occ_cutoff.max(core_floor_hartree);
66
67 let mut occ_indices: Vec<usize> = (0..n_occ)
68 .filter(|&i| scf.orbital_energies[i] >= effective_occ_cutoff)
69 .collect();
70
71 if occ_indices.is_empty() && n_occ > 0 {
76 let n_include = n_occ.clamp(1, 3);
77 occ_indices = ((n_occ - n_include)..n_occ).collect();
78 }
79
80 let virt_indices: Vec<usize> = (n_occ..scf.n_basis)
81 .filter(|&a| scf.orbital_energies[a] <= virt_cutoff)
82 .collect();
83
84 ActiveSpace {
85 n_occ: occ_indices.len(),
86 n_virt: virt_indices.len(),
87 occ_indices,
88 virt_indices,
89 }
90}
91
92fn transition_charges(
94 scf: &ScfInput,
95 basis_to_atom: &[usize],
96 n_atoms: usize,
97) -> Vec<Vec<Vec<f64>>> {
98 let n_occ = scf.n_electrons / 2;
99 let n_basis = scf.n_basis;
100
101 let sc = &scf.overlap_matrix * &scf.mo_coefficients;
102
103 let mut q = vec![vec![vec![0.0; n_atoms]; n_basis - n_occ]; n_occ];
104
105 for i in 0..n_occ {
106 for (a_idx, a) in (n_occ..n_basis).enumerate() {
107 for mu in 0..n_basis {
108 let atom = basis_to_atom[mu];
109 q[i][a_idx][atom] += scf.mo_coefficients[(mu, i)] * sc[(mu, a)];
110 }
111 }
112 }
113
114 q
115}
116
117pub fn compute_stda(
121 scf: &ScfInput,
122 basis_to_atom: &[usize],
123 positions_bohr: &[[f64; 3]],
124 config: &StdaConfig,
125) -> SpectroscopyResult {
126 let active = select_active_space(scf, config);
127 let n_active_occ = active.n_occ;
128 let n_active_virt = active.n_virt;
129 let n_singles = n_active_occ * n_active_virt;
130
131 if n_singles == 0 {
132 return SpectroscopyResult {
133 transitions: Vec::new(),
134 method: "sTDA".to_string(),
135 };
136 }
137
138 let n_atoms = positions_bohr.len();
139
140 let mut a_matrix = DMatrix::zeros(n_singles, n_singles);
142
143 for (idx, (i_local, a_local)) in iproduct(n_active_occ, n_active_virt).enumerate() {
145 let i = active.occ_indices[i_local];
146 let a = active.virt_indices[a_local];
147 a_matrix[(idx, idx)] = scf.orbital_energies[a] - scf.orbital_energies[i];
148 }
149
150 let eta: Vec<f64> = (0..n_atoms).map(|_| 0.3).collect();
152 let gamma = compute_gamma(positions_bohr, &eta);
153
154 let n_occ_total = scf.n_electrons / 2;
155 let q = transition_charges(scf, basis_to_atom, n_atoms);
156
157 let q_norms: Vec<f64> = iproduct(n_active_occ, n_active_virt)
160 .map(|(i_l, a_l)| {
161 let i = active.occ_indices[i_l];
162 let a_abs = active.virt_indices[a_l] - n_occ_total;
163 q[i][a_abs].iter().map(|x| x * x).sum::<f64>().sqrt()
164 })
165 .collect();
166
167 #[cfg(feature = "parallel")]
168 {
169 use rayon::prelude::*;
170
171 let pairs_1: Vec<(usize, usize, usize)> = iproduct(n_active_occ, n_active_virt)
172 .enumerate()
173 .map(|(idx, (i_l, a_l))| {
174 (
175 idx,
176 active.occ_indices[i_l],
177 active.virt_indices[a_l] - n_occ_total,
178 )
179 })
180 .collect();
181
182 let pairs_2: Vec<(usize, usize, usize)> = iproduct(n_active_occ, n_active_virt)
183 .enumerate()
184 .map(|(idx, (j_l, b_l))| {
185 (
186 idx,
187 active.occ_indices[j_l],
188 active.virt_indices[b_l] - n_occ_total,
189 )
190 })
191 .collect();
192
193 let row_contribs: Vec<Vec<(usize, f64)>> = pairs_1
194 .par_iter()
195 .map(|&(idx1, i, a_abs)| {
196 let mut row = Vec::with_capacity(n_singles);
197 let norm1 = q_norms[idx1];
198 for &(idx2, j, b_abs) in &pairs_2 {
199 if norm1 * q_norms[idx2] < config.threshold {
201 continue;
202 }
203 let mut j_integral = 0.0;
204 for atom_a in 0..n_atoms {
205 let q_ia = q[i][a_abs][atom_a];
206 if q_ia.abs() < config.threshold {
207 continue;
208 }
209 for atom_b in 0..n_atoms {
210 j_integral += q_ia * gamma[(atom_a, atom_b)] * q[j][b_abs][atom_b];
211 }
212 }
213 row.push((idx2, 2.0 * j_integral));
214 }
215 row
216 })
217 .collect();
218
219 for (idx1, row) in row_contribs.into_iter().enumerate() {
220 for (idx2, val) in row {
221 a_matrix[(idx1, idx2)] += val;
222 }
223 }
224 }
225
226 #[cfg(not(feature = "parallel"))]
227 {
228 for (idx1, (i_l, a_l)) in iproduct(n_active_occ, n_active_virt).enumerate() {
229 let i = active.occ_indices[i_l];
230 let a_abs = active.virt_indices[a_l] - n_occ_total;
231 let norm1 = q_norms[idx1];
232
233 for (idx2, (j_l, b_l)) in iproduct(n_active_occ, n_active_virt).enumerate() {
234 let j = active.occ_indices[j_l];
235 let b_abs = active.virt_indices[b_l] - n_occ_total;
236
237 if norm1 * q_norms[idx2] < config.threshold {
239 continue;
240 }
241
242 let mut j_integral = 0.0;
243 for atom_a in 0..n_atoms {
244 let q_ia = q[i][a_abs][atom_a];
245 if q_ia.abs() < config.threshold {
246 continue;
247 }
248 for atom_b in 0..n_atoms {
249 j_integral += q_ia * gamma[(atom_a, atom_b)] * q[j][b_abs][atom_b];
250 }
251 }
252
253 a_matrix[(idx1, idx2)] += 2.0 * j_integral;
254 }
255 }
256 }
257
258 let eigen = a_matrix.symmetric_eigen();
260
261 let mut idx_sorted: Vec<usize> = (0..n_singles).collect();
262 idx_sorted.sort_by(|&a, &b| {
263 eigen.eigenvalues[a]
264 .partial_cmp(&eigen.eigenvalues[b])
265 .unwrap()
266 });
267
268 let n_roots = config.n_roots.min(n_singles);
269 let mut transitions = Vec::with_capacity(n_roots);
270
271 for root in 0..n_roots {
272 let idx = idx_sorted[root];
273 let energy_hartree = eigen.eigenvalues[idx];
274 let energy_ev = energy_hartree * HARTREE_TO_EV;
275
276 if energy_ev < 0.0 {
277 continue;
278 }
279
280 let ci_vector = eigen.eigenvectors.column(idx);
281 let (tdm, osc_strength) = transition_dipole_from_ci(
282 &ci_vector,
283 &active,
284 &q,
285 positions_bohr,
286 n_occ_total,
287 energy_hartree,
288 );
289
290 transitions.push(TransitionInfo {
291 energy_ev,
292 wavelength_nm: if energy_ev > 0.0 {
293 1239.84198 / energy_ev
294 } else {
295 0.0
296 },
297 oscillator_strength: osc_strength,
298 transition_dipole: tdm,
299 });
300 }
301
302 SpectroscopyResult {
303 transitions,
304 method: "sTDA".to_string(),
305 }
306}
307
308fn transition_dipole_from_ci(
309 ci: &nalgebra::DVectorView<f64>,
310 active: &ActiveSpace,
311 q: &[Vec<Vec<f64>>],
312 positions_bohr: &[[f64; 3]],
313 n_occ_total: usize,
314 energy_hartree: f64,
315) -> ([f64; 3], f64) {
316 let n_atoms = positions_bohr.len();
317 let mut tdm = [0.0f64; 3];
318
319 for (idx, (i_l, a_l)) in iproduct(active.n_occ, active.n_virt).enumerate() {
322 let i = active.occ_indices[i_l];
323 let a_abs = active.virt_indices[a_l] - n_occ_total;
324 let x_ia = ci[idx];
325
326 if x_ia.abs() < 1e-10 {
327 continue;
328 }
329
330 for atom in 0..n_atoms {
331 let charge = q[i][a_abs][atom];
332 tdm[0] += x_ia * charge * positions_bohr[atom][0];
333 tdm[1] += x_ia * charge * positions_bohr[atom][1];
334 tdm[2] += x_ia * charge * positions_bohr[atom][2];
335 }
336 }
337
338 let sqrt2 = std::f64::consts::SQRT_2;
340 tdm[0] *= sqrt2;
341 tdm[1] *= sqrt2;
342 tdm[2] *= sqrt2;
343
344 let tdm_sq = tdm[0] * tdm[0] + tdm[1] * tdm[1] + tdm[2] * tdm[2];
345 let osc = (2.0 / 3.0) * energy_hartree * tdm_sq;
346
347 (tdm, osc)
348}
349
350fn compute_gamma(positions: &[[f64; 3]], eta: &[f64]) -> DMatrix<f64> {
351 let n = positions.len();
352 let mut gamma = DMatrix::zeros(n, n);
353
354 for a in 0..n {
355 for b in 0..n {
356 if a == b {
357 gamma[(a, b)] = eta[a];
358 } else {
359 let dx = positions[a][0] - positions[b][0];
360 let dy = positions[a][1] - positions[b][1];
361 let dz = positions[a][2] - positions[b][2];
362 let r2 = dx * dx + dy * dy + dz * dz;
363
364 let avg_eta_inv = 1.0 / (2.0 * eta[a]) + 1.0 / (2.0 * eta[b]);
365 gamma[(a, b)] = 1.0 / (r2 + avg_eta_inv * avg_eta_inv).sqrt();
366 }
367 }
368 }
369
370 gamma
371}
372
373fn iproduct(n_occ: usize, n_virt: usize) -> impl Iterator<Item = (usize, usize)> + Clone {
374 (0..n_occ).flat_map(move |i| (0..n_virt).map(move |a| (i, a)))
375}
376
377#[cfg(test)]
378mod tests {
379 use super::*;
380 use nalgebra::DMatrix;
381
382 #[test]
383 fn test_gamma_matrix_symmetry() {
384 let pos = vec![[0.0, 0.0, 0.0], [3.0, 0.0, 0.0], [0.0, 3.0, 0.0]];
385 let eta = vec![0.3, 0.3, 0.3];
386 let gamma = compute_gamma(&pos, &eta);
387
388 for i in 0..3 {
389 for j in 0..3 {
390 assert!(
391 (gamma[(i, j)] - gamma[(j, i)]).abs() < 1e-14,
392 "Gamma should be symmetric"
393 );
394 }
395 }
396 assert!((gamma[(0, 0)] - 0.3).abs() < 1e-14);
397 }
398
399 #[test]
400 fn test_active_space_selection() {
401 let n_basis = 5;
402 let scf = ScfInput {
403 orbital_energies: vec![-1.0, -0.5, 0.2, 0.8, 1.5],
404 mo_coefficients: DMatrix::identity(n_basis, n_basis),
405 density_matrix: DMatrix::zeros(n_basis, n_basis),
406 overlap_matrix: DMatrix::identity(n_basis, n_basis),
407 n_basis,
408 n_electrons: 4,
409 };
410
411 let config = StdaConfig::default();
412 let active = select_active_space(&scf, &config);
413
414 assert!(active.n_occ > 0, "Should have active occupied orbitals");
415 assert!(active.n_virt > 0, "Should have active virtual orbitals");
416 }
417
418 #[test]
419 fn test_stda_empty_on_no_space() {
420 let scf = ScfInput {
421 orbital_energies: vec![-10.0, 10.0],
422 mo_coefficients: DMatrix::identity(2, 2),
423 density_matrix: DMatrix::zeros(2, 2),
424 overlap_matrix: DMatrix::identity(2, 2),
425 n_basis: 2,
426 n_electrons: 2,
427 };
428
429 let config = StdaConfig {
430 occ_window_ev: 0.1,
431 virt_window_ev: 0.1,
432 ..Default::default()
433 };
434 let result = compute_stda(&scf, &[0, 0], &[[0.0, 0.0, 0.0]], &config);
435 assert!(result.method == "sTDA");
437 }
438
439 #[test]
440 fn test_stda_produces_transitions() {
441 let n_basis = 4;
442 let scf = ScfInput {
443 orbital_energies: vec![-0.8, -0.3, 0.1, 0.5],
444 mo_coefficients: DMatrix::identity(n_basis, n_basis),
445 density_matrix: DMatrix::zeros(n_basis, n_basis),
446 overlap_matrix: DMatrix::identity(n_basis, n_basis),
447 n_basis,
448 n_electrons: 4,
449 };
450
451 let config = StdaConfig::default();
452 let positions = [[0.0, 0.0, 0.0], [2.0, 0.0, 0.0]];
453 let basis_to_atom = [0, 0, 1, 1];
454 let result = compute_stda(&scf, &basis_to_atom, &positions, &config);
455
456 assert!(!result.transitions.is_empty(), "Should produce transitions");
457 for t in &result.transitions {
458 assert!(t.energy_ev > 0.0);
459 assert!(t.wavelength_nm > 0.0);
460 }
461 }
462
463 #[test]
464 fn test_iproduct() {
465 let pairs: Vec<_> = iproduct(2, 3).collect();
466 assert_eq!(pairs.len(), 6);
467 assert_eq!(pairs[0], (0, 0));
468 assert_eq!(pairs[5], (1, 2));
469 }
470}