tiny_solver/
residual_block.rs

1use nalgebra as na;
2use rayon::prelude::*;
3
4use crate::corrector::Corrector;
5use crate::factors::FactorImpl;
6use crate::loss_functions::Loss;
7use crate::parameter_block::ParameterBlock;
8
9pub struct ResidualBlock {
10    pub residual_block_id: usize,
11    pub dim_residual: usize,
12    pub residual_row_start_idx: usize,
13    pub variable_key_list: Vec<String>,
14    pub factor: Box<dyn FactorImpl + Send>,
15    pub loss_func: Option<Box<dyn Loss + Send>>,
16}
17impl ResidualBlock {
18    pub fn new(
19        residual_block_id: usize,
20        dim_residual: usize,
21        residual_row_start_idx: usize,
22        variable_key_size_list: &[&str],
23        factor: Box<dyn FactorImpl + Send>,
24        loss_func: Option<Box<dyn Loss + Send>>,
25    ) -> Self {
26        ResidualBlock {
27            residual_block_id,
28            dim_residual,
29            residual_row_start_idx,
30            variable_key_list: variable_key_size_list
31                .iter()
32                .map(|s| s.to_string())
33                .collect(),
34            factor,
35            loss_func,
36        }
37    }
38
39    pub fn residual(&self, params: &[&ParameterBlock], with_loss_fn: bool) -> na::DVector<f64> {
40        let param_vec: Vec<_> = params.iter().map(|p| p.params.clone()).collect();
41        let mut residual = self.factor.residual_func_f64(&param_vec);
42        let squared_norm = residual.norm_squared();
43        if with_loss_fn {
44            if let Some(loss_func) = self.loss_func.as_ref() {
45                let rho = loss_func.evaluate(squared_norm);
46                // let cost = 0.5 * rho[0];
47                let corrector = Corrector::new(squared_norm, &rho);
48                corrector.correct_residuals(&mut residual);
49            }
50        } else {
51            // let cost = 0.5 * squared_norm;
52        }
53        residual
54    }
55    pub fn residual_and_jacobian(
56        &self,
57        params: &[&ParameterBlock],
58    ) -> (na::DVector<f64>, na::DMatrix<f64>) {
59        let variable_rows: Vec<usize> = params.iter().map(|x| x.tangent_size()).collect();
60        let dim_variable = variable_rows.iter().sum::<usize>();
61        let variable_row_idx_vec = get_variable_rows(&variable_rows);
62        let indentity_mat = na::DMatrix::<f64>::identity(dim_variable, dim_variable);
63
64        // ambient size
65        let params_plus_tangent_dual: Vec<na::DVector<num_dual::DualDVec64>> = params
66            .par_iter()
67            .enumerate()
68            .map(|(param_idx, param)| {
69                let zeros_with_dual = na::DVector::from_row_iterator(
70                    param.tangent_size(),
71                    (0..param.tangent_size()).map(|j| {
72                        num_dual::DualDVec64::new(
73                            0.0,
74                            num_dual::Derivative::some(na::DVector::from(
75                                indentity_mat.column(variable_row_idx_vec[param_idx][j]),
76                            )),
77                        )
78                    }),
79                );
80                param.plus_dual(zeros_with_dual.as_view())
81            })
82            .collect();
83
84        // tangent size
85        let residual_with_jacobian = self.factor.residual_func_dual(&params_plus_tangent_dual);
86        let mut residual = residual_with_jacobian.map(|x| x.re);
87        let jacobian = residual_with_jacobian
88            .map(|x| x.eps.unwrap_generic(na::Dyn(dim_variable), na::Const::<1>));
89        let mut jacobian =
90            na::DMatrix::<f64>::from_fn(residual_with_jacobian.nrows(), dim_variable, |r, c| {
91                jacobian[r][c]
92            });
93        let squared_norm = residual.norm_squared();
94        if let Some(loss_func) = self.loss_func.as_ref() {
95            let rho = loss_func.evaluate(squared_norm);
96            // let cost = 0.5 * rho[0];
97            let corrector = Corrector::new(squared_norm, &rho);
98            corrector.correct_jacobian(&residual, &mut jacobian);
99            corrector.correct_residuals(&mut residual);
100        } else {
101            // let cost = 0.5 * squared_norm;
102        }
103        (residual, jacobian)
104    }
105}
106
107fn get_variable_rows(variable_rows: &[usize]) -> Vec<Vec<usize>> {
108    let mut result = Vec::with_capacity(variable_rows.len());
109    let mut current = 0;
110    for &num in variable_rows {
111        let next = current + num;
112        let range = (current..next).collect();
113        result.push(range);
114        current = next;
115    }
116    result
117}