tiny_solver/
problem.rs

1use std::collections::{HashMap, HashSet};
2use std::sync::{Arc, Mutex};
3
4use faer::sparse::{Argsort, Pair, SparseColMat, SymbolicSparseColMat};
5use faer_ext::IntoFaer;
6use nalgebra as na;
7use rayon::prelude::*;
8
9use crate::manifold::Manifold;
10use crate::parameter_block::ParameterBlock;
11use crate::{factors, loss_functions, residual_block};
12
13type ResidualBlockId = usize;
14
15pub struct Problem {
16    pub total_residual_dimension: usize,
17    residual_id_count: usize,
18    residual_blocks: HashMap<ResidualBlockId, residual_block::ResidualBlock>,
19    pub fixed_variable_indexes: HashMap<String, HashSet<usize>>,
20    pub variable_bounds: HashMap<String, HashMap<usize, (f64, f64)>>,
21    pub variable_manifold: HashMap<String, Arc<dyn Manifold + Sync + Send>>,
22}
23impl Default for Problem {
24    fn default() -> Self {
25        Self::new()
26    }
27}
28
29pub struct SymbolicStructure {
30    pattern: SymbolicSparseColMat<usize>,
31    order: Argsort<usize>,
32}
33
34type JacobianValue = f64;
35
36impl Problem {
37    pub fn new() -> Problem {
38        Problem {
39            total_residual_dimension: 0,
40            residual_id_count: 0,
41            residual_blocks: HashMap::new(),
42            fixed_variable_indexes: HashMap::new(),
43            variable_bounds: HashMap::new(),
44            variable_manifold: HashMap::new(),
45        }
46    }
47
48    pub fn build_symbolic_structure(
49        &self,
50        parameter_blocks: &HashMap<String, ParameterBlock>,
51        total_variable_dimension: usize,
52        variable_name_to_col_idx_dict: &HashMap<String, usize>,
53    ) -> SymbolicStructure {
54        let mut indices = Vec::<Pair<usize, usize>>::new();
55
56        self.residual_blocks.iter().for_each(|(_, residual_block)| {
57            let mut variable_local_idx_size_list = Vec::<(usize, usize)>::new();
58            let mut count_variable_local_idx: usize = 0;
59            for var_key in &residual_block.variable_key_list {
60                if let Some(param) = parameter_blocks.get(var_key) {
61                    variable_local_idx_size_list
62                        .push((count_variable_local_idx, param.tangent_size()));
63                    count_variable_local_idx += param.tangent_size();
64                };
65            }
66            for (i, var_key) in residual_block.variable_key_list.iter().enumerate() {
67                if let Some(variable_global_idx) = variable_name_to_col_idx_dict.get(var_key) {
68                    let (_, var_size) = variable_local_idx_size_list[i];
69                    for row_idx in 0..residual_block.dim_residual {
70                        for col_idx in 0..var_size {
71                            let global_row_idx = residual_block.residual_row_start_idx + row_idx;
72                            let global_col_idx = variable_global_idx + col_idx;
73                            indices.push(Pair::new(global_row_idx, global_col_idx));
74                        }
75                    }
76                }
77            }
78        });
79        let start = std::time::Instant::now();
80        let (s, o) = SymbolicSparseColMat::try_new_from_indices(
81            self.total_residual_dimension,
82            total_variable_dimension,
83            &indices,
84        )
85        .unwrap();
86        log::trace!("Built symbolic matrix: {:?}", start.elapsed());
87        SymbolicStructure {
88            pattern: s,
89            order: o,
90        }
91    }
92
93    pub fn get_variable_name_to_col_idx_dict(
94        &self,
95        parameter_blocks: &HashMap<String, ParameterBlock>,
96    ) -> HashMap<String, usize> {
97        let mut count_col_idx = 0;
98        let mut variable_name_to_col_idx_dict = HashMap::new();
99        parameter_blocks
100            .iter()
101            .for_each(|(param_name, param_block)| {
102                variable_name_to_col_idx_dict.insert(param_name.to_owned(), count_col_idx);
103                count_col_idx += param_block.tangent_size();
104            });
105        variable_name_to_col_idx_dict
106    }
107    pub fn add_residual_block(
108        &mut self,
109        dim_residual: usize,
110        variable_key_size_list: &[&str],
111        factor: Box<dyn factors::FactorImpl + Send>,
112        loss_func: Option<Box<dyn loss_functions::Loss + Send>>,
113    ) -> ResidualBlockId {
114        self.residual_blocks.insert(
115            self.residual_id_count,
116            residual_block::ResidualBlock::new(
117                self.residual_id_count,
118                dim_residual,
119                self.total_residual_dimension,
120                variable_key_size_list,
121                factor,
122                loss_func,
123            ),
124        );
125        let block_id = self.residual_id_count;
126        self.residual_id_count += 1;
127
128        self.total_residual_dimension += dim_residual;
129
130        block_id
131    }
132    pub fn remove_residual_block(
133        &mut self,
134        block_id: ResidualBlockId,
135    ) -> Option<residual_block::ResidualBlock> {
136        if let Some(residual_block) = self.residual_blocks.remove(&block_id) {
137            self.total_residual_dimension -= residual_block.dim_residual;
138            Some(residual_block)
139        } else {
140            None
141        }
142    }
143    pub fn fix_variable(&mut self, var_to_fix: &str, idx: usize) {
144        if let Some(var_mut) = self.fixed_variable_indexes.get_mut(var_to_fix) {
145            var_mut.insert(idx);
146        } else {
147            self.fixed_variable_indexes
148                .insert(var_to_fix.to_owned(), HashSet::from([idx]));
149        }
150    }
151    pub fn unfix_variable(&mut self, var_to_unfix: &str) {
152        self.fixed_variable_indexes.remove(var_to_unfix);
153    }
154    pub fn set_variable_bounds(
155        &mut self,
156        var_to_bound: &str,
157        idx: usize,
158        lower_bound: f64,
159        upper_bound: f64,
160    ) {
161        if lower_bound > upper_bound {
162            log::error!("lower bound is larger than upper bound");
163        } else if let Some(var_mut) = self.variable_bounds.get_mut(var_to_bound) {
164            var_mut.insert(idx, (lower_bound, upper_bound));
165        } else {
166            self.variable_bounds.insert(
167                var_to_bound.to_owned(),
168                HashMap::from([(idx, (lower_bound, upper_bound))]),
169            );
170        }
171    }
172    pub fn set_variable_manifold(
173        &mut self,
174        var_name: &str,
175        manifold: Arc<dyn Manifold + Sync + Send>,
176    ) {
177        self.variable_manifold
178            .insert(var_name.to_string(), manifold);
179    }
180    pub fn remove_variable_bounds(&mut self, var_to_unbound: &str) {
181        self.variable_bounds.remove(var_to_unbound);
182    }
183    pub fn initialize_parameter_blocks(
184        &self,
185        initial_values: &HashMap<String, na::DVector<f64>>,
186    ) -> HashMap<String, ParameterBlock> {
187        let parameter_blocks: HashMap<String, ParameterBlock> = initial_values
188            .iter()
189            .map(|(k, v)| {
190                let mut p_block = ParameterBlock::from_vec(v.clone());
191                if let Some(indexes) = self.fixed_variable_indexes.get(k) {
192                    p_block.fixed_variables = indexes.clone();
193                }
194                if let Some(bounds) = self.variable_bounds.get(k) {
195                    p_block.variable_bounds = bounds.clone();
196                }
197                if let Some(manifold) = self.variable_manifold.get(k) {
198                    p_block.manifold = Some(manifold.clone())
199                }
200
201                (k.to_owned(), p_block)
202            })
203            .collect();
204        parameter_blocks
205    }
206
207    pub fn compute_residuals(
208        &self,
209        parameter_blocks: &HashMap<String, ParameterBlock>,
210        with_loss_fn: bool,
211    ) -> faer::Mat<f64> {
212        let total_residual = Arc::new(Mutex::new(na::DVector::<f64>::zeros(
213            self.total_residual_dimension,
214        )));
215        self.residual_blocks
216            .par_iter()
217            .for_each(|(_, residual_block)| {
218                self.compute_residual_impl(
219                    residual_block,
220                    parameter_blocks,
221                    &total_residual,
222                    with_loss_fn,
223                )
224            });
225        let total_residual = Arc::try_unwrap(total_residual)
226            .unwrap()
227            .into_inner()
228            .unwrap();
229
230        total_residual.view_range(.., ..).into_faer().to_owned()
231    }
232
233    pub fn compute_residual_and_jacobian(
234        &self,
235        parameter_blocks: &HashMap<String, ParameterBlock>,
236        variable_name_to_col_idx_dict: &HashMap<String, usize>,
237        symbolic_structure: &SymbolicStructure,
238    ) -> (faer::Mat<f64>, SparseColMat<usize, f64>) {
239        // multi
240        let total_residual = Arc::new(Mutex::new(na::DVector::<f64>::zeros(
241            self.total_residual_dimension,
242        )));
243
244        let jacobian_lists: Vec<JacobianValue> = self
245            .residual_blocks
246            .par_iter()
247            .map(|(_, residual_block)| {
248                self.compute_residual_and_jacobian_impl(
249                    residual_block,
250                    parameter_blocks,
251                    variable_name_to_col_idx_dict,
252                    &total_residual,
253                )
254            })
255            .flatten()
256            .collect();
257
258        let total_residual = Arc::try_unwrap(total_residual)
259            .unwrap()
260            .into_inner()
261            .unwrap();
262
263        let residual_faer = total_residual.view_range(.., ..).into_faer().to_owned();
264        let jacobian_faer = SparseColMat::new_from_argsort(
265            symbolic_structure.pattern.clone(),
266            &symbolic_structure.order,
267            jacobian_lists.as_slice(),
268        )
269        .unwrap();
270        (residual_faer, jacobian_faer)
271    }
272
273    fn compute_residual_impl(
274        &self,
275        residual_block: &crate::ResidualBlock,
276        parameter_blocks: &HashMap<String, ParameterBlock>,
277        total_residual: &Arc<Mutex<na::DVector<f64>>>,
278        with_loss_fn: bool,
279    ) {
280        let mut params = Vec::new();
281        for var_key in &residual_block.variable_key_list {
282            if let Some(param) = parameter_blocks.get(var_key) {
283                params.push(param);
284            };
285        }
286        let res = residual_block.residual(&params, with_loss_fn);
287
288        {
289            let mut total_residual = total_residual.lock().unwrap();
290            total_residual
291                .rows_mut(
292                    residual_block.residual_row_start_idx,
293                    residual_block.dim_residual,
294                )
295                .copy_from(&res);
296        }
297    }
298
299    fn compute_residual_and_jacobian_impl(
300        &self,
301        residual_block: &crate::ResidualBlock,
302        parameter_blocks: &HashMap<String, ParameterBlock>,
303        variable_name_to_col_idx_dict: &HashMap<String, usize>,
304        total_residual: &Arc<Mutex<na::DVector<f64>>>,
305    ) -> Vec<JacobianValue> {
306        let mut params = Vec::new();
307        let mut variable_local_idx_size_list = Vec::<(usize, usize)>::new();
308        let mut count_variable_local_idx: usize = 0;
309        for var_key in &residual_block.variable_key_list {
310            if let Some(param) = parameter_blocks.get(var_key) {
311                params.push(param);
312                variable_local_idx_size_list.push((count_variable_local_idx, param.tangent_size()));
313                count_variable_local_idx += param.tangent_size();
314            };
315        }
316        let (res, jac) = residual_block.residual_and_jacobian(&params);
317        {
318            let mut total_residual = total_residual.lock().unwrap();
319            total_residual
320                .rows_mut(
321                    residual_block.residual_row_start_idx,
322                    residual_block.dim_residual,
323                )
324                .copy_from(&res);
325        }
326
327        let mut local_jacobian_list = Vec::new();
328
329        for (i, var_key) in residual_block.variable_key_list.iter().enumerate() {
330            if variable_name_to_col_idx_dict.contains_key(var_key) {
331                let (variable_local_idx, var_size) = variable_local_idx_size_list[i];
332                let variable_jac = jac.view((0, variable_local_idx), (jac.shape().0, var_size));
333                for row_idx in 0..jac.shape().0 {
334                    for col_idx in 0..var_size {
335                        local_jacobian_list.push(variable_jac[(row_idx, col_idx)]);
336                    }
337                }
338            } else {
339                panic!(
340                    "Missing key {} in variable-to-column-index mapping",
341                    var_key
342                );
343            }
344        }
345
346        local_jacobian_list
347    }
348}