1use crate::eht::basis::{build_basis, AtomicOrbital};
7use crate::eht::overlap::build_overlap_matrix;
8use serde::{Deserialize, Serialize};
9
10#[allow(dead_code)]
11fn gaussian_value(energy: f64, center: f64, norm: f64, inv_2s2: f64) -> f64 {
12 norm * (-(energy - center).powi(2) * inv_2s2).exp()
13}
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct DosResult {
18 pub energies: Vec<f64>,
20 pub total_dos: Vec<f64>,
22 pub pdos: Vec<Vec<f64>>,
24 pub sigma: f64,
26}
27
28pub fn compute_dos(
35 orbital_energies: &[f64],
36 sigma: f64,
37 e_min: f64,
38 e_max: f64,
39 n_points: usize,
40) -> DosResult {
41 let step = (e_max - e_min) / (n_points - 1).max(1) as f64;
42 let energies: Vec<f64> = (0..n_points).map(|i| e_min + i as f64 * step).collect();
43
44 let norm = 1.0 / (sigma * (2.0 * std::f64::consts::PI).sqrt());
45 let inv_2s2 = 1.0 / (2.0 * sigma * sigma);
46
47 let total_dos: Vec<f64> = energies
48 .iter()
49 .map(|&e| {
50 orbital_energies
51 .iter()
52 .map(|&ei| norm * (-(e - ei).powi(2) * inv_2s2).exp())
53 .sum()
54 })
55 .collect();
56
57 DosResult {
58 energies,
59 total_dos,
60 pdos: Vec::new(),
61 sigma,
62 }
63}
64
65#[cfg(feature = "parallel")]
67pub fn compute_dos_parallel(
68 orbital_energies: &[f64],
69 sigma: f64,
70 e_min: f64,
71 e_max: f64,
72 n_points: usize,
73) -> DosResult {
74 use rayon::prelude::*;
75
76 let step = (e_max - e_min) / (n_points - 1).max(1) as f64;
77 let energies: Vec<f64> = (0..n_points).map(|i| e_min + i as f64 * step).collect();
78
79 let norm = 1.0 / (sigma * (2.0 * std::f64::consts::PI).sqrt());
80 let inv_2s2 = 1.0 / (2.0 * sigma * sigma);
81
82 let total_dos: Vec<f64> = energies
83 .par_iter()
84 .map(|&energy| {
85 orbital_energies
86 .iter()
87 .map(|¢er| gaussian_value(energy, center, norm, inv_2s2))
88 .sum()
89 })
90 .collect();
91
92 DosResult {
93 energies,
94 total_dos,
95 pdos: Vec::new(),
96 sigma,
97 }
98}
99
100#[allow(clippy::too_many_arguments)]
111pub fn compute_pdos(
112 elements: &[u8],
113 positions: &[f64],
114 orbital_energies: &[f64],
115 coefficients: &[Vec<f64>],
116 n_electrons: usize,
117 sigma: f64,
118 e_min: f64,
119 e_max: f64,
120 n_points: usize,
121) -> DosResult {
122 let n_atoms = elements.len();
123 let pos_arr: Vec<[f64; 3]> = positions
124 .chunks_exact(3)
125 .map(|c| [c[0], c[1], c[2]])
126 .collect();
127 let basis: Vec<AtomicOrbital> = build_basis(elements, &pos_arr);
128 let overlap = build_overlap_matrix(&basis);
129 let n_basis = basis.len();
130
131 let n_orb = orbital_energies.len().min(coefficients.len());
134 let mut orbital_atom_weight = vec![vec![0.0f64; n_atoms]; n_orb];
135
136 for k in 0..n_orb {
137 for mu in 0..n_basis {
138 if coefficients.len() <= mu || coefficients[mu].len() <= k {
139 continue;
140 }
141 let atom_mu = basis[mu].atom_index;
142 let mut w = 0.0;
143 for nu in 0..n_basis {
144 if coefficients.len() <= nu || coefficients[nu].len() <= k {
145 continue;
146 }
147 w += coefficients[mu][k] * overlap[(mu, nu)] * coefficients[nu][k];
149 }
150 orbital_atom_weight[k][atom_mu] += w;
151 }
152 let total_w: f64 = orbital_atom_weight[k].iter().sum();
154 if total_w.abs() > 1e-12 {
155 for a in 0..n_atoms {
156 orbital_atom_weight[k][a] /= total_w;
157 }
158 }
159 }
160
161 let step = (e_max - e_min) / (n_points - 1).max(1) as f64;
162 let energies: Vec<f64> = (0..n_points).map(|i| e_min + i as f64 * step).collect();
163
164 let norm = 1.0 / (sigma * (2.0 * std::f64::consts::PI).sqrt());
165 let inv_2s2 = 1.0 / (2.0 * sigma * sigma);
166
167 let total_dos: Vec<f64> = energies
169 .iter()
170 .map(|&e| {
171 (0..n_orb)
172 .map(|k| norm * (-(e - orbital_energies[k]).powi(2) * inv_2s2).exp())
173 .sum()
174 })
175 .collect();
176
177 let mut pdos = vec![vec![0.0f64; n_points]; n_atoms];
179 for a in 0..n_atoms {
180 for (gi, &e) in energies.iter().enumerate() {
181 let mut val = 0.0;
182 for k in 0..n_orb {
183 let gauss = norm * (-(e - orbital_energies[k]).powi(2) * inv_2s2).exp();
184 val += orbital_atom_weight[k][a] * gauss;
185 }
186 pdos[a][gi] = val;
187 }
188 }
189
190 let _ = n_electrons; DosResult {
193 energies,
194 total_dos,
195 pdos,
196 sigma,
197 }
198}
199
200#[cfg(feature = "parallel")]
202#[allow(clippy::too_many_arguments)]
203pub fn compute_pdos_parallel(
204 elements: &[u8],
205 positions: &[f64],
206 orbital_energies: &[f64],
207 coefficients: &[Vec<f64>],
208 n_electrons: usize,
209 sigma: f64,
210 e_min: f64,
211 e_max: f64,
212 n_points: usize,
213) -> DosResult {
214 use rayon::prelude::*;
215
216 let n_atoms = elements.len();
217 let pos_arr: Vec<[f64; 3]> = positions
218 .chunks_exact(3)
219 .map(|c| [c[0], c[1], c[2]])
220 .collect();
221 let basis: Vec<AtomicOrbital> = build_basis(elements, &pos_arr);
222 let overlap = build_overlap_matrix(&basis);
223 let n_basis = basis.len();
224 let n_orb = orbital_energies.len().min(coefficients.len());
225
226 let orbital_atom_weight: Vec<Vec<f64>> = (0..n_orb)
227 .into_par_iter()
228 .map(|k| {
229 let mut weights = vec![0.0f64; n_atoms];
230 for mu in 0..n_basis {
231 if coefficients.len() <= mu || coefficients[mu].len() <= k {
232 continue;
233 }
234 let atom_mu = basis[mu].atom_index;
235 let mut weight = 0.0;
236 for nu in 0..n_basis {
237 if coefficients.len() <= nu || coefficients[nu].len() <= k {
238 continue;
239 }
240 weight += coefficients[mu][k] * overlap[(mu, nu)] * coefficients[nu][k];
241 }
242 weights[atom_mu] += weight;
243 }
244
245 let total_weight: f64 = weights.iter().sum();
246 if total_weight.abs() > 1e-12 {
247 for weight in &mut weights {
248 *weight /= total_weight;
249 }
250 }
251 weights
252 })
253 .collect();
254
255 let step = (e_max - e_min) / (n_points - 1).max(1) as f64;
256 let energies: Vec<f64> = (0..n_points).map(|i| e_min + i as f64 * step).collect();
257
258 let norm = 1.0 / (sigma * (2.0 * std::f64::consts::PI).sqrt());
259 let inv_2s2 = 1.0 / (2.0 * sigma * sigma);
260
261 let total_dos: Vec<f64> = energies
262 .par_iter()
263 .map(|&energy| {
264 (0..n_orb)
265 .map(|k| gaussian_value(energy, orbital_energies[k], norm, inv_2s2))
266 .sum()
267 })
268 .collect();
269
270 let pdos: Vec<Vec<f64>> = (0..n_atoms)
271 .into_par_iter()
272 .map(|atom_index| {
273 energies
274 .iter()
275 .map(|&energy| {
276 (0..n_orb)
277 .map(|k| {
278 orbital_atom_weight[k][atom_index]
279 * gaussian_value(energy, orbital_energies[k], norm, inv_2s2)
280 })
281 .sum()
282 })
283 .collect()
284 })
285 .collect();
286
287 let _ = n_electrons;
288
289 DosResult {
290 energies,
291 total_dos,
292 pdos,
293 sigma,
294 }
295}
296
297pub fn dos_mse(a: &[f64], b: &[f64]) -> f64 {
302 assert_eq!(a.len(), b.len(), "DOS curves must have same length");
303 let n = a.len() as f64;
304 a.iter()
305 .zip(b.iter())
306 .map(|(x, y)| (x - y).powi(2))
307 .sum::<f64>()
308 / n
309}
310
311pub fn export_dos_json(result: &DosResult) -> String {
323 let mut json = String::from("{");
324 json.push_str("\"energies\":[");
325 for (i, e) in result.energies.iter().enumerate() {
326 if i > 0 {
327 json.push(',');
328 }
329 json.push_str(&format!("{:.6}", e));
330 }
331 json.push_str("],\"total_dos\":[");
332 for (i, d) in result.total_dos.iter().enumerate() {
333 if i > 0 {
334 json.push(',');
335 }
336 json.push_str(&format!("{:.6}", d));
337 }
338 json.push_str(&format!("],\"sigma\":{:.6}", result.sigma));
339 if !result.pdos.is_empty() {
340 json.push_str(",\"pdos\":{");
341 for (a, pdos_a) in result.pdos.iter().enumerate() {
342 if a > 0 {
343 json.push(',');
344 }
345 json.push_str(&format!("\"{}\":[", a));
346 for (i, v) in pdos_a.iter().enumerate() {
347 if i > 0 {
348 json.push(',');
349 }
350 json.push_str(&format!("{:.6}", v));
351 }
352 json.push(']');
353 }
354 json.push('}');
355 }
356 json.push('}');
357 json
358}
359
360#[cfg(test)]
361mod tests {
362 use super::*;
363
364 #[test]
365 fn test_dos_single_level() {
366 let res = compute_dos(&[0.0], 0.1, -1.0, 1.0, 201);
368 assert_eq!(res.energies.len(), 201);
369 assert_eq!(res.total_dos.len(), 201);
370 let peak_idx = res
372 .total_dos
373 .iter()
374 .enumerate()
375 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
376 .unwrap()
377 .0;
378 assert_eq!(peak_idx, 100);
379 }
380
381 #[test]
382 fn test_dos_integral_approx_one() {
383 let res = compute_dos(&[0.0], 0.2, -3.0, 3.0, 1001);
385 let de = (3.0 - (-3.0)) / 1000.0;
386 let integral: f64 = res.total_dos.iter().sum::<f64>() * de;
387 assert!((integral - 1.0).abs() < 0.01, "integral = {integral}");
388 }
389
390 #[test]
391 fn test_dos_two_peaks() {
392 let res = compute_dos(&[-5.0, 5.0], 0.3, -10.0, 10.0, 501);
393 let mid = res.total_dos[250];
395 let left_peak = res.total_dos[125];
396 let right_peak = res.total_dos[375];
397 assert!(left_peak > mid * 5.0);
398 assert!(right_peak > mid * 5.0);
399 }
400
401 #[test]
402 fn test_pdos_h2() {
403 let elements = vec![1u8, 1];
405 let pos_arr = vec![[0.0, 0.0, 0.0], [0.74, 0.0, 0.0]];
406 let positions: Vec<f64> = pos_arr.iter().flat_map(|p| p.iter().copied()).collect();
407 let eht = crate::eht::solve_eht(&elements, &pos_arr, None).unwrap();
408 let res = compute_pdos(
409 &elements,
410 &positions,
411 &eht.energies,
412 &eht.coefficients,
413 eht.n_electrons,
414 0.2,
415 -20.0,
416 5.0,
417 201,
418 );
419 assert_eq!(res.pdos.len(), 2);
420 let peak_val = res.pdos[0].iter().cloned().fold(0.0f64, f64::max);
423 let threshold = peak_val * 0.1;
424 for i in 0..201 {
425 if res.pdos[0][i].abs() > threshold || res.pdos[1][i].abs() > threshold {
426 let diff = (res.pdos[0][i] - res.pdos[1][i]).abs();
427 let avg = (res.pdos[0][i].abs() + res.pdos[1][i].abs()) / 2.0;
428 assert!(
429 diff < avg * 0.05 + 1e-6,
430 "PDOS mismatch at grid point {i}: {} vs {} (peak={})",
431 res.pdos[0][i],
432 res.pdos[1][i],
433 peak_val
434 );
435 }
436 }
437 }
438
439 #[test]
440 fn test_pdos_sums_to_total() {
441 let elements = vec![8u8, 1, 1];
443 let pos_arr = vec![[0.0, 0.0, 0.0], [0.96, 0.0, 0.0], [-0.24, 0.93, 0.0]];
444 let positions: Vec<f64> = pos_arr.iter().flat_map(|p| p.iter().copied()).collect();
445 let eht = crate::eht::solve_eht(&elements, &pos_arr, None).unwrap();
446 let res = compute_pdos(
447 &elements,
448 &positions,
449 &eht.energies,
450 &eht.coefficients,
451 eht.n_electrons,
452 0.3,
453 -30.0,
454 5.0,
455 201,
456 );
457 for i in 0..201 {
458 let pdos_sum: f64 = res.pdos.iter().map(|p| p[i]).sum();
459 let diff = (pdos_sum - res.total_dos[i]).abs();
460 assert!(
461 diff < res.total_dos[i].abs() * 0.05 + 1e-10,
462 "PDOS sum {pdos_sum} vs total {} at grid {i}",
463 res.total_dos[i]
464 );
465 }
466 }
467
468 #[test]
469 fn test_dos_mse_identical() {
470 let a = vec![1.0, 2.0, 3.0, 4.0];
471 assert!((dos_mse(&a, &a)) < 1e-15);
472 }
473
474 #[test]
475 fn test_dos_mse_known() {
476 let a = vec![1.0, 2.0, 3.0];
477 let b = vec![1.1, 1.9, 3.2];
478 assert!((dos_mse(&a, &b) - 0.02).abs() < 1e-10);
480 }
481
482 #[test]
483 fn test_export_dos_json_roundtrip() {
484 let res = compute_dos(&[0.0, -5.0], 0.3, -10.0, 5.0, 51);
485 let json = export_dos_json(&res);
486
487 let parsed: serde_json::Value = serde_json::from_str(&json).expect("valid JSON");
489 assert!(parsed["energies"].is_array());
490 assert!(parsed["total_dos"].is_array());
491 assert_eq!(parsed["energies"].as_array().unwrap().len(), 51);
492 assert_eq!(parsed["total_dos"].as_array().unwrap().len(), 51);
493 }
494
495 #[test]
496 fn test_export_pdos_json() {
497 let elements = vec![1u8, 1];
498 let pos_arr = vec![[0.0, 0.0, 0.0], [0.74, 0.0, 0.0]];
499 let positions: Vec<f64> = pos_arr.iter().flat_map(|p| p.iter().copied()).collect();
500 let eht = crate::eht::solve_eht(&elements, &pos_arr, None).unwrap();
501 let res = compute_pdos(
502 &elements,
503 &positions,
504 &eht.energies,
505 &eht.coefficients,
506 eht.n_electrons,
507 0.2,
508 -20.0,
509 5.0,
510 51,
511 );
512 let json = export_dos_json(&res);
513 let parsed: serde_json::Value = serde_json::from_str(&json).expect("valid JSON");
514 assert!(parsed["pdos"].is_object());
515 assert!(parsed["pdos"]["0"].is_array());
516 assert!(parsed["pdos"]["1"].is_array());
517 }
518}