Skip to main content

sci_form/optimization/
bfgs.rs

1use nalgebra::{DMatrix, DVector};
2
3pub struct RustBfgsEngine {
4    pub iter_limit_max: usize,
5    pub strict_grad_tolerance: f64,
6    pub backtracking_line_search_limit: usize,
7}
8
9impl Default for RustBfgsEngine {
10    fn default() -> Self {
11        RustBfgsEngine {
12            iter_limit_max: 200,
13            strict_grad_tolerance: 3e-8,
14            backtracking_line_search_limit: 15,
15        }
16    }
17}
18
19impl RustBfgsEngine {
20    pub fn execute_minimization<F>(
21        &self,
22        global_coords: &mut [f64],
23        mut eval_lambda: F,
24    ) -> (f64, bool)
25    where
26        F: FnMut(&[f64], &mut [f64]) -> f64,
27    {
28        let dims = global_coords.len();
29        let mut local_gradient = vec![0.0; dims];
30        let mut current_energy = eval_lambda(global_coords, &mut local_gradient);
31
32        let g_norm = calculate_l2_norm(&local_gradient);
33        if g_norm < self.strict_grad_tolerance {
34            return (current_energy, true);
35        }
36
37        // RDKit Gradient Scaling Hack: 1.0 / sqrt(gradNorm) if gradNorm > 1.0
38        let g_scale = if g_norm > 1.0 {
39            1.0 / g_norm.sqrt()
40        } else {
41            1.0
42        };
43        for val in local_gradient.iter_mut() {
44            *val *= g_scale;
45        }
46
47        let mut hessian_inv_approx = DMatrix::<f64>::identity(dims, dims);
48
49        let mut state_vector_x = DVector::from_row_slice(global_coords);
50        let mut state_gradient_g = DVector::from_row_slice(&local_gradient);
51
52        for _iter_counter in 0..self.iter_limit_max {
53            let direction_p = -&hessian_inv_approx * &state_gradient_g;
54
55            let mut step_size_alpha = 1.0;
56            let armijo_constant_c1 = 1e-4;
57            let slope_derivative = state_gradient_g.dot(&direction_p);
58
59            if slope_derivative > 0.0 {
60                hessian_inv_approx = DMatrix::<f64>::identity(dims, dims);
61                continue;
62            }
63
64            let mut iter_x_next = state_vector_x.clone();
65            let mut iter_g_next = DVector::zeros(dims);
66            let mut next_hypothetical_energy = 0.0;
67
68            let mut success = false;
69            for _ in 0..self.backtracking_line_search_limit {
70                iter_x_next = &state_vector_x + step_size_alpha * &direction_p;
71                let mut tmp_grad = vec![0.0; dims];
72                next_hypothetical_energy = eval_lambda(iter_x_next.as_slice(), &mut tmp_grad);
73                iter_g_next = DVector::from_row_slice(&tmp_grad);
74
75                if next_hypothetical_energy
76                    <= current_energy + armijo_constant_c1 * step_size_alpha * slope_derivative
77                {
78                    success = true;
79                    break;
80                }
81                step_size_alpha *= 0.5;
82            }
83
84            if !success {
85                // If line search fails, we might be in a bad region, reset Hessian
86                hessian_inv_approx = DMatrix::<f64>::identity(dims, dims);
87                // Try one more time with SD direction if needed, but for now just break if it keeps failing
88            }
89
90            global_coords.copy_from_slice(iter_x_next.as_slice());
91
92            let residual_grad_norm = iter_g_next.norm();
93            if residual_grad_norm < self.strict_grad_tolerance {
94                return (next_hypothetical_energy, true);
95            }
96
97            let distance_diff_s = &iter_x_next - &state_vector_x;
98            let gradient_diff_y = &iter_g_next - &state_gradient_g;
99            let curvature_scalar_rho_inv = gradient_diff_y.dot(&distance_diff_s);
100
101            if curvature_scalar_rho_inv > 1e-10 {
102                let rho = 1.0 / curvature_scalar_rho_inv;
103                let id_mat = DMatrix::<f64>::identity(dims, dims);
104                let transform_1 =
105                    id_mat.clone() - rho * (&distance_diff_s * gradient_diff_y.transpose());
106                let transform_2 = id_mat - rho * (&gradient_diff_y * distance_diff_s.transpose());
107
108                hessian_inv_approx = transform_1 * &hessian_inv_approx * transform_2
109                    + rho * (&distance_diff_s * distance_diff_s.transpose());
110            }
111
112            state_vector_x = iter_x_next;
113            state_gradient_g = iter_g_next;
114            current_energy = next_hypothetical_energy;
115        }
116
117        (current_energy, false)
118    }
119}
120
121fn calculate_l2_norm(vector_space: &[f64]) -> f64 {
122    vector_space
123        .iter()
124        .map(|scalar| scalar * scalar)
125        .sum::<f64>()
126        .sqrt()
127}