1use std::collections::{HashMap, HashSet};
2use std::sync::{Arc, Mutex};
3
4use faer::sparse::{Argsort, Pair, SparseColMat, SymbolicSparseColMat};
5use faer_ext::IntoFaer;
6use nalgebra as na;
7use rayon::prelude::*;
8
9use crate::manifold::Manifold;
10use crate::parameter_block::ParameterBlock;
11use crate::{factors, loss_functions, residual_block};
12
13type ResidualBlockId = usize;
14
15pub struct Problem {
16 pub total_residual_dimension: usize,
17 residual_id_count: usize,
18 residual_blocks: HashMap<ResidualBlockId, residual_block::ResidualBlock>,
19 pub fixed_variable_indexes: HashMap<String, HashSet<usize>>,
20 pub variable_bounds: HashMap<String, HashMap<usize, (f64, f64)>>,
21 pub variable_manifold: HashMap<String, Arc<dyn Manifold + Sync + Send>>,
22}
23impl Default for Problem {
24 fn default() -> Self {
25 Self::new()
26 }
27}
28
29pub struct SymbolicStructure {
30 pattern: SymbolicSparseColMat<usize>,
31 order: Argsort<usize>,
32}
33
34type JacobianValue = f64;
35
36impl Problem {
37 pub fn new() -> Problem {
38 Problem {
39 total_residual_dimension: 0,
40 residual_id_count: 0,
41 residual_blocks: HashMap::new(),
42 fixed_variable_indexes: HashMap::new(),
43 variable_bounds: HashMap::new(),
44 variable_manifold: HashMap::new(),
45 }
46 }
47
48 pub fn build_symbolic_structure(
49 &self,
50 parameter_blocks: &HashMap<String, ParameterBlock>,
51 total_variable_dimension: usize,
52 variable_name_to_col_idx_dict: &HashMap<String, usize>,
53 ) -> SymbolicStructure {
54 let mut indices = Vec::<Pair<usize, usize>>::new();
55
56 self.residual_blocks.iter().for_each(|(_, residual_block)| {
57 let mut variable_local_idx_size_list = Vec::<(usize, usize)>::new();
58 let mut count_variable_local_idx: usize = 0;
59 for var_key in &residual_block.variable_key_list {
60 if let Some(param) = parameter_blocks.get(var_key) {
61 variable_local_idx_size_list
62 .push((count_variable_local_idx, param.tangent_size()));
63 count_variable_local_idx += param.tangent_size();
64 };
65 }
66 for (i, var_key) in residual_block.variable_key_list.iter().enumerate() {
67 if let Some(variable_global_idx) = variable_name_to_col_idx_dict.get(var_key) {
68 let (_, var_size) = variable_local_idx_size_list[i];
69 for row_idx in 0..residual_block.dim_residual {
70 for col_idx in 0..var_size {
71 let global_row_idx = residual_block.residual_row_start_idx + row_idx;
72 let global_col_idx = variable_global_idx + col_idx;
73 indices.push(Pair::new(global_row_idx, global_col_idx));
74 }
75 }
76 }
77 }
78 });
79 let start = std::time::Instant::now();
80 let (s, o) = SymbolicSparseColMat::try_new_from_indices(
81 self.total_residual_dimension,
82 total_variable_dimension,
83 &indices,
84 )
85 .unwrap();
86 log::trace!("Built symbolic matrix: {:?}", start.elapsed());
87 SymbolicStructure {
88 pattern: s,
89 order: o,
90 }
91 }
92
93 pub fn get_variable_name_to_col_idx_dict(
94 &self,
95 parameter_blocks: &HashMap<String, ParameterBlock>,
96 ) -> HashMap<String, usize> {
97 let mut count_col_idx = 0;
98 let mut variable_name_to_col_idx_dict = HashMap::new();
99 parameter_blocks
100 .iter()
101 .for_each(|(param_name, param_block)| {
102 variable_name_to_col_idx_dict.insert(param_name.to_owned(), count_col_idx);
103 count_col_idx += param_block.tangent_size();
104 });
105 variable_name_to_col_idx_dict
106 }
107 pub fn add_residual_block(
108 &mut self,
109 dim_residual: usize,
110 variable_key_size_list: &[&str],
111 factor: Box<dyn factors::FactorImpl + Send>,
112 loss_func: Option<Box<dyn loss_functions::Loss + Send>>,
113 ) -> ResidualBlockId {
114 self.residual_blocks.insert(
115 self.residual_id_count,
116 residual_block::ResidualBlock::new(
117 self.residual_id_count,
118 dim_residual,
119 self.total_residual_dimension,
120 variable_key_size_list,
121 factor,
122 loss_func,
123 ),
124 );
125 let block_id = self.residual_id_count;
126 self.residual_id_count += 1;
127
128 self.total_residual_dimension += dim_residual;
129
130 block_id
131 }
132 pub fn remove_residual_block(
133 &mut self,
134 block_id: ResidualBlockId,
135 ) -> Option<residual_block::ResidualBlock> {
136 if let Some(residual_block) = self.residual_blocks.remove(&block_id) {
137 self.total_residual_dimension -= residual_block.dim_residual;
138 Some(residual_block)
139 } else {
140 None
141 }
142 }
143 pub fn fix_variable(&mut self, var_to_fix: &str, idx: usize) {
144 if let Some(var_mut) = self.fixed_variable_indexes.get_mut(var_to_fix) {
145 var_mut.insert(idx);
146 } else {
147 self.fixed_variable_indexes
148 .insert(var_to_fix.to_owned(), HashSet::from([idx]));
149 }
150 }
151 pub fn unfix_variable(&mut self, var_to_unfix: &str) {
152 self.fixed_variable_indexes.remove(var_to_unfix);
153 }
154 pub fn set_variable_bounds(
155 &mut self,
156 var_to_bound: &str,
157 idx: usize,
158 lower_bound: f64,
159 upper_bound: f64,
160 ) {
161 if lower_bound > upper_bound {
162 log::error!("lower bound is larger than upper bound");
163 } else if let Some(var_mut) = self.variable_bounds.get_mut(var_to_bound) {
164 var_mut.insert(idx, (lower_bound, upper_bound));
165 } else {
166 self.variable_bounds.insert(
167 var_to_bound.to_owned(),
168 HashMap::from([(idx, (lower_bound, upper_bound))]),
169 );
170 }
171 }
172 pub fn set_variable_manifold(
173 &mut self,
174 var_name: &str,
175 manifold: Arc<dyn Manifold + Sync + Send>,
176 ) {
177 self.variable_manifold
178 .insert(var_name.to_string(), manifold);
179 }
180 pub fn remove_variable_bounds(&mut self, var_to_unbound: &str) {
181 self.variable_bounds.remove(var_to_unbound);
182 }
183 pub fn initialize_parameter_blocks(
184 &self,
185 initial_values: &HashMap<String, na::DVector<f64>>,
186 ) -> HashMap<String, ParameterBlock> {
187 let parameter_blocks: HashMap<String, ParameterBlock> = initial_values
188 .iter()
189 .map(|(k, v)| {
190 let mut p_block = ParameterBlock::from_vec(v.clone());
191 if let Some(indexes) = self.fixed_variable_indexes.get(k) {
192 p_block.fixed_variables = indexes.clone();
193 }
194 if let Some(bounds) = self.variable_bounds.get(k) {
195 p_block.variable_bounds = bounds.clone();
196 }
197 if let Some(manifold) = self.variable_manifold.get(k) {
198 p_block.manifold = Some(manifold.clone())
199 }
200
201 (k.to_owned(), p_block)
202 })
203 .collect();
204 parameter_blocks
205 }
206
207 pub fn compute_residuals(
208 &self,
209 parameter_blocks: &HashMap<String, ParameterBlock>,
210 with_loss_fn: bool,
211 ) -> faer::Mat<f64> {
212 let total_residual = Arc::new(Mutex::new(na::DVector::<f64>::zeros(
213 self.total_residual_dimension,
214 )));
215 self.residual_blocks
216 .par_iter()
217 .for_each(|(_, residual_block)| {
218 self.compute_residual_impl(
219 residual_block,
220 parameter_blocks,
221 &total_residual,
222 with_loss_fn,
223 )
224 });
225 let total_residual = Arc::try_unwrap(total_residual)
226 .unwrap()
227 .into_inner()
228 .unwrap();
229
230 total_residual.view_range(.., ..).into_faer().to_owned()
231 }
232
233 pub fn compute_residual_and_jacobian(
234 &self,
235 parameter_blocks: &HashMap<String, ParameterBlock>,
236 variable_name_to_col_idx_dict: &HashMap<String, usize>,
237 symbolic_structure: &SymbolicStructure,
238 ) -> (faer::Mat<f64>, SparseColMat<usize, f64>) {
239 let total_residual = Arc::new(Mutex::new(na::DVector::<f64>::zeros(
241 self.total_residual_dimension,
242 )));
243
244 let jacobian_lists: Vec<JacobianValue> = self
245 .residual_blocks
246 .par_iter()
247 .map(|(_, residual_block)| {
248 self.compute_residual_and_jacobian_impl(
249 residual_block,
250 parameter_blocks,
251 variable_name_to_col_idx_dict,
252 &total_residual,
253 )
254 })
255 .flatten()
256 .collect();
257
258 let total_residual = Arc::try_unwrap(total_residual)
259 .unwrap()
260 .into_inner()
261 .unwrap();
262
263 let residual_faer = total_residual.view_range(.., ..).into_faer().to_owned();
264 let jacobian_faer = SparseColMat::new_from_argsort(
265 symbolic_structure.pattern.clone(),
266 &symbolic_structure.order,
267 jacobian_lists.as_slice(),
268 )
269 .unwrap();
270 (residual_faer, jacobian_faer)
271 }
272
273 fn compute_residual_impl(
274 &self,
275 residual_block: &crate::ResidualBlock,
276 parameter_blocks: &HashMap<String, ParameterBlock>,
277 total_residual: &Arc<Mutex<na::DVector<f64>>>,
278 with_loss_fn: bool,
279 ) {
280 let mut params = Vec::new();
281 for var_key in &residual_block.variable_key_list {
282 if let Some(param) = parameter_blocks.get(var_key) {
283 params.push(param);
284 };
285 }
286 let res = residual_block.residual(¶ms, with_loss_fn);
287
288 {
289 let mut total_residual = total_residual.lock().unwrap();
290 total_residual
291 .rows_mut(
292 residual_block.residual_row_start_idx,
293 residual_block.dim_residual,
294 )
295 .copy_from(&res);
296 }
297 }
298
299 fn compute_residual_and_jacobian_impl(
300 &self,
301 residual_block: &crate::ResidualBlock,
302 parameter_blocks: &HashMap<String, ParameterBlock>,
303 variable_name_to_col_idx_dict: &HashMap<String, usize>,
304 total_residual: &Arc<Mutex<na::DVector<f64>>>,
305 ) -> Vec<JacobianValue> {
306 let mut params = Vec::new();
307 let mut variable_local_idx_size_list = Vec::<(usize, usize)>::new();
308 let mut count_variable_local_idx: usize = 0;
309 for var_key in &residual_block.variable_key_list {
310 if let Some(param) = parameter_blocks.get(var_key) {
311 params.push(param);
312 variable_local_idx_size_list.push((count_variable_local_idx, param.tangent_size()));
313 count_variable_local_idx += param.tangent_size();
314 };
315 }
316 let (res, jac) = residual_block.residual_and_jacobian(¶ms);
317 {
318 let mut total_residual = total_residual.lock().unwrap();
319 total_residual
320 .rows_mut(
321 residual_block.residual_row_start_idx,
322 residual_block.dim_residual,
323 )
324 .copy_from(&res);
325 }
326
327 let mut local_jacobian_list = Vec::new();
328
329 for (i, var_key) in residual_block.variable_key_list.iter().enumerate() {
330 if variable_name_to_col_idx_dict.contains_key(var_key) {
331 let (variable_local_idx, var_size) = variable_local_idx_size_list[i];
332 let variable_jac = jac.view((0, variable_local_idx), (jac.shape().0, var_size));
333 for row_idx in 0..jac.shape().0 {
334 for col_idx in 0..var_size {
335 local_jacobian_list.push(variable_jac[(row_idx, col_idx)]);
336 }
337 }
338 } else {
339 panic!(
340 "Missing key {} in variable-to-column-index mapping",
341 var_key
342 );
343 }
344 }
345
346 local_jacobian_list
347 }
348}