tiny_solver/optimizer/
levenberg_marquardt_optimizer.rs1use log::trace;
2use std::ops::Mul;
3use std::{collections::HashMap, time::Instant};
4
5use faer_ext::IntoNalgebra;
6
7use crate::common::OptimizerOptions;
8use crate::linear;
9use crate::optimizer;
10use crate::parameter_block::ParameterBlock;
11use crate::sparse::LinearSolverType;
12use crate::sparse::SparseLinearSolver;
13
14const DEFAULT_MIN_DIAGONAL: f64 = 1e-6;
15const DEFAULT_MAX_DIAGONAL: f64 = 1e32;
16const DEFAULT_INITIAL_TRUST_REGION_RADIUS: f64 = 1e4;
17
18#[derive(Debug)]
19pub struct LevenbergMarquardtOptimizer {
20 min_diagonal: f64,
21 max_diagonal: f64,
22 initial_trust_region_radius: f64,
23}
24
25impl LevenbergMarquardtOptimizer {
26 pub fn new(min_diagonal: f64, max_diagonal: f64, initial_trust_region_radius: f64) -> Self {
27 Self {
28 min_diagonal,
29 max_diagonal,
30 initial_trust_region_radius,
31 }
32 }
33}
34
35impl Default for LevenbergMarquardtOptimizer {
36 fn default() -> Self {
37 Self {
38 min_diagonal: DEFAULT_MIN_DIAGONAL,
39 max_diagonal: DEFAULT_MAX_DIAGONAL,
40 initial_trust_region_radius: DEFAULT_INITIAL_TRUST_REGION_RADIUS,
41 }
42 }
43}
44
45impl optimizer::Optimizer for LevenbergMarquardtOptimizer {
46 fn optimize(
47 &self,
48 problem: &crate::problem::Problem,
49 initial_values: &std::collections::HashMap<String, nalgebra::DVector<f64>>,
50 optimizer_option: Option<OptimizerOptions>,
51 ) -> Option<HashMap<String, nalgebra::DVector<f64>>> {
52 let mut parameter_blocks: HashMap<String, ParameterBlock> =
53 problem.initialize_parameter_blocks(initial_values);
54
55 let variable_name_to_col_idx_dict =
56 problem.get_variable_name_to_col_idx_dict(¶meter_blocks);
57 let total_variable_dimension = parameter_blocks.values().map(|p| p.tangent_size()).sum();
58
59 let opt_option = optimizer_option.unwrap_or_default();
60 let mut linear_solver: Box<dyn SparseLinearSolver> = match opt_option.linear_solver_type {
61 LinearSolverType::SparseCholesky => Box::new(linear::SparseCholeskySolver::new()),
62 LinearSolverType::SparseQR => Box::new(linear::SparseQRSolver::new()),
63 };
64
65 let mut jacobi_scaling_diagonal: Option<faer::sparse::SparseColMat<usize, f64>> = None;
69
70 let symbolic_structure = problem.build_symbolic_structure(
71 ¶meter_blocks,
72 total_variable_dimension,
73 &variable_name_to_col_idx_dict,
74 );
75
76 let mut u = 1.0 / self.initial_trust_region_radius;
78
79 let mut last_err;
80 let mut current_error = self.compute_error(&problem, ¶meter_blocks);
81 for i in 0..opt_option.max_iteration {
82 last_err = current_error;
83
84 let (residuals, mut jac) = problem.compute_residual_and_jacobian(
85 ¶meter_blocks,
86 &variable_name_to_col_idx_dict,
87 &symbolic_structure,
88 );
89
90 if i == 0 {
91 let cols = jac.shape().1;
93 let jacobi_scaling_vec: Vec<(usize, usize, f64)> = (0..cols)
94 .map(|c| {
95 let v = jac
96 .values_of_col(c)
97 .iter()
98 .map(|&i| i * i)
99 .sum::<f64>()
100 .sqrt();
101 (c, c, 1.0 / (1.0 + v))
102 })
103 .collect();
104
105 jacobi_scaling_diagonal = Some(
106 faer::sparse::SparseColMat::<usize, f64>::try_new_from_triplets(
107 cols,
108 cols,
109 &jacobi_scaling_vec,
110 )
111 .unwrap(),
112 );
113 }
114
115 jac = jac * jacobi_scaling_diagonal.as_ref().unwrap();
117
118 let jtj = jac
120 .as_ref()
121 .transpose()
122 .to_col_major()
123 .unwrap()
124 .mul(jac.as_ref());
125
126 let jtr = jac.as_ref().transpose().mul(-&residuals);
128
129 let mut jtj_regularized = jtj.clone();
131 for i in 0..total_variable_dimension {
132 jtj_regularized[(i, i)] +=
133 u * (jtj[(i, i)].max(self.min_diagonal)).min(self.max_diagonal);
134 }
135
136 let start = Instant::now();
137 if let Some(lm_step) = linear_solver.solve_jtj(&jtr, &jtj_regularized) {
138 let duration = start.elapsed();
139 let dx = jacobi_scaling_diagonal.as_ref().unwrap() * &lm_step;
140
141 trace!("Time elapsed in solve Ax=b is: {:?}", duration);
142
143 let dx_na = dx.as_ref().into_nalgebra().column(0).clone_owned();
144
145 let mut new_param_blocks = parameter_blocks.clone();
146
147 self.apply_dx2(
148 &dx_na,
149 &mut new_param_blocks,
150 &variable_name_to_col_idx_dict,
151 );
152
153 let new_residuals = problem.compute_residuals(&new_param_blocks, true);
155
156 let actual_residual_change =
159 residuals.squared_norm_l2() - new_residuals.squared_norm_l2();
160 trace!("actual_residual_change {}", actual_residual_change);
161 let linear_residual_change: faer::Mat<f64> =
162 lm_step.transpose().mul(2.0 * &jtr - &jtj * &lm_step);
163 let rho = actual_residual_change / linear_residual_change[(0, 0)];
164
165 if rho > 0.0 {
166 parameter_blocks = new_param_blocks;
168
169 let tmp = 2.0 * rho - 1.0;
171 u *= (1.0_f64 / 3.0).max(1.0 - tmp * tmp * tmp);
172 } else {
173 u *= 2.0;
175 trace!("u {}", u);
176 }
177 } else {
178 log::debug!("solve ax=b failed");
179 return None;
180 }
181
182 current_error = self.compute_error(problem, ¶meter_blocks);
183 trace!("iter:{} total err:{}", i, current_error);
184
185 if current_error < opt_option.min_error_threshold {
186 trace!("error too low");
187 break;
188 } else if current_error.is_nan() {
189 log::debug!("solve ax=b failed, current error is nan");
190 return None;
191 }
192
193 if (last_err - current_error).abs() < opt_option.min_abs_error_decrease_threshold {
194 trace!("absolute error decrease low");
195 break;
196 } else if (last_err - current_error).abs() / last_err
197 < opt_option.min_rel_error_decrease_threshold
198 {
199 trace!("relative error decrease low");
200 break;
201 }
202 }
203 let params = parameter_blocks
204 .iter()
205 .map(|(k, v)| (k.to_owned(), v.params.clone()))
206 .collect();
207 Some(params)
208 }
209}