1use serde::{Deserialize, Serialize};
11
12use crate::eht;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
16pub enum MeshMethod {
17 Eht,
19 Pm3,
21 Xtb,
23 Hf3c,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct OrbitalMeshResult {
30 pub mesh: eht::IsosurfaceMesh,
32 pub grid: eht::VolumetricGrid,
34 pub method: MeshMethod,
36 pub mo_index: usize,
38 pub homo_index: usize,
40 pub orbital_energies: Vec<f64>,
42 pub gap: f64,
44}
45
46fn compute_eht_mesh(
50 elements: &[u8],
51 positions: &[[f64; 3]],
52 mo_index: usize,
53 spacing: f64,
54 padding: f64,
55 isovalue: f32,
56) -> Result<OrbitalMeshResult, String> {
57 let result = eht::solve_eht(elements, positions, None)?;
58 let basis = eht::basis::build_basis(elements, positions);
59 let grid = eht::evaluate_orbital_on_grid(
60 &basis,
61 &result.coefficients,
62 mo_index,
63 positions,
64 spacing,
65 padding,
66 );
67 let mesh = eht::marching_cubes(&grid, isovalue);
68
69 Ok(OrbitalMeshResult {
70 mesh,
71 grid,
72 method: MeshMethod::Eht,
73 mo_index,
74 homo_index: result.homo_index,
75 orbital_energies: result.energies.clone(),
76 gap: result.gap,
77 })
78}
79
80fn compute_pm3_mesh(
83 elements: &[u8],
84 positions: &[[f64; 3]],
85 mo_index: usize,
86 spacing: f64,
87 padding: f64,
88 isovalue: f32,
89) -> Result<OrbitalMeshResult, String> {
90 let pm3 = crate::pm3::solve_pm3(elements, positions)?;
91 let eht_result = eht::solve_eht(elements, positions, None)?;
92 let basis = eht::basis::build_basis(elements, positions);
93 let grid = eht::evaluate_orbital_on_grid(
94 &basis,
95 &eht_result.coefficients,
96 mo_index,
97 positions,
98 spacing,
99 padding,
100 );
101 let mesh = eht::marching_cubes(&grid, isovalue);
102
103 let homo_idx = if pm3.n_electrons > 0 {
104 pm3.n_electrons / 2 - 1
105 } else {
106 0
107 };
108
109 Ok(OrbitalMeshResult {
110 mesh,
111 grid,
112 method: MeshMethod::Pm3,
113 mo_index,
114 homo_index: homo_idx,
115 orbital_energies: pm3.orbital_energies,
116 gap: pm3.gap,
117 })
118}
119
120fn compute_xtb_mesh(
123 elements: &[u8],
124 positions: &[[f64; 3]],
125 mo_index: usize,
126 spacing: f64,
127 padding: f64,
128 isovalue: f32,
129) -> Result<OrbitalMeshResult, String> {
130 let xtb = crate::xtb::solve_xtb(elements, positions)?;
131 let eht_result = eht::solve_eht(elements, positions, None)?;
132 let basis = eht::basis::build_basis(elements, positions);
133 let grid = eht::evaluate_orbital_on_grid(
134 &basis,
135 &eht_result.coefficients,
136 mo_index,
137 positions,
138 spacing,
139 padding,
140 );
141 let mesh = eht::marching_cubes(&grid, isovalue);
142
143 let homo_idx = if xtb.n_electrons > 0 {
144 xtb.n_electrons / 2 - 1
145 } else {
146 0
147 };
148
149 Ok(OrbitalMeshResult {
150 mesh,
151 grid,
152 method: MeshMethod::Xtb,
153 mo_index,
154 homo_index: homo_idx,
155 orbital_energies: xtb.orbital_energies,
156 gap: xtb.gap,
157 })
158}
159
160fn compute_hf3c_mesh(
163 elements: &[u8],
164 positions: &[[f64; 3]],
165 mo_index: usize,
166 spacing: f64,
167 padding: f64,
168 isovalue: f32,
169) -> Result<OrbitalMeshResult, String> {
170 let config = crate::hf::HfConfig::default();
171 let hf = crate::hf::api::solve_hf3c(elements, positions, &config)?;
172 let eht_result = eht::solve_eht(elements, positions, None)?;
173 let basis = eht::basis::build_basis(elements, positions);
174 let grid = eht::evaluate_orbital_on_grid(
175 &basis,
176 &eht_result.coefficients,
177 mo_index,
178 positions,
179 spacing,
180 padding,
181 );
182 let mesh = eht::marching_cubes(&grid, isovalue);
183
184 let n_electrons: usize = elements.iter().map(|&z| z as usize).sum();
186 let homo_idx = if n_electrons > 0 {
187 n_electrons / 2 - 1
188 } else {
189 0
190 };
191
192 let gap = if hf.orbital_energies.len() > homo_idx + 1 {
193 hf.orbital_energies[homo_idx + 1] - hf.orbital_energies[homo_idx]
194 } else {
195 0.0
196 };
197
198 Ok(OrbitalMeshResult {
199 mesh,
200 grid,
201 method: MeshMethod::Hf3c,
202 mo_index,
203 homo_index: homo_idx,
204 orbital_energies: hf.orbital_energies,
205 gap,
206 })
207}
208
209pub fn compute_orbital_mesh(
219 elements: &[u8],
220 positions: &[[f64; 3]],
221 method: MeshMethod,
222 mo_index: usize,
223 spacing: f64,
224 padding: f64,
225 isovalue: f32,
226) -> Result<OrbitalMeshResult, String> {
227 match method {
228 MeshMethod::Eht => {
229 compute_eht_mesh(elements, positions, mo_index, spacing, padding, isovalue)
230 }
231 MeshMethod::Pm3 => {
232 compute_pm3_mesh(elements, positions, mo_index, spacing, padding, isovalue)
233 }
234 MeshMethod::Xtb => {
235 compute_xtb_mesh(elements, positions, mo_index, spacing, padding, isovalue)
236 }
237 MeshMethod::Hf3c => {
238 compute_hf3c_mesh(elements, positions, mo_index, spacing, padding, isovalue)
239 }
240 }
241}
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246
247 fn water_positions() -> (Vec<u8>, Vec<[f64; 3]>) {
248 (
249 vec![8, 1, 1],
250 vec![[0.0, 0.0, 0.0], [0.0, 0.757, 0.587], [0.0, -0.757, 0.587]],
251 )
252 }
253
254 #[test]
255 fn test_eht_mesh_water() {
256 let (elements, positions) = water_positions();
257 let result =
258 compute_orbital_mesh(&elements, &positions, MeshMethod::Eht, 0, 0.4, 3.0, 0.02);
259 assert!(result.is_ok());
260 let r = result.unwrap();
261 assert!(r.grid.num_points() > 0);
262 assert_eq!(r.method, MeshMethod::Eht);
263 }
264
265 #[test]
266 fn test_pm3_mesh_water() {
267 let (elements, positions) = water_positions();
268 let result =
269 compute_orbital_mesh(&elements, &positions, MeshMethod::Pm3, 0, 0.4, 3.0, 0.02);
270 assert!(result.is_ok());
271 let r = result.unwrap();
272 assert!(r.grid.num_points() > 0);
273 assert_eq!(r.method, MeshMethod::Pm3);
274 }
275
276 #[test]
277 fn test_xtb_mesh_water() {
278 let (elements, positions) = water_positions();
279 let result =
280 compute_orbital_mesh(&elements, &positions, MeshMethod::Xtb, 0, 0.4, 3.0, 0.02);
281 assert!(result.is_ok());
282 let r = result.unwrap();
283 assert!(r.grid.num_points() > 0);
284 assert_eq!(r.method, MeshMethod::Xtb);
285 }
286
287 #[test]
288 fn test_hf3c_mesh_water() {
289 let (elements, positions) = water_positions();
290 let result =
291 compute_orbital_mesh(&elements, &positions, MeshMethod::Hf3c, 0, 0.4, 3.0, 0.02);
292 assert!(result.is_ok());
293 let r = result.unwrap();
294 assert!(r.grid.num_points() > 0);
295 assert_eq!(r.method, MeshMethod::Hf3c);
296 }
297}