tiny_solver/manifold/
so3.rs

1use std::{num::NonZero, ops::Mul};
2
3use nalgebra as na;
4
5use super::{AutoDiffManifold, Manifold};
6
7pub struct SO3<T: na::RealField> {
8    qx: T,
9    qy: T,
10    qz: T,
11    qw: T,
12}
13
14impl<T: na::RealField> SO3<T> {
15    /// [x, y, z, w]
16    pub fn from_vec(xyzw: na::DVectorView<T>) -> Self {
17        SO3 {
18            qx: xyzw[0].clone(),
19            qy: xyzw[1].clone(),
20            qz: xyzw[2].clone(),
21            qw: xyzw[3].clone(),
22        }
23    }
24    pub fn from_xyzw(x: T, y: T, z: T, w: T) -> Self {
25        SO3 {
26            qx: x,
27            qy: y,
28            qz: z,
29            qw: w,
30        }
31    }
32    pub fn identity() -> Self {
33        SO3 {
34            qx: T::zero(),
35            qy: T::zero(),
36            qz: T::zero(),
37            qw: T::one(),
38        }
39    }
40
41    pub fn exp(xi: na::DVectorView<T>) -> Self {
42        let mut xyzw = na::Vector4::zeros();
43
44        let theta2 = xi.norm_squared();
45
46        if theta2 < T::from_f64(1e-6).unwrap() {
47            // cos(theta / 2) \approx 1 - theta^2 / 8
48            xyzw.w = T::one() - theta2 / T::from_f64(8.0).unwrap();
49            // Complete the square so that norm is one
50            let tmp = T::from_f64(0.5).unwrap();
51            xyzw.x = xi[0].clone() * tmp.clone();
52            xyzw.y = xi[1].clone() * tmp.clone();
53            xyzw.z = xi[2].clone() * tmp;
54        } else {
55            let theta = theta2.sqrt();
56            xyzw.w = (theta.clone() * T::from_f64(0.5).unwrap()).cos();
57
58            let omega = xi / theta;
59            let sin_theta_half = (T::one() - xyzw.w.clone() * xyzw.w.clone()).sqrt();
60            xyzw.x = omega[0].clone() * sin_theta_half.clone();
61            xyzw.y = omega[1].clone() * sin_theta_half.clone();
62            xyzw.z = omega[2].clone() * sin_theta_half;
63        }
64
65        SO3::from_vec(xyzw.as_view())
66    }
67
68    pub fn log(&self) -> na::DVector<T> {
69        const EPS: f64 = 1e-6;
70        let ivec = na::dvector![self.qx.clone(), self.qy.clone(), self.qz.clone()];
71
72        let squared_n = ivec.norm_squared();
73        let w = self.qw.clone();
74
75        let near_zero = squared_n.le(&T::from_f64(EPS * EPS).unwrap());
76
77        let w_sq = w.clone() * w.clone();
78        let t0 = T::from_f64(2.0).unwrap() / w.clone()
79            - T::from_f64(2.0 / 3.0).unwrap() * squared_n.clone() / (w_sq * w.clone());
80
81        let n = squared_n.sqrt();
82
83        let sign = T::from_f64(-1.0)
84            .unwrap()
85            .select(w.le(&T::zero()), T::one());
86        let atan_nbyw = sign.clone() * n.clone().atan2(sign * w);
87
88        let t = T::from_f64(2.0).unwrap() * atan_nbyw / n;
89
90        let two_atan_nbyd_by_n = t0.select(near_zero, t);
91
92        ivec * two_atan_nbyd_by_n
93    }
94
95    pub fn hat(xi: na::VectorView3<T>) -> na::Matrix3<T> {
96        let mut xi_hat = na::Matrix3::zeros();
97        xi_hat[(0, 1)] = -xi[2].clone();
98        xi_hat[(0, 2)] = xi[1].clone();
99        xi_hat[(1, 0)] = xi[2].clone();
100        xi_hat[(1, 2)] = -xi[0].clone();
101        xi_hat[(2, 0)] = -xi[1].clone();
102        xi_hat[(2, 1)] = xi[0].clone();
103
104        xi_hat
105    }
106
107    pub fn to_vec(&self) -> na::Vector4<T> {
108        na::Vector4::new(
109            self.qx.clone(),
110            self.qy.clone(),
111            self.qz.clone(),
112            self.qw.clone(),
113        )
114    }
115    pub fn to_dvec(&self) -> na::DVector<T> {
116        na::dvector![
117            self.qx.clone(),
118            self.qy.clone(),
119            self.qz.clone(),
120            self.qw.clone(),
121        ]
122    }
123    pub fn cast<U: na::RealField + simba::scalar::SupersetOf<T>>(&self) -> SO3<U> {
124        SO3::from_vec(self.to_vec().cast().as_view())
125    }
126    pub fn inverse(&self) -> Self {
127        SO3 {
128            qx: -self.qx.clone(),
129            qy: -self.qy.clone(),
130            qz: -self.qz.clone(),
131            qw: self.qw.clone(),
132        }
133    }
134    pub fn compose(&self, rhs: &Self) -> Self {
135        let x0 = self.qx.clone();
136        let y0 = self.qy.clone();
137        let z0 = self.qz.clone();
138        let w0 = self.qw.clone();
139
140        let x1 = rhs.qx.clone();
141        let y1 = rhs.qy.clone();
142        let z1 = rhs.qz.clone();
143        let w1 = rhs.qw.clone();
144
145        // Compute the product of the two quaternions, term by term
146        let qx = w0.clone() * x1.clone() + x0.clone() * w1.clone() + y0.clone() * z1.clone()
147            - z0.clone() * y1.clone();
148        let qy = w0.clone() * y1.clone() - x0.clone() * z1.clone()
149            + y0.clone() * w1.clone()
150            + z0.clone() * x1.clone();
151        let qz = w0.clone() * z1.clone() + x0.clone() * y1.clone() - y0.clone() * x1.clone()
152            + z0.clone() * w1.clone();
153        let qw = w0 * w1 - x0 * x1 - y0 * y1 - z0 * z1;
154        // println!("q {}", self.to_vec());
155        SO3 { qx, qy, qz, qw }
156    }
157}
158
159impl<T: na::RealField> Mul for SO3<T> {
160    type Output = Self;
161
162    fn mul(self, rhs: Self) -> Self::Output {
163        self.compose(&rhs)
164    }
165}
166impl<T: na::RealField> Mul for &SO3<T> {
167    type Output = SO3<T>;
168
169    fn mul(self, rhs: Self) -> Self::Output {
170        self.compose(rhs)
171    }
172}
173
174impl<T: na::RealField> Mul<na::VectorView3<'_, T>> for &SO3<T> {
175    type Output = na::Vector3<T>;
176
177    fn mul(self, rhs: na::VectorView3<'_, T>) -> Self::Output {
178        let qv = SO3::from_xyzw(rhs[0].clone(), rhs[1].clone(), rhs[2].clone(), T::zero());
179        let inv = self.inverse();
180        let v_rot: SO3<T> = (self * &qv) * inv;
181        na::Vector3::new(v_rot.qx, v_rot.qy, v_rot.qz)
182    }
183}
184
185#[derive(Debug, Clone)]
186pub struct QuaternionManifold;
187impl<T: na::RealField> AutoDiffManifold<T> for QuaternionManifold {
188    fn plus(
189        &self,
190        x: nalgebra::DVectorView<T>,
191        delta: nalgebra::DVectorView<T>,
192    ) -> nalgebra::DVector<T> {
193        let d: SO3<T> = SO3::exp(delta);
194        let x_s03: SO3<T> = SO3::from_xyzw(x[0].clone(), x[1].clone(), x[2].clone(), x[3].clone());
195        let x_plus = x_s03 * d;
196        na::dvector![
197            x_plus.qx.clone(),
198            x_plus.qy.clone(),
199            x_plus.qz.clone(),
200            x_plus.qw.clone(),
201        ]
202    }
203
204    fn minus(
205        &self,
206        y: nalgebra::DVectorView<T>,
207        x: nalgebra::DVectorView<T>,
208    ) -> nalgebra::DVector<T> {
209        let y_so3 = SO3::from_vec(y);
210        let x_so3_inv = SO3::from_vec(x).inverse();
211        let x_inv_y_log = (x_so3_inv * y_so3).log();
212        na::dvector![
213            x_inv_y_log[0].clone(),
214            x_inv_y_log[1].clone(),
215            x_inv_y_log[2].clone()
216        ]
217    }
218}
219impl Manifold for QuaternionManifold {
220    fn tangent_size(&self) -> NonZero<usize> {
221        NonZero::new(3).unwrap()
222    }
223}