1use super::dos::compute_dos;
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
11pub enum DosMethod {
12 Eht,
14 Pm3,
16 Xtb,
18 Gfn1,
20 Gfn2,
22 Hf3c,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct MultiMethodDosResult {
29 pub energies: Vec<f64>,
31 pub total_dos: Vec<f64>,
33 pub method: DosMethod,
35 pub homo_energy: f64,
37 pub lumo_energy: f64,
39 pub gap: f64,
41 pub orbital_energies: Vec<f64>,
43 pub sigma: f64,
45}
46
47pub fn compute_dos_multimethod(
52 elements: &[u8],
53 positions: &[[f64; 3]],
54 method: DosMethod,
55 sigma: f64,
56 e_min: f64,
57 e_max: f64,
58 n_points: usize,
59) -> Result<MultiMethodDosResult, String> {
60 let (orbital_energies, homo_energy, lumo_energy, gap) = match method {
62 DosMethod::Eht => {
63 let eht_result = crate::eht::solve_eht(elements, positions, None)?;
64 let homo = eht_result.homo_energy;
65 let lumo = eht_result.lumo_energy;
66 (eht_result.energies, homo, lumo, eht_result.gap)
67 }
68 DosMethod::Pm3 => {
69 let pm3 = crate::compute_pm3(elements, positions)?;
70 let homo = pm3.homo_energy;
71 let lumo = pm3.lumo_energy;
72 (pm3.orbital_energies, homo, lumo, pm3.gap)
73 }
74 DosMethod::Xtb => {
75 let xtb = crate::xtb::solve_xtb(elements, positions)?;
76 let homo = xtb.homo_energy;
77 let lumo = xtb.lumo_energy;
78 (xtb.orbital_energies, homo, lumo, xtb.gap)
79 }
80 DosMethod::Gfn1 => {
81 let gfn1 = crate::xtb::gfn1::solve_gfn1(elements, positions)?;
82 let homo = gfn1.homo_energy;
83 let lumo = gfn1.lumo_energy;
84 (gfn1.orbital_energies, homo, lumo, gfn1.gap)
85 }
86 DosMethod::Gfn2 => {
87 let gfn2 = crate::xtb::gfn2::solve_gfn2(elements, positions)?;
88 let homo = gfn2.homo_energy;
89 let lumo = gfn2.lumo_energy;
90 (gfn2.orbital_energies, homo, lumo, gfn2.gap)
91 }
92 DosMethod::Hf3c => {
93 let config = crate::hf::HfConfig::default();
94 let hf = crate::hf::solve_hf3c(elements, positions, &config)?;
95 let total_e: usize = elements.iter().map(|&z| z as usize).sum();
97 let n_occ = total_e / 2;
98 let homo = if n_occ > 0 && n_occ <= hf.orbital_energies.len() {
99 hf.orbital_energies[n_occ - 1] * 27.2114 } else {
101 0.0
102 };
103 let lumo = if n_occ < hf.orbital_energies.len() {
104 hf.orbital_energies[n_occ] * 27.2114
105 } else {
106 0.0
107 };
108 let gap = lumo - homo;
109 let oe_ev: Vec<f64> = hf.orbital_energies.iter().map(|e| e * 27.2114).collect();
111 (oe_ev, homo, lumo, gap)
112 }
113 };
114
115 let dos = compute_dos(&orbital_energies, sigma, e_min, e_max, n_points);
117
118 Ok(MultiMethodDosResult {
119 energies: dos.energies,
120 total_dos: dos.total_dos,
121 method,
122 homo_energy,
123 lumo_energy,
124 gap,
125 orbital_energies,
126 sigma,
127 })
128}
129
130#[cfg(test)]
131mod tests {
132 use super::*;
133
134 #[test]
135 fn test_dos_method_eht() {
136 let elements = vec![8u8, 1, 1];
137 let positions = vec![
138 [0.0, 0.0, 0.117],
139 [0.0, 0.757, -0.469],
140 [0.0, -0.757, -0.469],
141 ];
142 let result =
143 compute_dos_multimethod(&elements, &positions, DosMethod::Eht, 0.3, -30.0, 5.0, 100);
144 assert!(result.is_ok());
145 let r = result.unwrap();
146 assert_eq!(r.energies.len(), 100);
147 assert!(r.gap > 0.0);
148 }
149}