slamkit_rs/mapping/
bundle_adjustment.rs

1use crate::odometry::CameraIntrinsics;
2use nalgebra as na;
3// use nalgebra::{Cholesky, LU};
4use nalgebra::LU;
5use std::collections::HashMap;
6use std::ops::SubAssign;
7
8/// SE(3) Lie algebra operations for proper pose optimization
9mod lie {
10    use nalgebra as na;
11
12    /// Convert angle-axis vector to rotation matrix (Rodrigues' formula)
13    pub fn exp_map(omega: &na::Vector3<f64>) -> na::Matrix3<f64> {
14        let theta = omega.norm();
15        if theta < 1e-8 {
16            return na::Matrix3::identity();
17        }
18        let w = omega / theta;
19        let w_hat = na::Matrix3::new(0.0, -w[2], w[1], w[2], 0.0, -w[0], -w[1], w[0], 0.0);
20        na::Matrix3::identity() + w_hat * theta.sin() + (w_hat * w_hat) * (1.0 - theta.cos())
21    }
22}
23
24#[derive(Debug, Clone)]
25pub struct Observation {
26    pub keyframe_idx: usize,
27    pub point_idx: usize,
28    pub pixel: na::Point2<f64>,
29}
30
31impl Observation {
32    pub fn new(keyframe_idx: usize, point_idx: usize, pixel: na::Point2<f64>) -> Self {
33        Self {
34            keyframe_idx,
35            point_idx,
36            pixel,
37        }
38    }
39}
40
41/// Huber cost function
42fn huber_loss(residual: f64, delta: f64) -> f64 {
43    let rsq = residual * residual;
44    if rsq <= delta * delta {
45        rsq
46    } else {
47        2.0 * delta * residual.abs() - delta * delta
48    }
49}
50
51pub struct BundleAdjuster {
52    intrinsics: CameraIntrinsics,
53    max_iterations: usize,
54    lambda: f64,
55    min_error_change: f64,
56    huber_delta: f64,
57}
58
59impl BundleAdjuster {
60    pub fn new(intrinsics: CameraIntrinsics) -> Self {
61        Self {
62            intrinsics,
63            max_iterations: 10,
64            lambda: 1e-3,
65            min_error_change: 1e-6,
66            huber_delta: 2.0, // pixels
67        }
68    }
69
70    pub fn with_max_iterations(mut self, max_iterations: usize) -> Self {
71        self.max_iterations = max_iterations;
72        self
73    }
74
75    pub fn with_lambda(mut self, lambda: f64) -> Self {
76        self.lambda = lambda;
77        self
78    }
79
80    pub fn with_huber_delta(mut self, delta: f64) -> Self {
81        self.huber_delta = delta;
82        self
83    }
84
85    fn project(
86        &self,
87        point: &na::Point3<f64>,
88        r: &na::Matrix3<f64>,
89        t: &na::Vector3<f64>,
90    ) -> Option<na::Point2<f64>> {
91        let p_cam = r * point + t;
92
93        if p_cam.z <= 1e-6 {
94            return None;
95        }
96
97        let x = self.intrinsics.fx * (p_cam.x / p_cam.z) + self.intrinsics.cx;
98        let y = self.intrinsics.fy * (p_cam.y / p_cam.z) + self.intrinsics.cy;
99        Some(na::Point2::new(x, y))
100    }
101
102    /// Compute Jacobians for pose (6 DOF) and point (3 DOF)
103    fn compute_jacobians(
104        &self,
105        point: &na::Point3<f64>,
106        r: &na::Matrix3<f64>,
107        t: &na::Vector3<f64>,
108    ) -> Option<(na::Matrix2x6<f64>, na::Matrix2x3<f64>)> {
109        let p_cam = r * point + t;
110
111        if p_cam.z <= 1e-6 {
112            return None;
113        }
114
115        let z = p_cam.z;
116        let z2 = z * z;
117        let fx = self.intrinsics.fx;
118        let fy = self.intrinsics.fy;
119
120        // ∂pixel/∂p_cam (2×3)
121        let j_proj = na::Matrix2x3::new(
122            fx / z,
123            0.0,
124            -fx * p_cam.x / z2,
125            0.0,
126            fy / z,
127            -fy * p_cam.y / z2,
128        );
129
130        // ∂pixel/∂point = ∂pixel/∂p_cam * ∂p_cam/∂point (2×3)
131        let j_point = j_proj * r;
132
133        // ∂p_cam/∂pose = [∂p_cam/∂ω, ∂p_cam/∂t] (3×6)
134        let point_cam = r * point;
135        let point_cam_cross = na::Matrix3::new(
136            0.0,
137            -point_cam[2],
138            point_cam[1],
139            point_cam[2],
140            0.0,
141            -point_cam[0],
142            -point_cam[1],
143            point_cam[0],
144            0.0,
145        );
146
147        let mut j_pose = na::Matrix2x6::zeros();
148        // Rotation part: del pixel/del w = del pixel/del p_cam * del p_cam/del w
149        j_pose
150            .fixed_view_mut::<2, 3>(0, 0)
151            .copy_from(&(j_proj * (-point_cam_cross)));
152        // Translation part: del pixel/del t = del pixel/del p_cam * del p_cam/del t
153        j_pose
154            .fixed_view_mut::<2, 3>(0, 3)
155            .copy_from(&(j_proj * (-point_cam_cross)));
156
157        Some((j_pose, j_point))
158    }
159
160    pub fn compute_total_error(
161        &self,
162        poses: &[(na::Matrix3<f64>, na::Vector3<f64>)],
163        points: &[na::Point3<f64>],
164        observations: &[Observation],
165    ) -> f64 {
166        let mut total_error = 0.0;
167        let mut count = 0;
168
169        for obs in observations {
170            if obs.keyframe_idx >= poses.len() || obs.point_idx >= points.len() {
171                continue;
172            }
173
174            let (r, t) = &poses[obs.keyframe_idx];
175            let point = &points[obs.point_idx];
176
177            if let Some(proj) = self.project(point, r, t) {
178                let dx = proj.x - obs.pixel.x;
179                let dy = proj.y - obs.pixel.y;
180                let residual = (dx * dx + dy * dy).sqrt();
181                total_error += huber_loss(residual, self.huber_delta);
182                count += 1;
183            }
184        }
185
186        if count > 0 { total_error } else { 0.0 }
187    }
188
189    /// Full sparse BA using Schur complement for pose-point marginalization
190    pub fn optimize(
191        &self,
192        poses: &mut [(na::Matrix3<f64>, na::Vector3<f64>)],
193        points: &mut [na::Point3<f64>],
194        observations: &[Observation],
195        fix_first_pose: bool,
196    ) -> Result<f64, Box<dyn std::error::Error>> {
197        if observations.is_empty() {
198            return Ok(0.0);
199        }
200
201        let mut prev_error = self.compute_total_error(poses, points, observations);
202        let n_poses = poses.len();
203        let n_points = points.len();
204
205        for _iter in 0..self.max_iterations {
206            // Build sparse blocks
207            let mut h_pp: HashMap<usize, na::Matrix6<f64>> = HashMap::new();
208            let mut h_ll: HashMap<usize, na::Matrix3<f64>> = HashMap::new();
209            let mut h_pl: HashMap<(usize, usize), na::Matrix6x3<f64>> = HashMap::new();
210            let mut b_p: HashMap<usize, na::Vector6<f64>> = HashMap::new();
211            let mut b_l: HashMap<usize, na::Vector3<f64>> = HashMap::new();
212
213            // Build blocks
214            for obs in observations {
215                if obs.keyframe_idx >= n_poses || obs.point_idx >= n_points {
216                    continue;
217                }
218                let (r, t) = &poses[obs.keyframe_idx];
219                let point = &points[obs.point_idx];
220                if let Some(proj) = self.project(point, r, t) {
221                    let residual = na::Vector2::new(proj.x - obs.pixel.x, proj.y - obs.pixel.y);
222                    let r_norm = residual.norm();
223                    // Huber weight calculation
224                    let weight = if r_norm > 1e-8 {
225                        let huber_w = huber_loss(r_norm, self.huber_delta) / (r_norm * r_norm);
226                        huber_w.sqrt()
227                    } else {
228                        1.0
229                    };
230                    let weighted_residual = residual * weight;
231
232                    if let Some((j_pose, j_point)) = self.compute_jacobians(point, r, t) {
233                        let j_pose_w = j_pose * weight;
234                        let j_point_w = j_point * weight;
235
236                        *h_pp
237                            .entry(obs.keyframe_idx)
238                            .or_insert_with(na::Matrix6::<f64>::zeros) +=
239                            j_pose_w.transpose() * j_pose;
240                        *h_ll
241                            .entry(obs.point_idx)
242                            .or_insert_with(na::Matrix3::<f64>::zeros) +=
243                            j_point_w.transpose() * j_point;
244                        *h_pl
245                            .entry((obs.keyframe_idx, obs.point_idx))
246                            .or_insert_with(na::Matrix6x3::<f64>::zeros) +=
247                            j_pose_w.transpose() * j_point;
248
249                        *b_p.entry(obs.keyframe_idx)
250                            .or_insert_with(na::Vector6::<f64>::zeros) -=
251                            j_pose_w.transpose() * weighted_residual;
252                        *b_l.entry(obs.point_idx)
253                            .or_insert_with(na::Vector3::<f64>::zeros) -=
254                            j_point_w.transpose() * weighted_residual;
255                    }
256                }
257            }
258
259            // Build reduced pose system: h_reduced = h_pp - summation of H_pl * h_ll^-1 * H_pl^T
260            let mut h_reduced = na::DMatrix::<f64>::zeros(n_poses * 6, n_poses * 6);
261            let mut b_reduced = na::DVector::<f64>::zeros(n_poses * 6);
262
263            // Initialize with h_pp and b_p
264            for (i, h_pp_i) in &h_pp {
265                let start = *i * 6;
266                h_reduced.view_mut((start, start), (6, 6)).copy_from(h_pp_i);
267            }
268            for (i, b_p_i) in &b_p {
269                let start = *i * 6;
270                b_reduced.rows_mut(start, 6).copy_from(b_p_i);
271            }
272
273            // Fix first pose if requested
274            if fix_first_pose && n_poses > 0 {
275                h_reduced.view_mut((0, 0), (6, 6)).fill(0.0);
276                h_reduced.view_mut((0, 0), (6, 6)).fill_with_identity();
277                b_reduced.rows_mut(0, 6).fill(0.0);
278            }
279
280            // For each point, subtract its contribution from reduced system
281            for j in 0..n_points {
282                if let Some(h_ll_j) = h_ll.get(&j) {
283                    let h_ll_inv = h_ll_j
284                        .try_inverse()
285                        .unwrap_or_else(|| na::Matrix3::<f64>::identity() * 1e6);
286
287                    // Subtract h_pl[i,j] * h_ll^-1 * h_pl[i,j]^T from each pose block
288                    for ((i, pj), h_pl_ij) in h_pl.iter() {
289                        if *pj == j {
290                            let contrib = h_pl_ij * h_ll_inv * h_pl_ij.transpose();
291                            let start = *i * 6;
292                            h_reduced
293                                .view_mut((start, start), (6, 6))
294                                .sub_assign(&contrib);
295                        }
296                    }
297
298                    // Subtract H_pl[i,j] * H_ll^-1 * b_l[j] from each pose's b vector
299                    if let Some(b_l_j) = b_l.get(&j) {
300                        for ((i, pj), h_pl_ij) in h_pl.iter() {
301                            if *pj == j {
302                                let update = h_pl_ij * h_ll_inv * b_l_j;
303                                let start = *i * 6;
304                                b_reduced.rows_mut(start, 6).sub_assign(&update);
305                            }
306                        }
307                    }
308                }
309            }
310
311            // Apply damping directly to diagonal blocks for stability
312            for i in 0..n_poses {
313                for j in 0..6 {
314                    h_reduced[(i * 6 + j, i * 6 + j)] += self.lambda * 10.0; // this is very strong damping, this needs testing
315                }
316            }
317
318            // Solve reduced system with Cholesky
319            // Didn't work lmao??
320            // let chol = H_reduced
321            //     .clone()
322            //     .cholesky()
323            //     .ok_or("Cholesky decomposition failed")?;
324            // let delta_poses = chol.solve(&b_reduced);
325            // LDLT doesn't seem to exists in Rust :(
326            // let chol = Cholesky::new(h_reduced.clone())
327            //     .ok_or("Cholesky failed - matrix not positive definite. Try increasing lambda")?;
328            // let delta_poses = chol.solve(&b_reduced);
329            // Solve reduced system with LU
330            let lu = LU::new(h_reduced.clone());
331            let delta_poses = lu.solve(&b_reduced).ok_or("LU solve failed")?;
332
333            // Update poses
334            for i in 0..n_poses {
335                let delta = delta_poses.rows(i * 6, 6);
336                let delta_rot = lie::exp_map(&na::Vector3::new(delta[0], delta[1], delta[2]));
337                let delta_trans = na::Vector3::new(delta[3], delta[4], delta[5]);
338
339                poses[i].0 = delta_rot * poses[i].0;
340                poses[i].1 += delta_trans;
341            }
342
343            // Back-substitute to get point updates: del x_l = h_ll^-1 * (b_l - h_pl^T * del x_p)
344            for j in 0..n_points {
345                if let Some(h_ll_j) = h_ll.get(&j) {
346                    let h_ll_inv = h_ll_j
347                        .try_inverse()
348                        .unwrap_or_else(|| na::Matrix3::<f64>::identity() * 1e6);
349
350                    let mut sum = na::Vector3::<f64>::zeros();
351                    for ((i, pj), h_pl_ij) in h_pl.iter() {
352                        if *pj == j {
353                            let delta_p = delta_poses.rows(*i * 6, 6);
354                            sum += h_pl_ij.transpose() * delta_p;
355                        }
356                    }
357
358                    if let Some(b_l_j) = b_l.get(&j) {
359                        let delta_l = h_ll_inv * (b_l_j - sum);
360                        points[j].coords += delta_l;
361                    }
362                }
363            }
364
365            // Check for divergence and abort if error increases
366            let current_error = self.compute_total_error(poses, points, observations);
367            if current_error > prev_error * 1.5 {
368                // Optimization diverged, return previous error
369                return Ok(prev_error);
370            }
371            let error_change = (prev_error - current_error).abs();
372
373            if error_change < self.min_error_change {
374                break;
375            }
376
377            prev_error = current_error;
378        }
379
380        Ok(prev_error)
381    }
382
383    pub fn local_bundle_adjustment(
384        &self,
385        poses: &mut [(na::Matrix3<f64>, na::Vector3<f64>)],
386        points: &mut [na::Point3<f64>],
387        observations: &[Observation],
388        window_size: usize,
389    ) -> Result<f64, Box<dyn std::error::Error>> {
390        if poses.is_empty() {
391            return Ok(0.0);
392        }
393
394        let start_idx = poses.len().saturating_sub(window_size);
395        let local_observations: Vec<Observation> = observations
396            .iter()
397            .filter(|obs| obs.keyframe_idx >= start_idx)
398            .cloned()
399            .collect();
400
401        self.optimize(poses, points, &local_observations, start_idx == 0)
402    }
403}
404
405#[cfg(test)]
406mod tests {
407    use super::*;
408    use approx::assert_relative_eq;
409
410    #[test]
411    fn test_se3_jacobian_numerical() {
412        let cam = CameraIntrinsics::new(500.0, 500.0, 320.0, 240.0);
413        let ba = BundleAdjuster::new(cam);
414
415        let point = na::Point3::new(1.0, 2.0, 10.0);
416        let r = na::Matrix3::identity();
417        let t = na::Vector3::new(0.1, -0.2, 0.05);
418
419        let (j_pose, _) = ba.compute_jacobians(&point, &r, &t).unwrap();
420
421        // Numerical check: perturb rotation
422        let eps = 1e-6;
423        let delta_omega = na::Vector3::new(eps, 0.0, 0.0);
424        let r_perturbed = lie::exp_map(&delta_omega) * r;
425
426        let proj_original = ba.project(&point, &r, &t).unwrap();
427        let proj_perturbed = ba.project(&point, &r_perturbed, &t).unwrap();
428
429        let expected_change = j_pose.fixed_view::<2, 3>(0, 0) * delta_omega;
430        let actual_change = proj_perturbed - proj_original;
431
432        assert_relative_eq!(expected_change, actual_change, epsilon = 1e-5);
433    }
434
435    #[test]
436    fn test_rotation_convergence() {
437        let cam = CameraIntrinsics::new(500.0, 500.0, 320.0, 240.0);
438        let ba = BundleAdjuster::new(cam).with_max_iterations(30);
439
440        let true_point = na::Point3::new(1.0, 0.5, 5.0);
441        let true_r = na::UnitQuaternion::from_euler_angles(0.1, 0.2, 0.05).to_rotation_matrix();
442        let true_t = na::Vector3::new(0.1, -0.1, 0.0);
443
444        let true_r_mat: na::Matrix3<f64> = true_r.into();
445        let proj = ba.project(&true_point, &true_r_mat, &true_t).unwrap();
446
447        let mut poses = vec![(na::Matrix3::identity(), na::Vector3::zeros())];
448        let mut points = vec![na::Point3::new(1.5, 0.8, 6.0)];
449        let observations = vec![Observation::new(0, 0, proj)];
450
451        let result = ba.optimize(&mut poses, &mut points, &observations, false);
452        assert!(result.is_ok());
453
454        let final_error = result.unwrap();
455        assert!(final_error < 1e-6, "Should reach near-zero error");
456    }
457}