Skip to main content

sci_form/dos/
dos.rs

1//! Density of States (DOS) and Projected DOS (PDOS).
2//!
3//! Computes total DOS by Gaussian-smearing EHT orbital energies
4//! and atom-projected DOS by weighting with Mulliken orbital populations.
5
6use crate::eht::basis::{build_basis, AtomicOrbital};
7use crate::eht::overlap::build_overlap_matrix;
8
9fn gaussian_value(energy: f64, center: f64, norm: f64, inv_2s2: f64) -> f64 {
10    norm * (-(energy - center).powi(2) * inv_2s2).exp()
11}
12
13/// Result of a DOS/PDOS calculation.
14#[derive(Debug, Clone)]
15pub struct DosResult {
16    /// Energy grid values (eV).
17    pub energies: Vec<f64>,
18    /// Total DOS values (states/eV).
19    pub total_dos: Vec<f64>,
20    /// Per-atom PDOS: pdos\[atom_idx\]\[grid_idx\].
21    pub pdos: Vec<Vec<f64>>,
22    /// Smearing width used (eV).
23    pub sigma: f64,
24}
25
26/// Compute total density of states from EHT orbital energies.
27///
28/// `orbital_energies`: eigenvalues from EHT (eV).
29/// `sigma`: Gaussian smearing width (eV).
30/// `e_min`, `e_max`: energy window.
31/// `n_points`: grid resolution.
32pub fn compute_dos(
33    orbital_energies: &[f64],
34    sigma: f64,
35    e_min: f64,
36    e_max: f64,
37    n_points: usize,
38) -> DosResult {
39    let step = (e_max - e_min) / (n_points - 1).max(1) as f64;
40    let energies: Vec<f64> = (0..n_points).map(|i| e_min + i as f64 * step).collect();
41
42    let norm = 1.0 / (sigma * (2.0 * std::f64::consts::PI).sqrt());
43    let inv_2s2 = 1.0 / (2.0 * sigma * sigma);
44
45    let total_dos: Vec<f64> = energies
46        .iter()
47        .map(|&e| {
48            orbital_energies
49                .iter()
50                .map(|&ei| norm * (-(e - ei).powi(2) * inv_2s2).exp())
51                .sum()
52        })
53        .collect();
54
55    DosResult {
56        energies,
57        total_dos,
58        pdos: Vec::new(),
59        sigma,
60    }
61}
62
63/// Compute total DOS using rayon over the energy grid.
64#[cfg(feature = "parallel")]
65pub fn compute_dos_parallel(
66    orbital_energies: &[f64],
67    sigma: f64,
68    e_min: f64,
69    e_max: f64,
70    n_points: usize,
71) -> DosResult {
72    use rayon::prelude::*;
73
74    let step = (e_max - e_min) / (n_points - 1).max(1) as f64;
75    let energies: Vec<f64> = (0..n_points).map(|i| e_min + i as f64 * step).collect();
76
77    let norm = 1.0 / (sigma * (2.0 * std::f64::consts::PI).sqrt());
78    let inv_2s2 = 1.0 / (2.0 * sigma * sigma);
79
80    let total_dos: Vec<f64> = energies
81        .par_iter()
82        .map(|&energy| {
83            orbital_energies
84                .iter()
85                .map(|&center| gaussian_value(energy, center, norm, inv_2s2))
86                .sum()
87        })
88        .collect();
89
90    DosResult {
91        energies,
92        total_dos,
93        pdos: Vec::new(),
94        sigma,
95    }
96}
97
98/// Compute atom-projected DOS from EHT results.
99///
100/// `elements`: atomic numbers per atom.
101/// `positions`: flat \[x0,y0,z0, x1,y1,z1,...\] in Å.
102/// `orbital_energies`: eigenvalues from EHT (eV).
103/// `coefficients`: coefficients\[orbital\]\[basis\] from EHT.
104/// `n_electrons`: number of electrons.
105/// `sigma`: Gaussian smearing width (eV).
106/// `e_min`, `e_max`: energy window.
107/// `n_points`: grid resolution.
108#[allow(clippy::too_many_arguments)]
109pub fn compute_pdos(
110    elements: &[u8],
111    positions: &[f64],
112    orbital_energies: &[f64],
113    coefficients: &[Vec<f64>],
114    n_electrons: usize,
115    sigma: f64,
116    e_min: f64,
117    e_max: f64,
118    n_points: usize,
119) -> DosResult {
120    let n_atoms = elements.len();
121    let pos_arr: Vec<[f64; 3]> = positions
122        .chunks_exact(3)
123        .map(|c| [c[0], c[1], c[2]])
124        .collect();
125    let basis: Vec<AtomicOrbital> = build_basis(elements, &pos_arr);
126    let overlap = build_overlap_matrix(&basis);
127    let n_basis = basis.len();
128
129    // Build density-like Mulliken weight per (orbital, atom):
130    // w_{k,A} = Σ_{μ∈A} Σ_ν c_{k,μ} S_{μν} c_{k,ν}
131    let n_orb = orbital_energies.len().min(coefficients.len());
132    let mut orbital_atom_weight = vec![vec![0.0f64; n_atoms]; n_orb];
133
134    for k in 0..n_orb {
135        for mu in 0..n_basis {
136            if coefficients.len() <= mu || coefficients[mu].len() <= k {
137                continue;
138            }
139            let atom_mu = basis[mu].atom_index;
140            let mut w = 0.0;
141            for nu in 0..n_basis {
142                if coefficients.len() <= nu || coefficients[nu].len() <= k {
143                    continue;
144                }
145                // coefficients[mu][k] = AO mu, MO k
146                w += coefficients[mu][k] * overlap[(mu, nu)] * coefficients[nu][k];
147            }
148            orbital_atom_weight[k][atom_mu] += w;
149        }
150        // Normalize weights so they sum to 1 per orbital
151        let total_w: f64 = orbital_atom_weight[k].iter().sum();
152        if total_w.abs() > 1e-12 {
153            for a in 0..n_atoms {
154                orbital_atom_weight[k][a] /= total_w;
155            }
156        }
157    }
158
159    let step = (e_max - e_min) / (n_points - 1).max(1) as f64;
160    let energies: Vec<f64> = (0..n_points).map(|i| e_min + i as f64 * step).collect();
161
162    let norm = 1.0 / (sigma * (2.0 * std::f64::consts::PI).sqrt());
163    let inv_2s2 = 1.0 / (2.0 * sigma * sigma);
164
165    // Total DOS
166    let total_dos: Vec<f64> = energies
167        .iter()
168        .map(|&e| {
169            (0..n_orb)
170                .map(|k| norm * (-(e - orbital_energies[k]).powi(2) * inv_2s2).exp())
171                .sum()
172        })
173        .collect();
174
175    // Per-atom PDOS
176    let mut pdos = vec![vec![0.0f64; n_points]; n_atoms];
177    for a in 0..n_atoms {
178        for (gi, &e) in energies.iter().enumerate() {
179            let mut val = 0.0;
180            for k in 0..n_orb {
181                let gauss = norm * (-(e - orbital_energies[k]).powi(2) * inv_2s2).exp();
182                val += orbital_atom_weight[k][a] * gauss;
183            }
184            pdos[a][gi] = val;
185        }
186    }
187
188    let _ = n_electrons; // used contextually; weight already normalized via Mulliken
189
190    DosResult {
191        energies,
192        total_dos,
193        pdos,
194        sigma,
195    }
196}
197
198/// Compute atom-projected DOS using rayon for orbital weights and atom-grid accumulation.
199#[cfg(feature = "parallel")]
200#[allow(clippy::too_many_arguments)]
201pub fn compute_pdos_parallel(
202    elements: &[u8],
203    positions: &[f64],
204    orbital_energies: &[f64],
205    coefficients: &[Vec<f64>],
206    n_electrons: usize,
207    sigma: f64,
208    e_min: f64,
209    e_max: f64,
210    n_points: usize,
211) -> DosResult {
212    use rayon::prelude::*;
213
214    let n_atoms = elements.len();
215    let pos_arr: Vec<[f64; 3]> = positions
216        .chunks_exact(3)
217        .map(|c| [c[0], c[1], c[2]])
218        .collect();
219    let basis: Vec<AtomicOrbital> = build_basis(elements, &pos_arr);
220    let overlap = build_overlap_matrix(&basis);
221    let n_basis = basis.len();
222    let n_orb = orbital_energies.len().min(coefficients.len());
223
224    let orbital_atom_weight: Vec<Vec<f64>> = (0..n_orb)
225        .into_par_iter()
226        .map(|k| {
227            let mut weights = vec![0.0f64; n_atoms];
228            for mu in 0..n_basis {
229                if coefficients.len() <= mu || coefficients[mu].len() <= k {
230                    continue;
231                }
232                let atom_mu = basis[mu].atom_index;
233                let mut weight = 0.0;
234                for nu in 0..n_basis {
235                    if coefficients.len() <= nu || coefficients[nu].len() <= k {
236                        continue;
237                    }
238                    weight += coefficients[mu][k] * overlap[(mu, nu)] * coefficients[nu][k];
239                }
240                weights[atom_mu] += weight;
241            }
242
243            let total_weight: f64 = weights.iter().sum();
244            if total_weight.abs() > 1e-12 {
245                for weight in &mut weights {
246                    *weight /= total_weight;
247                }
248            }
249            weights
250        })
251        .collect();
252
253    let step = (e_max - e_min) / (n_points - 1).max(1) as f64;
254    let energies: Vec<f64> = (0..n_points).map(|i| e_min + i as f64 * step).collect();
255
256    let norm = 1.0 / (sigma * (2.0 * std::f64::consts::PI).sqrt());
257    let inv_2s2 = 1.0 / (2.0 * sigma * sigma);
258
259    let total_dos: Vec<f64> = energies
260        .par_iter()
261        .map(|&energy| {
262            (0..n_orb)
263                .map(|k| gaussian_value(energy, orbital_energies[k], norm, inv_2s2))
264                .sum()
265        })
266        .collect();
267
268    let pdos: Vec<Vec<f64>> = (0..n_atoms)
269        .into_par_iter()
270        .map(|atom_index| {
271            energies
272                .iter()
273                .map(|&energy| {
274                    (0..n_orb)
275                        .map(|k| {
276                            orbital_atom_weight[k][atom_index]
277                                * gaussian_value(energy, orbital_energies[k], norm, inv_2s2)
278                        })
279                        .sum()
280                })
281                .collect()
282        })
283        .collect();
284
285    let _ = n_electrons;
286
287    DosResult {
288        energies,
289        total_dos,
290        pdos,
291        sigma,
292    }
293}
294
295/// Compute mean-squared error between two DOS curves.
296///
297/// Both curves must have the same length.  Useful for comparing against
298/// reference DOS (e.g. Multiwfn output).
299pub fn dos_mse(a: &[f64], b: &[f64]) -> f64 {
300    assert_eq!(a.len(), b.len(), "DOS curves must have same length");
301    let n = a.len() as f64;
302    a.iter()
303        .zip(b.iter())
304        .map(|(x, y)| (x - y).powi(2))
305        .sum::<f64>()
306        / n
307}
308
309/// Serialize DOS/PDOS result to JSON for web visualization.
310///
311/// Format:
312/// ```json
313/// {
314///   "energies": [...],
315///   "total_dos": [...],
316///   "sigma": 0.3,
317///   "pdos": { "0": [...], "1": [...], ... }
318/// }
319/// ```
320pub fn export_dos_json(result: &DosResult) -> String {
321    let mut json = String::from("{");
322    json.push_str("\"energies\":[");
323    for (i, e) in result.energies.iter().enumerate() {
324        if i > 0 {
325            json.push(',');
326        }
327        json.push_str(&format!("{:.6}", e));
328    }
329    json.push_str("],\"total_dos\":[");
330    for (i, d) in result.total_dos.iter().enumerate() {
331        if i > 0 {
332            json.push(',');
333        }
334        json.push_str(&format!("{:.6}", d));
335    }
336    json.push_str(&format!("],\"sigma\":{:.6}", result.sigma));
337    if !result.pdos.is_empty() {
338        json.push_str(",\"pdos\":{");
339        for (a, pdos_a) in result.pdos.iter().enumerate() {
340            if a > 0 {
341                json.push(',');
342            }
343            json.push_str(&format!("\"{}\":[", a));
344            for (i, v) in pdos_a.iter().enumerate() {
345                if i > 0 {
346                    json.push(',');
347                }
348                json.push_str(&format!("{:.6}", v));
349            }
350            json.push(']');
351        }
352        json.push('}');
353    }
354    json.push('}');
355    json
356}
357
358#[cfg(test)]
359mod tests {
360    use super::*;
361
362    #[test]
363    fn test_dos_single_level() {
364        // Single orbital at 0 eV → DOS should peak at 0.
365        let res = compute_dos(&[0.0], 0.1, -1.0, 1.0, 201);
366        assert_eq!(res.energies.len(), 201);
367        assert_eq!(res.total_dos.len(), 201);
368        // Peak should be at mid-point (index 100)
369        let peak_idx = res
370            .total_dos
371            .iter()
372            .enumerate()
373            .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
374            .unwrap()
375            .0;
376        assert_eq!(peak_idx, 100);
377    }
378
379    #[test]
380    fn test_dos_integral_approx_one() {
381        // Integral of DOS for one level ≈ 1 (normalized Gaussian).
382        let res = compute_dos(&[0.0], 0.2, -3.0, 3.0, 1001);
383        let de = (3.0 - (-3.0)) / 1000.0;
384        let integral: f64 = res.total_dos.iter().sum::<f64>() * de;
385        assert!((integral - 1.0).abs() < 0.01, "integral = {integral}");
386    }
387
388    #[test]
389    fn test_dos_two_peaks() {
390        let res = compute_dos(&[-5.0, 5.0], 0.3, -10.0, 10.0, 501);
391        // Should have two peaks, one near index ~125, one near ~375
392        let mid = res.total_dos[250];
393        let left_peak = res.total_dos[125];
394        let right_peak = res.total_dos[375];
395        assert!(left_peak > mid * 5.0);
396        assert!(right_peak > mid * 5.0);
397    }
398
399    #[test]
400    fn test_pdos_h2() {
401        // H₂: two atoms should have symmetric PDOS.
402        let elements = vec![1u8, 1];
403        let pos_arr = vec![[0.0, 0.0, 0.0], [0.74, 0.0, 0.0]];
404        let positions: Vec<f64> = pos_arr.iter().flat_map(|p| p.iter().copied()).collect();
405        let eht = crate::eht::solve_eht(&elements, &pos_arr, None).unwrap();
406        let res = compute_pdos(
407            &elements,
408            &positions,
409            &eht.energies,
410            &eht.coefficients,
411            eht.n_electrons,
412            0.2,
413            -20.0,
414            5.0,
415            201,
416        );
417        assert_eq!(res.pdos.len(), 2);
418        // Both H atoms should have nearly equal PDOS at peaks.
419        // The peak region is where DOS > 10% of max.
420        let peak_val = res.pdos[0].iter().cloned().fold(0.0f64, f64::max);
421        let threshold = peak_val * 0.1;
422        for i in 0..201 {
423            if res.pdos[0][i].abs() > threshold || res.pdos[1][i].abs() > threshold {
424                let diff = (res.pdos[0][i] - res.pdos[1][i]).abs();
425                let avg = (res.pdos[0][i].abs() + res.pdos[1][i].abs()) / 2.0;
426                assert!(
427                    diff < avg * 0.05 + 1e-6,
428                    "PDOS mismatch at grid point {i}: {} vs {} (peak={})",
429                    res.pdos[0][i],
430                    res.pdos[1][i],
431                    peak_val
432                );
433            }
434        }
435    }
436
437    #[test]
438    fn test_pdos_sums_to_total() {
439        // Sum of PDOS over all atoms ≈ total DOS.
440        let elements = vec![8u8, 1, 1];
441        let pos_arr = vec![[0.0, 0.0, 0.0], [0.96, 0.0, 0.0], [-0.24, 0.93, 0.0]];
442        let positions: Vec<f64> = pos_arr.iter().flat_map(|p| p.iter().copied()).collect();
443        let eht = crate::eht::solve_eht(&elements, &pos_arr, None).unwrap();
444        let res = compute_pdos(
445            &elements,
446            &positions,
447            &eht.energies,
448            &eht.coefficients,
449            eht.n_electrons,
450            0.3,
451            -30.0,
452            5.0,
453            201,
454        );
455        for i in 0..201 {
456            let pdos_sum: f64 = res.pdos.iter().map(|p| p[i]).sum();
457            let diff = (pdos_sum - res.total_dos[i]).abs();
458            assert!(
459                diff < res.total_dos[i].abs() * 0.05 + 1e-10,
460                "PDOS sum {pdos_sum} vs total {} at grid {i}",
461                res.total_dos[i]
462            );
463        }
464    }
465
466    #[test]
467    fn test_dos_mse_identical() {
468        let a = vec![1.0, 2.0, 3.0, 4.0];
469        assert!((dos_mse(&a, &a)) < 1e-15);
470    }
471
472    #[test]
473    fn test_dos_mse_known() {
474        let a = vec![1.0, 2.0, 3.0];
475        let b = vec![1.1, 1.9, 3.2];
476        // MSE = (0.01 + 0.01 + 0.04) / 3 = 0.02
477        assert!((dos_mse(&a, &b) - 0.02).abs() < 1e-10);
478    }
479
480    #[test]
481    fn test_export_dos_json_roundtrip() {
482        let res = compute_dos(&[0.0, -5.0], 0.3, -10.0, 5.0, 51);
483        let json = export_dos_json(&res);
484
485        // Should be valid JSON
486        let parsed: serde_json::Value = serde_json::from_str(&json).expect("valid JSON");
487        assert!(parsed["energies"].is_array());
488        assert!(parsed["total_dos"].is_array());
489        assert_eq!(parsed["energies"].as_array().unwrap().len(), 51);
490        assert_eq!(parsed["total_dos"].as_array().unwrap().len(), 51);
491    }
492
493    #[test]
494    fn test_export_pdos_json() {
495        let elements = vec![1u8, 1];
496        let pos_arr = vec![[0.0, 0.0, 0.0], [0.74, 0.0, 0.0]];
497        let positions: Vec<f64> = pos_arr.iter().flat_map(|p| p.iter().copied()).collect();
498        let eht = crate::eht::solve_eht(&elements, &pos_arr, None).unwrap();
499        let res = compute_pdos(
500            &elements,
501            &positions,
502            &eht.energies,
503            &eht.coefficients,
504            eht.n_electrons,
505            0.2,
506            -20.0,
507            5.0,
508            51,
509        );
510        let json = export_dos_json(&res);
511        let parsed: serde_json::Value = serde_json::from_str(&json).expect("valid JSON");
512        assert!(parsed["pdos"].is_object());
513        assert!(parsed["pdos"]["0"].is_array());
514        assert!(parsed["pdos"]["1"].is_array());
515    }
516}