1use nalgebra::DMatrix;
9
10use super::types::{ScfInput, SpectroscopyResult, TransitionInfo};
11
12const HARTREE_TO_EV: f64 = 27.211386245988;
13
14#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
16pub struct StdaConfig {
17 pub occ_window_ev: f64,
19 pub virt_window_ev: f64,
21 pub n_roots: usize,
23 pub ax: f64,
25 pub threshold: f64,
27}
28
29impl Default for StdaConfig {
30 fn default() -> Self {
31 Self {
32 occ_window_ev: 7.0,
33 virt_window_ev: 9.0,
34 n_roots: 20,
35 ax: 0.5,
36 threshold: 1e-6,
37 }
38 }
39}
40
41struct ActiveSpace {
42 occ_indices: Vec<usize>,
43 virt_indices: Vec<usize>,
44 n_occ: usize,
45 n_virt: usize,
46}
47
48fn select_active_space(scf: &ScfInput, config: &StdaConfig) -> ActiveSpace {
49 let n_occ = scf.n_electrons / 2;
50 let homo_e = scf.orbital_energies[n_occ - 1];
51
52 let lumo_e = if n_occ < scf.n_basis {
53 scf.orbital_energies[n_occ]
54 } else {
55 homo_e + 1.0
56 };
57
58 let occ_cutoff = homo_e - config.occ_window_ev / HARTREE_TO_EV;
59 let virt_cutoff = lumo_e + config.virt_window_ev / HARTREE_TO_EV;
60
61 let core_floor_hartree = -20.0 / HARTREE_TO_EV;
65 let effective_occ_cutoff = occ_cutoff.max(core_floor_hartree);
66
67 let occ_indices: Vec<usize> = (0..n_occ)
68 .filter(|&i| scf.orbital_energies[i] >= effective_occ_cutoff)
69 .collect();
70
71 let virt_indices: Vec<usize> = (n_occ..scf.n_basis)
72 .filter(|&a| scf.orbital_energies[a] <= virt_cutoff)
73 .collect();
74
75 ActiveSpace {
76 n_occ: occ_indices.len(),
77 n_virt: virt_indices.len(),
78 occ_indices,
79 virt_indices,
80 }
81}
82
83fn transition_charges(
85 scf: &ScfInput,
86 basis_to_atom: &[usize],
87 n_atoms: usize,
88) -> Vec<Vec<Vec<f64>>> {
89 let n_occ = scf.n_electrons / 2;
90 let n_basis = scf.n_basis;
91
92 let sc = &scf.overlap_matrix * &scf.mo_coefficients;
93
94 let mut q = vec![vec![vec![0.0; n_atoms]; n_basis - n_occ]; n_occ];
95
96 for i in 0..n_occ {
97 for (a_idx, a) in (n_occ..n_basis).enumerate() {
98 for mu in 0..n_basis {
99 let atom = basis_to_atom[mu];
100 q[i][a_idx][atom] += scf.mo_coefficients[(mu, i)] * sc[(mu, a)];
101 }
102 }
103 }
104
105 q
106}
107
108pub fn compute_stda(
112 scf: &ScfInput,
113 basis_to_atom: &[usize],
114 positions_bohr: &[[f64; 3]],
115 config: &StdaConfig,
116) -> SpectroscopyResult {
117 let active = select_active_space(scf, config);
118 let n_active_occ = active.n_occ;
119 let n_active_virt = active.n_virt;
120 let n_singles = n_active_occ * n_active_virt;
121
122 if n_singles == 0 {
123 return SpectroscopyResult {
124 transitions: Vec::new(),
125 method: "sTDA".to_string(),
126 };
127 }
128
129 let n_atoms = positions_bohr.len();
130
131 let mut a_matrix = DMatrix::zeros(n_singles, n_singles);
133
134 for (idx, (i_local, a_local)) in iproduct(n_active_occ, n_active_virt).enumerate() {
136 let i = active.occ_indices[i_local];
137 let a = active.virt_indices[a_local];
138 a_matrix[(idx, idx)] = scf.orbital_energies[a] - scf.orbital_energies[i];
139 }
140
141 let eta: Vec<f64> = (0..n_atoms).map(|_| 0.3).collect();
143 let gamma = compute_gamma(positions_bohr, &eta);
144
145 let n_occ_total = scf.n_electrons / 2;
146 let q = transition_charges(scf, basis_to_atom, n_atoms);
147
148 let q_norms: Vec<f64> = iproduct(n_active_occ, n_active_virt)
151 .map(|(i_l, a_l)| {
152 let i = active.occ_indices[i_l];
153 let a_abs = active.virt_indices[a_l] - n_occ_total;
154 q[i][a_abs].iter().map(|x| x * x).sum::<f64>().sqrt()
155 })
156 .collect();
157
158 #[cfg(feature = "parallel")]
159 {
160 use rayon::prelude::*;
161
162 let pairs_1: Vec<(usize, usize, usize)> = iproduct(n_active_occ, n_active_virt)
163 .enumerate()
164 .map(|(idx, (i_l, a_l))| {
165 (
166 idx,
167 active.occ_indices[i_l],
168 active.virt_indices[a_l] - n_occ_total,
169 )
170 })
171 .collect();
172
173 let pairs_2: Vec<(usize, usize, usize)> = iproduct(n_active_occ, n_active_virt)
174 .enumerate()
175 .map(|(idx, (j_l, b_l))| {
176 (
177 idx,
178 active.occ_indices[j_l],
179 active.virt_indices[b_l] - n_occ_total,
180 )
181 })
182 .collect();
183
184 let row_contribs: Vec<Vec<(usize, f64)>> = pairs_1
185 .par_iter()
186 .map(|&(idx1, i, a_abs)| {
187 let mut row = Vec::with_capacity(n_singles);
188 let norm1 = q_norms[idx1];
189 for &(idx2, j, b_abs) in &pairs_2 {
190 if norm1 * q_norms[idx2] < config.threshold {
192 continue;
193 }
194 let mut j_integral = 0.0;
195 for atom_a in 0..n_atoms {
196 let q_ia = q[i][a_abs][atom_a];
197 if q_ia.abs() < config.threshold {
198 continue;
199 }
200 for atom_b in 0..n_atoms {
201 j_integral += q_ia * gamma[(atom_a, atom_b)] * q[j][b_abs][atom_b];
202 }
203 }
204 row.push((idx2, 2.0 * j_integral));
205 }
206 row
207 })
208 .collect();
209
210 for (idx1, row) in row_contribs.into_iter().enumerate() {
211 for (idx2, val) in row {
212 a_matrix[(idx1, idx2)] += val;
213 }
214 }
215 }
216
217 #[cfg(not(feature = "parallel"))]
218 {
219 for (idx1, (i_l, a_l)) in iproduct(n_active_occ, n_active_virt).enumerate() {
220 let i = active.occ_indices[i_l];
221 let a_abs = active.virt_indices[a_l] - n_occ_total;
222 let norm1 = q_norms[idx1];
223
224 for (idx2, (j_l, b_l)) in iproduct(n_active_occ, n_active_virt).enumerate() {
225 let j = active.occ_indices[j_l];
226 let b_abs = active.virt_indices[b_l] - n_occ_total;
227
228 if norm1 * q_norms[idx2] < config.threshold {
230 continue;
231 }
232
233 let mut j_integral = 0.0;
234 for atom_a in 0..n_atoms {
235 let q_ia = q[i][a_abs][atom_a];
236 if q_ia.abs() < config.threshold {
237 continue;
238 }
239 for atom_b in 0..n_atoms {
240 j_integral += q_ia * gamma[(atom_a, atom_b)] * q[j][b_abs][atom_b];
241 }
242 }
243
244 a_matrix[(idx1, idx2)] += 2.0 * j_integral;
245 }
246 }
247 }
248
249 let eigen = a_matrix.symmetric_eigen();
251
252 let mut idx_sorted: Vec<usize> = (0..n_singles).collect();
253 idx_sorted.sort_by(|&a, &b| {
254 eigen.eigenvalues[a]
255 .partial_cmp(&eigen.eigenvalues[b])
256 .unwrap()
257 });
258
259 let n_roots = config.n_roots.min(n_singles);
260 let mut transitions = Vec::with_capacity(n_roots);
261
262 for root in 0..n_roots {
263 let idx = idx_sorted[root];
264 let energy_hartree = eigen.eigenvalues[idx];
265 let energy_ev = energy_hartree * HARTREE_TO_EV;
266
267 if energy_ev < 0.0 {
268 continue;
269 }
270
271 let ci_vector = eigen.eigenvectors.column(idx);
272 let (tdm, osc_strength) = transition_dipole_from_ci(
273 &ci_vector,
274 &active,
275 &q,
276 positions_bohr,
277 n_occ_total,
278 energy_hartree,
279 );
280
281 transitions.push(TransitionInfo {
282 energy_ev,
283 wavelength_nm: if energy_ev > 0.0 {
284 1239.84198 / energy_ev
285 } else {
286 0.0
287 },
288 oscillator_strength: osc_strength,
289 transition_dipole: tdm,
290 });
291 }
292
293 SpectroscopyResult {
294 transitions,
295 method: "sTDA".to_string(),
296 }
297}
298
299fn transition_dipole_from_ci(
300 ci: &nalgebra::DVectorView<f64>,
301 active: &ActiveSpace,
302 q: &[Vec<Vec<f64>>],
303 positions_bohr: &[[f64; 3]],
304 n_occ_total: usize,
305 energy_hartree: f64,
306) -> ([f64; 3], f64) {
307 let n_atoms = positions_bohr.len();
308 let mut tdm = [0.0f64; 3];
309
310 for (idx, (i_l, a_l)) in iproduct(active.n_occ, active.n_virt).enumerate() {
313 let i = active.occ_indices[i_l];
314 let a_abs = active.virt_indices[a_l] - n_occ_total;
315 let x_ia = ci[idx];
316
317 if x_ia.abs() < 1e-10 {
318 continue;
319 }
320
321 for atom in 0..n_atoms {
322 let charge = q[i][a_abs][atom];
323 tdm[0] += x_ia * charge * positions_bohr[atom][0];
324 tdm[1] += x_ia * charge * positions_bohr[atom][1];
325 tdm[2] += x_ia * charge * positions_bohr[atom][2];
326 }
327 }
328
329 let sqrt2 = std::f64::consts::SQRT_2;
331 tdm[0] *= sqrt2;
332 tdm[1] *= sqrt2;
333 tdm[2] *= sqrt2;
334
335 let tdm_sq = tdm[0] * tdm[0] + tdm[1] * tdm[1] + tdm[2] * tdm[2];
336 let osc = (2.0 / 3.0) * energy_hartree * tdm_sq;
337
338 (tdm, osc)
339}
340
341fn compute_gamma(positions: &[[f64; 3]], eta: &[f64]) -> DMatrix<f64> {
342 let n = positions.len();
343 let mut gamma = DMatrix::zeros(n, n);
344
345 for a in 0..n {
346 for b in 0..n {
347 if a == b {
348 gamma[(a, b)] = eta[a];
349 } else {
350 let dx = positions[a][0] - positions[b][0];
351 let dy = positions[a][1] - positions[b][1];
352 let dz = positions[a][2] - positions[b][2];
353 let r2 = dx * dx + dy * dy + dz * dz;
354
355 let avg_eta_inv = 1.0 / (2.0 * eta[a]) + 1.0 / (2.0 * eta[b]);
356 gamma[(a, b)] = 1.0 / (r2 + avg_eta_inv * avg_eta_inv).sqrt();
357 }
358 }
359 }
360
361 gamma
362}
363
364fn iproduct(n_occ: usize, n_virt: usize) -> impl Iterator<Item = (usize, usize)> + Clone {
365 (0..n_occ).flat_map(move |i| (0..n_virt).map(move |a| (i, a)))
366}
367
368#[cfg(test)]
369mod tests {
370 use super::*;
371 use nalgebra::DMatrix;
372
373 #[test]
374 fn test_gamma_matrix_symmetry() {
375 let pos = vec![[0.0, 0.0, 0.0], [3.0, 0.0, 0.0], [0.0, 3.0, 0.0]];
376 let eta = vec![0.3, 0.3, 0.3];
377 let gamma = compute_gamma(&pos, &eta);
378
379 for i in 0..3 {
380 for j in 0..3 {
381 assert!(
382 (gamma[(i, j)] - gamma[(j, i)]).abs() < 1e-14,
383 "Gamma should be symmetric"
384 );
385 }
386 }
387 assert!((gamma[(0, 0)] - 0.3).abs() < 1e-14);
388 }
389
390 #[test]
391 fn test_active_space_selection() {
392 let n_basis = 5;
393 let scf = ScfInput {
394 orbital_energies: vec![-1.0, -0.5, 0.2, 0.8, 1.5],
395 mo_coefficients: DMatrix::identity(n_basis, n_basis),
396 density_matrix: DMatrix::zeros(n_basis, n_basis),
397 overlap_matrix: DMatrix::identity(n_basis, n_basis),
398 n_basis,
399 n_electrons: 4,
400 };
401
402 let config = StdaConfig::default();
403 let active = select_active_space(&scf, &config);
404
405 assert!(active.n_occ > 0, "Should have active occupied orbitals");
406 assert!(active.n_virt > 0, "Should have active virtual orbitals");
407 }
408
409 #[test]
410 fn test_stda_empty_on_no_space() {
411 let scf = ScfInput {
412 orbital_energies: vec![-10.0, 10.0],
413 mo_coefficients: DMatrix::identity(2, 2),
414 density_matrix: DMatrix::zeros(2, 2),
415 overlap_matrix: DMatrix::identity(2, 2),
416 n_basis: 2,
417 n_electrons: 2,
418 };
419
420 let config = StdaConfig {
421 occ_window_ev: 0.1,
422 virt_window_ev: 0.1,
423 ..Default::default()
424 };
425 let result = compute_stda(&scf, &[0, 0], &[[0.0, 0.0, 0.0]], &config);
426 assert!(result.method == "sTDA");
428 }
429
430 #[test]
431 fn test_stda_produces_transitions() {
432 let n_basis = 4;
433 let scf = ScfInput {
434 orbital_energies: vec![-0.8, -0.3, 0.1, 0.5],
435 mo_coefficients: DMatrix::identity(n_basis, n_basis),
436 density_matrix: DMatrix::zeros(n_basis, n_basis),
437 overlap_matrix: DMatrix::identity(n_basis, n_basis),
438 n_basis,
439 n_electrons: 4,
440 };
441
442 let config = StdaConfig::default();
443 let positions = [[0.0, 0.0, 0.0], [2.0, 0.0, 0.0]];
444 let basis_to_atom = [0, 0, 1, 1];
445 let result = compute_stda(&scf, &basis_to_atom, &positions, &config);
446
447 assert!(!result.transitions.is_empty(), "Should produce transitions");
448 for t in &result.transitions {
449 assert!(t.energy_ev > 0.0);
450 assert!(t.wavelength_nm > 0.0);
451 }
452 }
453
454 #[test]
455 fn test_iproduct() {
456 let pairs: Vec<_> = iproduct(2, 3).collect();
457 assert_eq!(pairs.len(), 6);
458 assert_eq!(pairs[0], (0, 0));
459 assert_eq!(pairs[5], (1, 2));
460 }
461}