Skip to main content

sci_form/dos/
dos.rs

1//! Density of States (DOS) and Projected DOS (PDOS).
2//!
3//! Computes total DOS by Gaussian-smearing EHT orbital energies
4//! and atom-projected DOS by weighting with Mulliken orbital populations.
5
6use crate::eht::basis::{build_basis, AtomicOrbital};
7use crate::eht::overlap::build_overlap_matrix;
8use serde::{Deserialize, Serialize};
9
10#[allow(dead_code)]
11fn gaussian_value(energy: f64, center: f64, norm: f64, inv_2s2: f64) -> f64 {
12    norm * (-(energy - center).powi(2) * inv_2s2).exp()
13}
14
15/// Result of a DOS/PDOS calculation.
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct DosResult {
18    /// Energy grid values (eV).
19    pub energies: Vec<f64>,
20    /// Total DOS values (states/eV).
21    pub total_dos: Vec<f64>,
22    /// Per-atom PDOS: pdos\[atom_idx\]\[grid_idx\].
23    pub pdos: Vec<Vec<f64>>,
24    /// Smearing width used (eV).
25    pub sigma: f64,
26}
27
28/// Compute total density of states from EHT orbital energies.
29///
30/// `orbital_energies`: eigenvalues from EHT (eV).
31/// `sigma`: Gaussian smearing width (eV).
32/// `e_min`, `e_max`: energy window.
33/// `n_points`: grid resolution.
34pub fn compute_dos(
35    orbital_energies: &[f64],
36    sigma: f64,
37    e_min: f64,
38    e_max: f64,
39    n_points: usize,
40) -> DosResult {
41    let step = (e_max - e_min) / (n_points - 1).max(1) as f64;
42    let energies: Vec<f64> = (0..n_points).map(|i| e_min + i as f64 * step).collect();
43
44    let norm = 1.0 / (sigma * (2.0 * std::f64::consts::PI).sqrt());
45    let inv_2s2 = 1.0 / (2.0 * sigma * sigma);
46
47    let total_dos: Vec<f64> = energies
48        .iter()
49        .map(|&e| {
50            orbital_energies
51                .iter()
52                .map(|&ei| norm * (-(e - ei).powi(2) * inv_2s2).exp())
53                .sum()
54        })
55        .collect();
56
57    DosResult {
58        energies,
59        total_dos,
60        pdos: Vec::new(),
61        sigma,
62    }
63}
64
65/// Compute total DOS using rayon over the energy grid.
66#[cfg(feature = "parallel")]
67pub fn compute_dos_parallel(
68    orbital_energies: &[f64],
69    sigma: f64,
70    e_min: f64,
71    e_max: f64,
72    n_points: usize,
73) -> DosResult {
74    use rayon::prelude::*;
75
76    let step = (e_max - e_min) / (n_points - 1).max(1) as f64;
77    let energies: Vec<f64> = (0..n_points).map(|i| e_min + i as f64 * step).collect();
78
79    let norm = 1.0 / (sigma * (2.0 * std::f64::consts::PI).sqrt());
80    let inv_2s2 = 1.0 / (2.0 * sigma * sigma);
81
82    let total_dos: Vec<f64> = energies
83        .par_iter()
84        .map(|&energy| {
85            orbital_energies
86                .iter()
87                .map(|&center| gaussian_value(energy, center, norm, inv_2s2))
88                .sum()
89        })
90        .collect();
91
92    DosResult {
93        energies,
94        total_dos,
95        pdos: Vec::new(),
96        sigma,
97    }
98}
99
100/// Compute atom-projected DOS from EHT results.
101///
102/// `elements`: atomic numbers per atom.
103/// `positions`: flat \[x0,y0,z0, x1,y1,z1,...\] in Å.
104/// `orbital_energies`: eigenvalues from EHT (eV).
105/// `coefficients`: coefficients\[orbital\]\[basis\] from EHT.
106/// `n_electrons`: number of electrons.
107/// `sigma`: Gaussian smearing width (eV).
108/// `e_min`, `e_max`: energy window.
109/// `n_points`: grid resolution.
110#[allow(clippy::too_many_arguments)]
111pub fn compute_pdos(
112    elements: &[u8],
113    positions: &[f64],
114    orbital_energies: &[f64],
115    coefficients: &[Vec<f64>],
116    n_electrons: usize,
117    sigma: f64,
118    e_min: f64,
119    e_max: f64,
120    n_points: usize,
121) -> DosResult {
122    let n_atoms = elements.len();
123    let pos_arr: Vec<[f64; 3]> = positions
124        .chunks_exact(3)
125        .map(|c| [c[0], c[1], c[2]])
126        .collect();
127    let basis: Vec<AtomicOrbital> = build_basis(elements, &pos_arr);
128    let overlap = build_overlap_matrix(&basis);
129    let n_basis = basis.len();
130
131    // Build density-like Mulliken weight per (orbital, atom):
132    // w_{k,A} = Σ_{μ∈A} Σ_ν c_{k,μ} S_{μν} c_{k,ν}
133    let n_orb = orbital_energies.len().min(coefficients.len());
134    let mut orbital_atom_weight = vec![vec![0.0f64; n_atoms]; n_orb];
135
136    for k in 0..n_orb {
137        for mu in 0..n_basis {
138            if coefficients.len() <= mu || coefficients[mu].len() <= k {
139                continue;
140            }
141            let atom_mu = basis[mu].atom_index;
142            let mut w = 0.0;
143            for nu in 0..n_basis {
144                if coefficients.len() <= nu || coefficients[nu].len() <= k {
145                    continue;
146                }
147                // coefficients[mu][k] = AO mu, MO k
148                w += coefficients[mu][k] * overlap[(mu, nu)] * coefficients[nu][k];
149            }
150            orbital_atom_weight[k][atom_mu] += w;
151        }
152        // Normalize weights so they sum to 1 per orbital
153        let total_w: f64 = orbital_atom_weight[k].iter().sum();
154        if total_w.abs() > 1e-12 {
155            for a in 0..n_atoms {
156                orbital_atom_weight[k][a] /= total_w;
157            }
158        }
159    }
160
161    let step = (e_max - e_min) / (n_points - 1).max(1) as f64;
162    let energies: Vec<f64> = (0..n_points).map(|i| e_min + i as f64 * step).collect();
163
164    let norm = 1.0 / (sigma * (2.0 * std::f64::consts::PI).sqrt());
165    let inv_2s2 = 1.0 / (2.0 * sigma * sigma);
166
167    // Total DOS
168    let total_dos: Vec<f64> = energies
169        .iter()
170        .map(|&e| {
171            (0..n_orb)
172                .map(|k| norm * (-(e - orbital_energies[k]).powi(2) * inv_2s2).exp())
173                .sum()
174        })
175        .collect();
176
177    // Per-atom PDOS
178    let mut pdos = vec![vec![0.0f64; n_points]; n_atoms];
179    for a in 0..n_atoms {
180        for (gi, &e) in energies.iter().enumerate() {
181            let mut val = 0.0;
182            for k in 0..n_orb {
183                let gauss = norm * (-(e - orbital_energies[k]).powi(2) * inv_2s2).exp();
184                val += orbital_atom_weight[k][a] * gauss;
185            }
186            pdos[a][gi] = val;
187        }
188    }
189
190    let _ = n_electrons; // used contextually; weight already normalized via Mulliken
191
192    DosResult {
193        energies,
194        total_dos,
195        pdos,
196        sigma,
197    }
198}
199
200/// Compute atom-projected DOS using rayon for orbital weights and atom-grid accumulation.
201#[cfg(feature = "parallel")]
202#[allow(clippy::too_many_arguments)]
203pub fn compute_pdos_parallel(
204    elements: &[u8],
205    positions: &[f64],
206    orbital_energies: &[f64],
207    coefficients: &[Vec<f64>],
208    n_electrons: usize,
209    sigma: f64,
210    e_min: f64,
211    e_max: f64,
212    n_points: usize,
213) -> DosResult {
214    use rayon::prelude::*;
215
216    let n_atoms = elements.len();
217    let pos_arr: Vec<[f64; 3]> = positions
218        .chunks_exact(3)
219        .map(|c| [c[0], c[1], c[2]])
220        .collect();
221    let basis: Vec<AtomicOrbital> = build_basis(elements, &pos_arr);
222    let overlap = build_overlap_matrix(&basis);
223    let n_basis = basis.len();
224    let n_orb = orbital_energies.len().min(coefficients.len());
225
226    let orbital_atom_weight: Vec<Vec<f64>> = (0..n_orb)
227        .into_par_iter()
228        .map(|k| {
229            let mut weights = vec![0.0f64; n_atoms];
230            for mu in 0..n_basis {
231                if coefficients.len() <= mu || coefficients[mu].len() <= k {
232                    continue;
233                }
234                let atom_mu = basis[mu].atom_index;
235                let mut weight = 0.0;
236                for nu in 0..n_basis {
237                    if coefficients.len() <= nu || coefficients[nu].len() <= k {
238                        continue;
239                    }
240                    weight += coefficients[mu][k] * overlap[(mu, nu)] * coefficients[nu][k];
241                }
242                weights[atom_mu] += weight;
243            }
244
245            let total_weight: f64 = weights.iter().sum();
246            if total_weight.abs() > 1e-12 {
247                for weight in &mut weights {
248                    *weight /= total_weight;
249                }
250            }
251            weights
252        })
253        .collect();
254
255    let step = (e_max - e_min) / (n_points - 1).max(1) as f64;
256    let energies: Vec<f64> = (0..n_points).map(|i| e_min + i as f64 * step).collect();
257
258    let norm = 1.0 / (sigma * (2.0 * std::f64::consts::PI).sqrt());
259    let inv_2s2 = 1.0 / (2.0 * sigma * sigma);
260
261    let total_dos: Vec<f64> = energies
262        .par_iter()
263        .map(|&energy| {
264            (0..n_orb)
265                .map(|k| gaussian_value(energy, orbital_energies[k], norm, inv_2s2))
266                .sum()
267        })
268        .collect();
269
270    let pdos: Vec<Vec<f64>> = (0..n_atoms)
271        .into_par_iter()
272        .map(|atom_index| {
273            energies
274                .iter()
275                .map(|&energy| {
276                    (0..n_orb)
277                        .map(|k| {
278                            orbital_atom_weight[k][atom_index]
279                                * gaussian_value(energy, orbital_energies[k], norm, inv_2s2)
280                        })
281                        .sum()
282                })
283                .collect()
284        })
285        .collect();
286
287    let _ = n_electrons;
288
289    DosResult {
290        energies,
291        total_dos,
292        pdos,
293        sigma,
294    }
295}
296
297/// Compute mean-squared error between two DOS curves.
298///
299/// Both curves must have the same length.  Useful for comparing against
300/// reference DOS (e.g. Multiwfn output).
301pub fn dos_mse(a: &[f64], b: &[f64]) -> f64 {
302    assert_eq!(a.len(), b.len(), "DOS curves must have same length");
303    let n = a.len() as f64;
304    a.iter()
305        .zip(b.iter())
306        .map(|(x, y)| (x - y).powi(2))
307        .sum::<f64>()
308        / n
309}
310
311/// Serialize DOS/PDOS result to JSON for web visualization.
312///
313/// Format:
314/// ```json
315/// {
316///   "energies": [...],
317///   "total_dos": [...],
318///   "sigma": 0.3,
319///   "pdos": { "0": [...], "1": [...], ... }
320/// }
321/// ```
322pub fn export_dos_json(result: &DosResult) -> String {
323    let mut json = String::from("{");
324    json.push_str("\"energies\":[");
325    for (i, e) in result.energies.iter().enumerate() {
326        if i > 0 {
327            json.push(',');
328        }
329        json.push_str(&format!("{:.6}", e));
330    }
331    json.push_str("],\"total_dos\":[");
332    for (i, d) in result.total_dos.iter().enumerate() {
333        if i > 0 {
334            json.push(',');
335        }
336        json.push_str(&format!("{:.6}", d));
337    }
338    json.push_str(&format!("],\"sigma\":{:.6}", result.sigma));
339    if !result.pdos.is_empty() {
340        json.push_str(",\"pdos\":{");
341        for (a, pdos_a) in result.pdos.iter().enumerate() {
342            if a > 0 {
343                json.push(',');
344            }
345            json.push_str(&format!("\"{}\":[", a));
346            for (i, v) in pdos_a.iter().enumerate() {
347                if i > 0 {
348                    json.push(',');
349                }
350                json.push_str(&format!("{:.6}", v));
351            }
352            json.push(']');
353        }
354        json.push('}');
355    }
356    json.push('}');
357    json
358}
359
360#[cfg(test)]
361mod tests {
362    use super::*;
363
364    #[test]
365    fn test_dos_single_level() {
366        // Single orbital at 0 eV → DOS should peak at 0.
367        let res = compute_dos(&[0.0], 0.1, -1.0, 1.0, 201);
368        assert_eq!(res.energies.len(), 201);
369        assert_eq!(res.total_dos.len(), 201);
370        // Peak should be at mid-point (index 100)
371        let peak_idx = res
372            .total_dos
373            .iter()
374            .enumerate()
375            .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
376            .unwrap()
377            .0;
378        assert_eq!(peak_idx, 100);
379    }
380
381    #[test]
382    fn test_dos_integral_approx_one() {
383        // Integral of DOS for one level ≈ 1 (normalized Gaussian).
384        let res = compute_dos(&[0.0], 0.2, -3.0, 3.0, 1001);
385        let de = (3.0 - (-3.0)) / 1000.0;
386        let integral: f64 = res.total_dos.iter().sum::<f64>() * de;
387        assert!((integral - 1.0).abs() < 0.01, "integral = {integral}");
388    }
389
390    #[test]
391    fn test_dos_two_peaks() {
392        let res = compute_dos(&[-5.0, 5.0], 0.3, -10.0, 10.0, 501);
393        // Should have two peaks, one near index ~125, one near ~375
394        let mid = res.total_dos[250];
395        let left_peak = res.total_dos[125];
396        let right_peak = res.total_dos[375];
397        assert!(left_peak > mid * 5.0);
398        assert!(right_peak > mid * 5.0);
399    }
400
401    #[test]
402    fn test_pdos_h2() {
403        // H₂: two atoms should have symmetric PDOS.
404        let elements = vec![1u8, 1];
405        let pos_arr = vec![[0.0, 0.0, 0.0], [0.74, 0.0, 0.0]];
406        let positions: Vec<f64> = pos_arr.iter().flat_map(|p| p.iter().copied()).collect();
407        let eht = crate::eht::solve_eht(&elements, &pos_arr, None).unwrap();
408        let res = compute_pdos(
409            &elements,
410            &positions,
411            &eht.energies,
412            &eht.coefficients,
413            eht.n_electrons,
414            0.2,
415            -20.0,
416            5.0,
417            201,
418        );
419        assert_eq!(res.pdos.len(), 2);
420        // Both H atoms should have nearly equal PDOS at peaks.
421        // The peak region is where DOS > 10% of max.
422        let peak_val = res.pdos[0].iter().cloned().fold(0.0f64, f64::max);
423        let threshold = peak_val * 0.1;
424        for i in 0..201 {
425            if res.pdos[0][i].abs() > threshold || res.pdos[1][i].abs() > threshold {
426                let diff = (res.pdos[0][i] - res.pdos[1][i]).abs();
427                let avg = (res.pdos[0][i].abs() + res.pdos[1][i].abs()) / 2.0;
428                assert!(
429                    diff < avg * 0.05 + 1e-6,
430                    "PDOS mismatch at grid point {i}: {} vs {} (peak={})",
431                    res.pdos[0][i],
432                    res.pdos[1][i],
433                    peak_val
434                );
435            }
436        }
437    }
438
439    #[test]
440    fn test_pdos_sums_to_total() {
441        // Sum of PDOS over all atoms ≈ total DOS.
442        let elements = vec![8u8, 1, 1];
443        let pos_arr = vec![[0.0, 0.0, 0.0], [0.96, 0.0, 0.0], [-0.24, 0.93, 0.0]];
444        let positions: Vec<f64> = pos_arr.iter().flat_map(|p| p.iter().copied()).collect();
445        let eht = crate::eht::solve_eht(&elements, &pos_arr, None).unwrap();
446        let res = compute_pdos(
447            &elements,
448            &positions,
449            &eht.energies,
450            &eht.coefficients,
451            eht.n_electrons,
452            0.3,
453            -30.0,
454            5.0,
455            201,
456        );
457        for i in 0..201 {
458            let pdos_sum: f64 = res.pdos.iter().map(|p| p[i]).sum();
459            let diff = (pdos_sum - res.total_dos[i]).abs();
460            assert!(
461                diff < res.total_dos[i].abs() * 0.05 + 1e-10,
462                "PDOS sum {pdos_sum} vs total {} at grid {i}",
463                res.total_dos[i]
464            );
465        }
466    }
467
468    #[test]
469    fn test_dos_mse_identical() {
470        let a = vec![1.0, 2.0, 3.0, 4.0];
471        assert!((dos_mse(&a, &a)) < 1e-15);
472    }
473
474    #[test]
475    fn test_dos_mse_known() {
476        let a = vec![1.0, 2.0, 3.0];
477        let b = vec![1.1, 1.9, 3.2];
478        // MSE = (0.01 + 0.01 + 0.04) / 3 = 0.02
479        assert!((dos_mse(&a, &b) - 0.02).abs() < 1e-10);
480    }
481
482    #[test]
483    fn test_export_dos_json_roundtrip() {
484        let res = compute_dos(&[0.0, -5.0], 0.3, -10.0, 5.0, 51);
485        let json = export_dos_json(&res);
486
487        // Should be valid JSON
488        let parsed: serde_json::Value = serde_json::from_str(&json).expect("valid JSON");
489        assert!(parsed["energies"].is_array());
490        assert!(parsed["total_dos"].is_array());
491        assert_eq!(parsed["energies"].as_array().unwrap().len(), 51);
492        assert_eq!(parsed["total_dos"].as_array().unwrap().len(), 51);
493    }
494
495    #[test]
496    fn test_export_pdos_json() {
497        let elements = vec![1u8, 1];
498        let pos_arr = vec![[0.0, 0.0, 0.0], [0.74, 0.0, 0.0]];
499        let positions: Vec<f64> = pos_arr.iter().flat_map(|p| p.iter().copied()).collect();
500        let eht = crate::eht::solve_eht(&elements, &pos_arr, None).unwrap();
501        let res = compute_pdos(
502            &elements,
503            &positions,
504            &eht.energies,
505            &eht.coefficients,
506            eht.n_electrons,
507            0.2,
508            -20.0,
509            5.0,
510            51,
511        );
512        let json = export_dos_json(&res);
513        let parsed: serde_json::Value = serde_json::from_str(&json).expect("valid JSON");
514        assert!(parsed["pdos"].is_object());
515        assert!(parsed["pdos"]["0"].is_array());
516        assert!(parsed["pdos"]["1"].is_array());
517    }
518}