1use crate::eht::basis::{build_basis, AtomicOrbital};
7use crate::eht::overlap::build_overlap_matrix;
8
9fn gaussian_value(energy: f64, center: f64, norm: f64, inv_2s2: f64) -> f64 {
10 norm * (-(energy - center).powi(2) * inv_2s2).exp()
11}
12
13#[derive(Debug, Clone)]
15pub struct DosResult {
16 pub energies: Vec<f64>,
18 pub total_dos: Vec<f64>,
20 pub pdos: Vec<Vec<f64>>,
22 pub sigma: f64,
24}
25
26pub fn compute_dos(
33 orbital_energies: &[f64],
34 sigma: f64,
35 e_min: f64,
36 e_max: f64,
37 n_points: usize,
38) -> DosResult {
39 let step = (e_max - e_min) / (n_points - 1).max(1) as f64;
40 let energies: Vec<f64> = (0..n_points).map(|i| e_min + i as f64 * step).collect();
41
42 let norm = 1.0 / (sigma * (2.0 * std::f64::consts::PI).sqrt());
43 let inv_2s2 = 1.0 / (2.0 * sigma * sigma);
44
45 let total_dos: Vec<f64> = energies
46 .iter()
47 .map(|&e| {
48 orbital_energies
49 .iter()
50 .map(|&ei| norm * (-(e - ei).powi(2) * inv_2s2).exp())
51 .sum()
52 })
53 .collect();
54
55 DosResult {
56 energies,
57 total_dos,
58 pdos: Vec::new(),
59 sigma,
60 }
61}
62
63#[cfg(feature = "parallel")]
65pub fn compute_dos_parallel(
66 orbital_energies: &[f64],
67 sigma: f64,
68 e_min: f64,
69 e_max: f64,
70 n_points: usize,
71) -> DosResult {
72 use rayon::prelude::*;
73
74 let step = (e_max - e_min) / (n_points - 1).max(1) as f64;
75 let energies: Vec<f64> = (0..n_points).map(|i| e_min + i as f64 * step).collect();
76
77 let norm = 1.0 / (sigma * (2.0 * std::f64::consts::PI).sqrt());
78 let inv_2s2 = 1.0 / (2.0 * sigma * sigma);
79
80 let total_dos: Vec<f64> = energies
81 .par_iter()
82 .map(|&energy| {
83 orbital_energies
84 .iter()
85 .map(|¢er| gaussian_value(energy, center, norm, inv_2s2))
86 .sum()
87 })
88 .collect();
89
90 DosResult {
91 energies,
92 total_dos,
93 pdos: Vec::new(),
94 sigma,
95 }
96}
97
98#[allow(clippy::too_many_arguments)]
109pub fn compute_pdos(
110 elements: &[u8],
111 positions: &[f64],
112 orbital_energies: &[f64],
113 coefficients: &[Vec<f64>],
114 n_electrons: usize,
115 sigma: f64,
116 e_min: f64,
117 e_max: f64,
118 n_points: usize,
119) -> DosResult {
120 let n_atoms = elements.len();
121 let pos_arr: Vec<[f64; 3]> = positions
122 .chunks_exact(3)
123 .map(|c| [c[0], c[1], c[2]])
124 .collect();
125 let basis: Vec<AtomicOrbital> = build_basis(elements, &pos_arr);
126 let overlap = build_overlap_matrix(&basis);
127 let n_basis = basis.len();
128
129 let n_orb = orbital_energies.len().min(coefficients.len());
132 let mut orbital_atom_weight = vec![vec![0.0f64; n_atoms]; n_orb];
133
134 for k in 0..n_orb {
135 for mu in 0..n_basis {
136 if coefficients.len() <= mu || coefficients[mu].len() <= k {
137 continue;
138 }
139 let atom_mu = basis[mu].atom_index;
140 let mut w = 0.0;
141 for nu in 0..n_basis {
142 if coefficients.len() <= nu || coefficients[nu].len() <= k {
143 continue;
144 }
145 w += coefficients[mu][k] * overlap[(mu, nu)] * coefficients[nu][k];
147 }
148 orbital_atom_weight[k][atom_mu] += w;
149 }
150 let total_w: f64 = orbital_atom_weight[k].iter().sum();
152 if total_w.abs() > 1e-12 {
153 for a in 0..n_atoms {
154 orbital_atom_weight[k][a] /= total_w;
155 }
156 }
157 }
158
159 let step = (e_max - e_min) / (n_points - 1).max(1) as f64;
160 let energies: Vec<f64> = (0..n_points).map(|i| e_min + i as f64 * step).collect();
161
162 let norm = 1.0 / (sigma * (2.0 * std::f64::consts::PI).sqrt());
163 let inv_2s2 = 1.0 / (2.0 * sigma * sigma);
164
165 let total_dos: Vec<f64> = energies
167 .iter()
168 .map(|&e| {
169 (0..n_orb)
170 .map(|k| norm * (-(e - orbital_energies[k]).powi(2) * inv_2s2).exp())
171 .sum()
172 })
173 .collect();
174
175 let mut pdos = vec![vec![0.0f64; n_points]; n_atoms];
177 for a in 0..n_atoms {
178 for (gi, &e) in energies.iter().enumerate() {
179 let mut val = 0.0;
180 for k in 0..n_orb {
181 let gauss = norm * (-(e - orbital_energies[k]).powi(2) * inv_2s2).exp();
182 val += orbital_atom_weight[k][a] * gauss;
183 }
184 pdos[a][gi] = val;
185 }
186 }
187
188 let _ = n_electrons; DosResult {
191 energies,
192 total_dos,
193 pdos,
194 sigma,
195 }
196}
197
198#[cfg(feature = "parallel")]
200#[allow(clippy::too_many_arguments)]
201pub fn compute_pdos_parallel(
202 elements: &[u8],
203 positions: &[f64],
204 orbital_energies: &[f64],
205 coefficients: &[Vec<f64>],
206 n_electrons: usize,
207 sigma: f64,
208 e_min: f64,
209 e_max: f64,
210 n_points: usize,
211) -> DosResult {
212 use rayon::prelude::*;
213
214 let n_atoms = elements.len();
215 let pos_arr: Vec<[f64; 3]> = positions
216 .chunks_exact(3)
217 .map(|c| [c[0], c[1], c[2]])
218 .collect();
219 let basis: Vec<AtomicOrbital> = build_basis(elements, &pos_arr);
220 let overlap = build_overlap_matrix(&basis);
221 let n_basis = basis.len();
222 let n_orb = orbital_energies.len().min(coefficients.len());
223
224 let orbital_atom_weight: Vec<Vec<f64>> = (0..n_orb)
225 .into_par_iter()
226 .map(|k| {
227 let mut weights = vec![0.0f64; n_atoms];
228 for mu in 0..n_basis {
229 if coefficients.len() <= mu || coefficients[mu].len() <= k {
230 continue;
231 }
232 let atom_mu = basis[mu].atom_index;
233 let mut weight = 0.0;
234 for nu in 0..n_basis {
235 if coefficients.len() <= nu || coefficients[nu].len() <= k {
236 continue;
237 }
238 weight += coefficients[mu][k] * overlap[(mu, nu)] * coefficients[nu][k];
239 }
240 weights[atom_mu] += weight;
241 }
242
243 let total_weight: f64 = weights.iter().sum();
244 if total_weight.abs() > 1e-12 {
245 for weight in &mut weights {
246 *weight /= total_weight;
247 }
248 }
249 weights
250 })
251 .collect();
252
253 let step = (e_max - e_min) / (n_points - 1).max(1) as f64;
254 let energies: Vec<f64> = (0..n_points).map(|i| e_min + i as f64 * step).collect();
255
256 let norm = 1.0 / (sigma * (2.0 * std::f64::consts::PI).sqrt());
257 let inv_2s2 = 1.0 / (2.0 * sigma * sigma);
258
259 let total_dos: Vec<f64> = energies
260 .par_iter()
261 .map(|&energy| {
262 (0..n_orb)
263 .map(|k| gaussian_value(energy, orbital_energies[k], norm, inv_2s2))
264 .sum()
265 })
266 .collect();
267
268 let pdos: Vec<Vec<f64>> = (0..n_atoms)
269 .into_par_iter()
270 .map(|atom_index| {
271 energies
272 .iter()
273 .map(|&energy| {
274 (0..n_orb)
275 .map(|k| {
276 orbital_atom_weight[k][atom_index]
277 * gaussian_value(energy, orbital_energies[k], norm, inv_2s2)
278 })
279 .sum()
280 })
281 .collect()
282 })
283 .collect();
284
285 let _ = n_electrons;
286
287 DosResult {
288 energies,
289 total_dos,
290 pdos,
291 sigma,
292 }
293}
294
295pub fn dos_mse(a: &[f64], b: &[f64]) -> f64 {
300 assert_eq!(a.len(), b.len(), "DOS curves must have same length");
301 let n = a.len() as f64;
302 a.iter()
303 .zip(b.iter())
304 .map(|(x, y)| (x - y).powi(2))
305 .sum::<f64>()
306 / n
307}
308
309pub fn export_dos_json(result: &DosResult) -> String {
321 let mut json = String::from("{");
322 json.push_str("\"energies\":[");
323 for (i, e) in result.energies.iter().enumerate() {
324 if i > 0 {
325 json.push(',');
326 }
327 json.push_str(&format!("{:.6}", e));
328 }
329 json.push_str("],\"total_dos\":[");
330 for (i, d) in result.total_dos.iter().enumerate() {
331 if i > 0 {
332 json.push(',');
333 }
334 json.push_str(&format!("{:.6}", d));
335 }
336 json.push_str(&format!("],\"sigma\":{:.6}", result.sigma));
337 if !result.pdos.is_empty() {
338 json.push_str(",\"pdos\":{");
339 for (a, pdos_a) in result.pdos.iter().enumerate() {
340 if a > 0 {
341 json.push(',');
342 }
343 json.push_str(&format!("\"{}\":[", a));
344 for (i, v) in pdos_a.iter().enumerate() {
345 if i > 0 {
346 json.push(',');
347 }
348 json.push_str(&format!("{:.6}", v));
349 }
350 json.push(']');
351 }
352 json.push('}');
353 }
354 json.push('}');
355 json
356}
357
358#[cfg(test)]
359mod tests {
360 use super::*;
361
362 #[test]
363 fn test_dos_single_level() {
364 let res = compute_dos(&[0.0], 0.1, -1.0, 1.0, 201);
366 assert_eq!(res.energies.len(), 201);
367 assert_eq!(res.total_dos.len(), 201);
368 let peak_idx = res
370 .total_dos
371 .iter()
372 .enumerate()
373 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
374 .unwrap()
375 .0;
376 assert_eq!(peak_idx, 100);
377 }
378
379 #[test]
380 fn test_dos_integral_approx_one() {
381 let res = compute_dos(&[0.0], 0.2, -3.0, 3.0, 1001);
383 let de = (3.0 - (-3.0)) / 1000.0;
384 let integral: f64 = res.total_dos.iter().sum::<f64>() * de;
385 assert!((integral - 1.0).abs() < 0.01, "integral = {integral}");
386 }
387
388 #[test]
389 fn test_dos_two_peaks() {
390 let res = compute_dos(&[-5.0, 5.0], 0.3, -10.0, 10.0, 501);
391 let mid = res.total_dos[250];
393 let left_peak = res.total_dos[125];
394 let right_peak = res.total_dos[375];
395 assert!(left_peak > mid * 5.0);
396 assert!(right_peak > mid * 5.0);
397 }
398
399 #[test]
400 fn test_pdos_h2() {
401 let elements = vec![1u8, 1];
403 let pos_arr = vec![[0.0, 0.0, 0.0], [0.74, 0.0, 0.0]];
404 let positions: Vec<f64> = pos_arr.iter().flat_map(|p| p.iter().copied()).collect();
405 let eht = crate::eht::solve_eht(&elements, &pos_arr, None).unwrap();
406 let res = compute_pdos(
407 &elements,
408 &positions,
409 &eht.energies,
410 &eht.coefficients,
411 eht.n_electrons,
412 0.2,
413 -20.0,
414 5.0,
415 201,
416 );
417 assert_eq!(res.pdos.len(), 2);
418 let peak_val = res.pdos[0].iter().cloned().fold(0.0f64, f64::max);
421 let threshold = peak_val * 0.1;
422 for i in 0..201 {
423 if res.pdos[0][i].abs() > threshold || res.pdos[1][i].abs() > threshold {
424 let diff = (res.pdos[0][i] - res.pdos[1][i]).abs();
425 let avg = (res.pdos[0][i].abs() + res.pdos[1][i].abs()) / 2.0;
426 assert!(
427 diff < avg * 0.05 + 1e-6,
428 "PDOS mismatch at grid point {i}: {} vs {} (peak={})",
429 res.pdos[0][i],
430 res.pdos[1][i],
431 peak_val
432 );
433 }
434 }
435 }
436
437 #[test]
438 fn test_pdos_sums_to_total() {
439 let elements = vec![8u8, 1, 1];
441 let pos_arr = vec![[0.0, 0.0, 0.0], [0.96, 0.0, 0.0], [-0.24, 0.93, 0.0]];
442 let positions: Vec<f64> = pos_arr.iter().flat_map(|p| p.iter().copied()).collect();
443 let eht = crate::eht::solve_eht(&elements, &pos_arr, None).unwrap();
444 let res = compute_pdos(
445 &elements,
446 &positions,
447 &eht.energies,
448 &eht.coefficients,
449 eht.n_electrons,
450 0.3,
451 -30.0,
452 5.0,
453 201,
454 );
455 for i in 0..201 {
456 let pdos_sum: f64 = res.pdos.iter().map(|p| p[i]).sum();
457 let diff = (pdos_sum - res.total_dos[i]).abs();
458 assert!(
459 diff < res.total_dos[i].abs() * 0.05 + 1e-10,
460 "PDOS sum {pdos_sum} vs total {} at grid {i}",
461 res.total_dos[i]
462 );
463 }
464 }
465
466 #[test]
467 fn test_dos_mse_identical() {
468 let a = vec![1.0, 2.0, 3.0, 4.0];
469 assert!((dos_mse(&a, &a)) < 1e-15);
470 }
471
472 #[test]
473 fn test_dos_mse_known() {
474 let a = vec![1.0, 2.0, 3.0];
475 let b = vec![1.1, 1.9, 3.2];
476 assert!((dos_mse(&a, &b) - 0.02).abs() < 1e-10);
478 }
479
480 #[test]
481 fn test_export_dos_json_roundtrip() {
482 let res = compute_dos(&[0.0, -5.0], 0.3, -10.0, 5.0, 51);
483 let json = export_dos_json(&res);
484
485 let parsed: serde_json::Value = serde_json::from_str(&json).expect("valid JSON");
487 assert!(parsed["energies"].is_array());
488 assert!(parsed["total_dos"].is_array());
489 assert_eq!(parsed["energies"].as_array().unwrap().len(), 51);
490 assert_eq!(parsed["total_dos"].as_array().unwrap().len(), 51);
491 }
492
493 #[test]
494 fn test_export_pdos_json() {
495 let elements = vec![1u8, 1];
496 let pos_arr = vec![[0.0, 0.0, 0.0], [0.74, 0.0, 0.0]];
497 let positions: Vec<f64> = pos_arr.iter().flat_map(|p| p.iter().copied()).collect();
498 let eht = crate::eht::solve_eht(&elements, &pos_arr, None).unwrap();
499 let res = compute_pdos(
500 &elements,
501 &positions,
502 &eht.energies,
503 &eht.coefficients,
504 eht.n_electrons,
505 0.2,
506 -20.0,
507 5.0,
508 51,
509 );
510 let json = export_dos_json(&res);
511 let parsed: serde_json::Value = serde_json::from_str(&json).expect("valid JSON");
512 assert!(parsed["pdos"].is_object());
513 assert!(parsed["pdos"]["0"].is_array());
514 assert!(parsed["pdos"]["1"].is_array());
515 }
516}