tiny_solver/optimizer/
gauss_newton_optimizer.rs

1use log::trace;
2use std::{collections::HashMap, time::Instant};
3
4use faer_ext::IntoNalgebra;
5
6use crate::common::OptimizerOptions;
7use crate::linear;
8use crate::optimizer;
9use crate::parameter_block::ParameterBlock;
10use crate::sparse::LinearSolverType;
11use crate::sparse::SparseLinearSolver;
12
13#[derive(Debug)]
14pub struct GaussNewtonOptimizer {}
15impl GaussNewtonOptimizer {
16    pub fn new() -> Self {
17        Self {}
18    }
19}
20impl Default for GaussNewtonOptimizer {
21    fn default() -> Self {
22        Self::new()
23    }
24}
25
26impl optimizer::Optimizer for GaussNewtonOptimizer {
27    fn optimize(
28        &self,
29        problem: &crate::problem::Problem,
30        initial_values: &std::collections::HashMap<String, nalgebra::DVector<f64>>,
31        optimizer_option: Option<OptimizerOptions>,
32    ) -> Option<HashMap<String, nalgebra::DVector<f64>>> {
33        let mut parameter_blocks: HashMap<String, ParameterBlock> =
34            problem.initialize_parameter_blocks(initial_values);
35
36        let variable_name_to_col_idx_dict =
37            problem.get_variable_name_to_col_idx_dict(&parameter_blocks);
38        let total_variable_dimension = parameter_blocks.values().map(|p| p.tangent_size()).sum();
39
40        let opt_option = optimizer_option.unwrap_or_default();
41        let mut linear_solver: Box<dyn SparseLinearSolver> = match opt_option.linear_solver_type {
42            LinearSolverType::SparseCholesky => Box::new(linear::SparseCholeskySolver::new()),
43            LinearSolverType::SparseQR => Box::new(linear::SparseQRSolver::new()),
44        };
45
46        let symbolic_structure = problem.build_symbolic_structure(
47            &parameter_blocks,
48            total_variable_dimension,
49            &variable_name_to_col_idx_dict,
50        );
51
52        let mut last_err;
53        let mut current_error = self.compute_error(problem, &parameter_blocks);
54
55        for i in 0..opt_option.max_iteration {
56            last_err = current_error;
57            let mut start = Instant::now();
58
59            let (residuals, jac) = problem.compute_residual_and_jacobian(
60                &parameter_blocks,
61                &variable_name_to_col_idx_dict,
62                &symbolic_structure,
63            );
64            let residual_and_jacobian_duration = start.elapsed();
65
66            start = Instant::now();
67            let solving_duration;
68            if let Some(dx) = linear_solver.solve(&residuals, &jac) {
69                solving_duration = start.elapsed();
70                let dx_na = dx.as_ref().into_nalgebra().column(0).clone_owned();
71                self.apply_dx2(
72                    &dx_na,
73                    &mut parameter_blocks,
74                    &variable_name_to_col_idx_dict,
75                );
76            } else {
77                log::debug!("solve ax=b failed");
78                return None;
79            }
80
81            current_error = self.compute_error(problem, &parameter_blocks);
82            trace!(
83                "iter:{}, total err:{}, residual + jacobian duration: {:?}, solving duration: {:?}",
84                i,
85                current_error,
86                residual_and_jacobian_duration,
87                solving_duration
88            );
89
90            if current_error < opt_option.min_error_threshold {
91                trace!("error too low");
92                break;
93            } else if current_error.is_nan() {
94                log::debug!("solve ax=b failed, current error is nan");
95                return None;
96            }
97
98            if (last_err - current_error).abs() < opt_option.min_abs_error_decrease_threshold {
99                trace!("absolute error decrease low");
100                break;
101            } else if (last_err - current_error).abs() / last_err
102                < opt_option.min_rel_error_decrease_threshold
103            {
104                trace!("relative error decrease low");
105                break;
106            }
107        }
108        let params = parameter_blocks
109            .iter()
110            .map(|(k, v)| (k.to_owned(), v.params.clone()))
111            .collect();
112        Some(params)
113    }
114}