Skip to main content

torsh_tensor/
complex_ops.rs

1//! Complex number operations for tensors
2//!
3//! This module provides specialized operations for complex-valued tensors including
4//! complex conjugation, real/imaginary part extraction, and complex-specific
5//! automatic differentiation support.
6//!
7//! # Features
8//!
9//! - **Complex conjugation**: Efficient complex conjugate operations
10//! - **Component extraction**: Real and imaginary part access
11//! - **Complex autograd**: Specialized gradient computation for complex numbers
12//! - **Complex arithmetic**: Element-wise operations preserving complex structure
13//! - **Magnitude and phase**: Polar representation support
14
15use scirs2_core::numeric::Float;
16use std::sync::Arc;
17use torsh_core::{
18    dtype::{ComplexElement, TensorElement},
19    error::{Result, TorshError},
20};
21
22use crate::core_ops::{Operation, Tensor};
23
24impl<T: ComplexElement + Copy> Tensor<T> {
25    /// Complex conjugate for complex tensors
26    pub fn complex_conj(&self) -> Result<Self>
27    where
28        T: Copy,
29    {
30        let data = self.to_vec()?;
31        let conj_data: Vec<T> = data.iter().map(|&z| z.conj()).collect();
32        let mut result = Self::from_data(conj_data, self.shape().dims().to_vec(), self.device)?;
33        result.requires_grad = self.requires_grad;
34
35        // Set up operation tracking for autograd
36        if self.requires_grad {
37            result.operation = Operation::Custom(
38                "complex_conj".to_string(),
39                vec![Arc::downgrade(&Arc::new(self.clone()))],
40            );
41        }
42
43        Ok(result)
44    }
45
46    /// Get real part of complex tensor
47    pub fn real(&self) -> Result<Tensor<T::Real>>
48    where
49        T::Real: TensorElement + Copy,
50    {
51        let data = self.to_vec()?;
52        let real_data: Vec<T::Real> = data.iter().map(|x| x.real()).collect();
53        Tensor::from_data(real_data, self.shape().dims().to_vec(), self.device)
54    }
55
56    /// Get imaginary part of complex tensor
57    pub fn imag(&self) -> Result<Tensor<T::Real>>
58    where
59        T::Real: TensorElement + Copy,
60    {
61        let data = self.to_vec()?;
62        let imag_data: Vec<T::Real> = data.iter().map(|x| x.imag()).collect();
63        Tensor::from_data(imag_data, self.shape().dims().to_vec(), self.device)
64    }
65
66    /// Get magnitude (absolute value) of complex tensor
67    pub fn abs(&self) -> Result<Tensor<T::Real>>
68    where
69        T::Real: TensorElement + Copy + num_traits::Float,
70    {
71        let data = self.to_vec()?;
72        let abs_data: Vec<T::Real> = data.iter().map(|x| x.abs()).collect();
73        Tensor::from_data(abs_data, self.shape().dims().to_vec(), self.device)
74    }
75
76    /// Get phase (argument) of complex tensor
77    pub fn angle(&self) -> Result<Tensor<T::Real>>
78    where
79        T::Real: TensorElement + Copy + num_traits::Float,
80    {
81        let data = self.to_vec()?;
82        let angle_data: Vec<T::Real> = data.iter().map(|x| x.arg()).collect();
83        Tensor::from_data(angle_data, self.shape().dims().to_vec(), self.device)
84    }
85
86    /// Create complex tensor from real and imaginary parts
87    pub fn complex(real: &Tensor<T::Real>, imag: &Tensor<T::Real>) -> Result<Self>
88    where
89        T::Real: TensorElement + Copy,
90    {
91        if real.shape() != imag.shape() {
92            return Err(TorshError::ShapeMismatch {
93                expected: real.shape().dims().to_vec(),
94                got: imag.shape().dims().to_vec(),
95            });
96        }
97
98        let real_data = real.to_vec()?;
99        let imag_data = imag.to_vec()?;
100
101        let complex_data: Vec<T> = real_data
102            .iter()
103            .zip(imag_data.iter())
104            .map(|(&r, &i)| T::new(r, i))
105            .collect();
106
107        Self::from_data(complex_data, real.shape().dims().to_vec(), real.device)
108    }
109
110    /// Create complex tensor from polar representation (magnitude and phase)
111    pub fn polar(magnitude: &Tensor<T::Real>, phase: &Tensor<T::Real>) -> Result<Self>
112    where
113        T::Real: TensorElement + Copy + num_traits::Float,
114    {
115        if magnitude.shape() != phase.shape() {
116            return Err(TorshError::ShapeMismatch {
117                expected: magnitude.shape().dims().to_vec(),
118                got: phase.shape().dims().to_vec(),
119            });
120        }
121
122        let mag_data = magnitude.to_vec()?;
123        let phase_data = phase.to_vec()?;
124
125        let complex_data: Vec<T> = mag_data
126            .iter()
127            .zip(phase_data.iter())
128            .map(|(&mag, &phase)| {
129                let real = mag * phase.cos();
130                let imag = mag * phase.sin();
131                T::new(real, imag)
132            })
133            .collect();
134
135        Self::from_data(
136            complex_data,
137            magnitude.shape().dims().to_vec(),
138            magnitude.device,
139        )
140    }
141
142    /// Backward pass for complex tensors (compute gradients)
143    ///
144    /// Complex autograd follows PyTorch's approach where gradients are computed
145    /// treating complex numbers as 2D vectors of real numbers.
146    pub fn backward_complex(&self) -> Result<()>
147    where
148        T: Copy
149            + Default
150            + std::ops::Add<Output = T>
151            + std::ops::Sub<Output = T>
152            + std::ops::Mul<Output = T>
153            + std::ops::Div<Output = T>,
154    {
155        if !self.requires_grad {
156            return Err(TorshError::AutogradError(
157                "Called backward on tensor that doesn't require grad".to_string(),
158            ));
159        }
160
161        if self.shape().numel() != 1 {
162            return Err(TorshError::AutogradError(
163                "Gradient can only be computed for scalar outputs".to_string(),
164            ));
165        }
166
167        // Create initial gradient of 1.0 + 0.0i for the output
168        let output_grad_data = vec![T::new(
169            <T::Real as TensorElement>::one(),
170            <T::Real as TensorElement>::zero(),
171        )];
172        let output_grad = Self::from_data(output_grad_data, vec![], self.device)?;
173
174        // Start backpropagation
175        self.backward_complex_impl(&output_grad)?;
176
177        Ok(())
178    }
179
180    /// Internal backward implementation for complex tensors
181    fn backward_complex_impl(&self, grad_output: &Self) -> Result<()>
182    where
183        T: Copy
184            + Default
185            + std::ops::Add<Output = T>
186            + std::ops::Sub<Output = T>
187            + std::ops::Mul<Output = T>
188            + std::ops::Div<Output = T>,
189    {
190        match &self.operation {
191            Operation::Leaf => {
192                // Accumulate gradient for leaf nodes
193                let mut grad_lock = self.grad.write().expect("lock should not be poisoned");
194                if let Some(existing_grad) = grad_lock.as_ref() {
195                    // Add gradients if they exist
196                    let new_grad = existing_grad.add_op(grad_output)?;
197                    *grad_lock = Some(new_grad);
198                } else {
199                    // Set gradient if it doesn't exist
200                    *grad_lock = Some(grad_output.clone());
201                }
202            }
203            Operation::Add { lhs, rhs } => {
204                // Gradient flows through both operands unchanged for complex addition
205                if lhs.requires_grad {
206                    lhs.backward_complex_impl(grad_output)?;
207                }
208                if rhs.requires_grad {
209                    rhs.backward_complex_impl(grad_output)?;
210                }
211            }
212            Operation::Mul { lhs, rhs } => {
213                // Complex multiplication rule: d/dz(f*g) = f'*g + f*g'
214                if lhs.requires_grad {
215                    let lhs_grad = (**rhs).mul_op(grad_output)?;
216                    lhs.backward_complex_impl(&lhs_grad)?;
217                }
218                if rhs.requires_grad {
219                    let rhs_grad = (**lhs).mul_op(grad_output)?;
220                    rhs.backward_complex_impl(&rhs_grad)?;
221                }
222            }
223            Operation::Custom(op_name, inputs) => {
224                match op_name.as_str() {
225                    "complex_conj" => {
226                        // Gradient of complex conjugate: d/dz(conj(f)) = conj(df/dz)
227                        if let Some(weak_input) = inputs.first() {
228                            if let Some(input) = weak_input.upgrade() {
229                                if input.requires_grad {
230                                    let conj_grad = grad_output.complex_conj()?;
231                                    input.backward_complex_impl(&conj_grad)?;
232                                }
233                            }
234                        }
235                    }
236                    "complex_abs" => {
237                        // Gradient of abs(z) = z / |z| for z != 0
238                        if let Some(weak_input) = inputs.first() {
239                            if let Some(input) = weak_input.upgrade() {
240                                if input.requires_grad {
241                                    let input_data = input.to_vec()?;
242                                    let grad_data = grad_output.to_vec()?;
243
244                                    let input_grad_data: Vec<T> = input_data
245                                        .iter()
246                                        .zip(grad_data.iter())
247                                        .map(|(&z, &grad)| {
248                                            let abs_z = z.abs();
249                                            if abs_z > T::Real::zero() {
250                                                // Gradient is z / |z| * grad_output
251                                                let z_normalized =
252                                                    T::new(z.real() / abs_z, z.imag() / abs_z);
253                                                T::new(
254                                                    z_normalized.real() * grad.real()
255                                                        - z_normalized.imag() * grad.imag(),
256                                                    z_normalized.real() * grad.imag()
257                                                        + z_normalized.imag() * grad.real(),
258                                                )
259                                            } else {
260                                                T::new(T::Real::zero(), T::Real::zero())
261                                            }
262                                        })
263                                        .collect();
264
265                                    let input_grad = Self::from_data(
266                                        input_grad_data,
267                                        input.shape().dims().to_vec(),
268                                        input.device,
269                                    )?;
270                                    input.backward_complex_impl(&input_grad)?;
271                                }
272                            }
273                        }
274                    }
275                    _ => {
276                        // For other custom operations, propagate gradient to all inputs
277                        for weak_input in inputs {
278                            if let Some(input) = weak_input.upgrade() {
279                                if input.requires_grad {
280                                    input.backward_complex_impl(grad_output)?;
281                                }
282                            }
283                        }
284                    }
285                }
286            }
287            _ => {
288                // For other operations, fall back to regular backward pass
289                // This would call the regular backward_impl method
290                // Note: This is a simplified approach - in practice, each operation
291                // would need its own complex-specific gradient computation
292            }
293        }
294
295        Ok(())
296    }
297
298    /// Element-wise complex multiplication with proper gradient tracking
299    pub fn complex_mul(&self, other: &Self) -> Result<Self>
300    where
301        T: std::ops::Mul<Output = T> + std::ops::Add<Output = T> + std::ops::Sub<Output = T>,
302    {
303        if self.shape() != other.shape() {
304            return Err(TorshError::ShapeMismatch {
305                expected: self.shape().dims().to_vec(),
306                got: other.shape().dims().to_vec(),
307            });
308        }
309
310        let self_data = self.to_vec()?;
311        let other_data = other.to_vec()?;
312
313        let result_data: Vec<T> = self_data
314            .iter()
315            .zip(other_data.iter())
316            .map(|(&a, &b)| {
317                // Complex multiplication: (a + bi)(c + di) = (ac - bd) + (ad + bc)i
318                T::new(
319                    a.real() * b.real() - a.imag() * b.imag(),
320                    a.real() * b.imag() + a.imag() * b.real(),
321                )
322            })
323            .collect();
324
325        let mut result = Self::from_data(result_data, self.shape().dims().to_vec(), self.device)?;
326
327        // Set up gradient tracking
328        if self.requires_grad || other.requires_grad {
329            result.requires_grad = true;
330            result.operation = Operation::Mul {
331                lhs: Arc::new(self.clone()),
332                rhs: Arc::new(other.clone()),
333            };
334        }
335
336        Ok(result)
337    }
338
339    /// Element-wise complex addition with proper gradient tracking
340    pub fn complex_add(&self, other: &Self) -> Result<Self>
341    where
342        T: std::ops::Add<Output = T>,
343    {
344        if self.shape() != other.shape() {
345            return Err(TorshError::ShapeMismatch {
346                expected: self.shape().dims().to_vec(),
347                got: other.shape().dims().to_vec(),
348            });
349        }
350
351        let self_data = self.to_vec()?;
352        let other_data = other.to_vec()?;
353
354        let result_data: Vec<T> = self_data
355            .iter()
356            .zip(other_data.iter())
357            .map(|(&a, &b)| T::new(a.real() + b.real(), a.imag() + b.imag()))
358            .collect();
359
360        let mut result = Self::from_data(result_data, self.shape().dims().to_vec(), self.device)?;
361
362        // Set up gradient tracking
363        if self.requires_grad || other.requires_grad {
364            result.requires_grad = true;
365            result.operation = Operation::Add {
366                lhs: Arc::new(self.clone()),
367                rhs: Arc::new(other.clone()),
368            };
369        }
370
371        Ok(result)
372    }
373
374    /// Check if all elements in the tensor are real (imaginary part is zero)
375    pub fn is_real(&self) -> Result<bool>
376    where
377        T::Real: PartialEq + num_traits::Zero,
378    {
379        let data = self.to_vec()?;
380        Ok(data.iter().all(|&z| z.imag() == T::Real::zero()))
381    }
382
383    /// Check if any elements in the tensor are complex (imaginary part is non-zero)
384    pub fn is_complex(&self) -> Result<bool>
385    where
386        T::Real: PartialEq + num_traits::Zero,
387    {
388        Ok(!self.is_real()?)
389    }
390}
391
392// Note: add_op and mul_op are provided by math_ops.rs for general use
393
394#[cfg(test)]
395mod tests {
396    use super::*;
397    use num_complex::Complex32;
398    use torsh_core::device::DeviceType;
399
400    type C32 = Complex32;
401
402    #[test]
403    fn test_complex_conjugate() {
404        let data = vec![C32::new(1.0, 2.0), C32::new(3.0, -4.0), C32::new(-1.0, 1.0)];
405        let tensor =
406            Tensor::from_data(data, vec![3], DeviceType::Cpu).expect("operation should succeed");
407
408        let conj_tensor = tensor
409            .complex_conj()
410            .expect("complex conjugate should succeed");
411        let conj_data = conj_tensor.to_vec().expect("to_vec should succeed");
412
413        assert_eq!(conj_data[0], C32::new(1.0, -2.0));
414        assert_eq!(conj_data[1], C32::new(3.0, 4.0));
415        assert_eq!(conj_data[2], C32::new(-1.0, -1.0));
416    }
417
418    #[test]
419    fn test_real_imag_extraction() {
420        let data = vec![C32::new(1.0, 2.0), C32::new(3.0, -4.0)];
421        let tensor =
422            Tensor::from_data(data, vec![2], DeviceType::Cpu).expect("operation should succeed");
423
424        let real_part = tensor.real().expect("real extraction should succeed");
425        let imag_part = tensor.imag().expect("imag extraction should succeed");
426
427        assert_eq!(
428            real_part.to_vec().expect("to_vec should succeed"),
429            vec![1.0, 3.0]
430        );
431        assert_eq!(
432            imag_part.to_vec().expect("to_vec should succeed"),
433            vec![2.0, -4.0]
434        );
435    }
436
437    #[test]
438    fn test_magnitude_and_phase() {
439        let data = vec![
440            C32::new(3.0, 4.0), // |z| = 5, arg = atan(4/3)
441            C32::new(1.0, 0.0), // |z| = 1, arg = 0
442        ];
443        let tensor =
444            Tensor::from_data(data, vec![2], DeviceType::Cpu).expect("operation should succeed");
445
446        let magnitude = tensor.abs().expect("abs computation should succeed");
447        let phase = tensor.angle().expect("angle computation should succeed");
448
449        let mag_data = magnitude.to_vec().expect("to_vec should succeed");
450        let phase_data = phase.to_vec().expect("to_vec should succeed");
451
452        assert!((mag_data[0] - 5.0).abs() < 1e-6);
453        assert!((mag_data[1] - 1.0).abs() < 1e-6);
454        assert!((phase_data[1] - 0.0).abs() < 1e-6);
455    }
456
457    #[test]
458    fn test_complex_from_components() {
459        let real_data = vec![1.0f32, 2.0, 3.0];
460        let imag_data = vec![4.0f32, 5.0, 6.0];
461
462        let real_tensor = Tensor::from_data(real_data, vec![3], DeviceType::Cpu)
463            .expect("operation should succeed");
464        let imag_tensor = Tensor::from_data(imag_data, vec![3], DeviceType::Cpu)
465            .expect("operation should succeed");
466
467        let complex_tensor =
468            Tensor::<C32>::complex(&real_tensor, &imag_tensor).expect("operation should succeed");
469        let result_data = complex_tensor.to_vec().expect("to_vec should succeed");
470
471        assert_eq!(result_data[0], C32::new(1.0, 4.0));
472        assert_eq!(result_data[1], C32::new(2.0, 5.0));
473        assert_eq!(result_data[2], C32::new(3.0, 6.0));
474    }
475
476    #[test]
477    fn test_complex_arithmetic() {
478        let a_data = vec![C32::new(1.0, 2.0), C32::new(3.0, 4.0)];
479        let b_data = vec![C32::new(2.0, 1.0), C32::new(1.0, -1.0)];
480
481        let a =
482            Tensor::from_data(a_data, vec![2], DeviceType::Cpu).expect("operation should succeed");
483        let b =
484            Tensor::from_data(b_data, vec![2], DeviceType::Cpu).expect("operation should succeed");
485
486        // Test complex addition
487        let sum = a.complex_add(&b).expect("operation should succeed");
488        let sum_data = sum.to_vec().expect("to_vec should succeed");
489        assert_eq!(sum_data[0], C32::new(3.0, 3.0));
490        assert_eq!(sum_data[1], C32::new(4.0, 3.0));
491
492        // Test complex multiplication
493        let product = a.complex_mul(&b).expect("operation should succeed");
494        let prod_data = product.to_vec().expect("to_vec should succeed");
495        // (1+2i)(2+1i) = 2 + i + 4i + 2i² = 2 + 5i - 2 = 0 + 5i
496        assert_eq!(prod_data[0], C32::new(0.0, 5.0));
497        // (3+4i)(1-1i) = 3 - 3i + 4i - 4i² = 3 + i + 4 = 7 + i
498        assert_eq!(prod_data[1], C32::new(7.0, 1.0));
499    }
500
501    #[test]
502    fn test_polar_construction() {
503        let mag_data = vec![1.0f32, 2.0];
504        let phase_data = vec![0.0f32, std::f32::consts::PI / 2.0];
505
506        let mag_tensor = Tensor::from_data(mag_data, vec![2], DeviceType::Cpu)
507            .expect("operation should succeed");
508        let phase_tensor = Tensor::from_data(phase_data, vec![2], DeviceType::Cpu)
509            .expect("operation should succeed");
510
511        let complex_tensor =
512            Tensor::<C32>::polar(&mag_tensor, &phase_tensor).expect("operation should succeed");
513        let result_data = complex_tensor.to_vec().expect("to_vec should succeed");
514
515        // First element: 1 * (cos(0) + i*sin(0)) = 1 + 0i
516        assert!((result_data[0].re - 1.0).abs() < 1e-6);
517        assert!((result_data[0].im - 0.0).abs() < 1e-6);
518
519        // Second element: 2 * (cos(π/2) + i*sin(π/2)) = 2 * (0 + i) = 0 + 2i
520        assert!((result_data[1].re - 0.0).abs() < 1e-6);
521        assert!((result_data[1].im - 2.0).abs() < 1e-6);
522    }
523
524    #[test]
525    fn test_is_real_complex() {
526        let real_data = vec![C32::new(1.0, 0.0), C32::new(2.0, 0.0)];
527        let complex_data = vec![C32::new(1.0, 1.0), C32::new(2.0, 0.0)];
528
529        let real_tensor = Tensor::from_data(real_data, vec![2], DeviceType::Cpu)
530            .expect("operation should succeed");
531        let complex_tensor = Tensor::from_data(complex_data, vec![2], DeviceType::Cpu)
532            .expect("operation should succeed");
533
534        assert!(real_tensor.is_real().expect("is_real check should succeed"));
535        assert!(!real_tensor
536            .is_complex()
537            .expect("is_complex check should succeed"));
538
539        assert!(!complex_tensor
540            .is_real()
541            .expect("is_real check should succeed"));
542        assert!(complex_tensor
543            .is_complex()
544            .expect("is_complex check should succeed"));
545    }
546
547    #[test]
548    fn test_shape_mismatch_errors() {
549        let a = Tensor::<C32>::zeros(&[2], DeviceType::Cpu).expect("operation should succeed");
550        let b = Tensor::<C32>::zeros(&[3], DeviceType::Cpu).expect("operation should succeed");
551
552        assert!(a.complex_add(&b).is_err());
553        assert!(a.complex_mul(&b).is_err());
554
555        let real_2 = Tensor::<f32>::zeros(&[2], DeviceType::Cpu).expect("operation should succeed");
556        let imag_3 = Tensor::<f32>::zeros(&[3], DeviceType::Cpu).expect("operation should succeed");
557
558        assert!(Tensor::<C32>::complex(&real_2, &imag_3).is_err());
559    }
560}