1extern crate alloc;
2
3use alloc::vec::Vec;
4use core::marker::PhantomData;
5
6use sophus_autodiff::{
7 manifold::IsManifold,
8 params::{
9 HasParams,
10 IsParamsImpl,
11 },
12 points::example_points,
13};
14
15use crate::prelude::*;
16
17#[derive(Clone, Debug, Copy)]
19pub struct Complex<S: IsScalar<BATCH, DM, DN>, const BATCH: usize, const DM: usize, const DN: usize>
20{
21 params: S::Vector<2>,
22}
23
24impl<S: IsScalar<BATCH, DM, DN>, const BATCH: usize, const DM: usize, const DN: usize>
25 Complex<S, BATCH, DM, DN>
26{
27 #[inline]
29 #[must_use]
30 pub fn from_real_imag(real: S, imag: S) -> Self {
31 Self::from_params(S::Vector::<2>::from_array([real, imag]))
32 }
33
34 #[inline]
36 #[must_use]
37 pub fn from_params(params: S::Vector<2>) -> Self {
38 Self { params }
39 }
40
41 pub fn zero() -> Self {
43 Self::from_params(S::Vector::<2>::zeros())
44 }
45
46 pub fn one() -> Self {
48 Self::from_params(S::Vector::<2>::from_f64_array([1.0, 0.0]))
49 }
50
51 pub fn params(&self) -> &S::Vector<2> {
53 &self.params
54 }
55
56 pub fn params_mut(&mut self) -> &mut S::Vector<2> {
58 &mut self.params
59 }
60
61 pub fn real(&self) -> S {
63 self.params.elem(0)
64 }
65
66 pub fn imag(&self) -> S {
68 self.params.elem(1)
69 }
70
71 pub fn mult(&self, rhs: Self) -> Self {
73 Self::from_params(ComplexImpl::<S, BATCH, DM, DN>::mult(
74 &self.params,
75 rhs.params,
76 ))
77 }
78
79 pub fn add(&self, rhs: Self) -> Self {
81 Self::from_params(self.params + rhs.params)
82 }
83
84 pub fn conjugate(&self) -> Self {
86 Self::from_params(ComplexImpl::<S, BATCH, DM, DN>::conjugate(&self.params))
87 }
88
89 pub fn inverse(&self) -> Self {
91 Self::from_params(ComplexImpl::<S, BATCH, DM, DN>::inverse(&self.params))
92 }
93
94 pub fn norm(&self) -> S {
96 ComplexImpl::<S, BATCH, DM, DN>::norm(&self.params)
97 }
98
99 pub fn squared_norm(&self) -> S {
101 ComplexImpl::<S, BATCH, DM, DN>::squared_norm(&self.params)
102 }
103
104 pub fn scale(&self, s: S) -> Self {
106 Self::from_params(self.params.scaled(s))
107 }
108}
109
110impl<S: IsScalar<BATCH, DM, DN>, const BATCH: usize, const DM: usize, const DN: usize>
111 core::ops::Add for Complex<S, BATCH, DM, DN>
112{
113 type Output = Self;
114
115 fn add(self, rhs: Self) -> Self::Output {
116 Complex::add(&self, rhs)
117 }
118}
119
120impl<S: IsScalar<BATCH, DM, DN>, const BATCH: usize, const DM: usize, const DN: usize>
121 core::ops::Mul for Complex<S, BATCH, DM, DN>
122{
123 type Output = Self;
124
125 fn mul(self, rhs: Self) -> Self::Output {
126 self.mult(rhs)
127 }
128}
129
130impl<S: IsScalar<BATCH, DM, DN>, const BATCH: usize, const DM: usize, const DN: usize>
131 IsParamsImpl<S, 2, BATCH, DM, DN> for Complex<S, BATCH, DM, DN>
132{
133 fn are_params_valid(_params: S::Vector<2>) -> S::Mask {
134 S::Mask::all_true()
135 }
136
137 fn params_examples() -> Vec<S::Vector<2>> {
138 example_points::<S, 2, BATCH, DM, DN>()
139 }
140
141 fn invalid_params_examples() -> Vec<S::Vector<2>> {
142 Vec::new()
143 }
144}
145
146impl<S: IsScalar<BATCH, DM, DN>, const BATCH: usize, const DM: usize, const DN: usize>
147 HasParams<S, 2, BATCH, DM, DN> for Complex<S, BATCH, DM, DN>
148{
149 fn from_params(params: S::Vector<2>) -> Self {
150 Self::from_params(params)
151 }
152
153 fn set_params(&mut self, params: S::Vector<2>) {
154 self.params = params;
155 }
156
157 fn params(&self) -> &S::Vector<2> {
158 &self.params
159 }
160}
161
162impl<S: IsScalar<BATCH, DM, DN>, const BATCH: usize, const DM: usize, const DN: usize>
163 IsManifold<S, 2, 2, BATCH, DM, DN> for Complex<S, BATCH, DM, DN>
164{
165 fn oplus(&self, tangent: &S::Vector<2>) -> Self {
166 Self::from_params(*self.params() + *tangent)
167 }
168
169 fn ominus(&self, rhs: &Self) -> S::Vector<2> {
170 *self.params() - *rhs.params()
171 }
172}
173
174pub type ComplexF64 = Complex<f64, 1, 0, 0>;
176
177#[derive(Clone, Copy, Debug)]
179pub struct ComplexImpl<
180 S: IsScalar<BATCH, DM, DN>,
181 const BATCH: usize,
182 const DM: usize,
183 const DN: usize,
184> {
185 phantom: PhantomData<S>,
186}
187
188impl<S: IsScalar<BATCH, DM, DN>, const BATCH: usize, const DM: usize, const DN: usize>
189 ComplexImpl<S, BATCH, DM, DN>
190{
191 pub fn zero() -> S::Vector<2> {
193 S::Vector::<2>::zeros()
194 }
195
196 pub fn one() -> S::Vector<2> {
198 S::Vector::<2>::from_f64_array([1.0, 0.0])
199 }
200
201 pub fn mult(lhs: &S::Vector<2>, rhs: S::Vector<2>) -> S::Vector<2> {
203 let lhs_re = lhs.elem(0);
204 let rhs_re = rhs.elem(0);
205
206 let lhs_im = lhs.elem(1);
207 let rhs_im = rhs.elem(1);
208
209 let re = lhs_re * rhs_re - lhs_im * rhs_im;
210 let im = lhs_re * rhs_im + lhs_im * rhs_re;
211
212 S::Vector::<2>::from_array([re, im])
213 }
214
215 pub fn add(a: &S::Vector<2>, b: S::Vector<2>) -> S::Vector<2> {
217 *a + b
218 }
219
220 pub fn conjugate(a: &S::Vector<2>) -> S::Vector<2> {
222 S::Vector::from_array([a.elem(0), -a.elem(1)])
223 }
224
225 pub fn inverse(z: &S::Vector<2>) -> S::Vector<2> {
227 Self::conjugate(z).scaled(S::from_f64(1.0) / z.squared_norm())
228 }
229
230 pub fn norm(z: &S::Vector<2>) -> S {
232 z.norm()
233 }
234
235 pub fn squared_norm(z: &S::Vector<2>) -> S {
237 z.squared_norm()
238 }
239}