tiny_solver/
problem.rs

1use std::collections::{HashMap, HashSet};
2use std::sync::{Arc, Mutex};
3
4use faer::sparse::{Argsort, Pair, SparseColMat, SymbolicSparseColMat};
5use faer_ext::IntoFaer;
6use nalgebra as na;
7use rayon::prelude::*;
8
9use crate::manifold::Manifold;
10use crate::parameter_block::ParameterBlock;
11use crate::{factors, loss_functions, residual_block};
12
13type ResidualBlockId = usize;
14
15pub struct Problem {
16    pub total_residual_dimension: usize,
17    residual_id_count: usize,
18    residual_blocks: HashMap<ResidualBlockId, residual_block::ResidualBlock>,
19    pub fixed_variable_indexes: HashMap<String, HashSet<usize>>,
20    pub variable_bounds: HashMap<String, HashMap<usize, (f64, f64)>>,
21    pub variable_manifold: HashMap<String, Arc<dyn Manifold + Sync + Send>>,
22}
23impl Default for Problem {
24    fn default() -> Self {
25        Self::new()
26    }
27}
28
29pub struct SymbolicStructure {
30    pattern: SymbolicSparseColMat<usize>,
31    order: Argsort<usize>,
32}
33
34type JacobianValue = f64;
35
36impl Problem {
37    pub fn new() -> Problem {
38        Problem {
39            total_residual_dimension: 0,
40            residual_id_count: 0,
41            residual_blocks: HashMap::new(),
42            fixed_variable_indexes: HashMap::new(),
43            variable_bounds: HashMap::new(),
44            variable_manifold: HashMap::new(),
45        }
46    }
47
48    pub fn build_symbolic_structure(
49        &self,
50        parameter_blocks: &HashMap<String, ParameterBlock>,
51        total_variable_dimension: usize,
52        variable_name_to_col_idx_dict: &HashMap<String, usize>,
53    ) -> SymbolicStructure {
54        let mut indices = Vec::<Pair<usize, usize>>::new();
55
56        self.residual_blocks.iter().for_each(|(_, residual_block)| {
57            let mut variable_local_idx_size_list = Vec::<(usize, usize)>::new();
58            let mut count_variable_local_idx: usize = 0;
59            for var_key in &residual_block.variable_key_list {
60                if let Some(param) = parameter_blocks.get(var_key) {
61                    variable_local_idx_size_list
62                        .push((count_variable_local_idx, param.tangent_size()));
63                    count_variable_local_idx += param.tangent_size();
64                };
65            }
66            for (i, var_key) in residual_block.variable_key_list.iter().enumerate() {
67                if let Some(variable_global_idx) = variable_name_to_col_idx_dict.get(var_key) {
68                    let (_, var_size) = variable_local_idx_size_list[i];
69                    for row_idx in 0..residual_block.dim_residual {
70                        let mut current_var_col_offset = 0;
71                        for col_idx in 0..var_size {
72                            if let Some(param) = parameter_blocks.get(var_key)
73                                && param.manifold.is_none()
74                                && param.fixed_variables.contains(&col_idx)
75                            {
76                                continue;
77                            }
78                            let global_row_idx = residual_block.residual_row_start_idx + row_idx;
79                            let global_col_idx = variable_global_idx + current_var_col_offset;
80                            indices.push(Pair::new(global_row_idx, global_col_idx));
81                            current_var_col_offset += 1;
82                        }
83                    }
84                }
85            }
86        });
87        let start = std::time::Instant::now();
88        let (s, o) = SymbolicSparseColMat::try_new_from_indices(
89            self.total_residual_dimension,
90            total_variable_dimension,
91            &indices,
92        )
93        .unwrap();
94        log::trace!("Built symbolic matrix: {:?}", start.elapsed());
95        SymbolicStructure {
96            pattern: s,
97            order: o,
98        }
99    }
100
101    pub fn get_variable_name_to_col_idx_dict(
102        &self,
103        parameter_blocks: &HashMap<String, ParameterBlock>,
104    ) -> HashMap<String, usize> {
105        let mut count_col_idx = 0;
106        let mut variable_name_to_col_idx_dict = HashMap::new();
107        parameter_blocks
108            .iter()
109            .for_each(|(param_name, param_block)| {
110                variable_name_to_col_idx_dict.insert(param_name.to_owned(), count_col_idx);
111                let effective_size = if param_block.manifold.is_some() {
112                    param_block.tangent_size()
113                } else {
114                    param_block.tangent_size() - param_block.fixed_variables.len()
115                };
116                count_col_idx += effective_size;
117            });
118        variable_name_to_col_idx_dict
119    }
120    pub fn add_residual_block(
121        &mut self,
122        dim_residual: usize,
123        variable_key_size_list: &[&str],
124        factor: Box<dyn factors::FactorImpl + Send>,
125        loss_func: Option<Box<dyn loss_functions::Loss + Send>>,
126    ) -> ResidualBlockId {
127        self.residual_blocks.insert(
128            self.residual_id_count,
129            residual_block::ResidualBlock::new(
130                self.residual_id_count,
131                dim_residual,
132                self.total_residual_dimension,
133                variable_key_size_list,
134                factor,
135                loss_func,
136            ),
137        );
138        let block_id = self.residual_id_count;
139        self.residual_id_count += 1;
140
141        self.total_residual_dimension += dim_residual;
142
143        block_id
144    }
145    pub fn remove_residual_block(
146        &mut self,
147        block_id: ResidualBlockId,
148    ) -> Option<residual_block::ResidualBlock> {
149        if let Some(residual_block) = self.residual_blocks.remove(&block_id) {
150            self.total_residual_dimension -= residual_block.dim_residual;
151            Some(residual_block)
152        } else {
153            None
154        }
155    }
156    pub fn fix_variable(&mut self, var_to_fix: &str, idx: usize) {
157        if let Some(var_mut) = self.fixed_variable_indexes.get_mut(var_to_fix) {
158            var_mut.insert(idx);
159        } else {
160            self.fixed_variable_indexes
161                .insert(var_to_fix.to_owned(), HashSet::from([idx]));
162        }
163    }
164    pub fn unfix_variable(&mut self, var_to_unfix: &str) {
165        self.fixed_variable_indexes.remove(var_to_unfix);
166    }
167    pub fn set_variable_bounds(
168        &mut self,
169        var_to_bound: &str,
170        idx: usize,
171        lower_bound: f64,
172        upper_bound: f64,
173    ) {
174        if lower_bound > upper_bound {
175            log::error!("lower bound is larger than upper bound");
176        } else if let Some(var_mut) = self.variable_bounds.get_mut(var_to_bound) {
177            var_mut.insert(idx, (lower_bound, upper_bound));
178        } else {
179            self.variable_bounds.insert(
180                var_to_bound.to_owned(),
181                HashMap::from([(idx, (lower_bound, upper_bound))]),
182            );
183        }
184    }
185    pub fn set_variable_manifold(
186        &mut self,
187        var_name: &str,
188        manifold: Arc<dyn Manifold + Sync + Send>,
189    ) {
190        self.variable_manifold
191            .insert(var_name.to_string(), manifold);
192    }
193    pub fn remove_variable_bounds(&mut self, var_to_unbound: &str) {
194        self.variable_bounds.remove(var_to_unbound);
195    }
196    pub fn initialize_parameter_blocks(
197        &self,
198        initial_values: &HashMap<String, na::DVector<f64>>,
199    ) -> HashMap<String, ParameterBlock> {
200        let parameter_blocks: HashMap<String, ParameterBlock> = initial_values
201            .iter()
202            .map(|(k, v)| {
203                let mut p_block = ParameterBlock::from_vec(v.clone());
204                if let Some(indexes) = self.fixed_variable_indexes.get(k) {
205                    p_block.fixed_variables = indexes.clone();
206                }
207                if let Some(bounds) = self.variable_bounds.get(k) {
208                    p_block.variable_bounds = bounds.clone();
209                }
210                if let Some(manifold) = self.variable_manifold.get(k) {
211                    p_block.manifold = Some(manifold.clone())
212                }
213
214                (k.to_owned(), p_block)
215            })
216            .collect();
217        parameter_blocks
218    }
219
220    pub fn compute_residuals(
221        &self,
222        parameter_blocks: &HashMap<String, ParameterBlock>,
223        with_loss_fn: bool,
224    ) -> faer::Mat<f64> {
225        let total_residual = Arc::new(Mutex::new(na::DVector::<f64>::zeros(
226            self.total_residual_dimension,
227        )));
228        self.residual_blocks
229            .par_iter()
230            .for_each(|(_, residual_block)| {
231                self.compute_residual_impl(
232                    residual_block,
233                    parameter_blocks,
234                    &total_residual,
235                    with_loss_fn,
236                )
237            });
238        let total_residual = Arc::try_unwrap(total_residual)
239            .unwrap()
240            .into_inner()
241            .unwrap();
242
243        total_residual.view_range(.., ..).into_faer().to_owned()
244    }
245
246    pub fn compute_residual_and_jacobian(
247        &self,
248        parameter_blocks: &HashMap<String, ParameterBlock>,
249        variable_name_to_col_idx_dict: &HashMap<String, usize>,
250        symbolic_structure: &SymbolicStructure,
251    ) -> (faer::Mat<f64>, SparseColMat<usize, f64>) {
252        // multi
253        let total_residual = Arc::new(Mutex::new(na::DVector::<f64>::zeros(
254            self.total_residual_dimension,
255        )));
256
257        let jacobian_lists: Vec<JacobianValue> = self
258            .residual_blocks
259            .par_iter()
260            .map(|(_, residual_block)| {
261                self.compute_residual_and_jacobian_impl(
262                    residual_block,
263                    parameter_blocks,
264                    variable_name_to_col_idx_dict,
265                    &total_residual,
266                )
267            })
268            .flatten()
269            .collect();
270
271        let total_residual = Arc::try_unwrap(total_residual)
272            .unwrap()
273            .into_inner()
274            .unwrap();
275
276        let residual_faer = total_residual.view_range(.., ..).into_faer().to_owned();
277        let jacobian_faer = SparseColMat::new_from_argsort(
278            symbolic_structure.pattern.clone(),
279            &symbolic_structure.order,
280            jacobian_lists.as_slice(),
281        )
282        .unwrap();
283        (residual_faer, jacobian_faer)
284    }
285
286    fn compute_residual_impl(
287        &self,
288        residual_block: &crate::ResidualBlock,
289        parameter_blocks: &HashMap<String, ParameterBlock>,
290        total_residual: &Arc<Mutex<na::DVector<f64>>>,
291        with_loss_fn: bool,
292    ) {
293        let mut params = Vec::new();
294        for var_key in &residual_block.variable_key_list {
295            if let Some(param) = parameter_blocks.get(var_key) {
296                params.push(param);
297            };
298        }
299        let res = residual_block.residual(&params, with_loss_fn);
300
301        {
302            let mut total_residual = total_residual.lock().unwrap();
303            total_residual
304                .rows_mut(
305                    residual_block.residual_row_start_idx,
306                    residual_block.dim_residual,
307                )
308                .copy_from(&res);
309        }
310    }
311
312    fn compute_residual_and_jacobian_impl(
313        &self,
314        residual_block: &crate::ResidualBlock,
315        parameter_blocks: &HashMap<String, ParameterBlock>,
316        variable_name_to_col_idx_dict: &HashMap<String, usize>,
317        total_residual: &Arc<Mutex<na::DVector<f64>>>,
318    ) -> Vec<JacobianValue> {
319        let mut params = Vec::new();
320        let mut variable_local_idx_size_list = Vec::<(usize, usize)>::new();
321        let mut count_variable_local_idx: usize = 0;
322        for var_key in &residual_block.variable_key_list {
323            if let Some(param) = parameter_blocks.get(var_key) {
324                params.push(param);
325                variable_local_idx_size_list.push((count_variable_local_idx, param.tangent_size()));
326                count_variable_local_idx += param.tangent_size();
327            };
328        }
329        let (res, jac) = residual_block.residual_and_jacobian(&params);
330        {
331            let mut total_residual = total_residual.lock().unwrap();
332            total_residual
333                .rows_mut(
334                    residual_block.residual_row_start_idx,
335                    residual_block.dim_residual,
336                )
337                .copy_from(&res);
338        }
339
340        let mut local_jacobian_list = Vec::new();
341
342        for (i, var_key) in residual_block.variable_key_list.iter().enumerate() {
343            if variable_name_to_col_idx_dict.contains_key(var_key) {
344                let (variable_local_idx, var_size) = variable_local_idx_size_list[i];
345                let variable_jac = jac.view((0, variable_local_idx), (jac.shape().0, var_size));
346                let param = &params[i];
347                for row_idx in 0..jac.shape().0 {
348                    for col_idx in 0..var_size {
349                        if param.manifold.is_none() && param.fixed_variables.contains(&col_idx) {
350                            continue;
351                        }
352                        let j_value = variable_jac[(row_idx, col_idx)];
353                        if j_value.is_finite() {
354                            local_jacobian_list.push(j_value);
355                        } else {
356                            log::warn!(
357                                "Non-finite Jacobian value detected at residual block {}, variable {}, row {}, col {}. Setting to 0.0",
358                                residual_block.residual_block_id,
359                                var_key,
360                                row_idx,
361                                col_idx
362                            );
363                            local_jacobian_list.push(0.0);
364                        }
365                    }
366                }
367            } else {
368                panic!(
369                    "Missing key {} in variable-to-column-index mapping",
370                    var_key
371                );
372            }
373        }
374
375        local_jacobian_list
376    }
377}