Skip to main content

sci_form/optimization/
bfgs.rs

1//! BFGS quasi-Newton minimizer with backtracking Armijo line search.
2//!
3//! The optimizer maintains an inverse-Hessian approximation (dense, full-rank)
4//! and applies gradient scaling borrowed from RDKit's embedding pipeline when
5//! the initial gradient norm exceeds 1.0.
6
7use nalgebra::{DMatrix, DVector};
8
9/// BFGS minimizer with configurable iteration limit, gradient tolerance,
10/// and line-search depth.
11///
12/// Returns `(final_energy, converged)` from [`execute_minimization`](Self::execute_minimization).
13pub struct RustBfgsEngine {
14    pub iter_limit_max: usize,
15    pub strict_grad_tolerance: f64,
16    pub backtracking_line_search_limit: usize,
17}
18
19impl Default for RustBfgsEngine {
20    fn default() -> Self {
21        RustBfgsEngine {
22            iter_limit_max: 200,
23            strict_grad_tolerance: 3e-8,
24            backtracking_line_search_limit: 15,
25        }
26    }
27}
28
29impl RustBfgsEngine {
30    pub fn execute_minimization<F>(
31        &self,
32        global_coords: &mut [f64],
33        mut eval_lambda: F,
34    ) -> (f64, bool)
35    where
36        F: FnMut(&[f64], &mut [f64]) -> f64,
37    {
38        let dims = global_coords.len();
39        let mut local_gradient = vec![0.0; dims];
40        let mut current_energy = eval_lambda(global_coords, &mut local_gradient);
41
42        let g_norm = calculate_l2_norm(&local_gradient);
43        if g_norm < self.strict_grad_tolerance {
44            return (current_energy, true);
45        }
46
47        // RDKit Gradient Scaling Hack: 1.0 / sqrt(gradNorm) if gradNorm > 1.0
48        let g_scale = if g_norm > 1.0 {
49            1.0 / g_norm.sqrt()
50        } else {
51            1.0
52        };
53        for val in local_gradient.iter_mut() {
54            *val *= g_scale;
55        }
56
57        let mut hessian_inv_approx = DMatrix::<f64>::identity(dims, dims);
58
59        let mut state_vector_x = DVector::from_row_slice(global_coords);
60        let mut state_gradient_g = DVector::from_row_slice(&local_gradient);
61
62        for _iter_counter in 0..self.iter_limit_max {
63            let mut direction_p = -&hessian_inv_approx * &state_gradient_g;
64
65            // Clamp step magnitude to prevent unphysical displacements.
66            // Max step = 0.3 Bohr ≈ 0.16 Å per coordinate.
67            let step_norm = direction_p.norm();
68            const MAX_STEP: f64 = 0.3;
69            if step_norm > MAX_STEP {
70                direction_p *= MAX_STEP / step_norm;
71            }
72
73            let mut step_size_alpha = 1.0;
74            let armijo_constant_c1 = 1e-4;
75            let slope_derivative = state_gradient_g.dot(&direction_p);
76
77            if slope_derivative > 0.0 {
78                hessian_inv_approx = DMatrix::<f64>::identity(dims, dims);
79                continue;
80            }
81
82            let mut iter_x_next = state_vector_x.clone();
83            let mut iter_g_next = DVector::zeros(dims);
84            let mut next_hypothetical_energy = 0.0;
85
86            let mut success = false;
87            for _ in 0..self.backtracking_line_search_limit {
88                iter_x_next = &state_vector_x + step_size_alpha * &direction_p;
89                let mut tmp_grad = vec![0.0; dims];
90                next_hypothetical_energy = eval_lambda(iter_x_next.as_slice(), &mut tmp_grad);
91                iter_g_next = DVector::from_row_slice(&tmp_grad);
92
93                if next_hypothetical_energy
94                    <= current_energy + armijo_constant_c1 * step_size_alpha * slope_derivative
95                {
96                    success = true;
97                    break;
98                }
99                step_size_alpha *= 0.5;
100            }
101
102            if !success {
103                // If line search fails, we might be in a bad region, reset Hessian
104                hessian_inv_approx = DMatrix::<f64>::identity(dims, dims);
105                // Try one more time with SD direction if needed, but for now just break if it keeps failing
106            }
107
108            global_coords.copy_from_slice(iter_x_next.as_slice());
109
110            let residual_grad_norm = iter_g_next.norm();
111            if residual_grad_norm < self.strict_grad_tolerance {
112                return (next_hypothetical_energy, true);
113            }
114
115            let distance_diff_s = &iter_x_next - &state_vector_x;
116            let gradient_diff_y = &iter_g_next - &state_gradient_g;
117            let curvature_scalar_rho_inv = gradient_diff_y.dot(&distance_diff_s);
118
119            if curvature_scalar_rho_inv > 1e-10 {
120                let rho = 1.0 / curvature_scalar_rho_inv;
121                let id_mat = DMatrix::<f64>::identity(dims, dims);
122                let transform_1 =
123                    id_mat.clone() - rho * (&distance_diff_s * gradient_diff_y.transpose());
124                let transform_2 = id_mat - rho * (&gradient_diff_y * distance_diff_s.transpose());
125
126                hessian_inv_approx = transform_1 * &hessian_inv_approx * transform_2
127                    + rho * (&distance_diff_s * distance_diff_s.transpose());
128            }
129
130            state_vector_x = iter_x_next;
131            state_gradient_g = iter_g_next;
132            current_energy = next_hypothetical_energy;
133        }
134
135        (current_energy, false)
136    }
137}
138
139fn calculate_l2_norm(vector_space: &[f64]) -> f64 {
140    vector_space
141        .iter()
142        .map(|scalar| scalar * scalar)
143        .sum::<f64>()
144        .sqrt()
145}