sci_form/eht/
gradients.rs1use serde::{Deserialize, Serialize};
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct EhtGradient {
15 pub gradients: Vec<[f64; 3]>,
17 pub rms_gradient: f64,
19 pub max_gradient: f64,
21 pub energy: f64,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct EhtOptResult {
28 pub positions: Vec<[f64; 3]>,
30 pub energy: f64,
32 pub n_steps: usize,
34 pub converged: bool,
36 pub rms_gradient: f64,
38 pub energies: Vec<f64>,
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct EhtOptConfig {
45 pub max_steps: usize,
47 pub grad_threshold: f64,
49 pub energy_threshold: f64,
51 pub step_size: f64,
53}
54
55impl Default for EhtOptConfig {
56 fn default() -> Self {
57 Self {
58 max_steps: 200,
59 grad_threshold: 0.01,
60 energy_threshold: 1e-6,
61 step_size: 0.05,
62 }
63 }
64}
65
66pub fn compute_eht_gradient(
74 elements: &[u8],
75 positions: &[[f64; 3]],
76) -> Result<EhtGradient, String> {
77 let delta = 1e-5; let n_atoms = elements.len();
79
80 let eht_ref = crate::eht::solve_eht(elements, positions, None)?;
82 let n_occ = eht_ref.n_electrons.div_ceil(2); let is_odd = eht_ref.n_electrons % 2 == 1;
84 let e0: f64 = eht_ref
85 .energies
86 .iter()
87 .take(n_occ)
88 .enumerate()
89 .map(|(i, &e)| {
90 if is_odd && i == n_occ - 1 {
91 e } else {
93 2.0 * e
94 }
95 })
96 .sum();
97
98 let mut gradients = vec![[0.0f64; 3]; n_atoms];
99
100 for atom in 0..n_atoms {
101 for coord in 0..3 {
102 let mut pos_plus = positions.to_vec();
103 let mut pos_minus = positions.to_vec();
104 pos_plus[atom][coord] += delta;
105 pos_minus[atom][coord] -= delta;
106
107 let eht_plus = crate::eht::solve_eht(elements, &pos_plus, None)?;
108 let eht_minus = crate::eht::solve_eht(elements, &pos_minus, None)?;
109
110 let e_plus: f64 = eht_plus
111 .energies
112 .iter()
113 .take(n_occ)
114 .enumerate()
115 .map(|(i, &e)| if is_odd && i == n_occ - 1 { e } else { 2.0 * e })
116 .sum();
117 let e_minus: f64 = eht_minus
118 .energies
119 .iter()
120 .take(n_occ)
121 .enumerate()
122 .map(|(i, &e)| if is_odd && i == n_occ - 1 { e } else { 2.0 * e })
123 .sum();
124
125 gradients[atom][coord] = (e_plus - e_minus) / (2.0 * delta);
126 }
127 }
128
129 let rms = (gradients
130 .iter()
131 .flat_map(|g| g.iter())
132 .map(|x| x * x)
133 .sum::<f64>()
134 / (3 * n_atoms) as f64)
135 .sqrt();
136 let max = gradients
137 .iter()
138 .flat_map(|g| g.iter())
139 .map(|x| x.abs())
140 .fold(0.0f64, |a, b| a.max(b));
141
142 Ok(EhtGradient {
143 gradients,
144 rms_gradient: rms,
145 max_gradient: max,
146 energy: e0,
147 })
148}
149
150pub fn optimize_geometry_eht(
152 elements: &[u8],
153 initial_positions: &[[f64; 3]],
154 config: Option<EhtOptConfig>,
155) -> Result<EhtOptResult, String> {
156 let cfg = config.unwrap_or_default();
157 let n_atoms = elements.len();
158 let mut positions = initial_positions.to_vec();
159 let mut energies = Vec::new();
160 let mut converged = false;
161 let mut step_size = cfg.step_size;
162
163 let mut prev_energy = f64::MAX;
164 let mut last_rms = f64::MAX;
165
166 for step in 0..cfg.max_steps {
167 let grad = compute_eht_gradient(elements, &positions)?;
168 energies.push(grad.energy);
169 last_rms = grad.rms_gradient;
170
171 if grad.rms_gradient < cfg.grad_threshold {
173 converged = true;
174 return Ok(EhtOptResult {
175 positions,
176 energy: grad.energy,
177 n_steps: step + 1,
178 converged,
179 rms_gradient: grad.rms_gradient,
180 energies,
181 });
182 }
183
184 if step > 0 {
185 let de = (grad.energy - prev_energy).abs();
186 if de < cfg.energy_threshold {
187 converged = true;
188 return Ok(EhtOptResult {
189 positions,
190 energy: grad.energy,
191 n_steps: step + 1,
192 converged,
193 rms_gradient: grad.rms_gradient,
194 energies,
195 });
196 }
197
198 if grad.energy > prev_energy {
200 step_size *= 0.5;
201 } else {
202 step_size *= 1.1;
203 step_size = step_size.min(0.2);
204 }
205 }
206
207 prev_energy = grad.energy;
208
209 for atom in 0..n_atoms {
211 for coord in 0..3 {
212 positions[atom][coord] -= step_size * grad.gradients[atom][coord];
213 }
214 }
215 }
216
217 Ok(EhtOptResult {
218 positions,
219 energy: prev_energy,
220 n_steps: cfg.max_steps,
221 converged,
222 rms_gradient: last_rms,
223 energies,
224 })
225}
226
227#[cfg(test)]
228mod tests {
229 use super::*;
230
231 #[test]
232 fn test_eht_gradient_h2() {
233 let elements = vec![1u8, 1];
234 let positions = vec![[0.0, 0.0, 0.0], [0.74, 0.0, 0.0]];
235 let grad = compute_eht_gradient(&elements, &positions);
236 assert!(grad.is_ok());
237 let g = grad.unwrap();
238 assert_eq!(g.gradients.len(), 2);
239 assert!(g.rms_gradient.is_finite());
241 }
242}