sophus_lie/
complex.rs

1extern crate alloc;
2
3use alloc::vec::Vec;
4use core::marker::PhantomData;
5
6use sophus_autodiff::{
7    manifold::IsManifold,
8    params::{
9        HasParams,
10        IsParamsImpl,
11    },
12    points::example_points,
13};
14
15use crate::prelude::*;
16
17/// Complex number represented as `(re, im)`.
18#[derive(Clone, Debug, Copy)]
19pub struct Complex<S: IsScalar<BATCH, DM, DN>, const BATCH: usize, const DM: usize, const DN: usize>
20{
21    params: S::Vector<2>,
22}
23
24impl<S: IsScalar<BATCH, DM, DN>, const BATCH: usize, const DM: usize, const DN: usize>
25    Complex<S, BATCH, DM, DN>
26{
27    /// Creates a complex number from real and imaginary scalar.
28    #[inline]
29    #[must_use]
30    pub fn from_real_imag(real: S, imag: S) -> Self {
31        Self::from_params(S::Vector::<2>::from_array([real, imag]))
32    }
33
34    /// Creates a complex number from its parameter vector `(re, im)`.
35    #[inline]
36    #[must_use]
37    pub fn from_params(params: S::Vector<2>) -> Self {
38        Self { params }
39    }
40
41    /// Returns zero `(0,0)`.
42    pub fn zero() -> Self {
43        Self::from_params(S::Vector::<2>::zeros())
44    }
45
46    /// Returns one `(1,0)`.
47    pub fn one() -> Self {
48        Self::from_params(S::Vector::<2>::from_f64_array([1.0, 0.0]))
49    }
50
51    /// Access the underlying parameter vector.
52    pub fn params(&self) -> &S::Vector<2> {
53        &self.params
54    }
55
56    /// Mutable access to the parameter vector.
57    pub fn params_mut(&mut self) -> &mut S::Vector<2> {
58        &mut self.params
59    }
60
61    /// Returns the real component.
62    pub fn real(&self) -> S {
63        self.params.elem(0)
64    }
65
66    /// Returns the imaginary component.
67    pub fn imag(&self) -> S {
68        self.params.elem(1)
69    }
70
71    /// Complex multiplication.
72    pub fn mult(&self, rhs: Self) -> Self {
73        Self::from_params(ComplexImpl::<S, BATCH, DM, DN>::mult(
74            &self.params,
75            rhs.params,
76        ))
77    }
78
79    /// Complex addition.
80    pub fn add(&self, rhs: Self) -> Self {
81        Self::from_params(self.params + rhs.params)
82    }
83
84    /// Conjugated complex number.
85    pub fn conjugate(&self) -> Self {
86        Self::from_params(ComplexImpl::<S, BATCH, DM, DN>::conjugate(&self.params))
87    }
88
89    /// Inverse complex number.
90    pub fn inverse(&self) -> Self {
91        Self::from_params(ComplexImpl::<S, BATCH, DM, DN>::inverse(&self.params))
92    }
93
94    /// Complex norm.
95    pub fn norm(&self) -> S {
96        ComplexImpl::<S, BATCH, DM, DN>::norm(&self.params)
97    }
98
99    /// Complex squared norm.
100    pub fn squared_norm(&self) -> S {
101        ComplexImpl::<S, BATCH, DM, DN>::squared_norm(&self.params)
102    }
103
104    /// Scale complex number by scalar.
105    pub fn scale(&self, s: S) -> Self {
106        Self::from_params(self.params.scaled(s))
107    }
108}
109
110impl<S: IsScalar<BATCH, DM, DN>, const BATCH: usize, const DM: usize, const DN: usize>
111    core::ops::Add for Complex<S, BATCH, DM, DN>
112{
113    type Output = Self;
114
115    fn add(self, rhs: Self) -> Self::Output {
116        Complex::add(&self, rhs)
117    }
118}
119
120impl<S: IsScalar<BATCH, DM, DN>, const BATCH: usize, const DM: usize, const DN: usize>
121    core::ops::Mul for Complex<S, BATCH, DM, DN>
122{
123    type Output = Self;
124
125    fn mul(self, rhs: Self) -> Self::Output {
126        self.mult(rhs)
127    }
128}
129
130impl<S: IsScalar<BATCH, DM, DN>, const BATCH: usize, const DM: usize, const DN: usize>
131    IsParamsImpl<S, 2, BATCH, DM, DN> for Complex<S, BATCH, DM, DN>
132{
133    fn are_params_valid(_params: S::Vector<2>) -> S::Mask {
134        S::Mask::all_true()
135    }
136
137    fn params_examples() -> Vec<S::Vector<2>> {
138        example_points::<S, 2, BATCH, DM, DN>()
139    }
140
141    fn invalid_params_examples() -> Vec<S::Vector<2>> {
142        Vec::new()
143    }
144}
145
146impl<S: IsScalar<BATCH, DM, DN>, const BATCH: usize, const DM: usize, const DN: usize>
147    HasParams<S, 2, BATCH, DM, DN> for Complex<S, BATCH, DM, DN>
148{
149    fn from_params(params: S::Vector<2>) -> Self {
150        Self::from_params(params)
151    }
152
153    fn set_params(&mut self, params: S::Vector<2>) {
154        self.params = params;
155    }
156
157    fn params(&self) -> &S::Vector<2> {
158        &self.params
159    }
160}
161
162impl<S: IsScalar<BATCH, DM, DN>, const BATCH: usize, const DM: usize, const DN: usize>
163    IsManifold<S, 2, 2, BATCH, DM, DN> for Complex<S, BATCH, DM, DN>
164{
165    fn oplus(&self, tangent: &S::Vector<2>) -> Self {
166        Self::from_params(*self.params() + *tangent)
167    }
168
169    fn ominus(&self, rhs: &Self) -> S::Vector<2> {
170        *self.params() - *rhs.params()
171    }
172}
173
174/// Complex number with `f64` scalar type.
175pub type ComplexF64 = Complex<f64, 1, 0, 0>;
176
177/// Implementation utilities for [`Complex`].
178#[derive(Clone, Copy, Debug)]
179pub struct ComplexImpl<
180    S: IsScalar<BATCH, DM, DN>,
181    const BATCH: usize,
182    const DM: usize,
183    const DN: usize,
184> {
185    phantom: PhantomData<S>,
186}
187
188impl<S: IsScalar<BATCH, DM, DN>, const BATCH: usize, const DM: usize, const DN: usize>
189    ComplexImpl<S, BATCH, DM, DN>
190{
191    /// Returns the zero complex number.
192    pub fn zero() -> S::Vector<2> {
193        S::Vector::<2>::zeros()
194    }
195
196    /// Returns the identity complex number.
197    pub fn one() -> S::Vector<2> {
198        S::Vector::<2>::from_f64_array([1.0, 0.0])
199    }
200
201    /// Multiplies two complex numbers.
202    pub fn mult(lhs: &S::Vector<2>, rhs: S::Vector<2>) -> S::Vector<2> {
203        let lhs_re = lhs.elem(0);
204        let rhs_re = rhs.elem(0);
205
206        let lhs_im = lhs.elem(1);
207        let rhs_im = rhs.elem(1);
208
209        let re = lhs_re * rhs_re - lhs_im * rhs_im;
210        let im = lhs_re * rhs_im + lhs_im * rhs_re;
211
212        S::Vector::<2>::from_array([re, im])
213    }
214
215    /// Adds two complex numbers component-wise.
216    pub fn add(a: &S::Vector<2>, b: S::Vector<2>) -> S::Vector<2> {
217        *a + b
218    }
219
220    /// Conjugates a complex number.
221    pub fn conjugate(a: &S::Vector<2>) -> S::Vector<2> {
222        S::Vector::from_array([a.elem(0), -a.elem(1)])
223    }
224
225    /// Computes the inverse complex number.
226    pub fn inverse(z: &S::Vector<2>) -> S::Vector<2> {
227        Self::conjugate(z).scaled(S::from_f64(1.0) / z.squared_norm())
228    }
229
230    /// Returns the complex norm.
231    pub fn norm(z: &S::Vector<2>) -> S {
232        z.norm()
233    }
234
235    /// Returns the squared complex norm.
236    pub fn squared_norm(z: &S::Vector<2>) -> S {
237        z.squared_norm()
238    }
239}