trustformers_core/tensor/
complex.rs1use super::Tensor;
8use crate::errors::{Result, TrustformersError};
9use scirs2_core::ndarray::{ArrayD, IxDyn};
10use scirs2_core::{Complex, Complex32, Complex64};
11
12const STABILITY_EPSILON_F32: f32 = 1e-7;
14const STABILITY_EPSILON_F64: f64 = 1e-15;
15const MAX_SAFE_MAGNITUDE_F32: f32 = 1e30;
16const MAX_SAFE_MAGNITUDE_F64: f64 = 1e300;
17
18fn is_stable_c32(z: Complex32) -> bool {
20 z.re.is_finite()
21 && z.im.is_finite()
22 && z.norm() < MAX_SAFE_MAGNITUDE_F32
23 && z.norm() > STABILITY_EPSILON_F32
24}
25
26fn is_stable_c64(z: Complex64) -> bool {
28 z.re.is_finite()
29 && z.im.is_finite()
30 && z.norm() < MAX_SAFE_MAGNITUDE_F64
31 && z.norm() > STABILITY_EPSILON_F64
32}
33
34fn stabilize_c32(z: Complex32) -> Complex32 {
36 if !z.re.is_finite() || !z.im.is_finite() {
37 return Complex32::new(0.0, 0.0);
38 }
39 let magnitude = z.norm();
40 if magnitude > MAX_SAFE_MAGNITUDE_F32 {
41 let scale = MAX_SAFE_MAGNITUDE_F32 / magnitude;
42 Complex32::new(z.re * scale, z.im * scale)
43 } else if magnitude < STABILITY_EPSILON_F32 && magnitude > 0.0 {
44 let scale = STABILITY_EPSILON_F32 / magnitude;
45 Complex32::new(z.re * scale, z.im * scale)
46 } else {
47 z
48 }
49}
50
51fn stabilize_c64(z: Complex64) -> Complex64 {
53 if !z.re.is_finite() || !z.im.is_finite() {
54 return Complex64::new(0.0, 0.0);
55 }
56 let magnitude = z.norm();
57 if magnitude > MAX_SAFE_MAGNITUDE_F64 {
58 let scale = MAX_SAFE_MAGNITUDE_F64 / magnitude;
59 Complex64::new(z.re * scale, z.im * scale)
60 } else if magnitude < STABILITY_EPSILON_F64 && magnitude > 0.0 {
61 let scale = STABILITY_EPSILON_F64 / magnitude;
62 Complex64::new(z.re * scale, z.im * scale)
63 } else {
64 z
65 }
66}
67
68impl Tensor {
69 pub fn real(&self) -> Result<Tensor> {
75 match self {
76 Tensor::C32(a) => {
77 let result = a.mapv(|x| x.re);
78 Ok(Tensor::F32(result))
79 },
80 Tensor::C64(a) => {
81 let result = a.mapv(|x| x.re);
82 Ok(Tensor::F64(result))
83 },
84 Tensor::CF16(a) => {
85 let result = a.mapv(|x| x.re);
86 Ok(Tensor::F16(result))
87 },
88 Tensor::CBF16(a) => {
89 let result = a.mapv(|x| x.re);
90 Ok(Tensor::BF16(result))
91 },
92 Tensor::F32(_) | Tensor::F64(_) | Tensor::F16(_) | Tensor::BF16(_) | Tensor::I64(_) => {
93 Ok(self.clone())
95 },
96 _ => Err(TrustformersError::tensor_op_error(
97 "Real part extraction not supported for this tensor type",
98 "complex real part extraction",
99 )),
100 }
101 }
102
103 pub fn imag(&self) -> Result<Tensor> {
109 match self {
110 Tensor::C32(a) => {
111 let result = a.mapv(|x| x.im);
112 Ok(Tensor::F32(result))
113 },
114 Tensor::C64(a) => {
115 let result = a.mapv(|x| x.im);
116 Ok(Tensor::F64(result))
117 },
118 Tensor::CF16(a) => {
119 let result = a.mapv(|x| x.im);
120 Ok(Tensor::F16(result))
121 },
122 Tensor::CBF16(a) => {
123 let result = a.mapv(|x| x.im);
124 Ok(Tensor::BF16(result))
125 },
126 Tensor::F32(a) => {
127 let result = ArrayD::zeros(a.raw_dim());
129 Ok(Tensor::F32(result))
130 },
131 Tensor::F64(a) => {
132 let result = ArrayD::zeros(a.raw_dim());
134 Ok(Tensor::F64(result))
135 },
136 Tensor::F16(a) => {
137 let size = a.len();
139 let data = vec![half::f16::ZERO; size];
140 let result = ArrayD::from_shape_vec(a.raw_dim(), data)
141 .map_err(|e| TrustformersError::shape_error(e.to_string()))?;
142 Ok(Tensor::F16(result))
143 },
144 Tensor::BF16(a) => {
145 let size = a.len();
147 let data = vec![half::bf16::ZERO; size];
148 let result = ArrayD::from_shape_vec(a.raw_dim(), data)
149 .map_err(|e| TrustformersError::shape_error(e.to_string()))?;
150 Ok(Tensor::BF16(result))
151 },
152 Tensor::I64(a) => {
153 let result = ArrayD::zeros(a.raw_dim());
155 Ok(Tensor::F32(result))
156 },
157 _ => Err(TrustformersError::tensor_op_error(
158 "Imaginary part extraction not supported for this tensor type",
159 "complex imaginary part extraction",
160 )),
161 }
162 }
163
164 pub fn magnitude(&self) -> Result<Tensor> {
172 match self {
173 Tensor::C32(a) => {
174 let result = a.mapv(|x| {
175 if !is_stable_c32(x) {
176 let stabilized = stabilize_c32(x);
177 stabilized.norm()
178 } else {
179 let abs_re = x.re.abs();
181 let abs_im = x.im.abs();
182 if abs_re == 0.0 {
183 abs_im
184 } else if abs_im == 0.0 {
185 abs_re
186 } else if abs_re > abs_im {
187 let ratio = abs_im / abs_re;
188 abs_re * (1.0 + ratio * ratio).sqrt()
189 } else {
190 let ratio = abs_re / abs_im;
191 abs_im * (1.0 + ratio * ratio).sqrt()
192 }
193 }
194 });
195 Ok(Tensor::F32(result))
196 },
197 Tensor::C64(a) => {
198 let result = a.mapv(|x| {
199 if !is_stable_c64(x) {
200 let stabilized = stabilize_c64(x);
201 stabilized.norm()
202 } else {
203 let abs_re = x.re.abs();
205 let abs_im = x.im.abs();
206 if abs_re == 0.0 {
207 abs_im
208 } else if abs_im == 0.0 {
209 abs_re
210 } else if abs_re > abs_im {
211 let ratio = abs_im / abs_re;
212 abs_re * (1.0 + ratio * ratio).sqrt()
213 } else {
214 let ratio = abs_re / abs_im;
215 abs_im * (1.0 + ratio * ratio).sqrt()
216 }
217 }
218 });
219 Ok(Tensor::F64(result))
220 },
221 Tensor::CF16(a) => {
222 let result = a.mapv(|x| {
223 let re_f32 = x.re.to_f32();
224 let im_f32 = x.im.to_f32();
225
226 if !re_f32.is_finite() || !im_f32.is_finite() {
228 return half::f16::from_f32(0.0);
229 }
230
231 let abs_re = re_f32.abs();
233 let abs_im = im_f32.abs();
234 let norm = if abs_re == 0.0 {
235 abs_im
236 } else if abs_im == 0.0 {
237 abs_re
238 } else if abs_re > abs_im {
239 let ratio = abs_im / abs_re;
240 abs_re * (1.0 + ratio * ratio).sqrt()
241 } else {
242 let ratio = abs_re / abs_im;
243 abs_im * (1.0 + ratio * ratio).sqrt()
244 };
245
246 half::f16::from_f32(norm.min(half::f16::MAX.to_f32()))
247 });
248 Ok(Tensor::F16(result))
249 },
250 Tensor::CBF16(a) => {
251 let result = a.mapv(|x| {
252 let re_f32 = x.re.to_f32();
253 let im_f32 = x.im.to_f32();
254
255 if !re_f32.is_finite() || !im_f32.is_finite() {
257 return half::bf16::from_f32(0.0);
258 }
259
260 let abs_re = re_f32.abs();
262 let abs_im = im_f32.abs();
263 let norm = if abs_re == 0.0 {
264 abs_im
265 } else if abs_im == 0.0 {
266 abs_re
267 } else if abs_re > abs_im {
268 let ratio = abs_im / abs_re;
269 abs_re * (1.0 + ratio * ratio).sqrt()
270 } else {
271 let ratio = abs_re / abs_im;
272 abs_im * (1.0 + ratio * ratio).sqrt()
273 };
274
275 half::bf16::from_f32(norm.min(half::bf16::MAX.to_f32()))
276 });
277 Ok(Tensor::BF16(result))
278 },
279 Tensor::F32(a) => {
280 let result = a.mapv(|x| x.abs());
282 Ok(Tensor::F32(result))
283 },
284 Tensor::F64(a) => {
285 let result = a.mapv(|x| x.abs());
287 Ok(Tensor::F64(result))
288 },
289 Tensor::F16(a) => {
290 let result = a.mapv(|x| {
292 let val = x.to_f32();
293 half::f16::from_f32(val.abs())
294 });
295 Ok(Tensor::F16(result))
296 },
297 Tensor::BF16(a) => {
298 let result = a.mapv(|x| {
300 let val = x.to_f32();
301 half::bf16::from_f32(val.abs())
302 });
303 Ok(Tensor::BF16(result))
304 },
305 Tensor::I64(a) => {
306 let result = a.mapv(|x| x.abs() as f32);
308 Ok(Tensor::F32(result))
309 },
310 _ => Err(TrustformersError::tensor_op_error(
311 "Magnitude not supported for this tensor type",
312 "complex magnitude calculation",
313 )),
314 }
315 }
316
317 pub fn phase(&self) -> Result<Tensor> {
323 match self {
324 Tensor::C32(a) => {
325 let result = a.mapv(|x| x.arg());
326 Ok(Tensor::F32(result))
327 },
328 Tensor::C64(a) => {
329 let result = a.mapv(|x| x.arg());
330 Ok(Tensor::F64(result))
331 },
332 Tensor::CF16(a) => {
333 let result = a.mapv(|x| {
334 let re_f32 = x.re.to_f32();
335 let im_f32 = x.im.to_f32();
336 let phase = im_f32.atan2(re_f32);
337 half::f16::from_f32(phase)
338 });
339 Ok(Tensor::F16(result))
340 },
341 Tensor::CBF16(a) => {
342 let result = a.mapv(|x| {
343 let re_f32 = x.re.to_f32();
344 let im_f32 = x.im.to_f32();
345 let phase = im_f32.atan2(re_f32);
346 half::bf16::from_f32(phase)
347 });
348 Ok(Tensor::BF16(result))
349 },
350 Tensor::F32(a) => {
351 let result = a.mapv(|x| if x >= 0.0 { 0.0 } else { std::f32::consts::PI });
353 Ok(Tensor::F32(result))
354 },
355 Tensor::F64(a) => {
356 let result = a.mapv(|x| if x >= 0.0 { 0.0 } else { std::f64::consts::PI });
358 Ok(Tensor::F64(result))
359 },
360 Tensor::F16(a) => {
361 let result = a.mapv(|x| {
363 let val = x.to_f32();
364 if val >= 0.0 {
365 half::f16::from_f32(0.0)
366 } else {
367 half::f16::from_f32(std::f32::consts::PI)
368 }
369 });
370 Ok(Tensor::F16(result))
371 },
372 Tensor::BF16(a) => {
373 let result = a.mapv(|x| {
375 let val = x.to_f32();
376 if val >= 0.0 {
377 half::bf16::from_f32(0.0)
378 } else {
379 half::bf16::from_f32(std::f32::consts::PI)
380 }
381 });
382 Ok(Tensor::BF16(result))
383 },
384 _ => Err(TrustformersError::tensor_op_error(
385 "Phase not supported for this tensor type",
386 "complex phase calculation",
387 )),
388 }
389 }
390
391 pub fn conj(&self) -> Result<Tensor> {
397 match self {
398 Tensor::C32(a) => {
399 let result = a.mapv(|x| x.conj());
400 Ok(Tensor::C32(result))
401 },
402 Tensor::C64(a) => {
403 let result = a.mapv(|x| x.conj());
404 Ok(Tensor::C64(result))
405 },
406 Tensor::CF16(a) => {
407 let result = a.mapv(|x| Complex::new(x.re, -x.im));
408 Ok(Tensor::CF16(result))
409 },
410 Tensor::CBF16(a) => {
411 let result = a.mapv(|x| Complex::new(x.re, -x.im));
412 Ok(Tensor::CBF16(result))
413 },
414 Tensor::F32(_) | Tensor::F64(_) | Tensor::F16(_) | Tensor::BF16(_) | Tensor::I64(_) => {
415 Ok(self.clone())
417 },
418 _ => Err(TrustformersError::tensor_op_error(
419 "Complex conjugate not supported for this tensor type",
420 "complex conjugate operation",
421 )),
422 }
423 }
424
425 pub fn to_complex(&self) -> Result<Tensor> {
431 match self {
432 Tensor::F32(a) => {
433 let result = a.mapv(|x| Complex32::new(x, 0.0));
434 Ok(Tensor::C32(result))
435 },
436 Tensor::F64(a) => {
437 let result = a.mapv(|x| Complex64::new(x, 0.0));
438 Ok(Tensor::C64(result))
439 },
440 Tensor::F16(a) => {
441 let result = a.mapv(|x| Complex::new(x, half::f16::from_f32(0.0)));
442 Ok(Tensor::CF16(result))
443 },
444 Tensor::BF16(a) => {
445 let result = a.mapv(|x| Complex::new(x, half::bf16::from_f32(0.0)));
446 Ok(Tensor::CBF16(result))
447 },
448 Tensor::I64(a) => {
449 let result = a.mapv(|x| Complex32::new(x as f32, 0.0));
450 Ok(Tensor::C32(result))
451 },
452 Tensor::C32(_) | Tensor::C64(_) | Tensor::CF16(_) | Tensor::CBF16(_) => {
453 Ok(self.clone())
455 },
456 _ => Err(TrustformersError::tensor_op_error(
457 "Cannot convert this tensor type to complex",
458 "complex tensor conversion",
459 )),
460 }
461 }
462
463 pub fn complex_hadamard(&self, other: &Tensor) -> Result<Tensor> {
476 match (self, other) {
477 (Tensor::C32(a), Tensor::C32(b)) => {
478 let result = a * b;
479 Ok(Tensor::C32(result))
480 },
481 (Tensor::C64(a), Tensor::C64(b)) => {
482 let result = a * b;
483 Ok(Tensor::C64(result))
484 },
485 (Tensor::CF16(a), Tensor::CF16(b)) => {
486 let result = a
488 .iter()
489 .zip(b.iter())
490 .map(|(a_val, b_val)| {
491 Complex::new(
492 a_val.re * b_val.re - a_val.im * b_val.im,
493 a_val.re * b_val.im + a_val.im * b_val.re,
494 )
495 })
496 .collect::<Vec<_>>();
497
498 Ok(Tensor::CF16(
499 ArrayD::from_shape_vec(a.raw_dim(), result)
500 .map_err(|e| TrustformersError::shape_error(e.to_string()))?,
501 ))
502 },
503 (Tensor::CBF16(a), Tensor::CBF16(b)) => {
504 let result = a
506 .iter()
507 .zip(b.iter())
508 .map(|(a_val, b_val)| {
509 Complex::new(
510 a_val.re * b_val.re - a_val.im * b_val.im,
511 a_val.re * b_val.im + a_val.im * b_val.re,
512 )
513 })
514 .collect::<Vec<_>>();
515
516 Ok(Tensor::CBF16(
517 ArrayD::from_shape_vec(a.raw_dim(), result)
518 .map_err(|e| TrustformersError::shape_error(e.to_string()))?,
519 ))
520 },
521 _ => Err(TrustformersError::tensor_op_error(
522 "Complex Hadamard product requires matching complex tensor types",
523 "complex Hadamard product",
524 )),
525 }
526 }
527
528 pub fn fft(&self) -> Result<Tensor> {
537 match self {
538 Tensor::C32(a) => {
539 if a.shape().len() != 1 {
540 return Err(TrustformersError::tensor_op_error(
541 "FFT currently only supports 1D tensors",
542 "complex FFT operation",
543 ));
544 }
545
546 let n = a.len();
547 if n == 0 {
548 return Err(TrustformersError::tensor_op_error(
549 "FFT requires non-empty tensor",
550 "complex FFT operation",
551 ));
552 }
553
554 let mut result = ArrayD::zeros(IxDyn(&[n]));
555 let n_f32 = n as f32;
556
557 let scale_factor = 1.0 / n_f32.sqrt();
559
560 for k in 0..n {
561 let mut sum = Complex32::new(0.0, 0.0);
562 let mut overflow_detected = false;
563
564 for j in 0..n {
565 if !is_stable_c32(a[[j]]) {
567 continue; }
569
570 let angle = -2.0 * std::f32::consts::PI * (k * j) as f32 / n_f32;
571 let twiddle = Complex32::new(angle.cos(), angle.sin());
572
573 let product = a[[j]] * twiddle;
574
575 if !is_stable_c32(sum + product) {
577 overflow_detected = true;
578 break;
579 }
580
581 sum += product;
582 }
583
584 if overflow_detected {
586 result[[k]] = stabilize_c32(sum * scale_factor);
587 } else {
588 result[[k]] = sum;
589 }
590 }
591
592 Ok(Tensor::C32(result))
593 },
594 Tensor::C64(a) => {
595 if a.shape().len() != 1 {
596 return Err(TrustformersError::tensor_op_error(
597 "FFT currently only supports 1D tensors",
598 "complex FFT operation",
599 ));
600 }
601
602 let n = a.len();
603 if n == 0 {
604 return Err(TrustformersError::tensor_op_error(
605 "FFT requires non-empty tensor",
606 "complex FFT operation",
607 ));
608 }
609
610 let mut result = ArrayD::zeros(IxDyn(&[n]));
611 let n_f64 = n as f64;
612
613 let scale_factor = 1.0 / n_f64.sqrt();
615
616 for k in 0..n {
617 let mut sum = Complex64::new(0.0, 0.0);
618 let mut overflow_detected = false;
619
620 for j in 0..n {
621 if !is_stable_c64(a[[j]]) {
623 continue; }
625
626 let angle = -2.0 * std::f64::consts::PI * (k * j) as f64 / n_f64;
627 let twiddle = Complex64::new(angle.cos(), angle.sin());
628
629 let product = a[[j]] * twiddle;
630
631 if !is_stable_c64(sum + product) {
633 overflow_detected = true;
634 break;
635 }
636
637 sum += product;
638 }
639
640 if overflow_detected {
642 result[[k]] = stabilize_c64(sum * scale_factor);
643 } else {
644 result[[k]] = sum;
645 }
646 }
647
648 Ok(Tensor::C64(result))
649 },
650 _ => Err(TrustformersError::tensor_op_error(
651 "FFT only supports complex tensors",
652 "complex FFT operation",
653 )),
654 }
655 }
656
657 pub fn complex_matmul(&self, other: &Tensor) -> Result<Tensor> {
670 match (self, other) {
671 (Tensor::C32(a), Tensor::C32(b)) => {
672 if a.shape().len() != 2 || b.shape().len() != 2 {
673 return Err(TrustformersError::tensor_op_error(
674 "Complex matrix multiplication requires 2D tensors",
675 "complex matrix multiplication",
676 ));
677 }
678
679 let a_rows = a.shape()[0];
680 let a_cols = a.shape()[1];
681 let b_rows = b.shape()[0];
682 let b_cols = b.shape()[1];
683
684 if a_cols != b_rows {
685 return Err(TrustformersError::tensor_op_error(
686 "Matrix dimensions incompatible for multiplication",
687 "complex matrix multiplication",
688 ));
689 }
690
691 if a_rows == 0 || a_cols == 0 || b_cols == 0 {
693 return Err(TrustformersError::tensor_op_error(
694 "Matrix multiplication requires non-zero dimensions",
695 "complex matrix multiplication",
696 ));
697 }
698
699 let mut result = ArrayD::zeros(IxDyn(&[a_rows, b_cols]));
700
701 for i in 0..a_rows {
703 for j in 0..b_cols {
704 let mut sum = Complex32::new(0.0, 0.0);
705 let mut compensation = Complex32::new(0.0, 0.0); let mut unstable_count = 0;
707
708 for k in 0..a_cols {
709 let a_val = a[[i, k]];
710 let b_val = b[[k, j]];
711
712 if !is_stable_c32(a_val) || !is_stable_c32(b_val) {
714 unstable_count += 1;
715 continue;
716 }
717
718 let product = a_val * b_val;
719
720 let y = product - compensation;
722 let t = sum + y;
723 compensation = (t - sum) - y;
724 sum = t;
725
726 if !is_stable_c32(sum) {
728 sum = stabilize_c32(sum);
729 break;
730 }
731 }
732
733 if unstable_count > a_cols / 2 {
735 sum = stabilize_c32(sum * Complex32::new(0.5, 0.0));
736 }
737
738 result[[i, j]] = sum;
739 }
740 }
741
742 Ok(Tensor::C32(result))
743 },
744 (Tensor::C64(a), Tensor::C64(b)) => {
745 if a.shape().len() != 2 || b.shape().len() != 2 {
746 return Err(TrustformersError::tensor_op_error(
747 "Complex matrix multiplication requires 2D tensors",
748 "complex matrix multiplication",
749 ));
750 }
751
752 let a_rows = a.shape()[0];
753 let a_cols = a.shape()[1];
754 let b_rows = b.shape()[0];
755 let b_cols = b.shape()[1];
756
757 if a_cols != b_rows {
758 return Err(TrustformersError::tensor_op_error(
759 "Matrix dimensions incompatible for multiplication",
760 "complex matrix multiplication",
761 ));
762 }
763
764 if a_rows == 0 || a_cols == 0 || b_cols == 0 {
766 return Err(TrustformersError::tensor_op_error(
767 "Matrix multiplication requires non-zero dimensions",
768 "complex matrix multiplication",
769 ));
770 }
771
772 let mut result = ArrayD::zeros(IxDyn(&[a_rows, b_cols]));
773
774 for i in 0..a_rows {
776 for j in 0..b_cols {
777 let mut sum = Complex64::new(0.0, 0.0);
778 let mut compensation = Complex64::new(0.0, 0.0); let mut unstable_count = 0;
780
781 for k in 0..a_cols {
782 let a_val = a[[i, k]];
783 let b_val = b[[k, j]];
784
785 if !is_stable_c64(a_val) || !is_stable_c64(b_val) {
787 unstable_count += 1;
788 continue;
789 }
790
791 let product = a_val * b_val;
792
793 let y = product - compensation;
795 let t = sum + y;
796 compensation = (t - sum) - y;
797 sum = t;
798
799 if !is_stable_c64(sum) {
801 sum = stabilize_c64(sum);
802 break;
803 }
804 }
805
806 if unstable_count > a_cols / 2 {
808 sum = stabilize_c64(sum * Complex64::new(0.5, 0.0));
809 }
810
811 result[[i, j]] = sum;
812 }
813 }
814
815 Ok(Tensor::C64(result))
816 },
817 _ => Err(TrustformersError::tensor_op_error(
818 "Complex matrix multiplication requires matching complex tensor types",
819 "complex matrix multiplication",
820 )),
821 }
822 }
823
824 pub fn complex_relu(&self) -> Result<Tensor> {
833 match self {
834 Tensor::C32(a) => {
835 let result = a.mapv(|x| Complex32::new(x.re.max(0.0), x.im.max(0.0)));
836 Ok(Tensor::C32(result))
837 },
838 Tensor::C64(a) => {
839 let result = a.mapv(|x| Complex64::new(x.re.max(0.0), x.im.max(0.0)));
840 Ok(Tensor::C64(result))
841 },
842 Tensor::CF16(a) => {
843 let result = a.mapv(|x| {
844 let re_f32 = x.re.to_f32().max(0.0);
845 let im_f32 = x.im.to_f32().max(0.0);
846 Complex::new(half::f16::from_f32(re_f32), half::f16::from_f32(im_f32))
847 });
848 Ok(Tensor::CF16(result))
849 },
850 Tensor::CBF16(a) => {
851 let result = a.mapv(|x| {
852 let re_f32 = x.re.to_f32().max(0.0);
853 let im_f32 = x.im.to_f32().max(0.0);
854 Complex::new(half::bf16::from_f32(re_f32), half::bf16::from_f32(im_f32))
855 });
856 Ok(Tensor::CBF16(result))
857 },
858 _ => Err(TrustformersError::tensor_op_error(
859 "Complex ReLU only supports complex tensors",
860 "complex ReLU activation",
861 )),
862 }
863 }
864}