1use scirs2_core::numeric::Float;
16use std::sync::Arc;
17use torsh_core::{
18 dtype::{ComplexElement, TensorElement},
19 error::{Result, TorshError},
20};
21
22use crate::core_ops::{Operation, Tensor};
23
24impl<T: ComplexElement + Copy> Tensor<T> {
25 pub fn complex_conj(&self) -> Result<Self>
27 where
28 T: Copy,
29 {
30 let data = self.to_vec()?;
31 let conj_data: Vec<T> = data.iter().map(|&z| z.conj()).collect();
32 let mut result = Self::from_data(conj_data, self.shape().dims().to_vec(), self.device)?;
33 result.requires_grad = self.requires_grad;
34
35 if self.requires_grad {
37 result.operation = Operation::Custom(
38 "complex_conj".to_string(),
39 vec![Arc::downgrade(&Arc::new(self.clone()))],
40 );
41 }
42
43 Ok(result)
44 }
45
46 pub fn real(&self) -> Result<Tensor<T::Real>>
48 where
49 T::Real: TensorElement + Copy,
50 {
51 let data = self.to_vec()?;
52 let real_data: Vec<T::Real> = data.iter().map(|x| x.real()).collect();
53 Tensor::from_data(real_data, self.shape().dims().to_vec(), self.device)
54 }
55
56 pub fn imag(&self) -> Result<Tensor<T::Real>>
58 where
59 T::Real: TensorElement + Copy,
60 {
61 let data = self.to_vec()?;
62 let imag_data: Vec<T::Real> = data.iter().map(|x| x.imag()).collect();
63 Tensor::from_data(imag_data, self.shape().dims().to_vec(), self.device)
64 }
65
66 pub fn abs(&self) -> Result<Tensor<T::Real>>
68 where
69 T::Real: TensorElement + Copy + num_traits::Float,
70 {
71 let data = self.to_vec()?;
72 let abs_data: Vec<T::Real> = data.iter().map(|x| x.abs()).collect();
73 Tensor::from_data(abs_data, self.shape().dims().to_vec(), self.device)
74 }
75
76 pub fn angle(&self) -> Result<Tensor<T::Real>>
78 where
79 T::Real: TensorElement + Copy + num_traits::Float,
80 {
81 let data = self.to_vec()?;
82 let angle_data: Vec<T::Real> = data.iter().map(|x| x.arg()).collect();
83 Tensor::from_data(angle_data, self.shape().dims().to_vec(), self.device)
84 }
85
86 pub fn complex(real: &Tensor<T::Real>, imag: &Tensor<T::Real>) -> Result<Self>
88 where
89 T::Real: TensorElement + Copy,
90 {
91 if real.shape() != imag.shape() {
92 return Err(TorshError::ShapeMismatch {
93 expected: real.shape().dims().to_vec(),
94 got: imag.shape().dims().to_vec(),
95 });
96 }
97
98 let real_data = real.to_vec()?;
99 let imag_data = imag.to_vec()?;
100
101 let complex_data: Vec<T> = real_data
102 .iter()
103 .zip(imag_data.iter())
104 .map(|(&r, &i)| T::new(r, i))
105 .collect();
106
107 Self::from_data(complex_data, real.shape().dims().to_vec(), real.device)
108 }
109
110 pub fn polar(magnitude: &Tensor<T::Real>, phase: &Tensor<T::Real>) -> Result<Self>
112 where
113 T::Real: TensorElement + Copy + num_traits::Float,
114 {
115 if magnitude.shape() != phase.shape() {
116 return Err(TorshError::ShapeMismatch {
117 expected: magnitude.shape().dims().to_vec(),
118 got: phase.shape().dims().to_vec(),
119 });
120 }
121
122 let mag_data = magnitude.to_vec()?;
123 let phase_data = phase.to_vec()?;
124
125 let complex_data: Vec<T> = mag_data
126 .iter()
127 .zip(phase_data.iter())
128 .map(|(&mag, &phase)| {
129 let real = mag * phase.cos();
130 let imag = mag * phase.sin();
131 T::new(real, imag)
132 })
133 .collect();
134
135 Self::from_data(
136 complex_data,
137 magnitude.shape().dims().to_vec(),
138 magnitude.device,
139 )
140 }
141
142 pub fn backward_complex(&self) -> Result<()>
147 where
148 T: Copy
149 + Default
150 + std::ops::Add<Output = T>
151 + std::ops::Sub<Output = T>
152 + std::ops::Mul<Output = T>
153 + std::ops::Div<Output = T>,
154 {
155 if !self.requires_grad {
156 return Err(TorshError::AutogradError(
157 "Called backward on tensor that doesn't require grad".to_string(),
158 ));
159 }
160
161 if self.shape().numel() != 1 {
162 return Err(TorshError::AutogradError(
163 "Gradient can only be computed for scalar outputs".to_string(),
164 ));
165 }
166
167 let output_grad_data = vec![T::new(
169 <T::Real as TensorElement>::one(),
170 <T::Real as TensorElement>::zero(),
171 )];
172 let output_grad = Self::from_data(output_grad_data, vec![], self.device)?;
173
174 self.backward_complex_impl(&output_grad)?;
176
177 Ok(())
178 }
179
180 fn backward_complex_impl(&self, grad_output: &Self) -> Result<()>
182 where
183 T: Copy
184 + Default
185 + std::ops::Add<Output = T>
186 + std::ops::Sub<Output = T>
187 + std::ops::Mul<Output = T>
188 + std::ops::Div<Output = T>,
189 {
190 match &self.operation {
191 Operation::Leaf => {
192 let mut grad_lock = self.grad.write().expect("lock should not be poisoned");
194 if let Some(existing_grad) = grad_lock.as_ref() {
195 let new_grad = existing_grad.add_op(grad_output)?;
197 *grad_lock = Some(new_grad);
198 } else {
199 *grad_lock = Some(grad_output.clone());
201 }
202 }
203 Operation::Add { lhs, rhs } => {
204 if lhs.requires_grad {
206 lhs.backward_complex_impl(grad_output)?;
207 }
208 if rhs.requires_grad {
209 rhs.backward_complex_impl(grad_output)?;
210 }
211 }
212 Operation::Mul { lhs, rhs } => {
213 if lhs.requires_grad {
215 let lhs_grad = (**rhs).mul_op(grad_output)?;
216 lhs.backward_complex_impl(&lhs_grad)?;
217 }
218 if rhs.requires_grad {
219 let rhs_grad = (**lhs).mul_op(grad_output)?;
220 rhs.backward_complex_impl(&rhs_grad)?;
221 }
222 }
223 Operation::Custom(op_name, inputs) => {
224 match op_name.as_str() {
225 "complex_conj" => {
226 if let Some(weak_input) = inputs.first() {
228 if let Some(input) = weak_input.upgrade() {
229 if input.requires_grad {
230 let conj_grad = grad_output.complex_conj()?;
231 input.backward_complex_impl(&conj_grad)?;
232 }
233 }
234 }
235 }
236 "complex_abs" => {
237 if let Some(weak_input) = inputs.first() {
239 if let Some(input) = weak_input.upgrade() {
240 if input.requires_grad {
241 let input_data = input.to_vec()?;
242 let grad_data = grad_output.to_vec()?;
243
244 let input_grad_data: Vec<T> = input_data
245 .iter()
246 .zip(grad_data.iter())
247 .map(|(&z, &grad)| {
248 let abs_z = z.abs();
249 if abs_z > T::Real::zero() {
250 let z_normalized =
252 T::new(z.real() / abs_z, z.imag() / abs_z);
253 T::new(
254 z_normalized.real() * grad.real()
255 - z_normalized.imag() * grad.imag(),
256 z_normalized.real() * grad.imag()
257 + z_normalized.imag() * grad.real(),
258 )
259 } else {
260 T::new(T::Real::zero(), T::Real::zero())
261 }
262 })
263 .collect();
264
265 let input_grad = Self::from_data(
266 input_grad_data,
267 input.shape().dims().to_vec(),
268 input.device,
269 )?;
270 input.backward_complex_impl(&input_grad)?;
271 }
272 }
273 }
274 }
275 _ => {
276 for weak_input in inputs {
278 if let Some(input) = weak_input.upgrade() {
279 if input.requires_grad {
280 input.backward_complex_impl(grad_output)?;
281 }
282 }
283 }
284 }
285 }
286 }
287 _ => {
288 }
293 }
294
295 Ok(())
296 }
297
298 pub fn complex_mul(&self, other: &Self) -> Result<Self>
300 where
301 T: std::ops::Mul<Output = T> + std::ops::Add<Output = T> + std::ops::Sub<Output = T>,
302 {
303 if self.shape() != other.shape() {
304 return Err(TorshError::ShapeMismatch {
305 expected: self.shape().dims().to_vec(),
306 got: other.shape().dims().to_vec(),
307 });
308 }
309
310 let self_data = self.to_vec()?;
311 let other_data = other.to_vec()?;
312
313 let result_data: Vec<T> = self_data
314 .iter()
315 .zip(other_data.iter())
316 .map(|(&a, &b)| {
317 T::new(
319 a.real() * b.real() - a.imag() * b.imag(),
320 a.real() * b.imag() + a.imag() * b.real(),
321 )
322 })
323 .collect();
324
325 let mut result = Self::from_data(result_data, self.shape().dims().to_vec(), self.device)?;
326
327 if self.requires_grad || other.requires_grad {
329 result.requires_grad = true;
330 result.operation = Operation::Mul {
331 lhs: Arc::new(self.clone()),
332 rhs: Arc::new(other.clone()),
333 };
334 }
335
336 Ok(result)
337 }
338
339 pub fn complex_add(&self, other: &Self) -> Result<Self>
341 where
342 T: std::ops::Add<Output = T>,
343 {
344 if self.shape() != other.shape() {
345 return Err(TorshError::ShapeMismatch {
346 expected: self.shape().dims().to_vec(),
347 got: other.shape().dims().to_vec(),
348 });
349 }
350
351 let self_data = self.to_vec()?;
352 let other_data = other.to_vec()?;
353
354 let result_data: Vec<T> = self_data
355 .iter()
356 .zip(other_data.iter())
357 .map(|(&a, &b)| T::new(a.real() + b.real(), a.imag() + b.imag()))
358 .collect();
359
360 let mut result = Self::from_data(result_data, self.shape().dims().to_vec(), self.device)?;
361
362 if self.requires_grad || other.requires_grad {
364 result.requires_grad = true;
365 result.operation = Operation::Add {
366 lhs: Arc::new(self.clone()),
367 rhs: Arc::new(other.clone()),
368 };
369 }
370
371 Ok(result)
372 }
373
374 pub fn is_real(&self) -> Result<bool>
376 where
377 T::Real: PartialEq + num_traits::Zero,
378 {
379 let data = self.to_vec()?;
380 Ok(data.iter().all(|&z| z.imag() == T::Real::zero()))
381 }
382
383 pub fn is_complex(&self) -> Result<bool>
385 where
386 T::Real: PartialEq + num_traits::Zero,
387 {
388 Ok(!self.is_real()?)
389 }
390}
391
392#[cfg(test)]
395mod tests {
396 use super::*;
397 use num_complex::Complex32;
398 use torsh_core::device::DeviceType;
399
400 type C32 = Complex32;
401
402 #[test]
403 fn test_complex_conjugate() {
404 let data = vec![C32::new(1.0, 2.0), C32::new(3.0, -4.0), C32::new(-1.0, 1.0)];
405 let tensor =
406 Tensor::from_data(data, vec![3], DeviceType::Cpu).expect("operation should succeed");
407
408 let conj_tensor = tensor
409 .complex_conj()
410 .expect("complex conjugate should succeed");
411 let conj_data = conj_tensor.to_vec().expect("to_vec should succeed");
412
413 assert_eq!(conj_data[0], C32::new(1.0, -2.0));
414 assert_eq!(conj_data[1], C32::new(3.0, 4.0));
415 assert_eq!(conj_data[2], C32::new(-1.0, -1.0));
416 }
417
418 #[test]
419 fn test_real_imag_extraction() {
420 let data = vec![C32::new(1.0, 2.0), C32::new(3.0, -4.0)];
421 let tensor =
422 Tensor::from_data(data, vec![2], DeviceType::Cpu).expect("operation should succeed");
423
424 let real_part = tensor.real().expect("real extraction should succeed");
425 let imag_part = tensor.imag().expect("imag extraction should succeed");
426
427 assert_eq!(
428 real_part.to_vec().expect("to_vec should succeed"),
429 vec![1.0, 3.0]
430 );
431 assert_eq!(
432 imag_part.to_vec().expect("to_vec should succeed"),
433 vec![2.0, -4.0]
434 );
435 }
436
437 #[test]
438 fn test_magnitude_and_phase() {
439 let data = vec![
440 C32::new(3.0, 4.0), C32::new(1.0, 0.0), ];
443 let tensor =
444 Tensor::from_data(data, vec![2], DeviceType::Cpu).expect("operation should succeed");
445
446 let magnitude = tensor.abs().expect("abs computation should succeed");
447 let phase = tensor.angle().expect("angle computation should succeed");
448
449 let mag_data = magnitude.to_vec().expect("to_vec should succeed");
450 let phase_data = phase.to_vec().expect("to_vec should succeed");
451
452 assert!((mag_data[0] - 5.0).abs() < 1e-6);
453 assert!((mag_data[1] - 1.0).abs() < 1e-6);
454 assert!((phase_data[1] - 0.0).abs() < 1e-6);
455 }
456
457 #[test]
458 fn test_complex_from_components() {
459 let real_data = vec![1.0f32, 2.0, 3.0];
460 let imag_data = vec![4.0f32, 5.0, 6.0];
461
462 let real_tensor = Tensor::from_data(real_data, vec![3], DeviceType::Cpu)
463 .expect("operation should succeed");
464 let imag_tensor = Tensor::from_data(imag_data, vec![3], DeviceType::Cpu)
465 .expect("operation should succeed");
466
467 let complex_tensor =
468 Tensor::<C32>::complex(&real_tensor, &imag_tensor).expect("operation should succeed");
469 let result_data = complex_tensor.to_vec().expect("to_vec should succeed");
470
471 assert_eq!(result_data[0], C32::new(1.0, 4.0));
472 assert_eq!(result_data[1], C32::new(2.0, 5.0));
473 assert_eq!(result_data[2], C32::new(3.0, 6.0));
474 }
475
476 #[test]
477 fn test_complex_arithmetic() {
478 let a_data = vec![C32::new(1.0, 2.0), C32::new(3.0, 4.0)];
479 let b_data = vec![C32::new(2.0, 1.0), C32::new(1.0, -1.0)];
480
481 let a =
482 Tensor::from_data(a_data, vec![2], DeviceType::Cpu).expect("operation should succeed");
483 let b =
484 Tensor::from_data(b_data, vec![2], DeviceType::Cpu).expect("operation should succeed");
485
486 let sum = a.complex_add(&b).expect("operation should succeed");
488 let sum_data = sum.to_vec().expect("to_vec should succeed");
489 assert_eq!(sum_data[0], C32::new(3.0, 3.0));
490 assert_eq!(sum_data[1], C32::new(4.0, 3.0));
491
492 let product = a.complex_mul(&b).expect("operation should succeed");
494 let prod_data = product.to_vec().expect("to_vec should succeed");
495 assert_eq!(prod_data[0], C32::new(0.0, 5.0));
497 assert_eq!(prod_data[1], C32::new(7.0, 1.0));
499 }
500
501 #[test]
502 fn test_polar_construction() {
503 let mag_data = vec![1.0f32, 2.0];
504 let phase_data = vec![0.0f32, std::f32::consts::PI / 2.0];
505
506 let mag_tensor = Tensor::from_data(mag_data, vec![2], DeviceType::Cpu)
507 .expect("operation should succeed");
508 let phase_tensor = Tensor::from_data(phase_data, vec![2], DeviceType::Cpu)
509 .expect("operation should succeed");
510
511 let complex_tensor =
512 Tensor::<C32>::polar(&mag_tensor, &phase_tensor).expect("operation should succeed");
513 let result_data = complex_tensor.to_vec().expect("to_vec should succeed");
514
515 assert!((result_data[0].re - 1.0).abs() < 1e-6);
517 assert!((result_data[0].im - 0.0).abs() < 1e-6);
518
519 assert!((result_data[1].re - 0.0).abs() < 1e-6);
521 assert!((result_data[1].im - 2.0).abs() < 1e-6);
522 }
523
524 #[test]
525 fn test_is_real_complex() {
526 let real_data = vec![C32::new(1.0, 0.0), C32::new(2.0, 0.0)];
527 let complex_data = vec![C32::new(1.0, 1.0), C32::new(2.0, 0.0)];
528
529 let real_tensor = Tensor::from_data(real_data, vec![2], DeviceType::Cpu)
530 .expect("operation should succeed");
531 let complex_tensor = Tensor::from_data(complex_data, vec![2], DeviceType::Cpu)
532 .expect("operation should succeed");
533
534 assert!(real_tensor.is_real().expect("is_real check should succeed"));
535 assert!(!real_tensor
536 .is_complex()
537 .expect("is_complex check should succeed"));
538
539 assert!(!complex_tensor
540 .is_real()
541 .expect("is_real check should succeed"));
542 assert!(complex_tensor
543 .is_complex()
544 .expect("is_complex check should succeed"));
545 }
546
547 #[test]
548 fn test_shape_mismatch_errors() {
549 let a = Tensor::<C32>::zeros(&[2], DeviceType::Cpu).expect("operation should succeed");
550 let b = Tensor::<C32>::zeros(&[3], DeviceType::Cpu).expect("operation should succeed");
551
552 assert!(a.complex_add(&b).is_err());
553 assert!(a.complex_mul(&b).is_err());
554
555 let real_2 = Tensor::<f32>::zeros(&[2], DeviceType::Cpu).expect("operation should succeed");
556 let imag_3 = Tensor::<f32>::zeros(&[3], DeviceType::Cpu).expect("operation should succeed");
557
558 assert!(Tensor::<C32>::complex(&real_2, &imag_3).is_err());
559 }
560}