Skip to main content

sci_form/materials/
geometry_opt.rs

1//! Framework geometry optimization under periodic boundary conditions.
2//!
3//! Optimizes atomic positions within a fixed unit cell, useful for
4//! MOF linker optimization, crystal surface relaxation, and post-assembly
5//! refinement of framework structures.
6//!
7//! Supports:
8//! - Cartesian and fractional coordinate optimization
9//! - BFGS quasi-Newton with line search
10//! - Steepest descent fallback
11//! - Fixed-atom constraints
12//! - Minimum image convention for periodic forces
13
14use serde::{Deserialize, Serialize};
15
16/// Optimization method.
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum OptMethod {
19    /// Steepest descent (robust, slow convergence).
20    SteepestDescent,
21    /// BFGS quasi-Newton (fast convergence near minimum).
22    Bfgs,
23}
24
25/// Configuration for framework geometry optimization.
26#[derive(Debug, Clone)]
27pub struct FrameworkOptConfig {
28    /// Optimization method.
29    pub method: OptMethod,
30    /// Maximum number of iterations.
31    pub max_iter: usize,
32    /// Force convergence threshold (eV/Å).
33    pub force_tol: f64,
34    /// Energy convergence threshold (eV).
35    pub energy_tol: f64,
36    /// Maximum step size (Å).
37    pub max_step: f64,
38    /// Indices of atoms whose positions are fixed.
39    pub fixed_atoms: Vec<usize>,
40    /// Unit cell lattice vectors (3×3, row-major). If None, non-periodic.
41    pub lattice: Option<[[f64; 3]; 3]>,
42}
43
44impl Default for FrameworkOptConfig {
45    fn default() -> Self {
46        Self {
47            method: OptMethod::Bfgs,
48            max_iter: 200,
49            force_tol: 0.05,
50            energy_tol: 1e-6,
51            max_step: 0.2,
52            fixed_atoms: vec![],
53            lattice: None,
54        }
55    }
56}
57
58/// Result of framework geometry optimization.
59#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct FrameworkOptResult {
61    /// Optimized positions (Cartesian, Å).
62    pub positions: Vec<[f64; 3]>,
63    /// Final energy (eV).
64    pub energy: f64,
65    /// Final forces (eV/Å).
66    pub forces: Vec<[f64; 3]>,
67    /// Maximum force magnitude at convergence (eV/Å).
68    pub max_force: f64,
69    /// Number of iterations performed.
70    pub n_iterations: usize,
71    /// Whether optimization converged.
72    pub converged: bool,
73    /// Energy trajectory.
74    pub energy_history: Vec<f64>,
75}
76
77/// Energy and force function type.
78/// Takes (elements, positions) → (energy, forces).
79pub type EnergyForceFn = dyn Fn(&[u8], &[[f64; 3]]) -> Result<(f64, Vec<[f64; 3]>), String>;
80
81/// Optimize framework geometry using the specified method.
82///
83/// # Arguments
84/// * `elements` - Atomic numbers
85/// * `initial_positions` - Starting coordinates (Å)
86/// * `energy_force_fn` - Function computing energy and forces
87/// * `config` - Optimization configuration
88pub fn optimize_framework(
89    elements: &[u8],
90    initial_positions: &[[f64; 3]],
91    energy_force_fn: &EnergyForceFn,
92    config: &FrameworkOptConfig,
93) -> Result<FrameworkOptResult, String> {
94    match config.method {
95        OptMethod::SteepestDescent => {
96            optimize_steepest_descent(elements, initial_positions, energy_force_fn, config)
97        }
98        OptMethod::Bfgs => optimize_bfgs(elements, initial_positions, energy_force_fn, config),
99    }
100}
101
102/// Steepest descent optimizer.
103fn optimize_steepest_descent(
104    elements: &[u8],
105    initial_positions: &[[f64; 3]],
106    energy_force_fn: &EnergyForceFn,
107    config: &FrameworkOptConfig,
108) -> Result<FrameworkOptResult, String> {
109    let n = elements.len();
110    let mut positions: Vec<[f64; 3]> = initial_positions.to_vec();
111    let mut energy_history = Vec::new();
112    let mut step_size = config.max_step;
113
114    let (mut energy, mut forces) = energy_force_fn(elements, &positions)?;
115    energy_history.push(energy);
116
117    let mut converged = false;
118    let mut n_iter = 0;
119
120    for iter in 0..config.max_iter {
121        n_iter = iter + 1;
122
123        // Zero forces on fixed atoms
124        zero_fixed_forces(&mut forces, &config.fixed_atoms);
125
126        let max_force = max_force_magnitude(&forces);
127        if max_force < config.force_tol {
128            converged = true;
129            break;
130        }
131
132        // Take step along force direction
133        let mut new_positions = positions.clone();
134        for i in 0..n {
135            if config.fixed_atoms.contains(&i) {
136                continue;
137            }
138            let f_mag = (forces[i][0].powi(2) + forces[i][1].powi(2) + forces[i][2].powi(2)).sqrt();
139            if f_mag < 1e-12 {
140                continue;
141            }
142            let scale = step_size / f_mag;
143            for d in 0..3 {
144                new_positions[i][d] += forces[i][d] * scale;
145            }
146        }
147
148        // Apply minimum image convention if periodic
149        if let Some(ref lattice) = config.lattice {
150            apply_pbc(&mut new_positions, lattice);
151        }
152
153        let (new_energy, new_forces) = energy_force_fn(elements, &new_positions)?;
154
155        if new_energy < energy {
156            positions = new_positions;
157            energy = new_energy;
158            forces = new_forces;
159            step_size = (step_size * 1.2).min(config.max_step);
160        } else {
161            step_size *= 0.5;
162            if step_size < 1e-10 {
163                break;
164            }
165        }
166
167        energy_history.push(energy);
168
169        if energy_history.len() > 1 {
170            let de = (energy_history[energy_history.len() - 2] - energy).abs();
171            if de < config.energy_tol {
172                converged = true;
173                break;
174            }
175        }
176    }
177
178    zero_fixed_forces(&mut forces, &config.fixed_atoms);
179    let max_force = max_force_magnitude(&forces);
180
181    Ok(FrameworkOptResult {
182        positions,
183        energy,
184        forces,
185        max_force,
186        n_iterations: n_iter,
187        converged,
188        energy_history,
189    })
190}
191
192/// BFGS quasi-Newton optimizer with approximate inverse Hessian.
193fn optimize_bfgs(
194    elements: &[u8],
195    initial_positions: &[[f64; 3]],
196    energy_force_fn: &EnergyForceFn,
197    config: &FrameworkOptConfig,
198) -> Result<FrameworkOptResult, String> {
199    let n = elements.len();
200    let ndim = n * 3;
201    let mut positions: Vec<[f64; 3]> = initial_positions.to_vec();
202    let mut energy_history = Vec::new();
203
204    // Initial evaluation
205    let (mut energy, mut forces) = energy_force_fn(elements, &positions)?;
206    zero_fixed_forces(&mut forces, &config.fixed_atoms);
207    energy_history.push(energy);
208
209    // Flatten gradient (negative force)
210    let mut grad = flatten_neg_forces(&forces);
211
212    // Initialize inverse Hessian as identity
213    let mut h_inv = vec![vec![0.0f64; ndim]; ndim];
214    for i in 0..ndim {
215        h_inv[i][i] = 1.0;
216    }
217
218    // Zero columns/rows for fixed atoms
219    for &fixed in &config.fixed_atoms {
220        for d in 0..3 {
221            let idx = fixed * 3 + d;
222            if idx < ndim {
223                for j in 0..ndim {
224                    h_inv[idx][j] = 0.0;
225                    h_inv[j][idx] = 0.0;
226                }
227            }
228        }
229    }
230
231    let mut converged = false;
232    let mut n_iter = 0;
233
234    for iter in 0..config.max_iter {
235        n_iter = iter + 1;
236
237        let max_force = max_force_magnitude(&forces);
238        if max_force < config.force_tol {
239            converged = true;
240            break;
241        }
242
243        // Search direction: p = -H_inv * grad
244        let mut p = vec![0.0f64; ndim];
245        for i in 0..ndim {
246            for j in 0..ndim {
247                p[i] -= h_inv[i][j] * grad[j];
248            }
249        }
250
251        // Limit step size
252        let p_norm: f64 = p.iter().map(|x| x * x).sum::<f64>().sqrt();
253        if p_norm > config.max_step {
254            let scale = config.max_step / p_norm;
255            for x in &mut p {
256                *x *= scale;
257            }
258        }
259
260        // Take step with Armijo backtracking line search
261        let directional_deriv: f64 = p.iter().zip(grad.iter()).map(|(a, b)| a * b).sum();
262        let c_armijo = 1e-4;
263        let mut alpha = 1.0;
264        let mut new_positions;
265        let mut new_energy;
266        let mut new_forces;
267
268        loop {
269            new_positions = positions.clone();
270            for i in 0..n {
271                if config.fixed_atoms.contains(&i) {
272                    continue;
273                }
274                for d in 0..3 {
275                    new_positions[i][d] += alpha * p[i * 3 + d];
276                }
277            }
278
279            if let Some(ref lattice) = config.lattice {
280                apply_pbc(&mut new_positions, lattice);
281            }
282
283            let result = energy_force_fn(elements, &new_positions)?;
284            new_energy = result.0;
285            new_forces = result.1;
286
287            // Armijo condition: f(x + α*p) <= f(x) + c * α * ∇f·p
288            if new_energy <= energy + c_armijo * alpha * directional_deriv || alpha < 0.1 {
289                break;
290            }
291            alpha *= 0.5;
292        }
293
294        zero_fixed_forces(&mut new_forces, &config.fixed_atoms);
295
296        let new_grad = flatten_neg_forces(&new_forces);
297
298        // BFGS update of inverse Hessian
299        let s: Vec<f64> = p; // step
300        let y: Vec<f64> = (0..ndim).map(|i| new_grad[i] - grad[i]).collect();
301
302        let sy: f64 = s.iter().zip(y.iter()).map(|(a, b)| a * b).sum();
303
304        if sy > 1e-12 {
305            // H_inv update: Sherman-Morrison-Woodbury
306            let mut hy = vec![0.0f64; ndim];
307            for i in 0..ndim {
308                for j in 0..ndim {
309                    hy[i] += h_inv[i][j] * y[j];
310                }
311            }
312
313            let yhy: f64 = y.iter().zip(hy.iter()).map(|(a, b)| a * b).sum();
314            let rho = 1.0 / sy;
315
316            for i in 0..ndim {
317                for j in 0..ndim {
318                    h_inv[i][j] +=
319                        rho * ((1.0 + yhy * rho) * s[i] * s[j] - hy[i] * s[j] - s[i] * hy[j]);
320                }
321            }
322
323            // Positive-definite check: if any diagonal becomes negative, reset to identity
324            let has_negative_diag = (0..ndim).any(|i| h_inv[i][i] <= 0.0);
325            if has_negative_diag {
326                for i in 0..ndim {
327                    for j in 0..ndim {
328                        h_inv[i][j] = if i == j { 1.0 } else { 0.0 };
329                    }
330                }
331            }
332        }
333
334        positions = new_positions;
335        energy = new_energy;
336        forces = new_forces;
337        grad = new_grad;
338        energy_history.push(energy);
339
340        if energy_history.len() > 1 {
341            let de = (energy_history[energy_history.len() - 2] - energy).abs();
342            if de < config.energy_tol {
343                converged = true;
344                break;
345            }
346        }
347    }
348
349    let max_force = max_force_magnitude(&forces);
350
351    Ok(FrameworkOptResult {
352        positions,
353        energy,
354        forces,
355        max_force,
356        n_iterations: n_iter,
357        converged,
358        energy_history,
359    })
360}
361
362fn zero_fixed_forces(forces: &mut [[f64; 3]], fixed: &[usize]) {
363    for &idx in fixed {
364        if idx < forces.len() {
365            forces[idx] = [0.0, 0.0, 0.0];
366        }
367    }
368}
369
370fn max_force_magnitude(forces: &[[f64; 3]]) -> f64 {
371    forces
372        .iter()
373        .map(|f| (f[0] * f[0] + f[1] * f[1] + f[2] * f[2]).sqrt())
374        .fold(0.0f64, f64::max)
375}
376
377fn flatten_neg_forces(forces: &[[f64; 3]]) -> Vec<f64> {
378    let mut g = Vec::with_capacity(forces.len() * 3);
379    for f in forces {
380        g.push(-f[0]);
381        g.push(-f[1]);
382        g.push(-f[2]);
383    }
384    g
385}
386
387/// Apply periodic boundary conditions: wrap Cartesian coordinates back into the unit cell.
388fn apply_pbc(positions: &mut [[f64; 3]], lattice: &[[f64; 3]; 3]) {
389    // Compute inverse lattice matrix
390    let inv = invert_3x3_lattice(lattice);
391
392    for pos in positions.iter_mut() {
393        // Convert to fractional
394        let frac = [
395            inv[0][0] * pos[0] + inv[0][1] * pos[1] + inv[0][2] * pos[2],
396            inv[1][0] * pos[0] + inv[1][1] * pos[1] + inv[1][2] * pos[2],
397            inv[2][0] * pos[0] + inv[2][1] * pos[1] + inv[2][2] * pos[2],
398        ];
399
400        // Wrap to [0, 1)
401        let wrapped = [
402            frac[0] - frac[0].floor(),
403            frac[1] - frac[1].floor(),
404            frac[2] - frac[2].floor(),
405        ];
406
407        // Convert back to Cartesian
408        pos[0] =
409            lattice[0][0] * wrapped[0] + lattice[1][0] * wrapped[1] + lattice[2][0] * wrapped[2];
410        pos[1] =
411            lattice[0][1] * wrapped[0] + lattice[1][1] * wrapped[1] + lattice[2][1] * wrapped[2];
412        pos[2] =
413            lattice[0][2] * wrapped[0] + lattice[1][2] * wrapped[1] + lattice[2][2] * wrapped[2];
414    }
415}
416
417fn invert_3x3_lattice(m: &[[f64; 3]; 3]) -> [[f64; 3]; 3] {
418    let det = m[0][0] * (m[1][1] * m[2][2] - m[1][2] * m[2][1])
419        - m[0][1] * (m[1][0] * m[2][2] - m[1][2] * m[2][0])
420        + m[0][2] * (m[1][0] * m[2][1] - m[1][1] * m[2][0]);
421
422    if det.abs() < 1e-30 {
423        return [[0.0; 3]; 3];
424    }
425
426    let inv_det = 1.0 / det;
427    [
428        [
429            (m[1][1] * m[2][2] - m[1][2] * m[2][1]) * inv_det,
430            (m[0][2] * m[2][1] - m[0][1] * m[2][2]) * inv_det,
431            (m[0][1] * m[1][2] - m[0][2] * m[1][1]) * inv_det,
432        ],
433        [
434            (m[1][2] * m[2][0] - m[1][0] * m[2][2]) * inv_det,
435            (m[0][0] * m[2][2] - m[0][2] * m[2][0]) * inv_det,
436            (m[0][2] * m[1][0] - m[0][0] * m[1][2]) * inv_det,
437        ],
438        [
439            (m[1][0] * m[2][1] - m[1][1] * m[2][0]) * inv_det,
440            (m[0][1] * m[2][0] - m[0][0] * m[2][1]) * inv_det,
441            (m[0][0] * m[1][1] - m[0][1] * m[1][0]) * inv_det,
442        ],
443    ]
444}
445
446/// Convert fractional coordinates to Cartesian given lattice vectors.
447pub fn frac_to_cart(frac: &[f64; 3], lattice: &[[f64; 3]; 3]) -> [f64; 3] {
448    [
449        lattice[0][0] * frac[0] + lattice[1][0] * frac[1] + lattice[2][0] * frac[2],
450        lattice[0][1] * frac[0] + lattice[1][1] * frac[1] + lattice[2][1] * frac[2],
451        lattice[0][2] * frac[0] + lattice[1][2] * frac[1] + lattice[2][2] * frac[2],
452    ]
453}
454
455/// Convert Cartesian coordinates to fractional given lattice vectors.
456pub fn cart_to_frac(cart: &[f64; 3], lattice: &[[f64; 3]; 3]) -> [f64; 3] {
457    let inv = invert_3x3_lattice(lattice);
458    [
459        inv[0][0] * cart[0] + inv[0][1] * cart[1] + inv[0][2] * cart[2],
460        inv[1][0] * cart[0] + inv[1][1] * cart[1] + inv[1][2] * cart[2],
461        inv[2][0] * cart[0] + inv[2][1] * cart[1] + inv[2][2] * cart[2],
462    ]
463}
464
465#[cfg(test)]
466mod tests {
467    use super::*;
468
469    fn simple_harmonic_energy(
470        _elements: &[u8],
471        positions: &[[f64; 3]],
472    ) -> Result<(f64, Vec<[f64; 3]>), String> {
473        // Simple harmonic well centered at origin for each atom
474        let mut energy = 0.0;
475        let mut forces = Vec::with_capacity(positions.len());
476        for pos in positions {
477            let r2 = pos[0] * pos[0] + pos[1] * pos[1] + pos[2] * pos[2];
478            energy += 0.5 * r2;
479            forces.push([-pos[0], -pos[1], -pos[2]]); // F = -grad(E)
480        }
481        Ok((energy, forces))
482    }
483
484    #[test]
485    fn test_steepest_descent() {
486        let elements = vec![6u8];
487        let initial = vec![[1.0, 0.5, 0.2]];
488        let config = FrameworkOptConfig {
489            method: OptMethod::SteepestDescent,
490            max_iter: 100,
491            force_tol: 0.01,
492            ..Default::default()
493        };
494
495        let result =
496            optimize_framework(&elements, &initial, &simple_harmonic_energy, &config).unwrap();
497        assert!(result.converged);
498        assert!(result.positions[0][0].abs() < 0.1);
499        assert!(result.positions[0][1].abs() < 0.1);
500    }
501
502    #[test]
503    fn test_bfgs() {
504        let elements = vec![6u8];
505        let initial = vec![[1.0, 0.5, 0.2]];
506        let config = FrameworkOptConfig {
507            method: OptMethod::Bfgs,
508            max_iter: 50,
509            force_tol: 0.01,
510            ..Default::default()
511        };
512
513        let result =
514            optimize_framework(&elements, &initial, &simple_harmonic_energy, &config).unwrap();
515        assert!(result.converged);
516        // BFGS should converge faster
517        assert!(result.n_iterations < 20);
518    }
519
520    #[test]
521    fn test_fixed_atoms() {
522        let elements = vec![6u8, 8u8];
523        let initial = vec![[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]];
524        let config = FrameworkOptConfig {
525            method: OptMethod::Bfgs,
526            max_iter: 50,
527            force_tol: 0.01,
528            fixed_atoms: vec![0], // Fix first atom
529            ..Default::default()
530        };
531
532        let result =
533            optimize_framework(&elements, &initial, &simple_harmonic_energy, &config).unwrap();
534        // First atom should remain at (1,0,0)
535        assert!((result.positions[0][0] - 1.0).abs() < 1e-10);
536        // Second atom should move toward origin
537        assert!(result.positions[1][1].abs() < 0.2);
538    }
539
540    #[test]
541    fn test_frac_cart_conversion() {
542        let lattice = [[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]];
543        let frac = [0.5, 0.25, 0.1];
544        let cart = frac_to_cart(&frac, &lattice);
545        assert!((cart[0] - 5.0).abs() < 1e-10);
546        assert!((cart[1] - 2.5).abs() < 1e-10);
547        assert!((cart[2] - 1.0).abs() < 1e-10);
548
549        let back = cart_to_frac(&cart, &lattice);
550        assert!((back[0] - 0.5).abs() < 1e-10);
551    }
552}