sophus_lie/
sl2c.rs

1extern crate alloc;
2
3use alloc::vec::Vec;
4use core::marker::PhantomData;
5
6use sophus_autodiff::{
7    manifold::IsManifold,
8    params::{
9        HasParams,
10        IsParamsImpl,
11    },
12    points::example_points,
13};
14
15use crate::prelude::*;
16
17/// SL(2,ℂ) - Complex linear group of rank 2.
18///
19/// This is the group of complex 2x2 matrices with determinant 1. Here it is simply
20/// stored as a list of 8 scalars: `(re00, im00, re01, im01, re10, im10, re11, im11)`.
21#[derive(Clone, Debug)]
22pub struct Sl2c<S: IsScalar<BATCH, DM, DN>, const BATCH: usize, const DM: usize, const DN: usize> {
23    params: S::Vector<8>,
24}
25
26impl<S: IsScalar<BATCH, DM, DN>, const BATCH: usize, const DM: usize, const DN: usize>
27    Sl2c<S, BATCH, DM, DN>
28{
29    /// Creates a matrix from its parameter vector.
30    #[inline]
31    #[must_use]
32    pub fn from_params(params: S::Vector<8>) -> Self {
33        Self { params }
34    }
35
36    /// Returns the zero matrix.
37    pub fn zero() -> Self {
38        Self::from_params(S::Vector::<8>::zeros())
39    }
40
41    /// Returns the identity matrix.
42    pub fn one() -> Self {
43        Self::from_params(S::Vector::<8>::from_f64_array([
44            1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0,
45        ]))
46    }
47
48    /// Access the underlying parameter vector.
49    pub fn params(&self) -> &S::Vector<8> {
50        &self.params
51    }
52
53    /// Mutable access to the parameter vector.
54    pub fn params_mut(&mut self) -> &mut S::Vector<8> {
55        &mut self.params
56    }
57
58    /// Matrix multiplication.
59    pub fn mult(&self, rhs: Self) -> Self {
60        Self::from_params(Sl2cImpl::<S, BATCH, DM, DN>::mult(&self.params, rhs.params))
61    }
62
63    /// Component-wise addition.
64    pub fn add(&self, rhs: Self) -> Self {
65        Self::from_params(self.params + rhs.params)
66    }
67
68    /// Conjugated matrix.
69    pub fn conjugate(&self) -> Self {
70        Self::from_params(Sl2cImpl::<S, BATCH, DM, DN>::conjugate(&self.params))
71    }
72
73    /// Transposed matrix.
74    pub fn transpose(&self) -> Self {
75        Self::from_params(Sl2cImpl::<S, BATCH, DM, DN>::transpose(&self.params))
76    }
77
78    /// Conjugate transpose.
79    pub fn conjugate_transpose(&self) -> Self {
80        Self::from_params(Sl2cImpl::<S, BATCH, DM, DN>::conjugate_transpose(
81            &self.params,
82        ))
83    }
84
85    /// Inverse matrix.
86    pub fn inverse(&self) -> Self {
87        Self::from_params(Sl2cImpl::<S, BATCH, DM, DN>::inverse(&self.params))
88    }
89
90    /// Frobenius norm of the matrix.
91    pub fn norm(&self) -> S {
92        Sl2cImpl::<S, BATCH, DM, DN>::norm(&self.params)
93    }
94
95    /// Squared Frobenius norm of the matrix.
96    pub fn squared_norm(&self) -> S {
97        Sl2cImpl::<S, BATCH, DM, DN>::squared_norm(&self.params)
98    }
99
100    /// Scale the matrix by a scalar.
101    pub fn scale(&self, s: S) -> Self {
102        Self::from_params(self.params.scaled(s))
103    }
104}
105
106impl<S: IsScalar<BATCH, DM, DN>, const BATCH: usize, const DM: usize, const DN: usize>
107    core::ops::Add for Sl2c<S, BATCH, DM, DN>
108{
109    type Output = Self;
110
111    fn add(self, rhs: Self) -> Self::Output {
112        Sl2c::add(&self, rhs)
113    }
114}
115
116impl<S: IsScalar<BATCH, DM, DN>, const BATCH: usize, const DM: usize, const DN: usize>
117    core::ops::Mul for Sl2c<S, BATCH, DM, DN>
118{
119    type Output = Self;
120
121    fn mul(self, rhs: Self) -> Self::Output {
122        self.mult(rhs)
123    }
124}
125
126impl<S: IsScalar<BATCH, DM, DN>, const BATCH: usize, const DM: usize, const DN: usize>
127    IsParamsImpl<S, 8, BATCH, DM, DN> for Sl2c<S, BATCH, DM, DN>
128{
129    fn are_params_valid(_params: S::Vector<8>) -> S::Mask {
130        S::Mask::all_true()
131    }
132
133    fn params_examples() -> Vec<S::Vector<8>> {
134        example_points::<S, 8, BATCH, DM, DN>()
135    }
136
137    fn invalid_params_examples() -> Vec<S::Vector<8>> {
138        Vec::new()
139    }
140}
141
142impl<S: IsScalar<BATCH, DM, DN>, const BATCH: usize, const DM: usize, const DN: usize>
143    HasParams<S, 8, BATCH, DM, DN> for Sl2c<S, BATCH, DM, DN>
144{
145    fn from_params(params: S::Vector<8>) -> Self {
146        Self::from_params(params)
147    }
148
149    fn set_params(&mut self, params: S::Vector<8>) {
150        self.params = params;
151    }
152
153    fn params(&self) -> &S::Vector<8> {
154        &self.params
155    }
156}
157
158impl<S: IsScalar<BATCH, DM, DN>, const BATCH: usize, const DM: usize, const DN: usize>
159    IsManifold<S, 8, 8, BATCH, DM, DN> for Sl2c<S, BATCH, DM, DN>
160{
161    fn oplus(&self, tangent: &S::Vector<8>) -> Self {
162        Self::from_params(*self.params() + *tangent)
163    }
164
165    fn ominus(&self, rhs: &Self) -> S::Vector<8> {
166        *self.params() - *rhs.params()
167    }
168}
169
170/// Matrix with `f64` scalar type.
171pub type Sl2cF64 = Sl2c<f64, 1, 0, 0>;
172
173/// Implementation utilities for [`Sl2c`].
174#[derive(Clone, Copy, Debug)]
175pub struct Sl2cImpl<
176    S: IsScalar<BATCH, DM, DN>,
177    const BATCH: usize,
178    const DM: usize,
179    const DN: usize,
180> {
181    phantom: PhantomData<S>,
182}
183
184impl<S: IsScalar<BATCH, DM, DN>, const BATCH: usize, const DM: usize, const DN: usize>
185    Sl2cImpl<S, BATCH, DM, DN>
186{
187    /// Returns the zero matrix.
188    pub fn zero() -> S::Vector<8> {
189        S::Vector::<8>::zeros()
190    }
191
192    /// Returns the identity matrix.
193    pub fn one() -> S::Vector<8> {
194        S::Vector::<8>::from_f64_array([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0])
195    }
196
197    /// Multiplies two matrices.
198    pub fn mult(lhs: &S::Vector<8>, rhs: S::Vector<8>) -> S::Vector<8> {
199        let lhs_re = S::Matrix::<2, 2>::from_array2([
200            [lhs.elem(0), lhs.elem(2)],
201            [lhs.elem(4), lhs.elem(6)],
202        ]);
203        let lhs_im = S::Matrix::<2, 2>::from_array2([
204            [lhs.elem(1), lhs.elem(3)],
205            [lhs.elem(5), lhs.elem(7)],
206        ]);
207
208        let rhs_re = S::Matrix::<2, 2>::from_array2([
209            [rhs.elem(0), rhs.elem(2)],
210            [rhs.elem(4), rhs.elem(6)],
211        ]);
212        let rhs_im = S::Matrix::<2, 2>::from_array2([
213            [rhs.elem(1), rhs.elem(3)],
214            [rhs.elem(5), rhs.elem(7)],
215        ]);
216
217        let re = lhs_re.mat_mul(rhs_re) - lhs_im.mat_mul(rhs_im);
218        let im = lhs_re.mat_mul(rhs_im) + lhs_im.mat_mul(rhs_re);
219
220        S::Vector::<8>::from_array([
221            re.elem([0, 0]),
222            im.elem([0, 0]),
223            re.elem([0, 1]),
224            im.elem([0, 1]),
225            re.elem([1, 0]),
226            im.elem([1, 0]),
227            re.elem([1, 1]),
228            im.elem([1, 1]),
229        ])
230    }
231
232    /// Adds two matrices component-wise.
233    pub fn add(a: &S::Vector<8>, b: S::Vector<8>) -> S::Vector<8> {
234        *a + b
235    }
236
237    /// Conjugates a matrix.
238    pub fn conjugate(m: &S::Vector<8>) -> S::Vector<8> {
239        S::Vector::<8>::from_array([
240            m.elem(0),
241            -m.elem(1),
242            m.elem(2),
243            -m.elem(3),
244            m.elem(4),
245            -m.elem(5),
246            m.elem(6),
247            -m.elem(7),
248        ])
249    }
250
251    /// Transposes a matrix.
252    pub fn transpose(m: &S::Vector<8>) -> S::Vector<8> {
253        S::Vector::<8>::from_array([
254            m.elem(0),
255            m.elem(1),
256            m.elem(4),
257            m.elem(5),
258            m.elem(2),
259            m.elem(3),
260            m.elem(6),
261            m.elem(7),
262        ])
263    }
264
265    /// Conjugate transpose of a matrix.
266    pub fn conjugate_transpose(m: &S::Vector<8>) -> S::Vector<8> {
267        Self::conjugate(&Self::transpose(m))
268    }
269
270    fn complex_mult(a_re: S, a_im: S, b_re: S, b_im: S) -> (S, S) {
271        (a_re * b_re - a_im * b_im, a_re * b_im + a_im * b_re)
272    }
273
274    fn determinant(m: &S::Vector<8>) -> (S, S) {
275        let a_re = m.elem(0);
276        let a_im = m.elem(1);
277        let b_re = m.elem(2);
278        let b_im = m.elem(3);
279        let c_re = m.elem(4);
280        let c_im = m.elem(5);
281        let d_re = m.elem(6);
282        let d_im = m.elem(7);
283
284        let (ad_re, ad_im) = Self::complex_mult(a_re, a_im, d_re, d_im);
285        let (bc_re, bc_im) = Self::complex_mult(b_re, b_im, c_re, c_im);
286        (ad_re - bc_re, ad_im - bc_im)
287    }
288
289    /// Inverse of a matrix.
290    pub fn inverse(m: &S::Vector<8>) -> S::Vector<8> {
291        let (det_re, det_im) = Self::determinant(m);
292        let norm_sq = det_re * det_re + det_im * det_im;
293        let inv_det_re = det_re / norm_sq;
294        let inv_det_im = -det_im / norm_sq;
295
296        let a_re = m.elem(0);
297        let a_im = m.elem(1);
298        let b_re = m.elem(2);
299        let b_im = m.elem(3);
300        let c_re = m.elem(4);
301        let c_im = m.elem(5);
302        let d_re = m.elem(6);
303        let d_im = m.elem(7);
304
305        let (e0_re, e0_im) = Self::complex_mult(d_re, d_im, inv_det_re, inv_det_im);
306        let (e1_re, e1_im) = Self::complex_mult(-b_re, -b_im, inv_det_re, inv_det_im);
307        let (e2_re, e2_im) = Self::complex_mult(-c_re, -c_im, inv_det_re, inv_det_im);
308        let (e3_re, e3_im) = Self::complex_mult(a_re, a_im, inv_det_re, inv_det_im);
309
310        S::Vector::<8>::from_array([e0_re, e0_im, e1_re, e1_im, e2_re, e2_im, e3_re, e3_im])
311    }
312
313    /// Frobenius norm of the matrix.
314    pub fn norm(m: &S::Vector<8>) -> S {
315        m.norm()
316    }
317
318    /// Squared Frobenius norm of the matrix.
319    pub fn squared_norm(m: &S::Vector<8>) -> S {
320        m.squared_norm()
321    }
322}
323
324#[cfg(test)]
325mod tests {
326    use crate::{
327        ComplexF64,
328        Sl2cImpl,
329    };
330
331    #[test]
332    fn test_complex_mult_helper() {
333        let a = ComplexF64::from_real_imag(0.5, 0.7);
334        let b = ComplexF64::from_real_imag(1.0, -0.2);
335
336        let ab = a * b;
337        let a_mul_b =
338            Sl2cImpl::<f64, 1, 0, 0>::complex_mult(a.real(), a.imag(), b.real(), b.imag());
339
340        assert_eq!(ab.real(), a_mul_b.0);
341        assert_eq!(ab.imag(), a_mul_b.1);
342    }
343}