1use tracing;
2use crate::forcefield::System;
3use glam::DVec3;
4use web_time::Instant;
5use std::collections::VecDeque;
6
7pub type StepHook = dyn Fn(usize, f64, f64) + Send + Sync;
9
10pub struct UffOptimizer {
12 pub max_iterations: usize,
14 pub force_threshold: f64,
16 pub verbose: bool,
18 pub num_threads: usize,
20 pub cutoff: f64,
22 pub history_size: usize,
24 pub max_displacement: f64,
26 pub step_hook: Option<std::sync::Arc<StepHook>>,
28 pub cancel_flag: Option<std::sync::Arc<std::sync::atomic::AtomicBool>>,
30}
31
32impl UffOptimizer {
33 pub fn new(max_iterations: usize, force_threshold: f64) -> Self {
35 Self {
36 max_iterations,
37 force_threshold,
38 verbose: false,
39 num_threads: 0,
40 cutoff: 6.0,
41 history_size: 10,
42 max_displacement: 0.2,
43 step_hook: None,
44 cancel_flag: None,
45 }
46 }
47
48 pub fn with_max_displacement(mut self, max: f64) -> Self {
49 self.max_displacement = max;
50 self
51 }
52
53 pub fn with_num_threads(mut self, num_threads: usize) -> Self {
54 self.num_threads = num_threads;
55 self
56 }
57
58 pub fn with_cutoff(mut self, cutoff: f64) -> Self {
59 self.cutoff = cutoff;
60 self
61 }
62
63 pub fn with_history_size(mut self, size: usize) -> Self {
64 self.history_size = size;
65 self
66 }
67
68 pub fn with_verbose(mut self, verbose: bool) -> Self {
69 self.verbose = verbose;
70 self
71 }
72
73 pub fn with_step_hook<F>(mut self, hook: F) -> Self
74 where F: Fn(usize, f64, f64) + Send + Sync + 'static {
75 self.step_hook = Some(std::sync::Arc::new(hook));
76 self
77 }
78
79 pub fn with_cancel_flag(mut self, flag: std::sync::Arc<std::sync::atomic::AtomicBool>) -> Self {
80 self.cancel_flag = Some(flag);
81 self
82 }
83
84 pub fn optimize(&self, system: &mut System) {
86 let n = system.atoms.len();
87 if n == 0 { return; }
88
89 if !matches!(system.cell.cell_type, crate::cell::CellType::None) {
91 for atom in &mut system.atoms {
92 atom.position = system.cell.wrap_vector(atom.position);
93 }
94 }
95
96 let mut velocities = vec![DVec3::ZERO; n];
97 let mut dt = 0.01;
98 let dt_max = 0.05;
99 let mut n_pos = 0;
100 let mut alpha = 0.15;
101 let alpha_start = 0.15;
102
103 let mut fmax_history = VecDeque::with_capacity(self.history_size);
104 let mut frms_history = VecDeque::with_capacity(self.history_size);
105 let mut ediff_history = VecDeque::with_capacity(self.history_size);
106 let mut last_energy: Option<f64> = None;
107
108 let start_time = Instant::now();
109
110 if self.verbose {
111 self.print_header(system);
112 }
113
114 let mut final_iter = 0;
115 let mut final_status = "Max-Iter";
116
117 for iter in 0..self.max_iterations {
118 final_iter = iter;
119
120 if let Some(ref cancel) = self.cancel_flag {
122 if cancel.load(std::sync::atomic::Ordering::SeqCst) {
123 final_status = "Cancelled";
124 break;
125 }
126 }
127
128 #[cfg(target_arch = "wasm32")]
129 let energy = system.compute_forces_with_threads(1, self.cutoff);
130 #[cfg(not(target_arch = "wasm32"))]
131 let energy = system.compute_forces_with_threads(self.num_threads, self.cutoff);
132
133 let (f_max, f_rms) = self.calculate_force_metrics(system);
134
135 if fmax_history.len() >= self.history_size { fmax_history.pop_front(); }
137 fmax_history.push_back(f_max);
138 if frms_history.len() >= self.history_size { frms_history.pop_front(); }
139 frms_history.push_back(f_rms);
140 if let Some(prev_e) = last_energy {
141 if ediff_history.len() >= self.history_size { ediff_history.pop_front(); }
142 ediff_history.push_back((energy.total - prev_e).abs() / n as f64);
143 }
144 last_energy = Some(energy.total);
145
146 let (converged, status) = self.check_convergence(f_max, f_rms, &fmax_history, &frms_history, &ediff_history);
148
149 if self.verbose && (iter % 10 == 0 || converged) {
150 if energy.total.abs() >= 1e10 {
151 tracing::info!("{:>6} | {:>14.4} | {:>14.4} | {:>16.4e} | {:<10}", iter, f_max, f_rms, energy.total, status);
152 } else {
153 tracing::info!("{:>6} | {:>14.4} | {:>14.4} | {:>16.4} | {:<10}", iter, f_max, f_rms, energy.total, status);
154 }
155 }
156
157 if let Some(ref hook) = self.step_hook {
158 hook(iter, f_max, energy.total);
159 }
160
161 if converged {
162 final_status = status;
163 break;
164 }
165
166 self.fire_update(system, &mut velocities, &mut dt, dt_max, &mut n_pos, &mut alpha, alpha_start);
167 }
168
169 if self.verbose {
170 self.print_footer(system, final_status, start_time, final_iter, &fmax_history, &frms_history);
171 }
172 }
173
174 pub async fn optimize_async(&self, system: &mut System) {
176 let n = system.atoms.len();
177 if n == 0 { return; }
178
179 if !matches!(system.cell.cell_type, crate::cell::CellType::None) {
180 for atom in &mut system.atoms {
181 atom.position = system.cell.wrap_vector(atom.position);
182 }
183 }
184
185 let mut velocities = vec![DVec3::ZERO; n];
186 let mut dt = 0.01;
187 let dt_max = 0.05;
188 let mut n_pos = 0;
189 let mut alpha = 0.15;
190 let alpha_start = 0.15;
191
192 let mut fmax_history = VecDeque::with_capacity(self.history_size);
193 let mut frms_history = VecDeque::with_capacity(self.history_size);
194 let mut ediff_history = VecDeque::with_capacity(self.history_size);
195 let mut last_energy: Option<f64> = None;
196
197 let start_time = Instant::now();
198
199 if self.verbose {
200 self.print_header(system);
201 }
202
203 let mut final_iter = 0;
204 let mut final_status = "Max-Iter";
205
206 for iter in 0..self.max_iterations {
207 final_iter = iter;
208
209 if let Some(ref cancel) = self.cancel_flag {
211 if cancel.load(std::sync::atomic::Ordering::SeqCst) {
212 final_status = "Cancelled";
213 break;
214 }
215 }
216
217 #[cfg(target_arch = "wasm32")]
219 let energy = system.compute_forces_with_threads(1, self.cutoff);
220 #[cfg(not(target_arch = "wasm32"))]
221 let energy = system.compute_forces_with_threads(self.num_threads, self.cutoff);
222
223 let (f_max, f_rms) = self.calculate_force_metrics(system);
224
225 if fmax_history.len() >= self.history_size { fmax_history.pop_front(); }
226 fmax_history.push_back(f_max);
227 if frms_history.len() >= self.history_size { frms_history.pop_front(); }
228 frms_history.push_back(f_rms);
229 if let Some(prev_e) = last_energy {
230 if ediff_history.len() >= self.history_size { ediff_history.pop_front(); }
231 ediff_history.push_back((energy.total - prev_e).abs() / n as f64);
232 }
233 last_energy = Some(energy.total);
234
235 let (converged, status) = self.check_convergence(f_max, f_rms, &fmax_history, &frms_history, &ediff_history);
236
237 if self.verbose && (iter % 10 == 0 || converged) {
238 if energy.total.abs() >= 1e10 {
239 tracing::info!("{:>6} | {:>14.4} | {:>14.4} | {:>16.4e} | {:<10}", iter, f_max, f_rms, energy.total, status);
240 } else {
241 tracing::info!("{:>6} | {:>14.4} | {:>14.4} | {:>16.4} | {:<10}", iter, f_max, f_rms, energy.total, status);
242 }
243 }
244
245 if let Some(ref hook) = self.step_hook {
246 hook(iter, f_max, energy.total);
247 }
248
249 if converged {
250 final_status = status;
251 break;
252 }
253
254 self.fire_update(system, &mut velocities, &mut dt, dt_max, &mut n_pos, &mut alpha, alpha_start);
255
256 if iter % 5 == 0 {
258 self.yield_now().await;
259 }
260 }
261
262 if self.verbose {
263 self.print_footer(system, final_status, start_time, final_iter, &fmax_history, &frms_history);
264 }
265 }
266
267 async fn yield_now(&self) {
268 #[cfg(feature = "wasm")]
269 {
270 let promise = js_sys::Promise::new(&mut |resolve, _| {
271 if let Some(window) = web_sys::window() {
272 window.set_timeout_with_callback_and_timeout_and_arguments_0(&resolve, 0).unwrap();
273 }
274 });
275 let _ = wasm_bindgen_futures::JsFuture::from(promise).await;
276 }
277 }
278
279 fn calculate_force_metrics(&self, system: &System) -> (f64, f64) {
280 let n = system.atoms.len();
281 let mut max_f_sq: f64 = 0.0;
282 let mut sum_f_sq: f64 = 0.0;
283 for atom in &system.atoms {
284 let f_sq = atom.force.length_squared();
285 max_f_sq = f64::max(max_f_sq, f_sq);
286 sum_f_sq += f_sq;
287 }
288 (max_f_sq.sqrt(), (sum_f_sq / (3.0 * n as f64)).sqrt())
289 }
290
291 fn check_convergence(&self, _f_max: f64, _f_rms: f64, fmax_hist: &VecDeque<f64>, frms_hist: &VecDeque<f64>, ediff_hist: &VecDeque<f64>) -> (bool, &'static str) {
292 if fmax_hist.len() < self.history_size {
293 return (false, "");
294 }
295 let avg_fmax: f64 = fmax_hist.iter().sum::<f64>() / self.history_size as f64;
296 let avg_frms: f64 = frms_hist.iter().sum::<f64>() / self.history_size as f64;
297 let avg_ediff: f64 = if ediff_hist.is_empty() { 1.0 } else { ediff_hist.iter().sum::<f64>() / ediff_hist.len() as f64 };
298
299 if avg_fmax < self.force_threshold {
300 (true, "Fmax-Conv")
301 } else if avg_fmax < self.force_threshold * 2.0 && avg_frms < self.force_threshold * 0.5 {
302 (true, "FRMS-Conv")
303 } else if !ediff_hist.is_empty() && avg_ediff < 1e-7 {
304 (true, "E-Stalled")
305 } else {
306 (false, "")
307 }
308 }
309
310 fn fire_update(&self, system: &mut System, velocities: &mut [DVec3], dt: &mut f64, dt_max: f64, n_pos: &mut usize, alpha: &mut f64, alpha_start: f64) {
311 let n = system.atoms.len();
312 let mut p = 0.0;
313 for i in 0..n {
314 p += velocities[i].dot(system.atoms[i].force);
315 }
316
317 for i in 0..n {
318 let f_norm = system.atoms[i].force.length();
319 let v_norm = velocities[i].length();
320 if f_norm > 1e-9 {
321 velocities[i] = (1.0 - *alpha) * velocities[i] + *alpha * (system.atoms[i].force / f_norm) * v_norm;
322 }
323 }
324
325 if p > 0.0 {
326 *n_pos += 1;
327 if *n_pos > 5 {
328 *dt = f64::min(*dt * 1.05, dt_max);
329 *alpha *= 0.99;
330 }
331 } else {
332 *n_pos = 0;
333 *dt *= 0.5;
334 *alpha = alpha_start;
335 for v in velocities.iter_mut() {
336 *v = DVec3::ZERO;
337 }
338 }
339
340 for i in 0..n {
341 velocities[i] += system.atoms[i].force * (*dt);
342 let mut move_vec = velocities[i] * (*dt);
343 let move_len = move_vec.length();
344 if move_len > self.max_displacement {
345 move_vec *= self.max_displacement / move_len;
346 velocities[i] = move_vec / (*dt);
347 }
348 let new_pos = system.atoms[i].position + move_vec;
349 system.atoms[i].position = system.cell.wrap_vector(new_pos);
350 }
351 }
352
353 fn print_header(&self, system: &System) {
354 let n_atoms = system.atoms.len();
355 let n_bonds = system.bonds.len();
356 let has_charges = system.atoms.iter().any(|a| a.charge.abs() > 1e-12);
357
358 #[cfg(target_arch = "wasm32")]
360 let actual_threads = 1;
361
362 #[cfg(not(target_arch = "wasm32"))]
363 let actual_threads = if self.num_threads == 1 {
364 1
365 } else if self.num_threads > 1 {
366 self.num_threads
367 } else if n_atoms >= 1000 { std::env::var("RAYON_NUM_THREADS")
369 .ok()
370 .and_then(|s| s.parse().ok())
371 .unwrap_or(4)
372 } else {
373 1
374 };
375
376 let version_str = format!(" uff-relax v{} ", env!("CARGO_PKG_VERSION"));
377 tracing::info!("\n{:=^80}", version_str);
378 tracing::info!("{:<10} {:<10} | {:<10} {:<10}", "Atoms:", n_atoms, "Bonds:", n_bonds);
379 tracing::info!("{:<10} {:<10.1} | {:<10} {:<10.4} kcal/mol/Å", "Cutoff:", self.cutoff, "Threshold:", self.force_threshold);
380 tracing::info!("{:<10} {:<10} | {:<10} {:<10}",
381 "Threads:", actual_threads,
382 "Charges:", if has_charges { "Active (Wolf)" } else { "Inactive" }
383 );
384 tracing::info!("{:<10} {:<10} | {:<10} {:<10}", "Max Iter:", self.max_iterations, "", "");
385 tracing::info!("{:-<80}", "");
386 tracing::info!("{:<6} | {:<14} | {:<14} | {:<16} | {:<10}", "", "Fmax", "FRMS", "Total E", "");
387 tracing::info!("{:<6} | {:<14} | {:<14} | {:<16} | {:<10}", "Iter", "(kcal/mol/Å)", "(kcal/mol/Å)", "(kcal/mol)", "Status");
388 tracing::info!("{:-<80}", "");
389 }
390
391 fn print_footer(&self, system: &mut System, final_status: &str, start_time: Instant, final_iter: usize, fmax_hist: &VecDeque<f64>, frms_hist: &VecDeque<f64>) {
392 let n = system.atoms.len();
393 let duration = start_time.elapsed();
394 let final_energy = system.compute_forces_with_threads(self.num_threads, self.cutoff);
395
396 let mut min_dist = f64::MAX;
397 let mut min_pair = (0, 0);
398 for i in 0..n {
399 for j in i + 1..n {
400 let d = system.cell.distance_vector(system.atoms[i].position, system.atoms[j].position).length();
401 if d < min_dist {
402 min_dist = d;
403 min_pair = (i, j);
404 }
405 }
406 }
407
408 tracing::info!("{:-<80}", "");
409 tracing::info!("=== Optimization Finished ===");
410 tracing::info!("Reason: {:<20}", final_status);
411 tracing::info!("Total Time: {:<10.3?} (Avg: {:.3?} / step)", duration, duration / (final_iter + 1) as u32);
412 if final_energy.total.abs() >= 1e10 {
413 tracing::info!("Final Energy: {:<15.4e} kcal/mol", final_energy.total);
414 } else {
415 tracing::info!("Final Energy: {:<15.4} kcal/mol", final_energy.total);
416 }
417 tracing::info!("Final Fmax: {:<15.4} kcal/mol/Å", fmax_hist.back().unwrap_or(&0.0));
418 tracing::info!("Final FRMS: {:<15.4} kcal/mol/Å", frms_hist.back().unwrap_or(&0.0));
419 tracing::info!("Min Distance: {:<15.4} Å (Atoms {} and {})", min_dist, min_pair.0 + 1, min_pair.1 + 1);
420 tracing::info!("{:>80}", "(c) 2026 Forblaze Project");
421 tracing::info!("{:-<80}\n", "");
422 }
423}