Skip to main content

uff_relax/
optimizer.rs

1use crate::forcefield::System;
2use glam::DVec3;
3
4/// Optimizer for molecular structures using the FIRE (Fast Iterative Relaxation Engine) algorithm.
5pub struct UffOptimizer {
6    /// Maximum number of iterations to perform.
7    pub max_iterations: usize,
8    /// Threshold for the maximum force on any atom (kcal/mol/Å).
9    pub force_threshold: f64,
10    /// Whether to print optimization progress to stdout.
11    pub verbose: bool,
12    /// Number of threads to use. 0 means automatic based on system size.
13    pub num_threads: usize,
14    /// Cutoff distance for non-bonded interactions (Å).
15    pub cutoff: f64,
16    /// Number of steps to average for convergence criteria.
17    pub history_size: usize,
18    /// Maximum distance an atom can move in a single step (Å).
19    pub max_displacement: f64,
20}
21
22impl UffOptimizer {
23    /// Creates a new optimizer with default settings.
24    pub fn new(max_iterations: usize, force_threshold: f64) -> Self {
25        Self {
26            max_iterations,
27            force_threshold,
28            verbose: false,
29            num_threads: 0,
30            cutoff: 6.0,
31            history_size: 10,
32            max_displacement: 0.5,
33        }
34    }
35
36    pub fn with_max_displacement(mut self, max: f64) -> Self {
37        self.max_displacement = max;
38        self
39    }
40
41    pub fn with_num_threads(mut self, num_threads: usize) -> Self {
42        self.num_threads = num_threads;
43        self
44    }
45
46    pub fn with_cutoff(mut self, cutoff: f64) -> Self {
47        self.cutoff = cutoff;
48        self
49    }
50
51    pub fn with_history_size(mut self, size: usize) -> Self {
52        self.history_size = size;
53        self
54    }
55
56    pub fn with_verbose(mut self, verbose: bool) -> Self {
57        self.verbose = verbose;
58        self
59    }
60
61    /// Optimized structural geometry using the FIRE algorithm.
62    pub fn optimize(&self, system: &mut System) {
63        let n = system.atoms.len();
64        
65        // Initial wrap only if periodic boundary conditions exist
66        if !matches!(system.cell.cell_type, crate::cell::CellType::None) {
67            for atom in &mut system.atoms {
68                atom.position = system.cell.wrap_vector(atom.position);
69            }
70        }
71
72        let mut velocities = vec![DVec3::ZERO; n];
73        
74        let mut dt = 0.02;
75        let dt_max = 0.2;
76        let mut n_pos = 0;
77        let mut alpha = 0.1;
78        let alpha_start = 0.1;
79
80        // Convergence history
81        let mut fmax_history = std::collections::VecDeque::with_capacity(self.history_size);
82        let mut frms_history = std::collections::VecDeque::with_capacity(self.history_size);
83        let mut ediff_history = std::collections::VecDeque::with_capacity(self.history_size);
84        let mut last_energy: Option<f64> = None;
85        
86        let start_time = std::time::Instant::now();
87
88        if self.verbose {
89            let version_str = format!(" uff-relax v{} ", env!("CARGO_PKG_VERSION"));
90            println!("\n{:=^80}", version_str);
91            println!("{:<10} {:<10} | {:<10} {:<10}", "Atoms:", n, "Bonds:", system.bonds.len());
92            println!("{:<10} {:<10.1} | {:<10} {:<10.4} kcal/mol/Å", "Cutoff:", self.cutoff, "Threshold:", self.force_threshold);
93            println!("{:<10} {:<10} | {:<10} {:<10}", "Max Iter:", self.max_iterations, "Threads:", if self.num_threads == 0 { "Auto".to_string() } else { self.num_threads.to_string() });
94            println!("{:-<80}", "");
95            println!("{:<6} | {:<14} | {:<14} | {:<16} | {:<10}", "", "Fmax", "FRMS", "Total E", "");
96            println!("{:<6} | {:<14} | {:<14} | {:<16} | {:<10}", "Iter", "(kcal/mol/Å)", "(kcal/mol/Å)", "(kcal/mol)", "Status");
97            println!("{:-<80}", "");
98        }
99
100        let mut final_iter = 0;
101        let mut final_status = "Max-Iter";
102
103        for iter in 0..self.max_iterations {
104            final_iter = iter;
105            let energy = system.compute_forces_with_threads(self.num_threads, self.cutoff);
106            
107            // Calculate Fmax and FRMS
108            let mut max_f_sq: f64 = 0.0;
109            let mut sum_f_sq: f64 = 0.0;
110            for atom in &system.atoms {
111                let f_sq = atom.force.length_squared();
112                max_f_sq = f64::max(max_f_sq, f_sq);
113                sum_f_sq += f_sq;
114            }
115            let f_max = max_f_sq.sqrt();
116            let f_rms = (sum_f_sq / (3.0 * n as f64)).sqrt();
117
118            // Update history
119            if fmax_history.len() >= self.history_size { fmax_history.pop_front(); }
120            fmax_history.push_back(f_max);
121            
122            if frms_history.len() >= self.history_size { frms_history.pop_front(); }
123            frms_history.push_back(f_rms);
124
125            if let Some(prev_e) = last_energy {
126                if ediff_history.len() >= self.history_size { ediff_history.pop_front(); }
127                ediff_history.push_back((energy.total - prev_e).abs() / n as f64);
128            }
129            last_energy = Some(energy.total);
130
131            // Convergence Check
132            let mut converged = false;
133            let mut status = "";
134            if fmax_history.len() >= self.history_size {
135                let avg_fmax: f64 = fmax_history.iter().sum::<f64>() / self.history_size as f64;
136                let avg_frms: f64 = frms_history.iter().sum::<f64>() / self.history_size as f64;
137                let avg_ediff: f64 = if ediff_history.is_empty() { 1.0 } else { ediff_history.iter().sum::<f64>() / ediff_history.len() as f64 };
138
139                if avg_fmax < self.force_threshold {
140                    converged = true;
141                    status = "Fmax-Conv";
142                } else if avg_fmax < self.force_threshold * 2.0 && avg_frms < self.force_threshold * 0.5 {
143                    converged = true;
144                    status = "FRMS-Conv";
145                } else if !ediff_history.is_empty() && avg_ediff < 1e-7 {
146                    converged = true;
147                    status = "E-Stalled";
148                }
149            }
150            
151            if self.verbose && (iter % 10 == 0 || converged) {
152                println!("{:>6} | {:>14.4} | {:>14.4} | {:>16.4} | {:<10}", iter, f_max, f_rms, energy.total, status);
153            }
154
155            if converged {
156                final_status = status;
157                break;
158            }
159
160            // FIRE logic
161            let mut p = 0.0;
162            for i in 0..n {
163                p += velocities[i].dot(system.atoms[i].force);
164            }
165
166            for i in 0..n {
167                let f_norm = system.atoms[i].force.length();
168                let v_norm = velocities[i].length();
169                if f_norm > 1e-9 {
170                    velocities[i] = (1.0 - alpha) * velocities[i] + alpha * (system.atoms[i].force / f_norm) * v_norm;
171                }
172            }
173
174            if p > 0.0 {
175                n_pos += 1;
176                if n_pos > 5 {
177                    dt = f64::min(dt * 1.1, dt_max);
178                    alpha *= 0.99;
179                }
180            } else {
181                n_pos = 0;
182                dt *= 0.5;
183                alpha = alpha_start;
184                for v in &mut velocities {
185                    *v = DVec3::ZERO;
186                }
187            }
188
189            // Semi-implicit Euler integration (Standard for FIRE)
190            for i in 0..n {
191                velocities[i] += system.atoms[i].force * dt;
192                let mut move_vec = velocities[i] * dt;
193                
194                // Displacement clamping
195                let move_len = move_vec.length();
196                if move_len > self.max_displacement {
197                    move_vec *= self.max_displacement / move_len;
198                    velocities[i] = move_vec / dt; // Sync velocity
199                }
200
201                let new_pos = system.atoms[i].position + move_vec;
202                system.atoms[i].position = system.cell.wrap_vector(new_pos);
203            }
204        }
205
206                if self.verbose {
207
208                    let duration = start_time.elapsed();
209
210                    let final_energy = system.compute_forces_with_threads(self.num_threads, self.cutoff);
211
212                    
213
214                                // Calculate minimum interatomic distance
215
216                    
217
218                                let mut min_dist = f64::MAX;
219
220                    
221
222                                let mut min_pair = (0, 0);
223
224                    
225
226                                for i in 0..n {
227
228                    
229
230                                    for j in i + 1..n {
231
232                    
233
234                                        let d = system.cell.distance_vector(system.atoms[i].position, system.atoms[j].position).length();
235
236                    
237
238                                        if d < min_dist { 
239
240                    
241
242                                            min_dist = d;
243
244                    
245
246                                            min_pair = (i, j);
247
248                    
249
250                                        }
251
252                    
253
254                                    }
255
256                    
257
258                                }
259
260                    
261
262                    
263
264                    
265
266                                println!("{:-<80}", "");
267
268                    
269
270                                println!("=== Optimization Finished ===");
271
272                    
273
274                                println!("Reason: {:<20}", final_status);
275
276                    
277
278                                println!("Total Time: {:<10.3?} (Avg: {:.3?} / step)", duration, duration / (final_iter + 1) as u32);
279
280                    
281
282                                println!("Final Energy: {:<15.4} kcal/mol", final_energy.total);
283
284                    
285
286                                println!("Final Fmax:   {:<15.4} kcal/mol/Å", fmax_history.back().unwrap_or(&0.0));
287
288                    
289
290                                println!("Final FRMS:   {:<15.4} kcal/mol/Å", frms_history.back().unwrap_or(&0.0));
291
292                    
293
294                                println!("Min Distance: {:<15.4} Å (Atoms {} and {})", min_dist, min_pair.0 + 1, min_pair.1 + 1);
295
296                    
297
298                    
299
300                    println!("{:>80}", "(c) 2026 Forblaze Project");
301
302                    println!("{:-<80}\n", "");
303
304                }
305
306            }
307
308        }
309
310