1extern crate alloc;
2
3use alloc::vec::Vec;
4use core::marker::PhantomData;
5
6use sophus_autodiff::{
7 manifold::IsManifold,
8 params::{
9 HasParams,
10 IsParamsImpl,
11 },
12 points::example_points,
13};
14
15use crate::prelude::*;
16
17#[derive(Clone, Debug)]
22pub struct Sl2c<S: IsScalar<BATCH, DM, DN>, const BATCH: usize, const DM: usize, const DN: usize> {
23 params: S::Vector<8>,
24}
25
26impl<S: IsScalar<BATCH, DM, DN>, const BATCH: usize, const DM: usize, const DN: usize>
27 Sl2c<S, BATCH, DM, DN>
28{
29 #[inline]
31 #[must_use]
32 pub fn from_params(params: S::Vector<8>) -> Self {
33 Self { params }
34 }
35
36 pub fn zero() -> Self {
38 Self::from_params(S::Vector::<8>::zeros())
39 }
40
41 pub fn one() -> Self {
43 Self::from_params(S::Vector::<8>::from_f64_array([
44 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0,
45 ]))
46 }
47
48 pub fn params(&self) -> &S::Vector<8> {
50 &self.params
51 }
52
53 pub fn params_mut(&mut self) -> &mut S::Vector<8> {
55 &mut self.params
56 }
57
58 pub fn mult(&self, rhs: Self) -> Self {
60 Self::from_params(Sl2cImpl::<S, BATCH, DM, DN>::mult(&self.params, rhs.params))
61 }
62
63 pub fn add(&self, rhs: Self) -> Self {
65 Self::from_params(self.params + rhs.params)
66 }
67
68 pub fn conjugate(&self) -> Self {
70 Self::from_params(Sl2cImpl::<S, BATCH, DM, DN>::conjugate(&self.params))
71 }
72
73 pub fn transpose(&self) -> Self {
75 Self::from_params(Sl2cImpl::<S, BATCH, DM, DN>::transpose(&self.params))
76 }
77
78 pub fn conjugate_transpose(&self) -> Self {
80 Self::from_params(Sl2cImpl::<S, BATCH, DM, DN>::conjugate_transpose(
81 &self.params,
82 ))
83 }
84
85 pub fn inverse(&self) -> Self {
87 Self::from_params(Sl2cImpl::<S, BATCH, DM, DN>::inverse(&self.params))
88 }
89
90 pub fn norm(&self) -> S {
92 Sl2cImpl::<S, BATCH, DM, DN>::norm(&self.params)
93 }
94
95 pub fn squared_norm(&self) -> S {
97 Sl2cImpl::<S, BATCH, DM, DN>::squared_norm(&self.params)
98 }
99
100 pub fn scale(&self, s: S) -> Self {
102 Self::from_params(self.params.scaled(s))
103 }
104}
105
106impl<S: IsScalar<BATCH, DM, DN>, const BATCH: usize, const DM: usize, const DN: usize>
107 core::ops::Add for Sl2c<S, BATCH, DM, DN>
108{
109 type Output = Self;
110
111 fn add(self, rhs: Self) -> Self::Output {
112 Sl2c::add(&self, rhs)
113 }
114}
115
116impl<S: IsScalar<BATCH, DM, DN>, const BATCH: usize, const DM: usize, const DN: usize>
117 core::ops::Mul for Sl2c<S, BATCH, DM, DN>
118{
119 type Output = Self;
120
121 fn mul(self, rhs: Self) -> Self::Output {
122 self.mult(rhs)
123 }
124}
125
126impl<S: IsScalar<BATCH, DM, DN>, const BATCH: usize, const DM: usize, const DN: usize>
127 IsParamsImpl<S, 8, BATCH, DM, DN> for Sl2c<S, BATCH, DM, DN>
128{
129 fn are_params_valid(_params: S::Vector<8>) -> S::Mask {
130 S::Mask::all_true()
131 }
132
133 fn params_examples() -> Vec<S::Vector<8>> {
134 example_points::<S, 8, BATCH, DM, DN>()
135 }
136
137 fn invalid_params_examples() -> Vec<S::Vector<8>> {
138 Vec::new()
139 }
140}
141
142impl<S: IsScalar<BATCH, DM, DN>, const BATCH: usize, const DM: usize, const DN: usize>
143 HasParams<S, 8, BATCH, DM, DN> for Sl2c<S, BATCH, DM, DN>
144{
145 fn from_params(params: S::Vector<8>) -> Self {
146 Self::from_params(params)
147 }
148
149 fn set_params(&mut self, params: S::Vector<8>) {
150 self.params = params;
151 }
152
153 fn params(&self) -> &S::Vector<8> {
154 &self.params
155 }
156}
157
158impl<S: IsScalar<BATCH, DM, DN>, const BATCH: usize, const DM: usize, const DN: usize>
159 IsManifold<S, 8, 8, BATCH, DM, DN> for Sl2c<S, BATCH, DM, DN>
160{
161 fn oplus(&self, tangent: &S::Vector<8>) -> Self {
162 Self::from_params(*self.params() + *tangent)
163 }
164
165 fn ominus(&self, rhs: &Self) -> S::Vector<8> {
166 *self.params() - *rhs.params()
167 }
168}
169
170pub type Sl2cF64 = Sl2c<f64, 1, 0, 0>;
172
173#[derive(Clone, Copy, Debug)]
175pub struct Sl2cImpl<
176 S: IsScalar<BATCH, DM, DN>,
177 const BATCH: usize,
178 const DM: usize,
179 const DN: usize,
180> {
181 phantom: PhantomData<S>,
182}
183
184impl<S: IsScalar<BATCH, DM, DN>, const BATCH: usize, const DM: usize, const DN: usize>
185 Sl2cImpl<S, BATCH, DM, DN>
186{
187 pub fn zero() -> S::Vector<8> {
189 S::Vector::<8>::zeros()
190 }
191
192 pub fn one() -> S::Vector<8> {
194 S::Vector::<8>::from_f64_array([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0])
195 }
196
197 pub fn mult(lhs: &S::Vector<8>, rhs: S::Vector<8>) -> S::Vector<8> {
199 let lhs_re = S::Matrix::<2, 2>::from_array2([
200 [lhs.elem(0), lhs.elem(2)],
201 [lhs.elem(4), lhs.elem(6)],
202 ]);
203 let lhs_im = S::Matrix::<2, 2>::from_array2([
204 [lhs.elem(1), lhs.elem(3)],
205 [lhs.elem(5), lhs.elem(7)],
206 ]);
207
208 let rhs_re = S::Matrix::<2, 2>::from_array2([
209 [rhs.elem(0), rhs.elem(2)],
210 [rhs.elem(4), rhs.elem(6)],
211 ]);
212 let rhs_im = S::Matrix::<2, 2>::from_array2([
213 [rhs.elem(1), rhs.elem(3)],
214 [rhs.elem(5), rhs.elem(7)],
215 ]);
216
217 let re = lhs_re.mat_mul(rhs_re) - lhs_im.mat_mul(rhs_im);
218 let im = lhs_re.mat_mul(rhs_im) + lhs_im.mat_mul(rhs_re);
219
220 S::Vector::<8>::from_array([
221 re.elem([0, 0]),
222 im.elem([0, 0]),
223 re.elem([0, 1]),
224 im.elem([0, 1]),
225 re.elem([1, 0]),
226 im.elem([1, 0]),
227 re.elem([1, 1]),
228 im.elem([1, 1]),
229 ])
230 }
231
232 pub fn add(a: &S::Vector<8>, b: S::Vector<8>) -> S::Vector<8> {
234 *a + b
235 }
236
237 pub fn conjugate(m: &S::Vector<8>) -> S::Vector<8> {
239 S::Vector::<8>::from_array([
240 m.elem(0),
241 -m.elem(1),
242 m.elem(2),
243 -m.elem(3),
244 m.elem(4),
245 -m.elem(5),
246 m.elem(6),
247 -m.elem(7),
248 ])
249 }
250
251 pub fn transpose(m: &S::Vector<8>) -> S::Vector<8> {
253 S::Vector::<8>::from_array([
254 m.elem(0),
255 m.elem(1),
256 m.elem(4),
257 m.elem(5),
258 m.elem(2),
259 m.elem(3),
260 m.elem(6),
261 m.elem(7),
262 ])
263 }
264
265 pub fn conjugate_transpose(m: &S::Vector<8>) -> S::Vector<8> {
267 Self::conjugate(&Self::transpose(m))
268 }
269
270 fn complex_mult(a_re: S, a_im: S, b_re: S, b_im: S) -> (S, S) {
271 (a_re * b_re - a_im * b_im, a_re * b_im + a_im * b_re)
272 }
273
274 fn determinant(m: &S::Vector<8>) -> (S, S) {
275 let a_re = m.elem(0);
276 let a_im = m.elem(1);
277 let b_re = m.elem(2);
278 let b_im = m.elem(3);
279 let c_re = m.elem(4);
280 let c_im = m.elem(5);
281 let d_re = m.elem(6);
282 let d_im = m.elem(7);
283
284 let (ad_re, ad_im) = Self::complex_mult(a_re, a_im, d_re, d_im);
285 let (bc_re, bc_im) = Self::complex_mult(b_re, b_im, c_re, c_im);
286 (ad_re - bc_re, ad_im - bc_im)
287 }
288
289 pub fn inverse(m: &S::Vector<8>) -> S::Vector<8> {
291 let (det_re, det_im) = Self::determinant(m);
292 let norm_sq = det_re * det_re + det_im * det_im;
293 let inv_det_re = det_re / norm_sq;
294 let inv_det_im = -det_im / norm_sq;
295
296 let a_re = m.elem(0);
297 let a_im = m.elem(1);
298 let b_re = m.elem(2);
299 let b_im = m.elem(3);
300 let c_re = m.elem(4);
301 let c_im = m.elem(5);
302 let d_re = m.elem(6);
303 let d_im = m.elem(7);
304
305 let (e0_re, e0_im) = Self::complex_mult(d_re, d_im, inv_det_re, inv_det_im);
306 let (e1_re, e1_im) = Self::complex_mult(-b_re, -b_im, inv_det_re, inv_det_im);
307 let (e2_re, e2_im) = Self::complex_mult(-c_re, -c_im, inv_det_re, inv_det_im);
308 let (e3_re, e3_im) = Self::complex_mult(a_re, a_im, inv_det_re, inv_det_im);
309
310 S::Vector::<8>::from_array([e0_re, e0_im, e1_re, e1_im, e2_re, e2_im, e3_re, e3_im])
311 }
312
313 pub fn norm(m: &S::Vector<8>) -> S {
315 m.norm()
316 }
317
318 pub fn squared_norm(m: &S::Vector<8>) -> S {
320 m.squared_norm()
321 }
322}
323
324#[cfg(test)]
325mod tests {
326 use crate::{
327 ComplexF64,
328 Sl2cImpl,
329 };
330
331 #[test]
332 fn test_complex_mult_helper() {
333 let a = ComplexF64::from_real_imag(0.5, 0.7);
334 let b = ComplexF64::from_real_imag(1.0, -0.2);
335
336 let ab = a * b;
337 let a_mul_b =
338 Sl2cImpl::<f64, 1, 0, 0>::complex_mult(a.real(), a.imag(), b.real(), b.imag());
339
340 assert_eq!(ab.real(), a_mul_b.0);
341 assert_eq!(ab.imag(), a_mul_b.1);
342 }
343}