tiny_solver/optimizer/
gauss_newton_optimizer.rs

1use log::trace;
2use std::{collections::HashMap, time::Instant};
3
4use faer_ext::IntoNalgebra;
5
6use crate::common::OptimizerOptions;
7use crate::linear;
8use crate::optimizer;
9use crate::parameter_block::ParameterBlock;
10use crate::sparse::LinearSolverType;
11use crate::sparse::SparseLinearSolver;
12
13#[derive(Debug)]
14pub struct GaussNewtonOptimizer {}
15impl GaussNewtonOptimizer {
16    pub fn new() -> Self {
17        Self {}
18    }
19}
20impl Default for GaussNewtonOptimizer {
21    fn default() -> Self {
22        Self::new()
23    }
24}
25
26impl optimizer::Optimizer for GaussNewtonOptimizer {
27    fn optimize(
28        &self,
29        problem: &crate::problem::Problem,
30        initial_values: &std::collections::HashMap<String, nalgebra::DVector<f64>>,
31        optimizer_option: Option<OptimizerOptions>,
32    ) -> Option<HashMap<String, nalgebra::DVector<f64>>> {
33        let mut parameter_blocks: HashMap<String, ParameterBlock> =
34            problem.initialize_parameter_blocks(initial_values);
35
36        let variable_name_to_col_idx_dict =
37            problem.get_variable_name_to_col_idx_dict(&parameter_blocks);
38        let total_variable_dimension = parameter_blocks
39            .values()
40            .map(|p| {
41                if p.manifold.is_some() {
42                    p.tangent_size()
43                } else {
44                    p.tangent_size() - p.fixed_variables.len()
45                }
46            })
47            .sum();
48
49        let opt_option = optimizer_option.unwrap_or_default();
50        let mut linear_solver: Box<dyn SparseLinearSolver> = match opt_option.linear_solver_type {
51            LinearSolverType::SparseCholesky => Box::new(linear::SparseCholeskySolver::new()),
52            LinearSolverType::SparseQR => Box::new(linear::SparseQRSolver::new()),
53        };
54
55        let symbolic_structure = problem.build_symbolic_structure(
56            &parameter_blocks,
57            total_variable_dimension,
58            &variable_name_to_col_idx_dict,
59        );
60
61        let mut last_err;
62        let mut current_error = self.compute_error(problem, &parameter_blocks);
63
64        for i in 0..opt_option.max_iteration {
65            last_err = current_error;
66            let mut start = Instant::now();
67
68            let (residuals, jac) = problem.compute_residual_and_jacobian(
69                &parameter_blocks,
70                &variable_name_to_col_idx_dict,
71                &symbolic_structure,
72            );
73            let residual_and_jacobian_duration = start.elapsed();
74
75            start = Instant::now();
76            let solving_duration;
77            if let Some(dx) = linear_solver.solve(&residuals, &jac) {
78                solving_duration = start.elapsed();
79                let dx_na = dx.as_ref().into_nalgebra().column(0).clone_owned();
80                self.apply_dx2(
81                    &dx_na,
82                    &mut parameter_blocks,
83                    &variable_name_to_col_idx_dict,
84                );
85            } else {
86                log::debug!("solve ax=b failed");
87                return None;
88            }
89
90            current_error = self.compute_error(problem, &parameter_blocks);
91            trace!(
92                "iter:{}, total err:{}, residual + jacobian duration: {:?}, solving duration: {:?}",
93                i, current_error, residual_and_jacobian_duration, solving_duration
94            );
95
96            if current_error < opt_option.min_error_threshold {
97                trace!("error too low");
98                break;
99            } else if current_error.is_nan() {
100                log::debug!("solve ax=b failed, current error is nan");
101                return None;
102            }
103
104            if (last_err - current_error).abs() < opt_option.min_abs_error_decrease_threshold {
105                trace!("absolute error decrease low");
106                break;
107            } else if last_err > 0.0
108                && (last_err - current_error).abs() / last_err
109                    < opt_option.min_rel_error_decrease_threshold
110            {
111                trace!("relative error decrease low");
112                break;
113            }
114        }
115        let params = parameter_blocks
116            .iter()
117            .map(|(k, v)| (k.to_owned(), v.params.clone()))
118            .collect();
119        Some(params)
120    }
121}