Skip to main content

uff_relax/
optimizer.rs

1use tracing;
2use crate::forcefield::System;
3use glam::DVec3;
4use web_time::Instant;
5use std::collections::VecDeque;
6
7/// Callback for optimization steps: (iteration, f_max, energy)
8pub type StepHook = dyn Fn(usize, f64, f64) + Send + Sync;
9
10/// Optimizer for molecular structures using the FIRE (Fast Iterative Relaxation Engine) algorithm.
11pub struct UffOptimizer {
12    /// Maximum number of iterations to perform.
13    pub max_iterations: usize,
14    /// Threshold for the maximum force on any atom (kcal/mol/Å).
15    pub force_threshold: f64,
16    /// Whether to print optimization progress to tracing logs.
17    pub verbose: bool,
18    /// Number of threads to use. 0 means automatic based on system size.
19    pub num_threads: usize,
20    /// Cutoff distance for non-bonded interactions (Å).
21    pub cutoff: f64,
22    /// Number of steps to average for convergence criteria.
23    pub history_size: usize,
24    /// Maximum distance an atom can move in a single step (Å).
25    pub max_displacement: f64,
26    /// Optional hook called after each iteration.
27    pub step_hook: Option<std::sync::Arc<StepHook>>,
28    /// Optional flag to cancel optimization from another thread/context.
29    pub cancel_flag: Option<std::sync::Arc<std::sync::atomic::AtomicBool>>,
30}
31
32impl UffOptimizer {
33    /// Creates a new optimizer with default settings.
34    pub fn new(max_iterations: usize, force_threshold: f64) -> Self {
35        Self {
36            max_iterations,
37            force_threshold,
38            verbose: false,
39            num_threads: 0,
40            cutoff: 6.0,
41            history_size: 10,
42            max_displacement: 0.2,
43            step_hook: None,
44            cancel_flag: None,
45        }
46    }
47
48    pub fn with_max_displacement(mut self, max: f64) -> Self {
49        self.max_displacement = max;
50        self
51    }
52
53    pub fn with_num_threads(mut self, num_threads: usize) -> Self {
54        self.num_threads = num_threads;
55        self
56    }
57
58    pub fn with_cutoff(mut self, cutoff: f64) -> Self {
59        self.cutoff = cutoff;
60        self
61    }
62
63    pub fn with_history_size(mut self, size: usize) -> Self {
64        self.history_size = size;
65        self
66    }
67
68    pub fn with_verbose(mut self, verbose: bool) -> Self {
69        self.verbose = verbose;
70        self
71    }
72
73    pub fn with_step_hook<F>(mut self, hook: F) -> Self 
74    where F: Fn(usize, f64, f64) + Send + Sync + 'static {
75        self.step_hook = Some(std::sync::Arc::new(hook));
76        self
77    }
78
79    pub fn with_cancel_flag(mut self, flag: std::sync::Arc<std::sync::atomic::AtomicBool>) -> Self {
80        self.cancel_flag = Some(flag);
81        self
82    }
83
84    /// Optimized structural geometry using the FIRE algorithm.
85    pub fn optimize(&self, system: &mut System) {
86        let n = system.atoms.len();
87        if n == 0 { return; }
88        
89        // Initial wrap only if periodic boundary conditions exist
90        if !matches!(system.cell.cell_type, crate::cell::CellType::None) {
91            for atom in &mut system.atoms {
92                atom.position = system.cell.wrap_vector(atom.position);
93            }
94        }
95
96        let mut velocities = vec![DVec3::ZERO; n];
97        let mut dt = 0.01;
98        let dt_max = 0.05;
99        let mut n_pos = 0;
100        let mut alpha = 0.15;
101        let alpha_start = 0.15;
102
103        let mut fmax_history = VecDeque::with_capacity(self.history_size);
104        let mut frms_history = VecDeque::with_capacity(self.history_size);
105        let mut ediff_history = VecDeque::with_capacity(self.history_size);
106        let mut last_energy: Option<f64> = None;
107        
108        let start_time = Instant::now();
109
110        if self.verbose {
111            self.print_header(system);
112        }
113
114        let mut final_iter = 0;
115        let mut final_status = "Max-Iter";
116
117        for iter in 0..self.max_iterations {
118            final_iter = iter;
119
120            // Check for cancellation
121            if let Some(ref cancel) = self.cancel_flag {
122                if cancel.load(std::sync::atomic::Ordering::SeqCst) {
123                    final_status = "Cancelled";
124                    break;
125                }
126            }
127
128            #[cfg(target_arch = "wasm32")]
129            let energy = system.compute_forces_with_threads(1, self.cutoff);
130            #[cfg(not(target_arch = "wasm32"))]
131            let energy = system.compute_forces_with_threads(self.num_threads, self.cutoff);
132            
133            let (f_max, f_rms) = self.calculate_force_metrics(system);
134
135            // Update history
136            if fmax_history.len() >= self.history_size { fmax_history.pop_front(); }
137            fmax_history.push_back(f_max);
138            if frms_history.len() >= self.history_size { frms_history.pop_front(); }
139            frms_history.push_back(f_rms);
140            if let Some(prev_e) = last_energy {
141                if ediff_history.len() >= self.history_size { ediff_history.pop_front(); }
142                ediff_history.push_back((energy.total - prev_e).abs() / n as f64);
143            }
144            last_energy = Some(energy.total);
145
146            // Convergence Check
147            let (converged, status) = self.check_convergence(f_max, f_rms, &fmax_history, &frms_history, &ediff_history);
148            
149            if self.verbose && (iter % 10 == 0 || converged) {
150                if energy.total.abs() >= 1e10 {
151                    tracing::info!("{:>6} | {:>14.4} | {:>14.4} | {:>16.4e} | {:<10}", iter, f_max, f_rms, energy.total, status);
152                } else {
153                    tracing::info!("{:>6} | {:>14.4} | {:>14.4} | {:>16.4} | {:<10}", iter, f_max, f_rms, energy.total, status);
154                }
155            }
156
157            if let Some(ref hook) = self.step_hook {
158                hook(iter, f_max, energy.total);
159            }
160
161            if converged {
162                final_status = status;
163                break;
164            }
165
166            self.fire_update(system, &mut velocities, &mut dt, dt_max, &mut n_pos, &mut alpha, alpha_start);
167        }
168
169        if self.verbose {
170            self.print_footer(system, final_status, start_time, final_iter, &fmax_history, &frms_history);
171        }
172    }
173
174    /// Asynchronous version of the optimizer for non-blocking environments (Wasm/UIs).
175    pub async fn optimize_async(&self, system: &mut System) {
176        let n = system.atoms.len();
177        if n == 0 { return; }
178        
179        if !matches!(system.cell.cell_type, crate::cell::CellType::None) {
180            for atom in &mut system.atoms {
181                atom.position = system.cell.wrap_vector(atom.position);
182            }
183        }
184
185        let mut velocities = vec![DVec3::ZERO; n];
186        let mut dt = 0.01;
187        let dt_max = 0.05;
188        let mut n_pos = 0;
189        let mut alpha = 0.15;
190        let alpha_start = 0.15;
191
192        let mut fmax_history = VecDeque::with_capacity(self.history_size);
193        let mut frms_history = VecDeque::with_capacity(self.history_size);
194        let mut ediff_history = VecDeque::with_capacity(self.history_size);
195        let mut last_energy: Option<f64> = None;
196        
197        let start_time = Instant::now();
198
199        if self.verbose {
200            self.print_header(system);
201        }
202
203        let mut final_iter = 0;
204        let mut final_status = "Max-Iter";
205
206        for iter in 0..self.max_iterations {
207            final_iter = iter;
208
209            // Check for cancellation
210            if let Some(ref cancel) = self.cancel_flag {
211                if cancel.load(std::sync::atomic::Ordering::SeqCst) {
212                    final_status = "Cancelled";
213                    break;
214                }
215            }
216            
217            // In Wasm, num_threads should be 1 as Rayon is not supported easily
218            #[cfg(target_arch = "wasm32")]
219            let energy = system.compute_forces_with_threads(1, self.cutoff);
220            #[cfg(not(target_arch = "wasm32"))]
221            let energy = system.compute_forces_with_threads(self.num_threads, self.cutoff);
222            
223            let (f_max, f_rms) = self.calculate_force_metrics(system);
224
225            if fmax_history.len() >= self.history_size { fmax_history.pop_front(); }
226            fmax_history.push_back(f_max);
227            if frms_history.len() >= self.history_size { frms_history.pop_front(); }
228            frms_history.push_back(f_rms);
229            if let Some(prev_e) = last_energy {
230                if ediff_history.len() >= self.history_size { ediff_history.pop_front(); }
231                ediff_history.push_back((energy.total - prev_e).abs() / n as f64);
232            }
233            last_energy = Some(energy.total);
234
235            let (converged, status) = self.check_convergence(f_max, f_rms, &fmax_history, &frms_history, &ediff_history);
236            
237            if self.verbose && (iter % 10 == 0 || converged) {
238                if energy.total.abs() >= 1e10 {
239                    tracing::info!("{:>6} | {:>14.4} | {:>14.4} | {:>16.4e} | {:<10}", iter, f_max, f_rms, energy.total, status);
240                } else {
241                    tracing::info!("{:>6} | {:>14.4} | {:>14.4} | {:>16.4} | {:<10}", iter, f_max, f_rms, energy.total, status);
242                }
243            }
244
245            if let Some(ref hook) = self.step_hook {
246                hook(iter, f_max, energy.total);
247            }
248
249            if converged {
250                final_status = status;
251                break;
252            }
253
254            self.fire_update(system, &mut velocities, &mut dt, dt_max, &mut n_pos, &mut alpha, alpha_start);
255
256            // Yield control back to the environment periodically
257            if iter % 5 == 0 {
258                self.yield_now().await;
259            }
260        }
261
262        if self.verbose {
263            self.print_footer(system, final_status, start_time, final_iter, &fmax_history, &frms_history);
264        }
265    }
266
267    async fn yield_now(&self) {
268        #[cfg(feature = "wasm")]
269        {
270            let promise = js_sys::Promise::new(&mut |resolve, _| {
271                if let Some(window) = web_sys::window() {
272                    window.set_timeout_with_callback_and_timeout_and_arguments_0(&resolve, 0).unwrap();
273                }
274            });
275            let _ = wasm_bindgen_futures::JsFuture::from(promise).await;
276        }
277    }
278
279    fn calculate_force_metrics(&self, system: &System) -> (f64, f64) {
280        let n = system.atoms.len();
281        let mut max_f_sq: f64 = 0.0;
282        let mut sum_f_sq: f64 = 0.0;
283        for atom in &system.atoms {
284            let f_sq = atom.force.length_squared();
285            max_f_sq = f64::max(max_f_sq, f_sq);
286            sum_f_sq += f_sq;
287        }
288        (max_f_sq.sqrt(), (sum_f_sq / (3.0 * n as f64)).sqrt())
289    }
290
291    fn check_convergence(&self, _f_max: f64, _f_rms: f64, fmax_hist: &VecDeque<f64>, frms_hist: &VecDeque<f64>, ediff_hist: &VecDeque<f64>) -> (bool, &'static str) {
292        if fmax_hist.len() < self.history_size {
293            return (false, "");
294        }
295        let avg_fmax: f64 = fmax_hist.iter().sum::<f64>() / self.history_size as f64;
296        let avg_frms: f64 = frms_hist.iter().sum::<f64>() / self.history_size as f64;
297        let avg_ediff: f64 = if ediff_hist.is_empty() { 1.0 } else { ediff_hist.iter().sum::<f64>() / ediff_hist.len() as f64 };
298
299        if avg_fmax < self.force_threshold {
300            (true, "Fmax-Conv")
301        } else if avg_fmax < self.force_threshold * 2.0 && avg_frms < self.force_threshold * 0.5 {
302            (true, "FRMS-Conv")
303        } else if !ediff_hist.is_empty() && avg_ediff < 1e-7 {
304            (true, "E-Stalled")
305        } else {
306            (false, "")
307        }
308    }
309
310    fn fire_update(&self, system: &mut System, velocities: &mut [DVec3], dt: &mut f64, dt_max: f64, n_pos: &mut usize, alpha: &mut f64, alpha_start: f64) {
311        let n = system.atoms.len();
312        let mut p = 0.0;
313        for i in 0..n {
314            p += velocities[i].dot(system.atoms[i].force);
315        }
316
317        for i in 0..n {
318            let f_norm = system.atoms[i].force.length();
319            let v_norm = velocities[i].length();
320            if f_norm > 1e-9 {
321                velocities[i] = (1.0 - *alpha) * velocities[i] + *alpha * (system.atoms[i].force / f_norm) * v_norm;
322            }
323        }
324
325        if p > 0.0 {
326            *n_pos += 1;
327            if *n_pos > 5 {
328                *dt = f64::min(*dt * 1.05, dt_max);
329                *alpha *= 0.99;
330            }
331        } else {
332            *n_pos = 0;
333            *dt *= 0.5;
334            *alpha = alpha_start;
335            for v in velocities.iter_mut() {
336                *v = DVec3::ZERO;
337            }
338        }
339
340        for i in 0..n {
341            velocities[i] += system.atoms[i].force * (*dt);
342            let mut move_vec = velocities[i] * (*dt);
343            let move_len = move_vec.length();
344            if move_len > self.max_displacement {
345                move_vec *= self.max_displacement / move_len;
346                velocities[i] = move_vec / (*dt);
347            }
348            let new_pos = system.atoms[i].position + move_vec;
349            system.atoms[i].position = system.cell.wrap_vector(new_pos);
350        }
351    }
352
353    fn print_header(&self, system: &System) {
354        let n_atoms = system.atoms.len();
355        let n_bonds = system.bonds.len();
356        let has_charges = system.atoms.iter().any(|a| a.charge.abs() > 1e-12);
357        
358        // Determine actual threads used
359        #[cfg(target_arch = "wasm32")]
360        let actual_threads = 1;
361        
362        #[cfg(not(target_arch = "wasm32"))]
363        let actual_threads = if self.num_threads == 1 {
364            1
365        } else if self.num_threads > 1 {
366            self.num_threads
367        } else if n_atoms >= 1000 { // PARALLEL_THRESHOLD
368            std::env::var("RAYON_NUM_THREADS")
369                .ok()
370                .and_then(|s| s.parse().ok())
371                .unwrap_or(4)
372        } else {
373            1
374        };
375
376        let version_str = format!(" uff-relax v{} ", env!("CARGO_PKG_VERSION"));
377        tracing::info!("\n{:=^80}", version_str);
378        tracing::info!("{:<10} {:<10} | {:<10} {:<10}", "Atoms:", n_atoms, "Bonds:", n_bonds);
379        tracing::info!("{:<10} {:<10.1} | {:<10} {:<10.4} kcal/mol/Å", "Cutoff:", self.cutoff, "Threshold:", self.force_threshold);
380        tracing::info!("{:<10} {:<10} | {:<10} {:<10}", 
381            "Threads:", actual_threads, 
382            "Charges:", if has_charges { "Active (Wolf)" } else { "Inactive" }
383        );
384        tracing::info!("{:<10} {:<10} | {:<10} {:<10}", "Max Iter:", self.max_iterations, "", "");
385        tracing::info!("{:-<80}", "");
386        tracing::info!("{:<6} | {:<14} | {:<14} | {:<16} | {:<10}", "", "Fmax", "FRMS", "Total E", "");
387        tracing::info!("{:<6} | {:<14} | {:<14} | {:<16} | {:<10}", "Iter", "(kcal/mol/Å)", "(kcal/mol/Å)", "(kcal/mol)", "Status");
388        tracing::info!("{:-<80}", "");
389    }
390
391    fn print_footer(&self, system: &mut System, final_status: &str, start_time: Instant, final_iter: usize, fmax_hist: &VecDeque<f64>, frms_hist: &VecDeque<f64>) {
392        let n = system.atoms.len();
393        let duration = start_time.elapsed();
394        let final_energy = system.compute_forces_with_threads(self.num_threads, self.cutoff);
395
396        let mut min_dist = f64::MAX;
397        let mut min_pair = (0, 0);
398        for i in 0..n {
399            for j in i + 1..n {
400                let d = system.cell.distance_vector(system.atoms[i].position, system.atoms[j].position).length();
401                if d < min_dist { 
402                    min_dist = d;
403                    min_pair = (i, j);
404                }
405            }
406        }
407
408        tracing::info!("{:-<80}", "");
409        tracing::info!("=== Optimization Finished ===");
410        tracing::info!("Reason: {:<20}", final_status);
411        tracing::info!("Total Time: {:<10.3?} (Avg: {:.3?} / step)", duration, duration / (final_iter + 1) as u32);
412        if final_energy.total.abs() >= 1e10 {
413            tracing::info!("Final Energy: {:<15.4e} kcal/mol", final_energy.total);
414        } else {
415            tracing::info!("Final Energy: {:<15.4} kcal/mol", final_energy.total);
416        }
417        tracing::info!("Final Fmax:   {:<15.4} kcal/mol/Å", fmax_hist.back().unwrap_or(&0.0));
418        tracing::info!("Final FRMS:   {:<15.4} kcal/mol/Å", frms_hist.back().unwrap_or(&0.0));
419        tracing::info!("Min Distance: {:<15.4} Å (Atoms {} and {})", min_dist, min_pair.0 + 1, min_pair.1 + 1);
420        tracing::info!("{:>80}", "(c) 2026 Forblaze Project");
421        tracing::info!("{:-<80}\n", "");
422    }
423}