sophus_lie/
quaternion.rs

1extern crate alloc;
2
3use alloc::vec::Vec;
4use core::marker::PhantomData;
5
6use sophus_autodiff::{
7    linalg::cross,
8    manifold::IsManifold,
9    params::{
10        HasParams,
11        IsParamsImpl,
12    },
13    points::example_points,
14};
15
16use crate::prelude::*;
17
18/// Quaternion represented as `(r, x, y, z)`.
19#[derive(Clone, Debug, Copy)]
20pub struct Quaternion<
21    S: IsScalar<BATCH, DM, DN>,
22    const BATCH: usize,
23    const DM: usize,
24    const DN: usize,
25> {
26    params: S::Vector<4>,
27}
28
29impl<S: IsScalar<BATCH, DM, DN>, const BATCH: usize, const DM: usize, const DN: usize>
30    Quaternion<S, BATCH, DM, DN>
31{
32    /// Creates a complex number from real scalar and imaginary 3-vector.
33    #[inline]
34    #[must_use]
35    pub fn from_real_imag(real: S, imag: S::Vector<3>) -> Self {
36        Self::from_params(S::Vector::<4>::from_array([
37            real,
38            imag.elem(0),
39            imag.elem(1),
40            imag.elem(2),
41        ]))
42    }
43
44    /// Creates a quaternion from its parameter vector `(r, x, y, z)`.
45    #[inline]
46    #[must_use]
47    pub fn from_params(params: S::Vector<4>) -> Self {
48        Self { params }
49    }
50
51    /// Returns the zero quaternion `(0,0,0,0)`.
52    pub fn zero() -> Self {
53        Self::from_params(S::Vector::<4>::zeros())
54    }
55
56    /// Returns the identity quaternion `(1,0,0,0)`.
57    pub fn one() -> Self {
58        Self::from_params(S::Vector::<4>::from_f64_array([1.0, 0.0, 0.0, 0.0]))
59    }
60
61    /// Access the underlying parameter vector.
62    pub fn params(&self) -> &S::Vector<4> {
63        &self.params
64    }
65
66    /// Mutable access to the parameter vector.
67    pub fn params_mut(&mut self) -> &mut S::Vector<4> {
68        &mut self.params
69    }
70
71    /// Returns the real component `r`.
72    pub fn real(&self) -> S {
73        self.params.elem(0)
74    }
75
76    /// Returns the imaginary component `(x,y,z)`.
77    pub fn imag(&self) -> S::Vector<3> {
78        self.params.get_fixed_subvec::<3>(1)
79    }
80
81    /// Quaternion multiplication.
82    pub fn mult(&self, rhs: Self) -> Self {
83        Self::from_params(QuaternionImpl::<S, BATCH, DM, DN>::mult(
84            &self.params,
85            rhs.params,
86        ))
87    }
88
89    /// Quaternion addition.
90    pub fn add(&self, rhs: Self) -> Self {
91        Self::from_params(self.params + rhs.params)
92    }
93
94    /// Conjugated quaternion.
95    pub fn conjugate(&self) -> Self {
96        Self::from_params(QuaternionImpl::<S, BATCH, DM, DN>::conjugate(&self.params))
97    }
98
99    /// Inverse quaternion.
100    pub fn inverse(&self) -> Self {
101        Self::from_params(QuaternionImpl::<S, BATCH, DM, DN>::inverse(&self.params))
102    }
103
104    /// Quaternion norm.
105    pub fn norm(&self) -> S {
106        QuaternionImpl::<S, BATCH, DM, DN>::norm(&self.params)
107    }
108
109    /// Quaternion squared norm.
110    pub fn squared_norm(&self) -> S {
111        QuaternionImpl::<S, BATCH, DM, DN>::squared_norm(&self.params)
112    }
113
114    /// Scale quaternion by scalar.
115    pub fn scale(&self, s: S) -> Self {
116        Self::from_params(self.params.scaled(s))
117    }
118}
119
120impl<S: IsScalar<BATCH, DM, DN>, const BATCH: usize, const DM: usize, const DN: usize>
121    core::ops::Add for Quaternion<S, BATCH, DM, DN>
122{
123    type Output = Self;
124
125    fn add(self, rhs: Self) -> Self::Output {
126        Quaternion::add(&self, rhs)
127    }
128}
129
130impl<S: IsScalar<BATCH, DM, DN>, const BATCH: usize, const DM: usize, const DN: usize>
131    core::ops::Mul for Quaternion<S, BATCH, DM, DN>
132{
133    type Output = Self;
134
135    fn mul(self, rhs: Self) -> Self::Output {
136        self.mult(rhs)
137    }
138}
139
140impl<S: IsScalar<BATCH, DM, DN>, const BATCH: usize, const DM: usize, const DN: usize>
141    IsParamsImpl<S, 4, BATCH, DM, DN> for Quaternion<S, BATCH, DM, DN>
142{
143    fn are_params_valid(_params: S::Vector<4>) -> S::Mask {
144        S::Mask::all_true()
145    }
146
147    fn params_examples() -> Vec<S::Vector<4>> {
148        example_points::<S, 4, BATCH, DM, DN>()
149    }
150
151    fn invalid_params_examples() -> Vec<S::Vector<4>> {
152        Vec::new()
153    }
154}
155
156impl<S: IsScalar<BATCH, DM, DN>, const BATCH: usize, const DM: usize, const DN: usize>
157    HasParams<S, 4, BATCH, DM, DN> for Quaternion<S, BATCH, DM, DN>
158{
159    fn from_params(params: S::Vector<4>) -> Self {
160        Self::from_params(params)
161    }
162
163    fn set_params(&mut self, params: S::Vector<4>) {
164        self.params = params;
165    }
166
167    fn params(&self) -> &S::Vector<4> {
168        &self.params
169    }
170}
171
172impl<S: IsScalar<BATCH, DM, DN>, const BATCH: usize, const DM: usize, const DN: usize>
173    IsManifold<S, 4, 4, BATCH, DM, DN> for Quaternion<S, BATCH, DM, DN>
174{
175    fn oplus(&self, tangent: &S::Vector<4>) -> Self {
176        Self::from_params(*self.params() + *tangent)
177    }
178
179    fn ominus(&self, rhs: &Self) -> S::Vector<4> {
180        *self.params() - *rhs.params()
181    }
182}
183
184/// Quaternion with `f64` scalar type.
185pub type QuaternionF64 = Quaternion<f64, 1, 0, 0>;
186
187/// Implementation utilities for [`Quaternion`].
188#[derive(Clone, Copy, Debug)]
189pub struct QuaternionImpl<
190    S: IsScalar<BATCH, DM, DN>,
191    const BATCH: usize,
192    const DM: usize,
193    const DN: usize,
194> {
195    phantom: PhantomData<S>,
196}
197
198impl<S: IsScalar<BATCH, DM, DN>, const BATCH: usize, const DM: usize, const DN: usize>
199    QuaternionImpl<S, BATCH, DM, DN>
200{
201    /// Returns the zero quaternion.
202    pub fn zero() -> S::Vector<4> {
203        S::Vector::<4>::zeros()
204    }
205
206    /// Returns the identity quaternion.
207    pub fn one() -> S::Vector<4> {
208        S::Vector::<4>::from_f64_array([1.0, 0.0, 0.0, 0.0])
209    }
210
211    /// Multiplies two quaternions.
212    pub fn mult(lhs: &S::Vector<4>, rhs: S::Vector<4>) -> S::Vector<4> {
213        let lhs_re = lhs.elem(0);
214        let rhs_re = rhs.elem(0);
215
216        let lhs_ivec = lhs.get_fixed_subvec::<3>(1);
217        let rhs_ivec = rhs.get_fixed_subvec::<3>(1);
218
219        let re = lhs_re * rhs_re - lhs_ivec.dot(rhs_ivec);
220        let ivec = rhs_ivec.scaled(lhs_re)
221            + lhs_ivec.scaled(rhs_re)
222            + cross::<S, BATCH, DM, DN>(lhs_ivec, rhs_ivec);
223
224        S::Vector::block_vec2(re.to_vec(), ivec)
225    }
226
227    /// Adds two quaternions component-wise.
228    pub fn add(a: &S::Vector<4>, b: S::Vector<4>) -> S::Vector<4> {
229        *a + b
230    }
231
232    /// Conjugates a quaternion.
233    pub fn conjugate(a: &S::Vector<4>) -> S::Vector<4> {
234        S::Vector::from_array([a.elem(0), -a.elem(1), -a.elem(2), -a.elem(3)])
235    }
236
237    /// Computes the inverse quaternion.
238    pub fn inverse(q: &S::Vector<4>) -> S::Vector<4> {
239        Self::conjugate(q).scaled(S::from_f64(1.0) / q.squared_norm())
240    }
241
242    /// Returns the quaternion norm.
243    pub fn norm(q: &S::Vector<4>) -> S {
244        q.norm()
245    }
246
247    /// Returns the squared quaternion norm.
248    pub fn squared_norm(q: &S::Vector<4>) -> S {
249        q.squared_norm()
250    }
251}