tiny_solver/optimizer/
levenberg_marquardt_optimizer.rs

1use log::trace;
2use std::ops::Mul;
3use std::{collections::HashMap, time::Instant};
4
5use faer_ext::IntoNalgebra;
6
7use crate::common::OptimizerOptions;
8use crate::linear;
9use crate::optimizer;
10use crate::parameter_block::ParameterBlock;
11use crate::sparse::LinearSolverType;
12use crate::sparse::SparseLinearSolver;
13
14const DEFAULT_MIN_DIAGONAL: f64 = 1e-6;
15const DEFAULT_MAX_DIAGONAL: f64 = 1e32;
16const DEFAULT_INITIAL_TRUST_REGION_RADIUS: f64 = 1e4;
17
18#[derive(Debug)]
19pub struct LevenbergMarquardtOptimizer {
20    min_diagonal: f64,
21    max_diagonal: f64,
22    initial_trust_region_radius: f64,
23}
24
25impl LevenbergMarquardtOptimizer {
26    pub fn new(min_diagonal: f64, max_diagonal: f64, initial_trust_region_radius: f64) -> Self {
27        Self {
28            min_diagonal,
29            max_diagonal,
30            initial_trust_region_radius,
31        }
32    }
33}
34
35impl Default for LevenbergMarquardtOptimizer {
36    fn default() -> Self {
37        Self {
38            min_diagonal: DEFAULT_MIN_DIAGONAL,
39            max_diagonal: DEFAULT_MAX_DIAGONAL,
40            initial_trust_region_radius: DEFAULT_INITIAL_TRUST_REGION_RADIUS,
41        }
42    }
43}
44
45impl optimizer::Optimizer for LevenbergMarquardtOptimizer {
46    fn optimize(
47        &self,
48        problem: &crate::problem::Problem,
49        initial_values: &std::collections::HashMap<String, nalgebra::DVector<f64>>,
50        optimizer_option: Option<OptimizerOptions>,
51    ) -> Option<HashMap<String, nalgebra::DVector<f64>>> {
52        let mut parameter_blocks: HashMap<String, ParameterBlock> =
53            problem.initialize_parameter_blocks(initial_values);
54
55        let variable_name_to_col_idx_dict =
56            problem.get_variable_name_to_col_idx_dict(&parameter_blocks);
57        let total_variable_dimension = parameter_blocks.values().map(|p| p.tangent_size()).sum();
58
59        let opt_option = optimizer_option.unwrap_or_default();
60        let mut linear_solver: Box<dyn SparseLinearSolver> = match opt_option.linear_solver_type {
61            LinearSolverType::SparseCholesky => Box::new(linear::SparseCholeskySolver::new()),
62            LinearSolverType::SparseQR => Box::new(linear::SparseQRSolver::new()),
63        };
64
65        // On the first iteration, we'll generate a diagonal matrix of the jacobian.
66        // Its shape will be (total_variable_dimension, total_variable_dimension).
67        // With LM, rather than solving A * dx = b for dx, we solve for (A + lambda * diag(A)) dx = b.
68        let mut jacobi_scaling_diagonal: Option<faer::sparse::SparseColMat<usize, f64>> = None;
69
70        let symbolic_structure = problem.build_symbolic_structure(
71            &parameter_blocks,
72            total_variable_dimension,
73            &variable_name_to_col_idx_dict,
74        );
75
76        // Damping parameter (a.k.a lambda / Marquardt parameter)
77        let mut u = 1.0 / self.initial_trust_region_radius;
78
79        let mut last_err;
80        let mut current_error = self.compute_error(&problem, &parameter_blocks);
81        for i in 0..opt_option.max_iteration {
82            last_err = current_error;
83
84            let (residuals, mut jac) = problem.compute_residual_and_jacobian(
85                &parameter_blocks,
86                &variable_name_to_col_idx_dict,
87                &symbolic_structure,
88            );
89
90            if i == 0 {
91                // On the first iteration, generate the diagonal of the jacobian.
92                let cols = jac.shape().1;
93                let jacobi_scaling_vec: Vec<(usize, usize, f64)> = (0..cols)
94                    .map(|c| {
95                        let v = jac
96                            .values_of_col(c)
97                            .iter()
98                            .map(|&i| i * i)
99                            .sum::<f64>()
100                            .sqrt();
101                        (c, c, 1.0 / (1.0 + v))
102                    })
103                    .collect();
104
105                jacobi_scaling_diagonal = Some(
106                    faer::sparse::SparseColMat::<usize, f64>::try_new_from_triplets(
107                        cols,
108                        cols,
109                        &jacobi_scaling_vec,
110                    )
111                    .unwrap(),
112                );
113            }
114
115            // Scale the current jacobian by the diagonal matrix
116            jac = jac * jacobi_scaling_diagonal.as_ref().unwrap();
117
118            // J^T * J = Matrix of shape (total_variable_dimension, total_variable_dimension)
119            let jtj = jac
120                .as_ref()
121                .transpose()
122                .to_col_major()
123                .unwrap()
124                .mul(jac.as_ref());
125
126            // J^T * -r = Matrix of shape (total_variable_dimension, 1)
127            let jtr = jac.as_ref().transpose().mul(-&residuals);
128
129            // Regularize the diagonal of jtj between the min and max diagonal values.
130            let mut jtj_regularized = jtj.clone();
131            for i in 0..total_variable_dimension {
132                jtj_regularized[(i, i)] +=
133                    u * (jtj[(i, i)].max(self.min_diagonal)).min(self.max_diagonal);
134            }
135
136            let start = Instant::now();
137            if let Some(lm_step) = linear_solver.solve_jtj(&jtr, &jtj_regularized) {
138                let duration = start.elapsed();
139                let dx = jacobi_scaling_diagonal.as_ref().unwrap() * &lm_step;
140
141                trace!("Time elapsed in solve Ax=b is: {:?}", duration);
142
143                let dx_na = dx.as_ref().into_nalgebra().column(0).clone_owned();
144
145                let mut new_param_blocks = parameter_blocks.clone();
146
147                self.apply_dx2(
148                    &dx_na,
149                    &mut new_param_blocks,
150                    &variable_name_to_col_idx_dict,
151                );
152
153                // Compute residuals of (x + dx)
154                let new_residuals = problem.compute_residuals(&new_param_blocks, true);
155
156                // rho is the ratio between the actual reduction in error and the reduction
157                // in error if the problem were linear.
158                let actual_residual_change =
159                    residuals.squared_norm_l2() - new_residuals.squared_norm_l2();
160                trace!("actual_residual_change {}", actual_residual_change);
161                let linear_residual_change: faer::Mat<f64> =
162                    lm_step.transpose().mul(2.0 * &jtr - &jtj * &lm_step);
163                let rho = actual_residual_change / linear_residual_change[(0, 0)];
164
165                if rho > 0.0 {
166                    // The linear model appears to be fitting, so accept (x + dx) as the new x.
167                    parameter_blocks = new_param_blocks;
168
169                    // Increase the trust region by reducing u
170                    let tmp = 2.0 * rho - 1.0;
171                    u *= (1.0_f64 / 3.0).max(1.0 - tmp * tmp * tmp);
172                } else {
173                    // If there's too much divergence, reduce the trust region and try again with the same parameters.
174                    u *= 2.0;
175                    trace!("u {}", u);
176                }
177            } else {
178                log::debug!("solve ax=b failed");
179                return None;
180            }
181
182            current_error = self.compute_error(problem, &parameter_blocks);
183            trace!("iter:{} total err:{}", i, current_error);
184
185            if current_error < opt_option.min_error_threshold {
186                trace!("error too low");
187                break;
188            } else if current_error.is_nan() {
189                log::debug!("solve ax=b failed, current error is nan");
190                return None;
191            }
192
193            if (last_err - current_error).abs() < opt_option.min_abs_error_decrease_threshold {
194                trace!("absolute error decrease low");
195                break;
196            } else if (last_err - current_error).abs() / last_err
197                < opt_option.min_rel_error_decrease_threshold
198            {
199                trace!("relative error decrease low");
200                break;
201            }
202        }
203        let params = parameter_blocks
204            .iter()
205            .map(|(k, v)| (k.to_owned(), v.params.clone()))
206            .collect();
207        Some(params)
208    }
209}