Skip to main content

scivex_core/
complex.rs

1//! Native complex number type for generic tensor operations.
2//!
3//! Provides [`Complex<T>`] which implements [`Scalar`] so it can be used
4//! directly as a tensor element type: `Tensor<Complex<f64>>`.
5
6use crate::dtype::{Float, Scalar};
7use core::fmt;
8use core::iter::Sum;
9use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
10
11// ---------------------------------------------------------------------------
12// Complex<T> struct
13// ---------------------------------------------------------------------------
14
15/// A complex number with real and imaginary parts of type `T`.
16///
17/// # Examples
18///
19/// ```
20/// # use scivex_core::complex::Complex;
21/// let z = Complex::new(3.0_f64, 4.0);
22/// assert_eq!(z.norm(), 5.0);
23/// ```
24#[derive(Debug, Clone, Copy, PartialEq, Default)]
25pub struct Complex<T> {
26    /// Real part.
27    pub re: T,
28    /// Imaginary part.
29    pub im: T,
30}
31
32// ---------------------------------------------------------------------------
33// Constructors and methods
34// ---------------------------------------------------------------------------
35
36impl<T: Float> Complex<T> {
37    /// Create a new complex number from real and imaginary parts.
38    ///
39    /// # Examples
40    ///
41    /// ```
42    /// # use scivex_core::complex::Complex;
43    /// let z = Complex::new(1.0_f64, 2.0);
44    /// assert_eq!(z.re, 1.0);
45    /// assert_eq!(z.im, 2.0);
46    /// ```
47    #[inline]
48    pub fn new(re: T, im: T) -> Self {
49        Self { re, im }
50    }
51
52    /// Create a complex number from a real value (imaginary part is zero).
53    ///
54    /// # Examples
55    ///
56    /// ```
57    /// # use scivex_core::complex::Complex;
58    /// let z = Complex::from_real(3.0_f64);
59    /// assert_eq!(z.im, 0.0);
60    /// ```
61    #[inline]
62    pub fn from_real(re: T) -> Self {
63        Self { re, im: T::zero() }
64    }
65
66    /// Create a complex number from polar form: `r * (cos(theta) + i*sin(theta))`.
67    ///
68    /// # Examples
69    ///
70    /// ```
71    /// # use scivex_core::complex::Complex;
72    /// let z = Complex::from_polar(1.0_f64, 0.0_f64);
73    /// assert!((z.re - 1.0).abs() < 1e-15);
74    /// assert!(z.im.abs() < 1e-15);
75    /// ```
76    #[inline]
77    pub fn from_polar(r: T, theta: T) -> Self {
78        Self {
79            re: r * theta.cos(),
80            im: r * theta.sin(),
81        }
82    }
83
84    /// Complex conjugate: flips the sign of the imaginary part.
85    ///
86    /// # Examples
87    ///
88    /// ```
89    /// # use scivex_core::complex::Complex;
90    /// let z = Complex::new(3.0_f64, 4.0);
91    /// let c = z.conj();
92    /// assert_eq!(c.re, 3.0);
93    /// assert_eq!(c.im, -4.0);
94    /// ```
95    #[inline]
96    pub fn conj(self) -> Self {
97        Self {
98            re: self.re,
99            im: -self.im,
100        }
101    }
102
103    /// Squared modulus: `re² + im²`. Avoids the `sqrt` in [`norm`](Self::norm).
104    ///
105    /// # Examples
106    ///
107    /// ```
108    /// # use scivex_core::complex::Complex;
109    /// let z = Complex::new(3.0_f64, 4.0);
110    /// assert_eq!(z.norm_sqr(), 25.0);
111    /// ```
112    #[inline]
113    pub fn norm_sqr(self) -> T {
114        self.re * self.re + self.im * self.im
115    }
116
117    /// Modulus (absolute value): `sqrt(re² + im²)`.
118    #[inline]
119    pub fn norm(self) -> T {
120        self.norm_sqr().sqrt()
121    }
122
123    /// Phase angle (argument): `atan2(im, re)`.
124    ///
125    /// # Examples
126    ///
127    /// ```
128    /// # use scivex_core::complex::Complex;
129    /// let z = Complex::new(0.0_f64, 1.0); // i
130    /// let angle = z.arg();
131    /// assert!((angle - std::f64::consts::FRAC_PI_2).abs() < 1e-10);
132    /// ```
133    #[inline]
134    pub fn arg(self) -> T {
135        // atan2 is not on Float trait; compute via the identity:
136        // atan2(y, x) = 2 * atan(y / (sqrt(x²+y²) + x))  for x > -|z|
137        // Fallback: use to_f64 / from_f64 for atan2.
138        let angle = self.im.to_f64().atan2(self.re.to_f64());
139        T::from_f64(angle)
140    }
141
142    /// Complex exponential: `e^(a+bi) = e^a * (cos b + i sin b)`.
143    ///
144    /// # Examples
145    ///
146    /// ```
147    /// # use scivex_core::complex::Complex;
148    /// // e^0 = 1
149    /// let z = Complex::new(0.0_f64, 0.0);
150    /// let e = z.exp();
151    /// assert!((e.re - 1.0).abs() < 1e-15);
152    /// assert!(e.im.abs() < 1e-15);
153    /// ```
154    #[inline]
155    pub fn exp(self) -> Self {
156        let ea = self.re.exp();
157        Self {
158            re: ea * self.im.cos(),
159            im: ea * self.im.sin(),
160        }
161    }
162
163    /// Complex natural logarithm: `ln|z| + i*arg(z)`.
164    ///
165    /// # Examples
166    ///
167    /// ```
168    /// # use scivex_core::complex::Complex;
169    /// // ln(1) = 0
170    /// let z = Complex::new(1.0_f64, 0.0);
171    /// let l = z.ln();
172    /// assert!(l.re.abs() < 1e-15);
173    /// assert!(l.im.abs() < 1e-15);
174    /// ```
175    #[inline]
176    pub fn ln(self) -> Self {
177        Self {
178            re: self.norm().ln(),
179            im: self.arg(),
180        }
181    }
182
183    /// Complex square root.
184    ///
185    /// Uses the principal branch: `sqrt(r) * (cos(theta/2) + i*sin(theta/2))`.
186    #[inline]
187    pub fn sqrt(self) -> Self {
188        let r = self.norm().sqrt();
189        let half_theta = self.arg() * T::from_f64(0.5);
190        Self {
191            re: r * half_theta.cos(),
192            im: r * half_theta.sin(),
193        }
194    }
195
196    /// Complex power: `z^n = e^(n * ln(z))`.
197    ///
198    /// # Examples
199    ///
200    /// ```
201    /// # use scivex_core::complex::Complex;
202    /// // (1+0i)^5 = 1
203    /// let z = Complex::new(1.0_f64, 0.0);
204    /// let p = z.pow(5.0_f64);
205    /// assert!((p.re - 1.0).abs() < 1e-10);
206    /// assert!(p.im.abs() < 1e-10);
207    /// ```
208    #[inline]
209    pub fn pow(self, n: T) -> Self {
210        let ln_z = self.ln();
211        let scaled = Self {
212            re: ln_z.re * n,
213            im: ln_z.im * n,
214        };
215        scaled.exp()
216    }
217
218    /// Returns `true` if both real and imaginary parts are finite.
219    ///
220    /// # Examples
221    ///
222    /// ```
223    /// # use scivex_core::complex::Complex;
224    /// assert!(Complex::new(1.0_f64, 2.0).is_finite());
225    /// assert!(!Complex::new(f64::INFINITY, 0.0).is_finite());
226    /// ```
227    #[inline]
228    pub fn is_finite(self) -> bool {
229        self.re.is_finite() && self.im.is_finite()
230    }
231
232    /// Returns `true` if either part is NaN.
233    ///
234    /// # Examples
235    ///
236    /// ```
237    /// # use scivex_core::complex::Complex;
238    /// assert!(!Complex::new(1.0_f64, 2.0).is_nan());
239    /// assert!(Complex::new(f64::NAN, 0.0).is_nan());
240    /// ```
241    #[inline]
242    pub fn is_nan(self) -> bool {
243        self.re.is_nan() || self.im.is_nan()
244    }
245}
246
247// ---------------------------------------------------------------------------
248// The imaginary unit
249// ---------------------------------------------------------------------------
250
251impl<T: Float> Complex<T> {
252    /// The imaginary unit `i`.
253    ///
254    /// # Examples
255    ///
256    /// ```
257    /// # use scivex_core::complex::Complex;
258    /// let i = Complex::<f64>::i();
259    /// assert_eq!(i.re, 0.0);
260    /// assert_eq!(i.im, 1.0);
261    /// ```
262    #[inline]
263    pub fn i() -> Self {
264        Self {
265            re: T::zero(),
266            im: T::one(),
267        }
268    }
269}
270
271// ---------------------------------------------------------------------------
272// Display
273// ---------------------------------------------------------------------------
274
275impl<T: Float + fmt::Display> fmt::Display for Complex<T> {
276    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
277        // We need to determine the sign of the imaginary part.
278        // Compare im >= zero using to_f64.
279        let im_f64 = self.im.to_f64();
280        if im_f64 < 0.0 {
281            write!(f, "{}-{}i", self.re, Float::abs(self.im))
282        } else {
283            write!(f, "{}+{}i", self.re, self.im)
284        }
285    }
286}
287
288// ---------------------------------------------------------------------------
289// PartialOrd — compare by norm (non-standard but required by Scalar)
290// ---------------------------------------------------------------------------
291
292impl<T: Float> PartialOrd for Complex<T> {
293    #[inline]
294    fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
295        self.norm_sqr().partial_cmp(&other.norm_sqr())
296    }
297}
298
299// ---------------------------------------------------------------------------
300// Arithmetic: Complex + Complex
301// ---------------------------------------------------------------------------
302
303impl<T: Float> Add for Complex<T> {
304    type Output = Self;
305    #[inline]
306    fn add(self, rhs: Self) -> Self {
307        Self {
308            re: self.re + rhs.re,
309            im: self.im + rhs.im,
310        }
311    }
312}
313
314impl<T: Float> Sub for Complex<T> {
315    type Output = Self;
316    #[inline]
317    fn sub(self, rhs: Self) -> Self {
318        Self {
319            re: self.re - rhs.re,
320            im: self.im - rhs.im,
321        }
322    }
323}
324
325impl<T: Float> Mul for Complex<T> {
326    type Output = Self;
327    #[inline]
328    fn mul(self, rhs: Self) -> Self {
329        // (a+bi)(c+di) = (ac-bd) + (ad+bc)i
330        Self {
331            re: self.re * rhs.re - self.im * rhs.im,
332            im: self.re * rhs.im + self.im * rhs.re,
333        }
334    }
335}
336
337impl<T: Float> Div for Complex<T> {
338    type Output = Self;
339    #[inline]
340    fn div(self, rhs: Self) -> Self {
341        // (a+bi)/(c+di) = ((ac+bd) + (bc-ad)i) / (c²+d²)
342        let denom = rhs.norm_sqr();
343        Self {
344            re: (self.re * rhs.re + self.im * rhs.im) / denom,
345            im: (self.im * rhs.re - self.re * rhs.im) / denom,
346        }
347    }
348}
349
350// ---------------------------------------------------------------------------
351// Arithmetic: Complex + T scalar
352// ---------------------------------------------------------------------------
353
354impl<T: Float> Add<T> for Complex<T> {
355    type Output = Self;
356    #[inline]
357    fn add(self, rhs: T) -> Self {
358        Self {
359            re: self.re + rhs,
360            im: self.im,
361        }
362    }
363}
364
365impl<T: Float> Sub<T> for Complex<T> {
366    type Output = Self;
367    #[inline]
368    fn sub(self, rhs: T) -> Self {
369        Self {
370            re: self.re - rhs,
371            im: self.im,
372        }
373    }
374}
375
376impl<T: Float> Mul<T> for Complex<T> {
377    type Output = Self;
378    #[inline]
379    fn mul(self, rhs: T) -> Self {
380        Self {
381            re: self.re * rhs,
382            im: self.im * rhs,
383        }
384    }
385}
386
387impl<T: Float> Div<T> for Complex<T> {
388    type Output = Self;
389    #[inline]
390    fn div(self, rhs: T) -> Self {
391        Self {
392            re: self.re / rhs,
393            im: self.im / rhs,
394        }
395    }
396}
397
398// ---------------------------------------------------------------------------
399// Assign ops
400// ---------------------------------------------------------------------------
401
402impl<T: Float> AddAssign for Complex<T> {
403    #[inline]
404    fn add_assign(&mut self, rhs: Self) {
405        self.re += rhs.re;
406        self.im += rhs.im;
407    }
408}
409
410impl<T: Float> SubAssign for Complex<T> {
411    #[inline]
412    fn sub_assign(&mut self, rhs: Self) {
413        self.re -= rhs.re;
414        self.im -= rhs.im;
415    }
416}
417
418impl<T: Float> MulAssign for Complex<T> {
419    #[inline]
420    fn mul_assign(&mut self, rhs: Self) {
421        *self = *self * rhs;
422    }
423}
424
425impl<T: Float> DivAssign for Complex<T> {
426    #[inline]
427    fn div_assign(&mut self, rhs: Self) {
428        *self = *self / rhs;
429    }
430}
431
432// ---------------------------------------------------------------------------
433// Neg
434// ---------------------------------------------------------------------------
435
436impl<T: Float> Neg for Complex<T> {
437    type Output = Self;
438    #[inline]
439    fn neg(self) -> Self {
440        Self {
441            re: -self.re,
442            im: -self.im,
443        }
444    }
445}
446
447// ---------------------------------------------------------------------------
448// Sum
449// ---------------------------------------------------------------------------
450
451impl<T: Float> Sum for Complex<T> {
452    fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
453        iter.fold(Complex::new(T::zero(), T::zero()), |acc, x| acc + x)
454    }
455}
456
457// ---------------------------------------------------------------------------
458// Scalar impl
459// ---------------------------------------------------------------------------
460
461impl<T: Float> Scalar for Complex<T> {
462    #[inline]
463    fn zero() -> Self {
464        Complex::new(T::zero(), T::zero())
465    }
466
467    #[inline]
468    fn one() -> Self {
469        Complex::new(T::one(), T::zero())
470    }
471
472    #[inline]
473    fn from_usize(v: usize) -> Self {
474        Complex::new(T::from_usize(v), T::zero())
475    }
476}
477
478// ---------------------------------------------------------------------------
479// Interleaved conversion utilities
480// ---------------------------------------------------------------------------
481
482/// Convert interleaved real data `[re0, im0, re1, im1, ...]` to a `Vec<Complex<T>>`.
483///
484/// The input slice length must be even. If it is odd the last element is ignored.
485///
486/// # Examples
487///
488/// ```
489/// # use scivex_core::complex::{Complex, from_interleaved};
490/// let data = [1.0_f64, 2.0, 3.0, 4.0];
491/// let v = from_interleaved(&data);
492/// assert_eq!(v.len(), 2);
493/// assert_eq!(v[0], Complex::new(1.0, 2.0));
494/// assert_eq!(v[1], Complex::new(3.0, 4.0));
495/// ```
496pub fn from_interleaved<T: Float>(data: &[T]) -> Vec<Complex<T>> {
497    let n = data.len() / 2;
498    let mut out = Vec::with_capacity(n);
499    for i in 0..n {
500        out.push(Complex::new(data[i * 2], data[i * 2 + 1]));
501    }
502    out
503}
504
505/// Convert a slice of `Complex<T>` to interleaved real data `[re0, im0, re1, im1, ...]`.
506///
507/// # Examples
508///
509/// ```
510/// # use scivex_core::complex::{Complex, to_interleaved};
511/// let v = vec![Complex::new(1.0_f64, 2.0), Complex::new(3.0, 4.0)];
512/// let flat = to_interleaved(&v);
513/// assert_eq!(flat, vec![1.0_f64, 2.0, 3.0, 4.0]);
514/// ```
515pub fn to_interleaved<T: Float>(data: &[Complex<T>]) -> Vec<T> {
516    let mut out = Vec::with_capacity(data.len() * 2);
517    for c in data {
518        out.push(c.re);
519        out.push(c.im);
520    }
521    out
522}
523
524// ===========================================================================
525// Tests
526// ===========================================================================
527
528#[cfg(test)]
529mod tests {
530    use super::*;
531
532    const EPS: f64 = 1e-10;
533
534    fn approx(a: f64, b: f64) -> bool {
535        (a - b).abs() < EPS
536    }
537
538    fn approx_c(a: Complex<f64>, b: Complex<f64>) -> bool {
539        approx(a.re, b.re) && approx(a.im, b.im)
540    }
541
542    #[test]
543    fn test_complex_arithmetic() {
544        let a = Complex::new(1.0, 2.0);
545        let b = Complex::new(3.0, 4.0);
546
547        // add
548        let s = a + b;
549        assert!(approx(s.re, 4.0));
550        assert!(approx(s.im, 6.0));
551
552        // sub
553        let d = a - b;
554        assert!(approx(d.re, -2.0));
555        assert!(approx(d.im, -2.0));
556
557        // mul: (1+2i)(3+4i) = 3+4i+6i+8i² = 3+10i-8 = -5+10i
558        let m = a * b;
559        assert!(approx(m.re, -5.0));
560        assert!(approx(m.im, 10.0));
561
562        // div: (1+2i)/(3+4i) = (1+2i)(3-4i)/25 = (3-4i+6i-8i²)/25 = (11+2i)/25
563        let q = a / b;
564        assert!(approx(q.re, 11.0 / 25.0));
565        assert!(approx(q.im, 2.0 / 25.0));
566    }
567
568    #[test]
569    fn test_complex_conjugate() {
570        let z = Complex::new(3.0, -7.0);
571        let c = z.conj();
572        assert!(approx(c.re, 3.0));
573        assert!(approx(c.im, 7.0));
574    }
575
576    #[test]
577    fn test_complex_norm() {
578        let z = Complex::new(3.0, 4.0);
579        assert!(approx(z.norm(), 5.0));
580        assert!(approx(z.norm_sqr(), 25.0));
581    }
582
583    #[test]
584    fn test_complex_arg() {
585        let z = Complex::new(1.0, 1.0);
586        let expected = std::f64::consts::FRAC_PI_4;
587        assert!(approx(z.arg(), expected));
588    }
589
590    #[test]
591    fn test_complex_exp() {
592        // e^(i*pi) ≈ -1 + 0i
593        let z = Complex::new(0.0, std::f64::consts::PI);
594        let r = z.exp();
595        assert!(approx(r.re, -1.0));
596        assert!(approx(r.im, 0.0));
597    }
598
599    #[test]
600    fn test_complex_from_polar() {
601        let r = 5.0;
602        let theta = std::f64::consts::FRAC_PI_4;
603        let z = Complex::from_polar(r, theta);
604
605        // Roundtrip: norm and arg should recover r and theta
606        assert!(approx(z.norm(), r));
607        assert!(approx(z.arg(), theta));
608    }
609
610    #[test]
611    fn test_complex_scalar_mul() {
612        let z = Complex::new(2.0, 3.0);
613        let scaled = z * 4.0;
614        assert!(approx(scaled.re, 8.0));
615        assert!(approx(scaled.im, 12.0));
616    }
617
618    #[test]
619    fn test_complex_sqrt() {
620        // sqrt(-1) = i
621        let z = Complex::new(-1.0, 0.0);
622        let s = z.sqrt();
623        assert!(approx_c(s, Complex::new(0.0, 1.0)));
624
625        // sqrt(4) = 2
626        let z2 = Complex::new(4.0, 0.0);
627        let s2 = z2.sqrt();
628        assert!(approx_c(s2, Complex::new(2.0, 0.0)));
629    }
630
631    #[test]
632    fn test_complex_display() {
633        let a = Complex::new(3.0_f64, 4.0_f64);
634        let s = format!("{a}");
635        assert!(s.contains('+'));
636        assert!(s.contains('i'));
637
638        let b = Complex::new(1.0_f64, -2.0_f64);
639        let s2 = format!("{b}");
640        assert!(s2.contains('-'));
641        assert!(s2.contains('i'));
642    }
643
644    #[test]
645    fn test_complex_tensor() {
646        use crate::tensor::Tensor;
647
648        let data = vec![
649            Complex::new(1.0, 0.0),
650            Complex::new(0.0, 1.0),
651            Complex::new(1.0, 1.0),
652            Complex::new(2.0, -1.0),
653        ];
654        let t = Tensor::from_vec(data, vec![2, 2]).unwrap();
655        assert_eq!(t.shape(), &[2, 2]);
656
657        // Element-wise add with itself
658        let t2 = &t + &t;
659        let elem = t2.get(&[0, 1]).unwrap();
660        assert!(approx_c(*elem, Complex::new(0.0, 2.0)));
661    }
662
663    #[test]
664    fn test_complex_interleaved_roundtrip() {
665        let original = vec![
666            Complex::new(1.0, 2.0),
667            Complex::new(3.0, 4.0),
668            Complex::new(5.0, 6.0),
669        ];
670        let interleaved = to_interleaved(&original);
671        assert_eq!(interleaved, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
672
673        let recovered = from_interleaved(&interleaved);
674        assert_eq!(recovered, original);
675    }
676
677    #[test]
678    fn test_complex_sum() {
679        let vals = vec![
680            Complex::new(1.0, 2.0),
681            Complex::new(3.0, 4.0),
682            Complex::new(5.0, 6.0),
683        ];
684        let total: Complex<f64> = vals.into_iter().sum();
685        assert!(approx(total.re, 9.0));
686        assert!(approx(total.im, 12.0));
687    }
688
689    #[test]
690    fn test_complex_scalar_trait() {
691        // Verify Scalar trait methods work
692        let z: Complex<f64> = Scalar::zero();
693        assert!(approx(z.re, 0.0));
694        assert!(approx(z.im, 0.0));
695
696        let o: Complex<f64> = Scalar::one();
697        assert!(approx(o.re, 1.0));
698        assert!(approx(o.im, 0.0));
699
700        let from_5: Complex<f64> = Scalar::from_usize(5);
701        assert!(approx(from_5.re, 5.0));
702        assert!(approx(from_5.im, 0.0));
703    }
704}