scivex_core/complex.rs
1//! Native complex number type for generic tensor operations.
2//!
3//! Provides [`Complex<T>`] which implements [`Scalar`] so it can be used
4//! directly as a tensor element type: `Tensor<Complex<f64>>`.
5
6use crate::dtype::{Float, Scalar};
7use core::fmt;
8use core::iter::Sum;
9use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
10
11// ---------------------------------------------------------------------------
12// Complex<T> struct
13// ---------------------------------------------------------------------------
14
15/// A complex number with real and imaginary parts of type `T`.
16///
17/// # Examples
18///
19/// ```
20/// # use scivex_core::complex::Complex;
21/// let z = Complex::new(3.0_f64, 4.0);
22/// assert_eq!(z.norm(), 5.0);
23/// ```
24#[derive(Debug, Clone, Copy, PartialEq, Default)]
25pub struct Complex<T> {
26 /// Real part.
27 pub re: T,
28 /// Imaginary part.
29 pub im: T,
30}
31
32// ---------------------------------------------------------------------------
33// Constructors and methods
34// ---------------------------------------------------------------------------
35
36impl<T: Float> Complex<T> {
37 /// Create a new complex number from real and imaginary parts.
38 ///
39 /// # Examples
40 ///
41 /// ```
42 /// # use scivex_core::complex::Complex;
43 /// let z = Complex::new(1.0_f64, 2.0);
44 /// assert_eq!(z.re, 1.0);
45 /// assert_eq!(z.im, 2.0);
46 /// ```
47 #[inline]
48 pub fn new(re: T, im: T) -> Self {
49 Self { re, im }
50 }
51
52 /// Create a complex number from a real value (imaginary part is zero).
53 ///
54 /// # Examples
55 ///
56 /// ```
57 /// # use scivex_core::complex::Complex;
58 /// let z = Complex::from_real(3.0_f64);
59 /// assert_eq!(z.im, 0.0);
60 /// ```
61 #[inline]
62 pub fn from_real(re: T) -> Self {
63 Self { re, im: T::zero() }
64 }
65
66 /// Create a complex number from polar form: `r * (cos(theta) + i*sin(theta))`.
67 ///
68 /// # Examples
69 ///
70 /// ```
71 /// # use scivex_core::complex::Complex;
72 /// let z = Complex::from_polar(1.0_f64, 0.0_f64);
73 /// assert!((z.re - 1.0).abs() < 1e-15);
74 /// assert!(z.im.abs() < 1e-15);
75 /// ```
76 #[inline]
77 pub fn from_polar(r: T, theta: T) -> Self {
78 Self {
79 re: r * theta.cos(),
80 im: r * theta.sin(),
81 }
82 }
83
84 /// Complex conjugate: flips the sign of the imaginary part.
85 ///
86 /// # Examples
87 ///
88 /// ```
89 /// # use scivex_core::complex::Complex;
90 /// let z = Complex::new(3.0_f64, 4.0);
91 /// let c = z.conj();
92 /// assert_eq!(c.re, 3.0);
93 /// assert_eq!(c.im, -4.0);
94 /// ```
95 #[inline]
96 pub fn conj(self) -> Self {
97 Self {
98 re: self.re,
99 im: -self.im,
100 }
101 }
102
103 /// Squared modulus: `re² + im²`. Avoids the `sqrt` in [`norm`](Self::norm).
104 ///
105 /// # Examples
106 ///
107 /// ```
108 /// # use scivex_core::complex::Complex;
109 /// let z = Complex::new(3.0_f64, 4.0);
110 /// assert_eq!(z.norm_sqr(), 25.0);
111 /// ```
112 #[inline]
113 pub fn norm_sqr(self) -> T {
114 self.re * self.re + self.im * self.im
115 }
116
117 /// Modulus (absolute value): `sqrt(re² + im²)`.
118 #[inline]
119 pub fn norm(self) -> T {
120 self.norm_sqr().sqrt()
121 }
122
123 /// Phase angle (argument): `atan2(im, re)`.
124 ///
125 /// # Examples
126 ///
127 /// ```
128 /// # use scivex_core::complex::Complex;
129 /// let z = Complex::new(0.0_f64, 1.0); // i
130 /// let angle = z.arg();
131 /// assert!((angle - std::f64::consts::FRAC_PI_2).abs() < 1e-10);
132 /// ```
133 #[inline]
134 pub fn arg(self) -> T {
135 // atan2 is not on Float trait; compute via the identity:
136 // atan2(y, x) = 2 * atan(y / (sqrt(x²+y²) + x)) for x > -|z|
137 // Fallback: use to_f64 / from_f64 for atan2.
138 let angle = self.im.to_f64().atan2(self.re.to_f64());
139 T::from_f64(angle)
140 }
141
142 /// Complex exponential: `e^(a+bi) = e^a * (cos b + i sin b)`.
143 ///
144 /// # Examples
145 ///
146 /// ```
147 /// # use scivex_core::complex::Complex;
148 /// // e^0 = 1
149 /// let z = Complex::new(0.0_f64, 0.0);
150 /// let e = z.exp();
151 /// assert!((e.re - 1.0).abs() < 1e-15);
152 /// assert!(e.im.abs() < 1e-15);
153 /// ```
154 #[inline]
155 pub fn exp(self) -> Self {
156 let ea = self.re.exp();
157 Self {
158 re: ea * self.im.cos(),
159 im: ea * self.im.sin(),
160 }
161 }
162
163 /// Complex natural logarithm: `ln|z| + i*arg(z)`.
164 ///
165 /// # Examples
166 ///
167 /// ```
168 /// # use scivex_core::complex::Complex;
169 /// // ln(1) = 0
170 /// let z = Complex::new(1.0_f64, 0.0);
171 /// let l = z.ln();
172 /// assert!(l.re.abs() < 1e-15);
173 /// assert!(l.im.abs() < 1e-15);
174 /// ```
175 #[inline]
176 pub fn ln(self) -> Self {
177 Self {
178 re: self.norm().ln(),
179 im: self.arg(),
180 }
181 }
182
183 /// Complex square root.
184 ///
185 /// Uses the principal branch: `sqrt(r) * (cos(theta/2) + i*sin(theta/2))`.
186 #[inline]
187 pub fn sqrt(self) -> Self {
188 let r = self.norm().sqrt();
189 let half_theta = self.arg() * T::from_f64(0.5);
190 Self {
191 re: r * half_theta.cos(),
192 im: r * half_theta.sin(),
193 }
194 }
195
196 /// Complex power: `z^n = e^(n * ln(z))`.
197 ///
198 /// # Examples
199 ///
200 /// ```
201 /// # use scivex_core::complex::Complex;
202 /// // (1+0i)^5 = 1
203 /// let z = Complex::new(1.0_f64, 0.0);
204 /// let p = z.pow(5.0_f64);
205 /// assert!((p.re - 1.0).abs() < 1e-10);
206 /// assert!(p.im.abs() < 1e-10);
207 /// ```
208 #[inline]
209 pub fn pow(self, n: T) -> Self {
210 let ln_z = self.ln();
211 let scaled = Self {
212 re: ln_z.re * n,
213 im: ln_z.im * n,
214 };
215 scaled.exp()
216 }
217
218 /// Returns `true` if both real and imaginary parts are finite.
219 ///
220 /// # Examples
221 ///
222 /// ```
223 /// # use scivex_core::complex::Complex;
224 /// assert!(Complex::new(1.0_f64, 2.0).is_finite());
225 /// assert!(!Complex::new(f64::INFINITY, 0.0).is_finite());
226 /// ```
227 #[inline]
228 pub fn is_finite(self) -> bool {
229 self.re.is_finite() && self.im.is_finite()
230 }
231
232 /// Returns `true` if either part is NaN.
233 ///
234 /// # Examples
235 ///
236 /// ```
237 /// # use scivex_core::complex::Complex;
238 /// assert!(!Complex::new(1.0_f64, 2.0).is_nan());
239 /// assert!(Complex::new(f64::NAN, 0.0).is_nan());
240 /// ```
241 #[inline]
242 pub fn is_nan(self) -> bool {
243 self.re.is_nan() || self.im.is_nan()
244 }
245}
246
247// ---------------------------------------------------------------------------
248// The imaginary unit
249// ---------------------------------------------------------------------------
250
251impl<T: Float> Complex<T> {
252 /// The imaginary unit `i`.
253 ///
254 /// # Examples
255 ///
256 /// ```
257 /// # use scivex_core::complex::Complex;
258 /// let i = Complex::<f64>::i();
259 /// assert_eq!(i.re, 0.0);
260 /// assert_eq!(i.im, 1.0);
261 /// ```
262 #[inline]
263 pub fn i() -> Self {
264 Self {
265 re: T::zero(),
266 im: T::one(),
267 }
268 }
269}
270
271// ---------------------------------------------------------------------------
272// Display
273// ---------------------------------------------------------------------------
274
275impl<T: Float + fmt::Display> fmt::Display for Complex<T> {
276 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
277 // We need to determine the sign of the imaginary part.
278 // Compare im >= zero using to_f64.
279 let im_f64 = self.im.to_f64();
280 if im_f64 < 0.0 {
281 write!(f, "{}-{}i", self.re, Float::abs(self.im))
282 } else {
283 write!(f, "{}+{}i", self.re, self.im)
284 }
285 }
286}
287
288// ---------------------------------------------------------------------------
289// PartialOrd — compare by norm (non-standard but required by Scalar)
290// ---------------------------------------------------------------------------
291
292impl<T: Float> PartialOrd for Complex<T> {
293 #[inline]
294 fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
295 self.norm_sqr().partial_cmp(&other.norm_sqr())
296 }
297}
298
299// ---------------------------------------------------------------------------
300// Arithmetic: Complex + Complex
301// ---------------------------------------------------------------------------
302
303impl<T: Float> Add for Complex<T> {
304 type Output = Self;
305 #[inline]
306 fn add(self, rhs: Self) -> Self {
307 Self {
308 re: self.re + rhs.re,
309 im: self.im + rhs.im,
310 }
311 }
312}
313
314impl<T: Float> Sub for Complex<T> {
315 type Output = Self;
316 #[inline]
317 fn sub(self, rhs: Self) -> Self {
318 Self {
319 re: self.re - rhs.re,
320 im: self.im - rhs.im,
321 }
322 }
323}
324
325impl<T: Float> Mul for Complex<T> {
326 type Output = Self;
327 #[inline]
328 fn mul(self, rhs: Self) -> Self {
329 // (a+bi)(c+di) = (ac-bd) + (ad+bc)i
330 Self {
331 re: self.re * rhs.re - self.im * rhs.im,
332 im: self.re * rhs.im + self.im * rhs.re,
333 }
334 }
335}
336
337impl<T: Float> Div for Complex<T> {
338 type Output = Self;
339 #[inline]
340 fn div(self, rhs: Self) -> Self {
341 // (a+bi)/(c+di) = ((ac+bd) + (bc-ad)i) / (c²+d²)
342 let denom = rhs.norm_sqr();
343 Self {
344 re: (self.re * rhs.re + self.im * rhs.im) / denom,
345 im: (self.im * rhs.re - self.re * rhs.im) / denom,
346 }
347 }
348}
349
350// ---------------------------------------------------------------------------
351// Arithmetic: Complex + T scalar
352// ---------------------------------------------------------------------------
353
354impl<T: Float> Add<T> for Complex<T> {
355 type Output = Self;
356 #[inline]
357 fn add(self, rhs: T) -> Self {
358 Self {
359 re: self.re + rhs,
360 im: self.im,
361 }
362 }
363}
364
365impl<T: Float> Sub<T> for Complex<T> {
366 type Output = Self;
367 #[inline]
368 fn sub(self, rhs: T) -> Self {
369 Self {
370 re: self.re - rhs,
371 im: self.im,
372 }
373 }
374}
375
376impl<T: Float> Mul<T> for Complex<T> {
377 type Output = Self;
378 #[inline]
379 fn mul(self, rhs: T) -> Self {
380 Self {
381 re: self.re * rhs,
382 im: self.im * rhs,
383 }
384 }
385}
386
387impl<T: Float> Div<T> for Complex<T> {
388 type Output = Self;
389 #[inline]
390 fn div(self, rhs: T) -> Self {
391 Self {
392 re: self.re / rhs,
393 im: self.im / rhs,
394 }
395 }
396}
397
398// ---------------------------------------------------------------------------
399// Assign ops
400// ---------------------------------------------------------------------------
401
402impl<T: Float> AddAssign for Complex<T> {
403 #[inline]
404 fn add_assign(&mut self, rhs: Self) {
405 self.re += rhs.re;
406 self.im += rhs.im;
407 }
408}
409
410impl<T: Float> SubAssign for Complex<T> {
411 #[inline]
412 fn sub_assign(&mut self, rhs: Self) {
413 self.re -= rhs.re;
414 self.im -= rhs.im;
415 }
416}
417
418impl<T: Float> MulAssign for Complex<T> {
419 #[inline]
420 fn mul_assign(&mut self, rhs: Self) {
421 *self = *self * rhs;
422 }
423}
424
425impl<T: Float> DivAssign for Complex<T> {
426 #[inline]
427 fn div_assign(&mut self, rhs: Self) {
428 *self = *self / rhs;
429 }
430}
431
432// ---------------------------------------------------------------------------
433// Neg
434// ---------------------------------------------------------------------------
435
436impl<T: Float> Neg for Complex<T> {
437 type Output = Self;
438 #[inline]
439 fn neg(self) -> Self {
440 Self {
441 re: -self.re,
442 im: -self.im,
443 }
444 }
445}
446
447// ---------------------------------------------------------------------------
448// Sum
449// ---------------------------------------------------------------------------
450
451impl<T: Float> Sum for Complex<T> {
452 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
453 iter.fold(Complex::new(T::zero(), T::zero()), |acc, x| acc + x)
454 }
455}
456
457// ---------------------------------------------------------------------------
458// Scalar impl
459// ---------------------------------------------------------------------------
460
461impl<T: Float> Scalar for Complex<T> {
462 #[inline]
463 fn zero() -> Self {
464 Complex::new(T::zero(), T::zero())
465 }
466
467 #[inline]
468 fn one() -> Self {
469 Complex::new(T::one(), T::zero())
470 }
471
472 #[inline]
473 fn from_usize(v: usize) -> Self {
474 Complex::new(T::from_usize(v), T::zero())
475 }
476}
477
478// ---------------------------------------------------------------------------
479// Interleaved conversion utilities
480// ---------------------------------------------------------------------------
481
482/// Convert interleaved real data `[re0, im0, re1, im1, ...]` to a `Vec<Complex<T>>`.
483///
484/// The input slice length must be even. If it is odd the last element is ignored.
485///
486/// # Examples
487///
488/// ```
489/// # use scivex_core::complex::{Complex, from_interleaved};
490/// let data = [1.0_f64, 2.0, 3.0, 4.0];
491/// let v = from_interleaved(&data);
492/// assert_eq!(v.len(), 2);
493/// assert_eq!(v[0], Complex::new(1.0, 2.0));
494/// assert_eq!(v[1], Complex::new(3.0, 4.0));
495/// ```
496pub fn from_interleaved<T: Float>(data: &[T]) -> Vec<Complex<T>> {
497 let n = data.len() / 2;
498 let mut out = Vec::with_capacity(n);
499 for i in 0..n {
500 out.push(Complex::new(data[i * 2], data[i * 2 + 1]));
501 }
502 out
503}
504
505/// Convert a slice of `Complex<T>` to interleaved real data `[re0, im0, re1, im1, ...]`.
506///
507/// # Examples
508///
509/// ```
510/// # use scivex_core::complex::{Complex, to_interleaved};
511/// let v = vec![Complex::new(1.0_f64, 2.0), Complex::new(3.0, 4.0)];
512/// let flat = to_interleaved(&v);
513/// assert_eq!(flat, vec![1.0_f64, 2.0, 3.0, 4.0]);
514/// ```
515pub fn to_interleaved<T: Float>(data: &[Complex<T>]) -> Vec<T> {
516 let mut out = Vec::with_capacity(data.len() * 2);
517 for c in data {
518 out.push(c.re);
519 out.push(c.im);
520 }
521 out
522}
523
524// ===========================================================================
525// Tests
526// ===========================================================================
527
528#[cfg(test)]
529mod tests {
530 use super::*;
531
532 const EPS: f64 = 1e-10;
533
534 fn approx(a: f64, b: f64) -> bool {
535 (a - b).abs() < EPS
536 }
537
538 fn approx_c(a: Complex<f64>, b: Complex<f64>) -> bool {
539 approx(a.re, b.re) && approx(a.im, b.im)
540 }
541
542 #[test]
543 fn test_complex_arithmetic() {
544 let a = Complex::new(1.0, 2.0);
545 let b = Complex::new(3.0, 4.0);
546
547 // add
548 let s = a + b;
549 assert!(approx(s.re, 4.0));
550 assert!(approx(s.im, 6.0));
551
552 // sub
553 let d = a - b;
554 assert!(approx(d.re, -2.0));
555 assert!(approx(d.im, -2.0));
556
557 // mul: (1+2i)(3+4i) = 3+4i+6i+8i² = 3+10i-8 = -5+10i
558 let m = a * b;
559 assert!(approx(m.re, -5.0));
560 assert!(approx(m.im, 10.0));
561
562 // div: (1+2i)/(3+4i) = (1+2i)(3-4i)/25 = (3-4i+6i-8i²)/25 = (11+2i)/25
563 let q = a / b;
564 assert!(approx(q.re, 11.0 / 25.0));
565 assert!(approx(q.im, 2.0 / 25.0));
566 }
567
568 #[test]
569 fn test_complex_conjugate() {
570 let z = Complex::new(3.0, -7.0);
571 let c = z.conj();
572 assert!(approx(c.re, 3.0));
573 assert!(approx(c.im, 7.0));
574 }
575
576 #[test]
577 fn test_complex_norm() {
578 let z = Complex::new(3.0, 4.0);
579 assert!(approx(z.norm(), 5.0));
580 assert!(approx(z.norm_sqr(), 25.0));
581 }
582
583 #[test]
584 fn test_complex_arg() {
585 let z = Complex::new(1.0, 1.0);
586 let expected = std::f64::consts::FRAC_PI_4;
587 assert!(approx(z.arg(), expected));
588 }
589
590 #[test]
591 fn test_complex_exp() {
592 // e^(i*pi) ≈ -1 + 0i
593 let z = Complex::new(0.0, std::f64::consts::PI);
594 let r = z.exp();
595 assert!(approx(r.re, -1.0));
596 assert!(approx(r.im, 0.0));
597 }
598
599 #[test]
600 fn test_complex_from_polar() {
601 let r = 5.0;
602 let theta = std::f64::consts::FRAC_PI_4;
603 let z = Complex::from_polar(r, theta);
604
605 // Roundtrip: norm and arg should recover r and theta
606 assert!(approx(z.norm(), r));
607 assert!(approx(z.arg(), theta));
608 }
609
610 #[test]
611 fn test_complex_scalar_mul() {
612 let z = Complex::new(2.0, 3.0);
613 let scaled = z * 4.0;
614 assert!(approx(scaled.re, 8.0));
615 assert!(approx(scaled.im, 12.0));
616 }
617
618 #[test]
619 fn test_complex_sqrt() {
620 // sqrt(-1) = i
621 let z = Complex::new(-1.0, 0.0);
622 let s = z.sqrt();
623 assert!(approx_c(s, Complex::new(0.0, 1.0)));
624
625 // sqrt(4) = 2
626 let z2 = Complex::new(4.0, 0.0);
627 let s2 = z2.sqrt();
628 assert!(approx_c(s2, Complex::new(2.0, 0.0)));
629 }
630
631 #[test]
632 fn test_complex_display() {
633 let a = Complex::new(3.0_f64, 4.0_f64);
634 let s = format!("{a}");
635 assert!(s.contains('+'));
636 assert!(s.contains('i'));
637
638 let b = Complex::new(1.0_f64, -2.0_f64);
639 let s2 = format!("{b}");
640 assert!(s2.contains('-'));
641 assert!(s2.contains('i'));
642 }
643
644 #[test]
645 fn test_complex_tensor() {
646 use crate::tensor::Tensor;
647
648 let data = vec![
649 Complex::new(1.0, 0.0),
650 Complex::new(0.0, 1.0),
651 Complex::new(1.0, 1.0),
652 Complex::new(2.0, -1.0),
653 ];
654 let t = Tensor::from_vec(data, vec![2, 2]).unwrap();
655 assert_eq!(t.shape(), &[2, 2]);
656
657 // Element-wise add with itself
658 let t2 = &t + &t;
659 let elem = t2.get(&[0, 1]).unwrap();
660 assert!(approx_c(*elem, Complex::new(0.0, 2.0)));
661 }
662
663 #[test]
664 fn test_complex_interleaved_roundtrip() {
665 let original = vec![
666 Complex::new(1.0, 2.0),
667 Complex::new(3.0, 4.0),
668 Complex::new(5.0, 6.0),
669 ];
670 let interleaved = to_interleaved(&original);
671 assert_eq!(interleaved, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
672
673 let recovered = from_interleaved(&interleaved);
674 assert_eq!(recovered, original);
675 }
676
677 #[test]
678 fn test_complex_sum() {
679 let vals = vec![
680 Complex::new(1.0, 2.0),
681 Complex::new(3.0, 4.0),
682 Complex::new(5.0, 6.0),
683 ];
684 let total: Complex<f64> = vals.into_iter().sum();
685 assert!(approx(total.re, 9.0));
686 assert!(approx(total.im, 12.0));
687 }
688
689 #[test]
690 fn test_complex_scalar_trait() {
691 // Verify Scalar trait methods work
692 let z: Complex<f64> = Scalar::zero();
693 assert!(approx(z.re, 0.0));
694 assert!(approx(z.im, 0.0));
695
696 let o: Complex<f64> = Scalar::one();
697 assert!(approx(o.re, 1.0));
698 assert!(approx(o.im, 0.0));
699
700 let from_5: Complex<f64> = Scalar::from_usize(5);
701 assert!(approx(from_5.re, 5.0));
702 assert!(approx(from_5.im, 0.0));
703 }
704}