1extern crate alloc;
2
3use alloc::vec::Vec;
4use core::marker::PhantomData;
5
6use sophus_autodiff::{
7 linalg::cross,
8 manifold::IsManifold,
9 params::{
10 HasParams,
11 IsParamsImpl,
12 },
13 points::example_points,
14};
15
16use crate::prelude::*;
17
18#[derive(Clone, Debug, Copy)]
20pub struct Quaternion<
21 S: IsScalar<BATCH, DM, DN>,
22 const BATCH: usize,
23 const DM: usize,
24 const DN: usize,
25> {
26 params: S::Vector<4>,
27}
28
29impl<S: IsScalar<BATCH, DM, DN>, const BATCH: usize, const DM: usize, const DN: usize>
30 Quaternion<S, BATCH, DM, DN>
31{
32 #[inline]
34 #[must_use]
35 pub fn from_real_imag(real: S, imag: S::Vector<3>) -> Self {
36 Self::from_params(S::Vector::<4>::from_array([
37 real,
38 imag.elem(0),
39 imag.elem(1),
40 imag.elem(2),
41 ]))
42 }
43
44 #[inline]
46 #[must_use]
47 pub fn from_params(params: S::Vector<4>) -> Self {
48 Self { params }
49 }
50
51 pub fn zero() -> Self {
53 Self::from_params(S::Vector::<4>::zeros())
54 }
55
56 pub fn one() -> Self {
58 Self::from_params(S::Vector::<4>::from_f64_array([1.0, 0.0, 0.0, 0.0]))
59 }
60
61 pub fn params(&self) -> &S::Vector<4> {
63 &self.params
64 }
65
66 pub fn params_mut(&mut self) -> &mut S::Vector<4> {
68 &mut self.params
69 }
70
71 pub fn real(&self) -> S {
73 self.params.elem(0)
74 }
75
76 pub fn imag(&self) -> S::Vector<3> {
78 self.params.get_fixed_subvec::<3>(1)
79 }
80
81 pub fn mult(&self, rhs: Self) -> Self {
83 Self::from_params(QuaternionImpl::<S, BATCH, DM, DN>::mult(
84 &self.params,
85 rhs.params,
86 ))
87 }
88
89 pub fn add(&self, rhs: Self) -> Self {
91 Self::from_params(self.params + rhs.params)
92 }
93
94 pub fn conjugate(&self) -> Self {
96 Self::from_params(QuaternionImpl::<S, BATCH, DM, DN>::conjugate(&self.params))
97 }
98
99 pub fn inverse(&self) -> Self {
101 Self::from_params(QuaternionImpl::<S, BATCH, DM, DN>::inverse(&self.params))
102 }
103
104 pub fn norm(&self) -> S {
106 QuaternionImpl::<S, BATCH, DM, DN>::norm(&self.params)
107 }
108
109 pub fn squared_norm(&self) -> S {
111 QuaternionImpl::<S, BATCH, DM, DN>::squared_norm(&self.params)
112 }
113
114 pub fn scale(&self, s: S) -> Self {
116 Self::from_params(self.params.scaled(s))
117 }
118}
119
120impl<S: IsScalar<BATCH, DM, DN>, const BATCH: usize, const DM: usize, const DN: usize>
121 core::ops::Add for Quaternion<S, BATCH, DM, DN>
122{
123 type Output = Self;
124
125 fn add(self, rhs: Self) -> Self::Output {
126 Quaternion::add(&self, rhs)
127 }
128}
129
130impl<S: IsScalar<BATCH, DM, DN>, const BATCH: usize, const DM: usize, const DN: usize>
131 core::ops::Mul for Quaternion<S, BATCH, DM, DN>
132{
133 type Output = Self;
134
135 fn mul(self, rhs: Self) -> Self::Output {
136 self.mult(rhs)
137 }
138}
139
140impl<S: IsScalar<BATCH, DM, DN>, const BATCH: usize, const DM: usize, const DN: usize>
141 IsParamsImpl<S, 4, BATCH, DM, DN> for Quaternion<S, BATCH, DM, DN>
142{
143 fn are_params_valid(_params: S::Vector<4>) -> S::Mask {
144 S::Mask::all_true()
145 }
146
147 fn params_examples() -> Vec<S::Vector<4>> {
148 example_points::<S, 4, BATCH, DM, DN>()
149 }
150
151 fn invalid_params_examples() -> Vec<S::Vector<4>> {
152 Vec::new()
153 }
154}
155
156impl<S: IsScalar<BATCH, DM, DN>, const BATCH: usize, const DM: usize, const DN: usize>
157 HasParams<S, 4, BATCH, DM, DN> for Quaternion<S, BATCH, DM, DN>
158{
159 fn from_params(params: S::Vector<4>) -> Self {
160 Self::from_params(params)
161 }
162
163 fn set_params(&mut self, params: S::Vector<4>) {
164 self.params = params;
165 }
166
167 fn params(&self) -> &S::Vector<4> {
168 &self.params
169 }
170}
171
172impl<S: IsScalar<BATCH, DM, DN>, const BATCH: usize, const DM: usize, const DN: usize>
173 IsManifold<S, 4, 4, BATCH, DM, DN> for Quaternion<S, BATCH, DM, DN>
174{
175 fn oplus(&self, tangent: &S::Vector<4>) -> Self {
176 Self::from_params(*self.params() + *tangent)
177 }
178
179 fn ominus(&self, rhs: &Self) -> S::Vector<4> {
180 *self.params() - *rhs.params()
181 }
182}
183
184pub type QuaternionF64 = Quaternion<f64, 1, 0, 0>;
186
187#[derive(Clone, Copy, Debug)]
189pub struct QuaternionImpl<
190 S: IsScalar<BATCH, DM, DN>,
191 const BATCH: usize,
192 const DM: usize,
193 const DN: usize,
194> {
195 phantom: PhantomData<S>,
196}
197
198impl<S: IsScalar<BATCH, DM, DN>, const BATCH: usize, const DM: usize, const DN: usize>
199 QuaternionImpl<S, BATCH, DM, DN>
200{
201 pub fn zero() -> S::Vector<4> {
203 S::Vector::<4>::zeros()
204 }
205
206 pub fn one() -> S::Vector<4> {
208 S::Vector::<4>::from_f64_array([1.0, 0.0, 0.0, 0.0])
209 }
210
211 pub fn mult(lhs: &S::Vector<4>, rhs: S::Vector<4>) -> S::Vector<4> {
213 let lhs_re = lhs.elem(0);
214 let rhs_re = rhs.elem(0);
215
216 let lhs_ivec = lhs.get_fixed_subvec::<3>(1);
217 let rhs_ivec = rhs.get_fixed_subvec::<3>(1);
218
219 let re = lhs_re * rhs_re - lhs_ivec.dot(rhs_ivec);
220 let ivec = rhs_ivec.scaled(lhs_re)
221 + lhs_ivec.scaled(rhs_re)
222 + cross::<S, BATCH, DM, DN>(lhs_ivec, rhs_ivec);
223
224 S::Vector::block_vec2(re.to_vec(), ivec)
225 }
226
227 pub fn add(a: &S::Vector<4>, b: S::Vector<4>) -> S::Vector<4> {
229 *a + b
230 }
231
232 pub fn conjugate(a: &S::Vector<4>) -> S::Vector<4> {
234 S::Vector::from_array([a.elem(0), -a.elem(1), -a.elem(2), -a.elem(3)])
235 }
236
237 pub fn inverse(q: &S::Vector<4>) -> S::Vector<4> {
239 Self::conjugate(q).scaled(S::from_f64(1.0) / q.squared_norm())
240 }
241
242 pub fn norm(q: &S::Vector<4>) -> S {
244 q.norm()
245 }
246
247 pub fn squared_norm(q: &S::Vector<4>) -> S {
249 q.squared_norm()
250 }
251}