tiny_solver/
factors.rs

1pub use nalgebra as na;
2
3use crate::manifold::se3::SE3;
4
5pub trait Factor<T: na::RealField>: Send + Sync {
6    fn residual_func(&self, params: &[na::DVector<T>]) -> na::DVector<T>;
7}
8pub trait FactorImpl: Factor<num_dual::DualDVec64> + Factor<f64> {
9    fn residual_func_dual(
10        &self,
11        params: &[na::DVector<num_dual::DualDVec64>],
12    ) -> na::DVector<num_dual::DualDVec64> {
13        self.residual_func(params)
14    }
15    fn residual_func_f64(&self, params: &[na::DVector<f64>]) -> na::DVector<f64> {
16        self.residual_func(params)
17    }
18}
19
20impl<T> FactorImpl for T
21where
22    T: Factor<num_dual::DualDVec64> + Factor<f64>,
23{
24    fn residual_func_dual(
25        &self,
26        params: &[na::DVector<num_dual::DualDVec64>],
27    ) -> na::DVector<num_dual::DualDVec64> {
28        self.residual_func(params)
29    }
30
31    fn residual_func_f64(&self, params: &[na::DVector<f64>]) -> na::DVector<f64> {
32        self.residual_func(params)
33    }
34}
35
36#[derive(Debug, Clone)]
37pub struct BetweenFactorSE2 {
38    pub dx: f64,
39    pub dy: f64,
40    pub dtheta: f64,
41}
42impl<T: na::RealField> Factor<T> for BetweenFactorSE2 {
43    fn residual_func(&self, params: &[na::DVector<T>]) -> na::DVector<T> {
44        let t_origin_k0 = &params[0];
45        let t_origin_k1 = &params[1];
46        let se2_origin_k0 = na::Isometry2::new(
47            na::Vector2::new(t_origin_k0[1].clone(), t_origin_k0[2].clone()),
48            t_origin_k0[0].clone(),
49        );
50        let se2_origin_k1 = na::Isometry2::new(
51            na::Vector2::new(t_origin_k1[1].clone(), t_origin_k1[2].clone()),
52            t_origin_k1[0].clone(),
53        );
54        let se2_k0_k1 = na::Isometry2::new(
55            na::Vector2::<T>::new(T::from_f64(self.dx).unwrap(), T::from_f64(self.dy).unwrap()),
56            T::from_f64(self.dtheta).unwrap(),
57        );
58
59        let se2_diff = se2_origin_k1.inverse() * se2_origin_k0 * se2_k0_k1;
60        na::dvector![
61            se2_diff.translation.x.clone(),
62            se2_diff.translation.y.clone(),
63            se2_diff.rotation.angle()
64        ]
65    }
66}
67
68#[derive(Debug, Clone)]
69pub struct BetweenFactorSE3 {
70    pub dtx: f64,
71    pub dty: f64,
72    pub dtz: f64,
73    pub dqx: f64,
74    pub dqy: f64,
75    pub dqz: f64,
76    pub dqw: f64,
77}
78impl<T: na::RealField> Factor<T> for BetweenFactorSE3 {
79    fn residual_func(&self, params: &[na::DVector<T>]) -> na::DVector<T> {
80        let t_origin_k0 = &params[0];
81        let t_origin_k1 = &params[1];
82        let se3_origin_k0 = SE3::from_vec(t_origin_k0.as_view());
83        let se3_origin_k1 = SE3::from_vec(t_origin_k1.as_view());
84
85        let se3_k0_k1 = SE3::from_vec(
86            na::dvector![
87                self.dqx, self.dqy, self.dqz, self.dqw, self.dtx, self.dty, self.dtz,
88            ]
89            .as_view(),
90        )
91        .cast::<T>();
92
93        let se3_diff = se3_origin_k1.inverse() * se3_origin_k0 * se3_k0_k1.cast();
94
95        se3_diff.log()
96    }
97}
98
99#[derive(Debug, Clone)]
100pub struct PriorFactor {
101    pub v: na::DVector<f64>,
102}
103impl<T: na::RealField> Factor<T> for PriorFactor {
104    fn residual_func(&self, params: &[na::DVector<T>]) -> na::DVector<T> {
105        params[0].clone() - self.v.clone().cast()
106    }
107}