1use tracing;
2use crate::forcefield::System;
3use crate::forcefield::interactions::get_bond_equilibrium_dist;
4use glam::DVec3;
5use web_time::Instant;
6use std::collections::VecDeque;
7
8pub type StepHook = dyn Fn(usize, f64, f64) + Send + Sync;
10
11pub struct UffOptimizer {
13 pub max_iterations: usize,
15 pub force_threshold: f64,
17 pub verbose: bool,
19 pub num_threads: usize,
21 pub cutoff: f64,
23 pub history_size: usize,
25 pub max_displacement: f64,
27 pub step_hook: Option<std::sync::Arc<StepHook>>,
29 pub cancel_flag: Option<std::sync::Arc<std::sync::atomic::AtomicBool>>,
31}
32
33impl UffOptimizer {
34 pub fn new(max_iterations: usize, force_threshold: f64) -> Self {
36 Self {
37 max_iterations,
38 force_threshold,
39 verbose: false,
40 num_threads: 0,
41 cutoff: 6.0,
42 history_size: 10,
43 max_displacement: 0.2,
44 step_hook: None,
45 cancel_flag: None,
46 }
47 }
48
49 pub fn with_max_displacement(mut self, max: f64) -> Self {
50 self.max_displacement = max;
51 self
52 }
53
54 pub fn with_num_threads(mut self, num_threads: usize) -> Self {
55 self.num_threads = num_threads;
56 self
57 }
58
59 pub fn with_cutoff(mut self, cutoff: f64) -> Self {
60 self.cutoff = cutoff;
61 self
62 }
63
64 pub fn with_history_size(mut self, size: usize) -> Self {
65 self.history_size = size;
66 self
67 }
68
69 pub fn with_verbose(mut self, verbose: bool) -> Self {
70 self.verbose = verbose;
71 self
72 }
73
74 pub fn with_step_hook<F>(mut self, hook: F) -> Self
75 where F: Fn(usize, f64, f64) + Send + Sync + 'static {
76 self.step_hook = Some(std::sync::Arc::new(hook));
77 self
78 }
79
80 pub fn with_cancel_flag(mut self, flag: std::sync::Arc<std::sync::atomic::AtomicBool>) -> Self {
81 self.cancel_flag = Some(flag);
82 self
83 }
84
85 pub fn optimize(&self, system: &mut System) {
87 let n = system.atoms.len();
88 if n == 0 { return; }
89
90 if !matches!(system.cell.cell_type, crate::cell::CellType::None) {
92 for atom in &mut system.atoms {
93 atom.position = system.cell.wrap_vector(atom.position);
94 }
95 }
96
97 let mut velocities = vec![DVec3::ZERO; n];
98 let mut dt = 0.01;
99 let dt_max = 0.05;
100 let mut n_pos = 0;
101 let mut alpha = 0.15;
102 let alpha_start = 0.15;
103
104 let mut fmax_history = VecDeque::with_capacity(self.history_size);
105 let mut frms_history = VecDeque::with_capacity(self.history_size);
106 let mut ediff_history = VecDeque::with_capacity(self.history_size);
107 let mut last_energy: Option<f64> = None;
108
109 let start_time = Instant::now();
110
111 if self.verbose {
112 self.print_header(system);
113 }
114
115 let mut final_iter = 0;
116 let mut final_status = "Max-Iter";
117
118 for iter in 0..self.max_iterations {
119 final_iter = iter;
120
121 if let Some(ref cancel) = self.cancel_flag {
123 if cancel.load(std::sync::atomic::Ordering::SeqCst) {
124 final_status = "Cancelled";
125 break;
126 }
127 }
128
129 #[cfg(target_arch = "wasm32")]
130 let energy = system.compute_forces_with_threads(1, self.cutoff);
131 #[cfg(not(target_arch = "wasm32"))]
132 let energy = system.compute_forces_with_threads(self.num_threads, self.cutoff);
133
134 let (f_max, f_rms) = self.calculate_force_metrics(system);
135
136 if fmax_history.len() >= self.history_size { fmax_history.pop_front(); }
138 fmax_history.push_back(f_max);
139 if frms_history.len() >= self.history_size { frms_history.pop_front(); }
140 frms_history.push_back(f_rms);
141 if let Some(prev_e) = last_energy {
142 if ediff_history.len() >= self.history_size { ediff_history.pop_front(); }
143 ediff_history.push_back((energy.total - prev_e).abs() / n as f64);
144 }
145 last_energy = Some(energy.total);
146
147 let (converged, status) = self.check_convergence(f_max, f_rms, &fmax_history, &frms_history, &ediff_history);
149
150 if self.verbose && (iter % 10 == 0 || converged) {
151 if energy.total.abs() >= 1e10 {
152 tracing::info!("{:>6} | {:>14.4} | {:>14.4} | {:>16.4e} | {:<10}", iter, f_max, f_rms, energy.total, status);
153 } else {
154 tracing::info!("{:>6} | {:>14.4} | {:>14.4} | {:>16.4} | {:<10}", iter, f_max, f_rms, energy.total, status);
155 }
156 }
157
158 if let Some(ref hook) = self.step_hook {
159 hook(iter, f_max, energy.total);
160 }
161
162 if converged {
163 final_status = status;
164 break;
165 }
166
167 self.fire_update(system, &mut velocities, &mut dt, dt_max, &mut n_pos, &mut alpha, alpha_start);
168 }
169
170 if self.verbose {
171 self.print_footer(system, final_status, start_time, final_iter, &fmax_history, &frms_history);
172 }
173 }
174
175 pub async fn optimize_async(&self, system: &mut System) {
177 let n = system.atoms.len();
178 if n == 0 { return; }
179
180 if !matches!(system.cell.cell_type, crate::cell::CellType::None) {
181 for atom in &mut system.atoms {
182 atom.position = system.cell.wrap_vector(atom.position);
183 }
184 }
185
186 let mut velocities = vec![DVec3::ZERO; n];
187 let mut dt = 0.01;
188 let dt_max = 0.05;
189 let mut n_pos = 0;
190 let mut alpha = 0.15;
191 let alpha_start = 0.15;
192
193 let mut fmax_history = VecDeque::with_capacity(self.history_size);
194 let mut frms_history = VecDeque::with_capacity(self.history_size);
195 let mut ediff_history = VecDeque::with_capacity(self.history_size);
196 let mut last_energy: Option<f64> = None;
197
198 let start_time = Instant::now();
199
200 if self.verbose {
201 self.print_header(system);
202 }
203
204 let mut final_iter = 0;
205 let mut final_status = "Max-Iter";
206
207 for iter in 0..self.max_iterations {
208 final_iter = iter;
209
210 if let Some(ref cancel) = self.cancel_flag {
212 if cancel.load(std::sync::atomic::Ordering::SeqCst) {
213 final_status = "Cancelled";
214 break;
215 }
216 }
217
218 #[cfg(target_arch = "wasm32")]
220 let energy = system.compute_forces_with_threads(1, self.cutoff);
221 #[cfg(not(target_arch = "wasm32"))]
222 let energy = system.compute_forces_with_threads(self.num_threads, self.cutoff);
223
224 let (f_max, f_rms) = self.calculate_force_metrics(system);
225
226 if fmax_history.len() >= self.history_size { fmax_history.pop_front(); }
227 fmax_history.push_back(f_max);
228 if frms_history.len() >= self.history_size { frms_history.pop_front(); }
229 frms_history.push_back(f_rms);
230 if let Some(prev_e) = last_energy {
231 if ediff_history.len() >= self.history_size { ediff_history.pop_front(); }
232 ediff_history.push_back((energy.total - prev_e).abs() / n as f64);
233 }
234 last_energy = Some(energy.total);
235
236 let (converged, status) = self.check_convergence(f_max, f_rms, &fmax_history, &frms_history, &ediff_history);
237
238 if self.verbose && (iter % 10 == 0 || converged) {
239 if energy.total.abs() >= 1e10 {
240 tracing::info!("{:>6} | {:>14.4} | {:>14.4} | {:>16.4e} | {:<10}", iter, f_max, f_rms, energy.total, status);
241 } else {
242 tracing::info!("{:>6} | {:>14.4} | {:>14.4} | {:>16.4} | {:<10}", iter, f_max, f_rms, energy.total, status);
243 }
244 }
245
246 if let Some(ref hook) = self.step_hook {
247 hook(iter, f_max, energy.total);
248 }
249
250 if converged {
251 final_status = status;
252 break;
253 }
254
255 self.fire_update(system, &mut velocities, &mut dt, dt_max, &mut n_pos, &mut alpha, alpha_start);
256
257 if iter % 5 == 0 {
259 self.yield_now().await;
260 }
261 }
262
263 if self.verbose {
264 self.print_footer(system, final_status, start_time, final_iter, &fmax_history, &frms_history);
265 }
266 }
267
268 async fn yield_now(&self) {
269 #[cfg(feature = "wasm")]
270 {
271 let promise = js_sys::Promise::new(&mut |resolve, _| {
272 if let Some(window) = web_sys::window() {
273 window.set_timeout_with_callback_and_timeout_and_arguments_0(&resolve, 0).unwrap();
274 }
275 });
276 let _ = wasm_bindgen_futures::JsFuture::from(promise).await;
277 }
278 }
279
280 fn calculate_force_metrics(&self, system: &System) -> (f64, f64) {
281 let n = system.atoms.len();
282 let mut max_f_sq: f64 = 0.0;
283 let mut sum_f_sq: f64 = 0.0;
284 for atom in &system.atoms {
285 let f_sq = atom.force.length_squared();
286 max_f_sq = f64::max(max_f_sq, f_sq);
287 sum_f_sq += f_sq;
288 }
289 (max_f_sq.sqrt(), (sum_f_sq / (3.0 * n as f64)).sqrt())
290 }
291
292 fn check_convergence(&self, _f_max: f64, _f_rms: f64, fmax_hist: &VecDeque<f64>, frms_hist: &VecDeque<f64>, ediff_hist: &VecDeque<f64>) -> (bool, &'static str) {
293 if fmax_hist.len() < self.history_size {
294 return (false, "");
295 }
296 let avg_fmax: f64 = fmax_hist.iter().sum::<f64>() / self.history_size as f64;
297 let avg_frms: f64 = frms_hist.iter().sum::<f64>() / self.history_size as f64;
298 let avg_ediff: f64 = if ediff_hist.is_empty() { 1.0 } else { ediff_hist.iter().sum::<f64>() / ediff_hist.len() as f64 };
299
300 if avg_fmax < self.force_threshold {
301 (true, "Fmax-Conv")
302 } else if avg_fmax < self.force_threshold * 2.0 && avg_frms < self.force_threshold * 0.5 {
303 (true, "FRMS-Conv")
304 } else if !ediff_hist.is_empty() && avg_ediff < 1e-7 {
305 (true, "E-Stalled")
306 } else {
307 (false, "")
308 }
309 }
310
311 fn fire_update(&self, system: &mut System, velocities: &mut [DVec3], dt: &mut f64, dt_max: f64, n_pos: &mut usize, alpha: &mut f64, alpha_start: f64) {
312 let n = system.atoms.len();
313 let mut p = 0.0;
314 for i in 0..n {
315 p += velocities[i].dot(system.atoms[i].force);
316 }
317
318 for i in 0..n {
319 let f_norm = system.atoms[i].force.length();
320 let v_norm = velocities[i].length();
321 if f_norm > 1e-9 {
322 velocities[i] = (1.0 - *alpha) * velocities[i] + *alpha * (system.atoms[i].force / f_norm) * v_norm;
323 }
324 }
325
326 if p > 0.0 {
327 *n_pos += 1;
328 if *n_pos > 5 {
329 *dt = f64::min(*dt * 1.05, dt_max);
330 *alpha *= 0.99;
331 }
332 } else {
333 *n_pos = 0;
334 *dt *= 0.5;
335 *alpha = alpha_start;
336 for v in velocities.iter_mut() {
337 *v = DVec3::ZERO;
338 }
339 }
340
341 for i in 0..n {
342 velocities[i] += system.atoms[i].force * (*dt);
343 let mut move_vec = velocities[i] * (*dt);
344 let move_len = move_vec.length();
345 if move_len > self.max_displacement {
346 move_vec *= self.max_displacement / move_len;
347 velocities[i] = move_vec / (*dt);
348 }
349 let new_pos = system.atoms[i].position + move_vec;
350 system.atoms[i].position = system.cell.wrap_vector(new_pos);
351 }
352 }
353
354 fn print_header(&self, system: &System) {
355 let n_atoms = system.atoms.len();
356 let n_bonds = system.bonds.len();
357 let has_charges = system.atoms.iter().any(|a| a.charge.abs() > 1e-12);
358
359 #[cfg(target_arch = "wasm32")]
361 let actual_threads = 1;
362
363 #[cfg(not(target_arch = "wasm32"))]
364 let actual_threads = if self.num_threads == 1 {
365 1
366 } else if self.num_threads > 1 {
367 self.num_threads
368 } else if n_atoms >= 1000 { std::env::var("RAYON_NUM_THREADS")
370 .ok()
371 .and_then(|s| s.parse().ok())
372 .unwrap_or(4)
373 } else {
374 1
375 };
376
377 let version_str = format!(" uff-relax v{} ", env!("CARGO_PKG_VERSION"));
378 tracing::info!("\n{:=^80}", version_str);
379 tracing::info!("{:<10} {:<10} | {:<10} {:<10}", "Atoms:", n_atoms, "Bonds:", n_bonds);
380 tracing::info!("{:<10} {:<10.1} | {:<10} {:<10.4} kcal/mol/Å", "Cutoff:", self.cutoff, "Threshold:", self.force_threshold);
381 tracing::info!("{:<10} {:<10} | {:<10} {:<10}",
382 "Threads:", actual_threads,
383 "Charges:", if has_charges { "Active (Wolf)" } else { "Inactive" }
384 );
385 tracing::info!("{:<10} {:<10} | {:<10} {:<10}", "Max Iter:", self.max_iterations, "", "");
386 tracing::info!("{:-<80}", "");
387 tracing::info!("{:<6} | {:<14} | {:<14} | {:<16} | {:<10}", "", "Fmax", "FRMS", "Total E", "");
388 tracing::info!("{:<6} | {:<14} | {:<14} | {:<16} | {:<10}", "Iter", "(kcal/mol/Å)", "(kcal/mol/Å)", "(kcal/mol)", "Status");
389 tracing::info!("{:-<80}", "");
390 }
391
392 fn print_footer(&self, system: &mut System, final_status: &str, start_time: Instant, final_iter: usize, fmax_hist: &VecDeque<f64>, frms_hist: &VecDeque<f64>) {
393 let n = system.atoms.len();
394 let duration = start_time.elapsed();
395 let final_energy = system.compute_forces_with_threads(self.num_threads, self.cutoff);
396
397 let mut min_dist = f64::MAX;
398 let mut min_pair = (0, 0);
399 for i in 0..n {
400 for j in i + 1..n {
401 let d = system.cell.distance_vector(system.atoms[i].position, system.atoms[j].position).length();
402 if d < min_dist {
403 min_dist = d;
404 min_pair = (i, j);
405 }
406 }
407 }
408
409 let mut abnormal_bonds = Vec::new();
411 for bond in &system.bonds {
412 let (i, j) = bond.atom_indices;
413 let current_dist = system.cell.distance_vector(system.atoms[i].position, system.atoms[j].position).length();
414 if let Some(r0) = get_bond_equilibrium_dist(&system.atoms[i].uff_type, &system.atoms[j].uff_type, bond.order) {
415 if current_dist > r0 * 1.3 {
416 abnormal_bonds.push((i, j, current_dist, r0));
417 }
418 }
419 }
420
421 tracing::info!("{:-<80}", "");
422 tracing::info!("=== Optimization Finished ===");
423 tracing::info!("Reason: {:<20}", final_status);
424 tracing::info!("Total Time: {:<10.3?} (Avg: {:.3?} / step)", duration, duration / (final_iter + 1) as u32);
425 if final_energy.total.abs() >= 1e10 {
426 tracing::info!("Final Energy: {:<15.4e} kcal/mol", final_energy.total);
427 } else {
428 tracing::info!("Final Energy: {:<15.4} kcal/mol", final_energy.total);
429 }
430 tracing::info!("Final Fmax: {:<15.4} kcal/mol/Å", fmax_hist.back().unwrap_or(&0.0));
431 tracing::info!("Final FRMS: {:<15.4} kcal/mol/Å", frms_hist.back().unwrap_or(&0.0));
432 tracing::info!("Min Distance: {:<15.4} Å (Atoms {} and {})", min_dist, min_pair.0 + 1, min_pair.1 + 1);
433
434 if !abnormal_bonds.is_empty() {
435 tracing::warn!("!!! ABNORMAL BONDS DETECTED ({} total) !!!", abnormal_bonds.len());
436 for (i, j, dist, r0) in abnormal_bonds.iter().take(3) {
437 tracing::warn!(" Bond {}-{} : Length {:.4} Å (Equiv: {:.4} Å, Dev: {:.1}%)",
438 i+1, j+1, dist, r0, (dist/r0 - 1.0)*100.0);
439 }
440 if abnormal_bonds.len() > 3 {
441 tracing::warn!(" ... and {} more abnormal bonds.", abnormal_bonds.len() - 3);
442 }
443 }
444
445 tracing::info!("{:>80}", "(c) 2026 Forblaze Project");
446 tracing::info!("{:-<80}\n", "");
447 }
448}