viam_rust_utils/spatialmath/
utils.rs

1use float_cmp::{ApproxEq, F64Margin};
2use nalgebra::{Quaternion, UnitQuaternion, UnitVector3, Vector3};
3
4const ANGLE_ACCEPTANCE: f64 = 0.0001;
5
6#[repr(C)]
7#[derive(Clone, Copy, Debug)]
8pub struct EulerAngles {
9    pub roll: f64,
10    pub pitch: f64,
11    pub yaw: f64,
12}
13
14impl EulerAngles {
15    pub fn new(roll: f64, pitch: f64, yaw: f64) -> Self {
16        EulerAngles { roll, pitch, yaw }
17    }
18
19    /// Converts a quaternion into euler angles (in radians). The euler angles are
20    /// represented according to the Tait-Bryan formalism and applied
21    /// in the Z-Y'-X" order (where Z -> yaw, Y -> pitch, X -> roll).
22    pub fn from_quaternion(quat: &Quaternion<f64>) -> Self {
23        // get a normalized version of the quaternion
24        let norm_quat = quat.normalize();
25
26        // calculate yaw
27        let yaw_sin_pitch_cos: f64 =
28            2.0 * ((norm_quat.w * norm_quat.k) + (norm_quat.i * norm_quat.j));
29        let yaw_cos_pitch_cos: f64 =
30            1.0 - 2.0 * ((norm_quat.j * norm_quat.j) + (norm_quat.k * norm_quat.k));
31        let yaw = yaw_sin_pitch_cos.atan2(yaw_cos_pitch_cos);
32
33        // calculate pitch and roll
34        let pitch_sin: f64 = 2.0 * ((norm_quat.w * norm_quat.j) - (norm_quat.k * norm_quat.i));
35        let pitch: f64;
36        let roll: f64;
37        // for a pitch that is π / 2, we experience gimbal lock
38        // and must calculate roll based on the real rotation and yaw
39        if pitch_sin.abs() >= 1.0 {
40            pitch = (std::f64::consts::PI / 2.0).copysign(pitch_sin);
41            roll = (2.0 * norm_quat.i.atan2(norm_quat.w)) + yaw.copysign(pitch_sin);
42        } else {
43            pitch = pitch_sin.asin();
44            let roll_sin_pitch_cos =
45                2.0 * ((norm_quat.w * norm_quat.i) + (norm_quat.j * norm_quat.k));
46            let roll_cos_pitch_cos =
47                1.0 - 2.0 * ((norm_quat.i * norm_quat.i) + (norm_quat.j * norm_quat.j));
48            roll = roll_sin_pitch_cos.atan2(roll_cos_pitch_cos);
49        }
50
51        EulerAngles { roll, pitch, yaw }
52    }
53}
54
55impl From<Quaternion<f64>> for EulerAngles {
56    fn from(quat: Quaternion<f64>) -> Self {
57        // get a normalized version of the quaternion
58        let norm_quat = quat.normalize();
59
60        // calculate yaw
61        let yaw_sin_pitch_cos: f64 =
62            2.0 * ((norm_quat.w * norm_quat.k) + (norm_quat.i * norm_quat.j));
63        let yaw_cos_pitch_cos: f64 =
64            1.0 - 2.0 * ((norm_quat.j * norm_quat.j) + (norm_quat.k * norm_quat.k));
65        let yaw = yaw_sin_pitch_cos.atan2(yaw_cos_pitch_cos);
66
67        // calculate pitch and roll
68        let pitch_sin: f64 = 2.0 * ((norm_quat.w * norm_quat.j) - (norm_quat.k * norm_quat.i));
69        let pitch: f64;
70        let roll: f64;
71        // for a pitch that is π / 2, we experience gimbal lock
72        // and must calculate roll based on the real rotation and yaw
73        if pitch_sin.abs() >= 1.0 {
74            pitch = (std::f64::consts::PI / 2.0).copysign(pitch_sin);
75            roll = (2.0 * norm_quat.i.atan2(norm_quat.w)) + yaw.copysign(pitch_sin);
76        } else {
77            pitch = pitch_sin.asin();
78            let roll_sin_pitch_cos =
79                2.0 * ((norm_quat.w * norm_quat.i) + (norm_quat.j * norm_quat.k));
80            let roll_cos_pitch_cos =
81                1.0 - 2.0 * ((norm_quat.i * norm_quat.i) + (norm_quat.j * norm_quat.j));
82            roll = roll_sin_pitch_cos.atan2(roll_cos_pitch_cos);
83        }
84
85        Self { roll, pitch, yaw }
86    }
87}
88
89#[repr(C)]
90#[derive(Clone, Copy, Debug)]
91pub struct AxisAngle {
92    pub axis: Vector3<f64>,
93    pub theta: f64,
94}
95
96impl AxisAngle {
97    pub fn new(x: f64, y: f64, z: f64, theta: f64) -> Self {
98        AxisAngle {
99            axis: Vector3::new(x, y, z),
100            theta,
101        }
102    }
103}
104
105impl TryFrom<Quaternion<f64>> for AxisAngle {
106    type Error = ();
107
108    fn try_from(quat: Quaternion<f64>) -> Result<Self, Self::Error> {
109        let unit_quat = UnitQuaternion::from_quaternion(quat);
110        let axis_opt = unit_quat.axis();
111        let angle = unit_quat.angle();
112        match axis_opt {
113            Some(value) => Ok(Self::new(value[0], value[1], value[2], angle)),
114            None => Err(()),
115        }
116    }
117}
118
119#[repr(C)]
120#[derive(Clone, Copy, Debug)]
121pub struct OrientationVector {
122    pub o_vector: UnitVector3<f64>,
123    pub theta: f64,
124}
125
126impl OrientationVector {
127    pub fn new(o_x: f64, o_y: f64, o_z: f64, theta: f64) -> Self {
128        let o_vector = UnitVector3::new_normalize(Vector3::new(o_x, o_y, o_z));
129        OrientationVector { o_vector, theta }
130    }
131
132    pub fn to_quaternion(&self) -> Quaternion<f64> {
133        let lat = self.o_vector.z.acos();
134        let lon = match self.o_vector.z {
135            val if 1.0 - val > ANGLE_ACCEPTANCE => self.o_vector.y.atan2(self.o_vector.x),
136            _ => 0.0,
137        };
138
139        // convert angles as euler angles (lon, lat, theta) to quaternion
140        // using the zyz rotational order
141        let s: [f64; 3] = [
142            (lon / 2.0).sin(),
143            (lat / 2.0).sin(),
144            (self.theta / 2.0).sin(),
145        ];
146
147        let c: [f64; 3] = [
148            (lon / 2.0).cos(),
149            (lat / 2.0).cos(),
150            (self.theta / 2.0).cos(),
151        ];
152
153        let real = c[0] * c[1] * c[2] - s[0] * c[1] * s[2];
154        let i = c[0] * s[1] * s[2] - s[0] * s[1] * c[2];
155        let j = c[0] * s[1] * c[2] + s[0] * s[1] * s[2];
156        let k = s[0] * c[1] * c[2] + c[0] * c[1] * s[2];
157
158        Quaternion::new(real, i, j, k)
159    }
160}
161
162impl ApproxEq for OrientationVector {
163    type Margin = F64Margin;
164
165    fn approx_eq<M: Into<Self::Margin>>(self, other: Self, margin: M) -> bool {
166        let margin = margin.into();
167        let vec_diff = self.o_vector.into_inner() - other.o_vector.into_inner();
168        vec_diff.norm_squared().approx_eq(0.0, margin) && self.theta.approx_eq(other.theta, margin)
169    }
170}
171
172impl From<Quaternion<f64>> for OrientationVector {
173    fn from(quat: Quaternion<f64>) -> Self {
174        let x_quat = Quaternion::new(0.0, -1.0, 0.0, 0.0);
175        let z_quat = Quaternion::new(0.0, 0.0, 0.0, 1.0);
176
177        let conj = quat.conjugate();
178        let new_x = (quat * x_quat) * conj;
179        let new_z = (quat * z_quat) * conj;
180
181        let o_vector = UnitVector3::new_normalize(new_z.imag());
182        // let theta = orientation_vector_theta_from_rotated_axes(new_x, new_z);
183        let theta = match 1.0 - new_z.k.abs() > ANGLE_ACCEPTANCE {
184            true => {
185                let new_z_imag = new_z.imag();
186                let new_x_imag = new_x.imag();
187                let z_imag_axis = Vector3::z_axis();
188
189                let normal_1 = new_z_imag.cross(&new_x_imag);
190                let normal_2 = new_z_imag.cross(&z_imag_axis);
191                let cos_theta_cand = normal_1.dot(&normal_2) / (normal_1.norm() * normal_2.norm());
192                let cos_theta = match cos_theta_cand {
193                    val if val < -1.0 => -1.0,
194                    val if val > 1.0 => 1.0,
195                    _ => cos_theta_cand,
196                };
197
198                match cos_theta.acos() {
199                    val if val > ANGLE_ACCEPTANCE => {
200                        let new_z_imag_unit = UnitVector3::new_normalize(new_z_imag);
201                        let rot_quat_unit =
202                            UnitQuaternion::from_axis_angle(&new_z_imag_unit, -1.0 * val);
203                        let rot_quat = rot_quat_unit.quaternion();
204                        let z_axis_quat = Quaternion::new(0.0, 0.0, 0.0, 1.0);
205                        let test_z = (rot_quat * z_axis_quat) * rot_quat.conjugate();
206                        let test_z_imag = test_z.imag();
207
208                        let normal_3 = new_z_imag.cross(&test_z_imag);
209                        let cos_test =
210                            normal_1.dot(&normal_3) / (normal_3.norm() * normal_1.norm());
211                        match cos_test {
212                            val2 if (1.0 - val2) < (ANGLE_ACCEPTANCE * ANGLE_ACCEPTANCE) => {
213                                -1.0 * val
214                            }
215                            _ => val,
216                        }
217                    }
218                    _ => 0.0,
219                }
220            }
221            _ => match new_z.k {
222                val if val < 0.0 => -1.0 * new_x.j.atan2(new_x.i),
223                _ => -1.0 * new_x.j.atan2(new_x.i * -1.0),
224            },
225        };
226        Self { o_vector, theta }
227    }
228}
229
230pub fn rotate_vector_by_quaternion(quat: &Quaternion<f64>, vector: &Vector3<f64>) -> Vector3<f64> {
231    let quat_vec = Vector3::new(quat.i, quat.j, quat.k);
232    let quat_real = quat.w;
233    (2.0 * quat_vec.dot(vector) * quat_vec)
234        + ((quat_real * quat_real - quat_vec.norm_squared()) * vector)
235        + (2.0 * quat_real) * quat_vec.cross(vector)
236}
237
238#[cfg(test)]
239mod tests {
240    use float_cmp::assert_approx_eq;
241    use nalgebra::{Quaternion, Vector3};
242
243    use super::{rotate_vector_by_quaternion, EulerAngles, OrientationVector};
244
245    fn get_quaternion_diff_norm(quat1: &Quaternion<f64>, quat2: &Quaternion<f64>) -> f64 {
246        let quat_diff = quat1.coords - quat2.coords;
247        quat_diff.norm_squared()
248    }
249
250    fn get_vector_diff_norm(vec1: &Vector3<f64>, vec2: &Vector3<f64>) -> f64 {
251        let vec_diff = vec2 - vec1;
252        vec_diff.norm_squared()
253    }
254
255    #[test]
256    fn quaternion_to_orientation_vector_works() {
257        let quat = Quaternion::new(0.7071067811865476, 0.7071067811865476, 0.0, 0.0);
258        let expected_ov = OrientationVector::new(0.0, -1.0, 0.0, 1.5707963267948966);
259        let calc_ov: OrientationVector = quat.into();
260        assert_approx_eq!(OrientationVector, calc_ov, expected_ov);
261
262        let quat2 = Quaternion::new(0.7071067811865476, -0.7071067811865476, 0.0, 0.0);
263        let expected_ov2 = OrientationVector::new(0.0, 1.0, 0.0, -1.5707963267948966);
264        let calc_ov2: OrientationVector = quat2.into();
265        assert_approx_eq!(OrientationVector, calc_ov2, expected_ov2);
266
267        let quat3 = Quaternion::new(0.96, 0.0, -0.28, 0.0);
268        let expected_ov3 =
269            OrientationVector::new(-0.5376, 0.0, 0.8432, -1.0 * std::f64::consts::PI);
270        let calc_ov3: OrientationVector = quat3.into();
271        assert_approx_eq!(OrientationVector, calc_ov3, expected_ov3);
272
273        let quat4 = Quaternion::new(0.96, 0.0, 0.0, -0.28);
274        let expected_ov4 = OrientationVector::new(0.0, 0.0, 1.0, -0.5675882184166557);
275        let calc_ov4: OrientationVector = quat4.into();
276        assert_approx_eq!(OrientationVector, calc_ov4, expected_ov4);
277
278        let quat5 = Quaternion::new(0.96, -0.28, 0.0, 0.0);
279        let expected_ov5 = OrientationVector::new(0.0, 0.5376, 0.8432, -1.5707963267948966);
280        let calc_ov5: OrientationVector = quat5.into();
281        assert_approx_eq!(OrientationVector, calc_ov5, expected_ov5);
282
283        let quat6 = Quaternion::new(0.96, 0.28, 0.0, 0.0);
284        let expected_ov6 = OrientationVector::new(0.0, -0.5376, 0.8432, 1.5707963267948966);
285        let calc_ov6: OrientationVector = quat6.into();
286        assert_approx_eq!(OrientationVector, calc_ov6, expected_ov6);
287
288        let quat7 = Quaternion::new(0.5, -0.5, -0.5, -0.5);
289        let expected_ov7 = OrientationVector::new(0.0, 1.0, 0.0, -1.0 * std::f64::consts::PI);
290        let calc_ov7: OrientationVector = quat7.into();
291        assert_approx_eq!(OrientationVector, calc_ov7, expected_ov7);
292
293        let quat8 = Quaternion::new(
294            0.816632212270443,
295            -0.17555966025413142,
296            0.39198397193979817,
297            0.3855375485164001,
298        );
299        let expected_ov8 = OrientationVector::new(
300            0.5048437942940054,
301            0.5889844266763397,
302            0.631054742867507,
303            0.02,
304        );
305        let calc_ov8: OrientationVector = quat8.into();
306        assert_approx_eq!(OrientationVector, calc_ov8, expected_ov8, epsilon = 0.0001);
307    }
308
309    #[test]
310    fn orientation_vector_to_quaternion_works() {
311        let ov = OrientationVector::new(0.0, -1.0, 0.0, 1.5707963267948966);
312        let expected_quat = Quaternion::new(0.7071067811865476, 0.7071067811865476, 0.0, 0.0);
313        let calc_quat = ov.to_quaternion();
314        let mut diff = get_quaternion_diff_norm(&expected_quat, &calc_quat);
315        assert_approx_eq!(f64, diff, 0.0);
316
317        let ov2 = OrientationVector::new(0.0, 1.0, 0.0, -1.5707963267948966);
318        let expected_quat2 = Quaternion::new(0.7071067811865476, -0.7071067811865476, 0.0, 0.0);
319        let calc_quat2 = ov2.to_quaternion();
320        diff = get_quaternion_diff_norm(&expected_quat2, &calc_quat2);
321        assert_approx_eq!(f64, diff, 0.0);
322
323        let ov3 = OrientationVector::new(-0.5376, 0.0, 0.8432, -1.0 * std::f64::consts::PI);
324        let expected_quat3 = Quaternion::new(0.96, 0.0, -0.28, 0.0);
325        let calc_quat3 = ov3.to_quaternion();
326        diff = get_quaternion_diff_norm(&expected_quat3, &calc_quat3);
327        assert_approx_eq!(f64, diff, 0.0);
328
329        let ov4 = OrientationVector::new(0.0, 0.0, 1.0, -0.5675882184166557);
330        let expected_quat4 = Quaternion::new(0.96, 0.0, 0.0, -0.28);
331        let calc_quat4 = ov4.to_quaternion();
332        diff = get_quaternion_diff_norm(&expected_quat4, &calc_quat4);
333        assert_approx_eq!(f64, diff, 0.0);
334
335        let ov5 = OrientationVector::new(0.0, 0.5376, 0.8432, -1.5707963267948966);
336        let expected_quat5 = Quaternion::new(0.96, -0.28, 0.0, 0.0);
337        let calc_quat5 = ov5.to_quaternion();
338        diff = get_quaternion_diff_norm(&expected_quat5, &calc_quat5);
339        assert_approx_eq!(f64, diff, 0.0);
340
341        let ov6 = OrientationVector::new(0.0, -0.5376, 0.8432, 1.5707963267948966);
342        let expected_quat6 = Quaternion::new(0.96, 0.28, 0.0, 0.0);
343        let calc_quat6 = ov6.to_quaternion();
344        diff = get_quaternion_diff_norm(&expected_quat6, &calc_quat6);
345        assert_approx_eq!(f64, diff, 0.0);
346
347        let ov7 = OrientationVector::new(0.0, 1.0, 0.0, -1.0 * std::f64::consts::PI);
348        let expected_quat7 = Quaternion::new(0.5, -0.5, -0.5, -0.5);
349        let calc_quat7 = ov7.to_quaternion();
350        diff = get_quaternion_diff_norm(&expected_quat7, &calc_quat7);
351        assert_approx_eq!(f64, diff, 0.0);
352
353        let ov8 = OrientationVector::new(
354            0.5048437942940054,
355            0.5889844266763397,
356            0.631054742867507,
357            0.02,
358        );
359        let expected_quat8 = Quaternion::new(
360            0.816632212270443,
361            -0.17555966025413142,
362            0.39198397193979817,
363            0.3855375485164001,
364        );
365        let calc_quat8 = ov8.to_quaternion();
366        diff = get_quaternion_diff_norm(&expected_quat8, &calc_quat8);
367        assert_approx_eq!(f64, diff, 0.0);
368    }
369
370    #[test]
371    fn euler_angles_from_quaternion_works() {
372        let quat = Quaternion::new(
373            0.2705980500730985,
374            -0.6532814824381882,
375            0.27059805007309856,
376            0.6532814824381883,
377        );
378        let euler_angles: EulerAngles = quat.into();
379        assert_approx_eq!(f64, euler_angles.pitch, std::f64::consts::PI / 2.0);
380        assert_approx_eq!(f64, euler_angles.yaw, std::f64::consts::PI);
381        assert_approx_eq!(f64, euler_angles.roll, std::f64::consts::PI / 4.0);
382
383        let quat2 = Quaternion::new(
384            0.4619397662556435,
385            -0.19134171618254486,
386            0.4619397662556434,
387            0.7325378163287418,
388        );
389        let euler_angles2: EulerAngles = quat2.into();
390        assert_approx_eq!(f64, euler_angles2.pitch, std::f64::consts::PI / 4.0);
391        assert_approx_eq!(f64, euler_angles2.yaw, 3.0 * std::f64::consts::PI / 4.0);
392        assert_approx_eq!(f64, euler_angles2.roll, std::f64::consts::PI / 4.0);
393    }
394
395    #[test]
396    fn rotation_by_quaternion_works() {
397        // rotation of (0,0,1) by 90 degrees about (0,1,0)
398        let quat = Quaternion::new(0.7071068, 0.0, 0.7071068, 0.0);
399        let vector = Vector3::new(0.0, 0.0, 1.0);
400        let expected_vector = Vector3::new(1.0, 0.0, 0.0);
401        let rotated_vector = rotate_vector_by_quaternion(&quat, &vector);
402        let diff = get_vector_diff_norm(&expected_vector, &rotated_vector);
403        assert_approx_eq!(f64, diff, 0.0, epsilon = 0.0001);
404
405        // rotation of (4.5, 1.3, 2.0) by 175 degrees about (2,3,4)
406        let quat2 = Quaternion::new(0.0436194, 0.3710372, 0.5565558, 0.7420744);
407        let vector2 = Vector3::new(4.5, 1.3, 2.0);
408        let expected_vector2 = Vector3::new(-1.593, 3.247, 3.586);
409        let rotated_vector2 = rotate_vector_by_quaternion(&quat2, &vector2);
410        let diff = get_vector_diff_norm(&expected_vector2, &rotated_vector2);
411        assert_approx_eq!(f64, diff, 0.0, epsilon = 0.0001);
412    }
413}