1use crate::odometry::CameraIntrinsics;
2use nalgebra as na;
3use nalgebra::LU;
5use std::collections::HashMap;
6use std::ops::SubAssign;
7
8mod lie {
10 use nalgebra as na;
11
12 pub fn exp_map(omega: &na::Vector3<f64>) -> na::Matrix3<f64> {
14 let theta = omega.norm();
15 if theta < 1e-8 {
16 return na::Matrix3::identity();
17 }
18 let w = omega / theta;
19 let w_hat = na::Matrix3::new(0.0, -w[2], w[1], w[2], 0.0, -w[0], -w[1], w[0], 0.0);
20 na::Matrix3::identity() + w_hat * theta.sin() + (w_hat * w_hat) * (1.0 - theta.cos())
21 }
22}
23
24#[derive(Debug, Clone)]
25pub struct Observation {
26 pub keyframe_idx: usize,
27 pub point_idx: usize,
28 pub pixel: na::Point2<f64>,
29}
30
31impl Observation {
32 pub fn new(keyframe_idx: usize, point_idx: usize, pixel: na::Point2<f64>) -> Self {
33 Self {
34 keyframe_idx,
35 point_idx,
36 pixel,
37 }
38 }
39}
40
41fn huber_loss(residual: f64, delta: f64) -> f64 {
43 let rsq = residual * residual;
44 if rsq <= delta * delta {
45 rsq
46 } else {
47 2.0 * delta * residual.abs() - delta * delta
48 }
49}
50
51pub struct BundleAdjuster {
52 intrinsics: CameraIntrinsics,
53 max_iterations: usize,
54 lambda: f64,
55 min_error_change: f64,
56 huber_delta: f64,
57}
58
59impl BundleAdjuster {
60 pub fn new(intrinsics: CameraIntrinsics) -> Self {
61 Self {
62 intrinsics,
63 max_iterations: 10,
64 lambda: 1e-3,
65 min_error_change: 1e-6,
66 huber_delta: 2.0, }
68 }
69
70 pub fn with_max_iterations(mut self, max_iterations: usize) -> Self {
71 self.max_iterations = max_iterations;
72 self
73 }
74
75 pub fn with_lambda(mut self, lambda: f64) -> Self {
76 self.lambda = lambda;
77 self
78 }
79
80 pub fn with_huber_delta(mut self, delta: f64) -> Self {
81 self.huber_delta = delta;
82 self
83 }
84
85 fn project(
86 &self,
87 point: &na::Point3<f64>,
88 r: &na::Matrix3<f64>,
89 t: &na::Vector3<f64>,
90 ) -> Option<na::Point2<f64>> {
91 let p_cam = r * point + t;
92
93 if p_cam.z <= 1e-6 {
94 return None;
95 }
96
97 let x = self.intrinsics.fx * (p_cam.x / p_cam.z) + self.intrinsics.cx;
98 let y = self.intrinsics.fy * (p_cam.y / p_cam.z) + self.intrinsics.cy;
99 Some(na::Point2::new(x, y))
100 }
101
102 fn compute_jacobians(
104 &self,
105 point: &na::Point3<f64>,
106 r: &na::Matrix3<f64>,
107 t: &na::Vector3<f64>,
108 ) -> Option<(na::Matrix2x6<f64>, na::Matrix2x3<f64>)> {
109 let p_cam = r * point + t;
110
111 if p_cam.z <= 1e-6 {
112 return None;
113 }
114
115 let z = p_cam.z;
116 let z2 = z * z;
117 let fx = self.intrinsics.fx;
118 let fy = self.intrinsics.fy;
119
120 let j_proj = na::Matrix2x3::new(
122 fx / z,
123 0.0,
124 -fx * p_cam.x / z2,
125 0.0,
126 fy / z,
127 -fy * p_cam.y / z2,
128 );
129
130 let j_point = j_proj * r;
132
133 let point_cam = r * point;
135 let point_cam_cross = na::Matrix3::new(
136 0.0,
137 -point_cam[2],
138 point_cam[1],
139 point_cam[2],
140 0.0,
141 -point_cam[0],
142 -point_cam[1],
143 point_cam[0],
144 0.0,
145 );
146
147 let mut j_pose = na::Matrix2x6::zeros();
148 j_pose
150 .fixed_view_mut::<2, 3>(0, 0)
151 .copy_from(&(j_proj * (-point_cam_cross)));
152 j_pose
154 .fixed_view_mut::<2, 3>(0, 3)
155 .copy_from(&(j_proj * (-point_cam_cross)));
156
157 Some((j_pose, j_point))
158 }
159
160 pub fn compute_total_error(
161 &self,
162 poses: &[(na::Matrix3<f64>, na::Vector3<f64>)],
163 points: &[na::Point3<f64>],
164 observations: &[Observation],
165 ) -> f64 {
166 let mut total_error = 0.0;
167 let mut count = 0;
168
169 for obs in observations {
170 if obs.keyframe_idx >= poses.len() || obs.point_idx >= points.len() {
171 continue;
172 }
173
174 let (r, t) = &poses[obs.keyframe_idx];
175 let point = &points[obs.point_idx];
176
177 if let Some(proj) = self.project(point, r, t) {
178 let dx = proj.x - obs.pixel.x;
179 let dy = proj.y - obs.pixel.y;
180 let residual = (dx * dx + dy * dy).sqrt();
181 total_error += huber_loss(residual, self.huber_delta);
182 count += 1;
183 }
184 }
185
186 if count > 0 { total_error } else { 0.0 }
187 }
188
189 pub fn optimize(
191 &self,
192 poses: &mut [(na::Matrix3<f64>, na::Vector3<f64>)],
193 points: &mut [na::Point3<f64>],
194 observations: &[Observation],
195 fix_first_pose: bool,
196 ) -> Result<f64, Box<dyn std::error::Error>> {
197 if observations.is_empty() {
198 return Ok(0.0);
199 }
200
201 let mut prev_error = self.compute_total_error(poses, points, observations);
202 let n_poses = poses.len();
203 let n_points = points.len();
204
205 for _iter in 0..self.max_iterations {
206 let mut h_pp: HashMap<usize, na::Matrix6<f64>> = HashMap::new();
208 let mut h_ll: HashMap<usize, na::Matrix3<f64>> = HashMap::new();
209 let mut h_pl: HashMap<(usize, usize), na::Matrix6x3<f64>> = HashMap::new();
210 let mut b_p: HashMap<usize, na::Vector6<f64>> = HashMap::new();
211 let mut b_l: HashMap<usize, na::Vector3<f64>> = HashMap::new();
212
213 for obs in observations {
215 if obs.keyframe_idx >= n_poses || obs.point_idx >= n_points {
216 continue;
217 }
218 let (r, t) = &poses[obs.keyframe_idx];
219 let point = &points[obs.point_idx];
220 if let Some(proj) = self.project(point, r, t) {
221 let residual = na::Vector2::new(proj.x - obs.pixel.x, proj.y - obs.pixel.y);
222 let r_norm = residual.norm();
223 let weight = if r_norm > 1e-8 {
225 let huber_w = huber_loss(r_norm, self.huber_delta) / (r_norm * r_norm);
226 huber_w.sqrt()
227 } else {
228 1.0
229 };
230 let weighted_residual = residual * weight;
231
232 if let Some((j_pose, j_point)) = self.compute_jacobians(point, r, t) {
233 let j_pose_w = j_pose * weight;
234 let j_point_w = j_point * weight;
235
236 *h_pp
237 .entry(obs.keyframe_idx)
238 .or_insert_with(na::Matrix6::<f64>::zeros) +=
239 j_pose_w.transpose() * j_pose;
240 *h_ll
241 .entry(obs.point_idx)
242 .or_insert_with(na::Matrix3::<f64>::zeros) +=
243 j_point_w.transpose() * j_point;
244 *h_pl
245 .entry((obs.keyframe_idx, obs.point_idx))
246 .or_insert_with(na::Matrix6x3::<f64>::zeros) +=
247 j_pose_w.transpose() * j_point;
248
249 *b_p.entry(obs.keyframe_idx)
250 .or_insert_with(na::Vector6::<f64>::zeros) -=
251 j_pose_w.transpose() * weighted_residual;
252 *b_l.entry(obs.point_idx)
253 .or_insert_with(na::Vector3::<f64>::zeros) -=
254 j_point_w.transpose() * weighted_residual;
255 }
256 }
257 }
258
259 let mut h_reduced = na::DMatrix::<f64>::zeros(n_poses * 6, n_poses * 6);
261 let mut b_reduced = na::DVector::<f64>::zeros(n_poses * 6);
262
263 for (i, h_pp_i) in &h_pp {
265 let start = *i * 6;
266 h_reduced.view_mut((start, start), (6, 6)).copy_from(h_pp_i);
267 }
268 for (i, b_p_i) in &b_p {
269 let start = *i * 6;
270 b_reduced.rows_mut(start, 6).copy_from(b_p_i);
271 }
272
273 if fix_first_pose && n_poses > 0 {
275 h_reduced.view_mut((0, 0), (6, 6)).fill(0.0);
276 h_reduced.view_mut((0, 0), (6, 6)).fill_with_identity();
277 b_reduced.rows_mut(0, 6).fill(0.0);
278 }
279
280 for j in 0..n_points {
282 if let Some(h_ll_j) = h_ll.get(&j) {
283 let h_ll_inv = h_ll_j
284 .try_inverse()
285 .unwrap_or_else(|| na::Matrix3::<f64>::identity() * 1e6);
286
287 for ((i, pj), h_pl_ij) in h_pl.iter() {
289 if *pj == j {
290 let contrib = h_pl_ij * h_ll_inv * h_pl_ij.transpose();
291 let start = *i * 6;
292 h_reduced
293 .view_mut((start, start), (6, 6))
294 .sub_assign(&contrib);
295 }
296 }
297
298 if let Some(b_l_j) = b_l.get(&j) {
300 for ((i, pj), h_pl_ij) in h_pl.iter() {
301 if *pj == j {
302 let update = h_pl_ij * h_ll_inv * b_l_j;
303 let start = *i * 6;
304 b_reduced.rows_mut(start, 6).sub_assign(&update);
305 }
306 }
307 }
308 }
309 }
310
311 for i in 0..n_poses {
313 for j in 0..6 {
314 h_reduced[(i * 6 + j, i * 6 + j)] += self.lambda * 10.0; }
316 }
317
318 let lu = LU::new(h_reduced.clone());
331 let delta_poses = lu.solve(&b_reduced).ok_or("LU solve failed")?;
332
333 for i in 0..n_poses {
335 let delta = delta_poses.rows(i * 6, 6);
336 let delta_rot = lie::exp_map(&na::Vector3::new(delta[0], delta[1], delta[2]));
337 let delta_trans = na::Vector3::new(delta[3], delta[4], delta[5]);
338
339 poses[i].0 = delta_rot * poses[i].0;
340 poses[i].1 += delta_trans;
341 }
342
343 for j in 0..n_points {
345 if let Some(h_ll_j) = h_ll.get(&j) {
346 let h_ll_inv = h_ll_j
347 .try_inverse()
348 .unwrap_or_else(|| na::Matrix3::<f64>::identity() * 1e6);
349
350 let mut sum = na::Vector3::<f64>::zeros();
351 for ((i, pj), h_pl_ij) in h_pl.iter() {
352 if *pj == j {
353 let delta_p = delta_poses.rows(*i * 6, 6);
354 sum += h_pl_ij.transpose() * delta_p;
355 }
356 }
357
358 if let Some(b_l_j) = b_l.get(&j) {
359 let delta_l = h_ll_inv * (b_l_j - sum);
360 points[j].coords += delta_l;
361 }
362 }
363 }
364
365 let current_error = self.compute_total_error(poses, points, observations);
367 if current_error > prev_error * 1.5 {
368 return Ok(prev_error);
370 }
371 let error_change = (prev_error - current_error).abs();
372
373 if error_change < self.min_error_change {
374 break;
375 }
376
377 prev_error = current_error;
378 }
379
380 Ok(prev_error)
381 }
382
383 pub fn local_bundle_adjustment(
384 &self,
385 poses: &mut [(na::Matrix3<f64>, na::Vector3<f64>)],
386 points: &mut [na::Point3<f64>],
387 observations: &[Observation],
388 window_size: usize,
389 ) -> Result<f64, Box<dyn std::error::Error>> {
390 if poses.is_empty() {
391 return Ok(0.0);
392 }
393
394 let start_idx = poses.len().saturating_sub(window_size);
395 let local_observations: Vec<Observation> = observations
396 .iter()
397 .filter(|obs| obs.keyframe_idx >= start_idx)
398 .cloned()
399 .collect();
400
401 self.optimize(poses, points, &local_observations, start_idx == 0)
402 }
403}
404
405#[cfg(test)]
406mod tests {
407 use super::*;
408 use approx::assert_relative_eq;
409
410 #[test]
411 fn test_se3_jacobian_numerical() {
412 let cam = CameraIntrinsics::new(500.0, 500.0, 320.0, 240.0);
413 let ba = BundleAdjuster::new(cam);
414
415 let point = na::Point3::new(1.0, 2.0, 10.0);
416 let r = na::Matrix3::identity();
417 let t = na::Vector3::new(0.1, -0.2, 0.05);
418
419 let (j_pose, _) = ba.compute_jacobians(&point, &r, &t).unwrap();
420
421 let eps = 1e-6;
423 let delta_omega = na::Vector3::new(eps, 0.0, 0.0);
424 let r_perturbed = lie::exp_map(&delta_omega) * r;
425
426 let proj_original = ba.project(&point, &r, &t).unwrap();
427 let proj_perturbed = ba.project(&point, &r_perturbed, &t).unwrap();
428
429 let expected_change = j_pose.fixed_view::<2, 3>(0, 0) * delta_omega;
430 let actual_change = proj_perturbed - proj_original;
431
432 assert_relative_eq!(expected_change, actual_change, epsilon = 1e-5);
433 }
434
435 #[test]
436 fn test_rotation_convergence() {
437 let cam = CameraIntrinsics::new(500.0, 500.0, 320.0, 240.0);
438 let ba = BundleAdjuster::new(cam).with_max_iterations(30);
439
440 let true_point = na::Point3::new(1.0, 0.5, 5.0);
441 let true_r = na::UnitQuaternion::from_euler_angles(0.1, 0.2, 0.05).to_rotation_matrix();
442 let true_t = na::Vector3::new(0.1, -0.1, 0.0);
443
444 let true_r_mat: na::Matrix3<f64> = true_r.into();
445 let proj = ba.project(&true_point, &true_r_mat, &true_t).unwrap();
446
447 let mut poses = vec![(na::Matrix3::identity(), na::Vector3::zeros())];
448 let mut points = vec![na::Point3::new(1.5, 0.8, 6.0)];
449 let observations = vec![Observation::new(0, 0, proj)];
450
451 let result = ba.optimize(&mut poses, &mut points, &observations, false);
452 assert!(result.is_ok());
453
454 let final_error = result.unwrap();
455 assert!(final_error < 1e-6, "Should reach near-zero error");
456 }
457}