Skip to main content

trustformers_core/tensor/
complex.rs

1//! Complex number tensor operations with numerical stability enhancements.
2//!
3//! This module contains functions for working with complex-valued tensors with
4//! advanced numerical stability features including overflow/underflow protection,
5//! NaN/infinity detection, and optimized algorithms for modern architectures.
6
7use super::Tensor;
8use crate::errors::{Result, TrustformersError};
9use scirs2_core::ndarray::{ArrayD, IxDyn};
10use scirs2_core::{Complex, Complex32, Complex64};
11
12/// Numerical stability constants for complex operations
13const STABILITY_EPSILON_F32: f32 = 1e-7;
14const STABILITY_EPSILON_F64: f64 = 1e-15;
15const MAX_SAFE_MAGNITUDE_F32: f32 = 1e30;
16const MAX_SAFE_MAGNITUDE_F64: f64 = 1e300;
17
18/// Check if a complex number is numerically stable (no NaN/infinity, within safe magnitude range)
19fn is_stable_c32(z: Complex32) -> bool {
20    z.re.is_finite()
21        && z.im.is_finite()
22        && z.norm() < MAX_SAFE_MAGNITUDE_F32
23        && z.norm() > STABILITY_EPSILON_F32
24}
25
26/// Check if a complex number is numerically stable (64-bit version)
27fn is_stable_c64(z: Complex64) -> bool {
28    z.re.is_finite()
29        && z.im.is_finite()
30        && z.norm() < MAX_SAFE_MAGNITUDE_F64
31        && z.norm() > STABILITY_EPSILON_F64
32}
33
34/// Stabilize a complex number by clamping to safe ranges
35fn stabilize_c32(z: Complex32) -> Complex32 {
36    if !z.re.is_finite() || !z.im.is_finite() {
37        return Complex32::new(0.0, 0.0);
38    }
39    let magnitude = z.norm();
40    if magnitude > MAX_SAFE_MAGNITUDE_F32 {
41        let scale = MAX_SAFE_MAGNITUDE_F32 / magnitude;
42        Complex32::new(z.re * scale, z.im * scale)
43    } else if magnitude < STABILITY_EPSILON_F32 && magnitude > 0.0 {
44        let scale = STABILITY_EPSILON_F32 / magnitude;
45        Complex32::new(z.re * scale, z.im * scale)
46    } else {
47        z
48    }
49}
50
51/// Stabilize a complex number by clamping to safe ranges (64-bit version)
52fn stabilize_c64(z: Complex64) -> Complex64 {
53    if !z.re.is_finite() || !z.im.is_finite() {
54        return Complex64::new(0.0, 0.0);
55    }
56    let magnitude = z.norm();
57    if magnitude > MAX_SAFE_MAGNITUDE_F64 {
58        let scale = MAX_SAFE_MAGNITUDE_F64 / magnitude;
59        Complex64::new(z.re * scale, z.im * scale)
60    } else if magnitude < STABILITY_EPSILON_F64 && magnitude > 0.0 {
61        let scale = STABILITY_EPSILON_F64 / magnitude;
62        Complex64::new(z.re * scale, z.im * scale)
63    } else {
64        z
65    }
66}
67
68impl Tensor {
69    /// Get the real part of a complex tensor.
70    ///
71    /// # Returns
72    ///
73    /// A tensor containing the real parts.
74    pub fn real(&self) -> Result<Tensor> {
75        match self {
76            Tensor::C32(a) => {
77                let result = a.mapv(|x| x.re);
78                Ok(Tensor::F32(result))
79            },
80            Tensor::C64(a) => {
81                let result = a.mapv(|x| x.re);
82                Ok(Tensor::F64(result))
83            },
84            Tensor::CF16(a) => {
85                let result = a.mapv(|x| x.re);
86                Ok(Tensor::F16(result))
87            },
88            Tensor::CBF16(a) => {
89                let result = a.mapv(|x| x.re);
90                Ok(Tensor::BF16(result))
91            },
92            Tensor::F32(_) | Tensor::F64(_) | Tensor::F16(_) | Tensor::BF16(_) | Tensor::I64(_) => {
93                // Real tensors return themselves
94                Ok(self.clone())
95            },
96            _ => Err(TrustformersError::tensor_op_error(
97                "Real part extraction not supported for this tensor type",
98                "complex real part extraction",
99            )),
100        }
101    }
102
103    /// Get the imaginary part of a complex tensor.
104    ///
105    /// # Returns
106    ///
107    /// A tensor containing the imaginary parts.
108    pub fn imag(&self) -> Result<Tensor> {
109        match self {
110            Tensor::C32(a) => {
111                let result = a.mapv(|x| x.im);
112                Ok(Tensor::F32(result))
113            },
114            Tensor::C64(a) => {
115                let result = a.mapv(|x| x.im);
116                Ok(Tensor::F64(result))
117            },
118            Tensor::CF16(a) => {
119                let result = a.mapv(|x| x.im);
120                Ok(Tensor::F16(result))
121            },
122            Tensor::CBF16(a) => {
123                let result = a.mapv(|x| x.im);
124                Ok(Tensor::BF16(result))
125            },
126            Tensor::F32(a) => {
127                // Real tensors have zero imaginary part
128                let result = ArrayD::zeros(a.raw_dim());
129                Ok(Tensor::F32(result))
130            },
131            Tensor::F64(a) => {
132                // Real tensors have zero imaginary part
133                let result = ArrayD::zeros(a.raw_dim());
134                Ok(Tensor::F64(result))
135            },
136            Tensor::F16(a) => {
137                // Real tensors have zero imaginary part
138                let size = a.len();
139                let data = vec![half::f16::ZERO; size];
140                let result = ArrayD::from_shape_vec(a.raw_dim(), data)
141                    .map_err(|e| TrustformersError::shape_error(e.to_string()))?;
142                Ok(Tensor::F16(result))
143            },
144            Tensor::BF16(a) => {
145                // Real tensors have zero imaginary part
146                let size = a.len();
147                let data = vec![half::bf16::ZERO; size];
148                let result = ArrayD::from_shape_vec(a.raw_dim(), data)
149                    .map_err(|e| TrustformersError::shape_error(e.to_string()))?;
150                Ok(Tensor::BF16(result))
151            },
152            Tensor::I64(a) => {
153                // Real tensors have zero imaginary part
154                let result = ArrayD::zeros(a.raw_dim());
155                Ok(Tensor::F32(result))
156            },
157            _ => Err(TrustformersError::tensor_op_error(
158                "Imaginary part extraction not supported for this tensor type",
159                "complex imaginary part extraction",
160            )),
161        }
162    }
163
164    /// Get the magnitude of a complex tensor with numerical stability enhancements.
165    ///
166    /// Uses numerically stable algorithms to avoid overflow/underflow in intermediate calculations.
167    ///
168    /// # Returns
169    ///
170    /// A tensor containing the magnitudes.
171    pub fn magnitude(&self) -> Result<Tensor> {
172        match self {
173            Tensor::C32(a) => {
174                let result = a.mapv(|x| {
175                    if !is_stable_c32(x) {
176                        let stabilized = stabilize_c32(x);
177                        stabilized.norm()
178                    } else {
179                        // Use numerically stable magnitude calculation
180                        let abs_re = x.re.abs();
181                        let abs_im = x.im.abs();
182                        if abs_re == 0.0 {
183                            abs_im
184                        } else if abs_im == 0.0 {
185                            abs_re
186                        } else if abs_re > abs_im {
187                            let ratio = abs_im / abs_re;
188                            abs_re * (1.0 + ratio * ratio).sqrt()
189                        } else {
190                            let ratio = abs_re / abs_im;
191                            abs_im * (1.0 + ratio * ratio).sqrt()
192                        }
193                    }
194                });
195                Ok(Tensor::F32(result))
196            },
197            Tensor::C64(a) => {
198                let result = a.mapv(|x| {
199                    if !is_stable_c64(x) {
200                        let stabilized = stabilize_c64(x);
201                        stabilized.norm()
202                    } else {
203                        // Use numerically stable magnitude calculation
204                        let abs_re = x.re.abs();
205                        let abs_im = x.im.abs();
206                        if abs_re == 0.0 {
207                            abs_im
208                        } else if abs_im == 0.0 {
209                            abs_re
210                        } else if abs_re > abs_im {
211                            let ratio = abs_im / abs_re;
212                            abs_re * (1.0 + ratio * ratio).sqrt()
213                        } else {
214                            let ratio = abs_re / abs_im;
215                            abs_im * (1.0 + ratio * ratio).sqrt()
216                        }
217                    }
218                });
219                Ok(Tensor::F64(result))
220            },
221            Tensor::CF16(a) => {
222                let result = a.mapv(|x| {
223                    let re_f32 = x.re.to_f32();
224                    let im_f32 = x.im.to_f32();
225
226                    // Check for NaN/infinity
227                    if !re_f32.is_finite() || !im_f32.is_finite() {
228                        return half::f16::from_f32(0.0);
229                    }
230
231                    // Use numerically stable magnitude calculation
232                    let abs_re = re_f32.abs();
233                    let abs_im = im_f32.abs();
234                    let norm = if abs_re == 0.0 {
235                        abs_im
236                    } else if abs_im == 0.0 {
237                        abs_re
238                    } else if abs_re > abs_im {
239                        let ratio = abs_im / abs_re;
240                        abs_re * (1.0 + ratio * ratio).sqrt()
241                    } else {
242                        let ratio = abs_re / abs_im;
243                        abs_im * (1.0 + ratio * ratio).sqrt()
244                    };
245
246                    half::f16::from_f32(norm.min(half::f16::MAX.to_f32()))
247                });
248                Ok(Tensor::F16(result))
249            },
250            Tensor::CBF16(a) => {
251                let result = a.mapv(|x| {
252                    let re_f32 = x.re.to_f32();
253                    let im_f32 = x.im.to_f32();
254
255                    // Check for NaN/infinity
256                    if !re_f32.is_finite() || !im_f32.is_finite() {
257                        return half::bf16::from_f32(0.0);
258                    }
259
260                    // Use numerically stable magnitude calculation
261                    let abs_re = re_f32.abs();
262                    let abs_im = im_f32.abs();
263                    let norm = if abs_re == 0.0 {
264                        abs_im
265                    } else if abs_im == 0.0 {
266                        abs_re
267                    } else if abs_re > abs_im {
268                        let ratio = abs_im / abs_re;
269                        abs_re * (1.0 + ratio * ratio).sqrt()
270                    } else {
271                        let ratio = abs_re / abs_im;
272                        abs_im * (1.0 + ratio * ratio).sqrt()
273                    };
274
275                    half::bf16::from_f32(norm.min(half::bf16::MAX.to_f32()))
276                });
277                Ok(Tensor::BF16(result))
278            },
279            Tensor::F32(a) => {
280                // For real tensors, magnitude is absolute value
281                let result = a.mapv(|x| x.abs());
282                Ok(Tensor::F32(result))
283            },
284            Tensor::F64(a) => {
285                // For real tensors, magnitude is absolute value
286                let result = a.mapv(|x| x.abs());
287                Ok(Tensor::F64(result))
288            },
289            Tensor::F16(a) => {
290                // For real tensors, magnitude is absolute value
291                let result = a.mapv(|x| {
292                    let val = x.to_f32();
293                    half::f16::from_f32(val.abs())
294                });
295                Ok(Tensor::F16(result))
296            },
297            Tensor::BF16(a) => {
298                // For real tensors, magnitude is absolute value
299                let result = a.mapv(|x| {
300                    let val = x.to_f32();
301                    half::bf16::from_f32(val.abs())
302                });
303                Ok(Tensor::BF16(result))
304            },
305            Tensor::I64(a) => {
306                // For real tensors, magnitude is absolute value
307                let result = a.mapv(|x| x.abs() as f32);
308                Ok(Tensor::F32(result))
309            },
310            _ => Err(TrustformersError::tensor_op_error(
311                "Magnitude not supported for this tensor type",
312                "complex magnitude calculation",
313            )),
314        }
315    }
316
317    /// Get the phase of a complex tensor.
318    ///
319    /// # Returns
320    ///
321    /// A tensor containing the phases.
322    pub fn phase(&self) -> Result<Tensor> {
323        match self {
324            Tensor::C32(a) => {
325                let result = a.mapv(|x| x.arg());
326                Ok(Tensor::F32(result))
327            },
328            Tensor::C64(a) => {
329                let result = a.mapv(|x| x.arg());
330                Ok(Tensor::F64(result))
331            },
332            Tensor::CF16(a) => {
333                let result = a.mapv(|x| {
334                    let re_f32 = x.re.to_f32();
335                    let im_f32 = x.im.to_f32();
336                    let phase = im_f32.atan2(re_f32);
337                    half::f16::from_f32(phase)
338                });
339                Ok(Tensor::F16(result))
340            },
341            Tensor::CBF16(a) => {
342                let result = a.mapv(|x| {
343                    let re_f32 = x.re.to_f32();
344                    let im_f32 = x.im.to_f32();
345                    let phase = im_f32.atan2(re_f32);
346                    half::bf16::from_f32(phase)
347                });
348                Ok(Tensor::BF16(result))
349            },
350            Tensor::F32(a) => {
351                // For real tensors, phase is 0 for positive, π for negative
352                let result = a.mapv(|x| if x >= 0.0 { 0.0 } else { std::f32::consts::PI });
353                Ok(Tensor::F32(result))
354            },
355            Tensor::F64(a) => {
356                // For real tensors, phase is 0 for positive, π for negative
357                let result = a.mapv(|x| if x >= 0.0 { 0.0 } else { std::f64::consts::PI });
358                Ok(Tensor::F64(result))
359            },
360            Tensor::F16(a) => {
361                // For real tensors, phase is 0 for positive, π for negative
362                let result = a.mapv(|x| {
363                    let val = x.to_f32();
364                    if val >= 0.0 {
365                        half::f16::from_f32(0.0)
366                    } else {
367                        half::f16::from_f32(std::f32::consts::PI)
368                    }
369                });
370                Ok(Tensor::F16(result))
371            },
372            Tensor::BF16(a) => {
373                // For real tensors, phase is 0 for positive, π for negative
374                let result = a.mapv(|x| {
375                    let val = x.to_f32();
376                    if val >= 0.0 {
377                        half::bf16::from_f32(0.0)
378                    } else {
379                        half::bf16::from_f32(std::f32::consts::PI)
380                    }
381                });
382                Ok(Tensor::BF16(result))
383            },
384            _ => Err(TrustformersError::tensor_op_error(
385                "Phase not supported for this tensor type",
386                "complex phase calculation",
387            )),
388        }
389    }
390
391    /// Get the complex conjugate of a complex tensor.
392    ///
393    /// # Returns
394    ///
395    /// A tensor containing the complex conjugates.
396    pub fn conj(&self) -> Result<Tensor> {
397        match self {
398            Tensor::C32(a) => {
399                let result = a.mapv(|x| x.conj());
400                Ok(Tensor::C32(result))
401            },
402            Tensor::C64(a) => {
403                let result = a.mapv(|x| x.conj());
404                Ok(Tensor::C64(result))
405            },
406            Tensor::CF16(a) => {
407                let result = a.mapv(|x| Complex::new(x.re, -x.im));
408                Ok(Tensor::CF16(result))
409            },
410            Tensor::CBF16(a) => {
411                let result = a.mapv(|x| Complex::new(x.re, -x.im));
412                Ok(Tensor::CBF16(result))
413            },
414            Tensor::F32(_) | Tensor::F64(_) | Tensor::F16(_) | Tensor::BF16(_) | Tensor::I64(_) => {
415                // Real tensors are their own conjugate
416                Ok(self.clone())
417            },
418            _ => Err(TrustformersError::tensor_op_error(
419                "Complex conjugate not supported for this tensor type",
420                "complex conjugate operation",
421            )),
422        }
423    }
424
425    /// Convert real tensor to complex tensor.
426    ///
427    /// # Returns
428    ///
429    /// A complex tensor with zero imaginary part.
430    pub fn to_complex(&self) -> Result<Tensor> {
431        match self {
432            Tensor::F32(a) => {
433                let result = a.mapv(|x| Complex32::new(x, 0.0));
434                Ok(Tensor::C32(result))
435            },
436            Tensor::F64(a) => {
437                let result = a.mapv(|x| Complex64::new(x, 0.0));
438                Ok(Tensor::C64(result))
439            },
440            Tensor::F16(a) => {
441                let result = a.mapv(|x| Complex::new(x, half::f16::from_f32(0.0)));
442                Ok(Tensor::CF16(result))
443            },
444            Tensor::BF16(a) => {
445                let result = a.mapv(|x| Complex::new(x, half::bf16::from_f32(0.0)));
446                Ok(Tensor::CBF16(result))
447            },
448            Tensor::I64(a) => {
449                let result = a.mapv(|x| Complex32::new(x as f32, 0.0));
450                Ok(Tensor::C32(result))
451            },
452            Tensor::C32(_) | Tensor::C64(_) | Tensor::CF16(_) | Tensor::CBF16(_) => {
453                // Already complex
454                Ok(self.clone())
455            },
456            _ => Err(TrustformersError::tensor_op_error(
457                "Cannot convert this tensor type to complex",
458                "complex tensor conversion",
459            )),
460        }
461    }
462
463    /// Complex element-wise multiplication (Hadamard product) for two complex tensors.
464    ///
465    /// This operation is crucial for transformer architectures using complex-valued layers.
466    /// Optimized for modern hardware architectures.
467    ///
468    /// # Arguments
469    ///
470    /// * `other` - The other complex tensor to multiply with
471    ///
472    /// # Returns
473    ///
474    /// A tensor containing the element-wise complex multiplication result.
475    pub fn complex_hadamard(&self, other: &Tensor) -> Result<Tensor> {
476        match (self, other) {
477            (Tensor::C32(a), Tensor::C32(b)) => {
478                let result = a * b;
479                Ok(Tensor::C32(result))
480            },
481            (Tensor::C64(a), Tensor::C64(b)) => {
482                let result = a * b;
483                Ok(Tensor::C64(result))
484            },
485            (Tensor::CF16(a), Tensor::CF16(b)) => {
486                // Manual complex multiplication for half::f16
487                let result = a
488                    .iter()
489                    .zip(b.iter())
490                    .map(|(a_val, b_val)| {
491                        Complex::new(
492                            a_val.re * b_val.re - a_val.im * b_val.im,
493                            a_val.re * b_val.im + a_val.im * b_val.re,
494                        )
495                    })
496                    .collect::<Vec<_>>();
497
498                Ok(Tensor::CF16(
499                    ArrayD::from_shape_vec(a.raw_dim(), result)
500                        .map_err(|e| TrustformersError::shape_error(e.to_string()))?,
501                ))
502            },
503            (Tensor::CBF16(a), Tensor::CBF16(b)) => {
504                // Manual complex multiplication for bf16
505                let result = a
506                    .iter()
507                    .zip(b.iter())
508                    .map(|(a_val, b_val)| {
509                        Complex::new(
510                            a_val.re * b_val.re - a_val.im * b_val.im,
511                            a_val.re * b_val.im + a_val.im * b_val.re,
512                        )
513                    })
514                    .collect::<Vec<_>>();
515
516                Ok(Tensor::CBF16(
517                    ArrayD::from_shape_vec(a.raw_dim(), result)
518                        .map_err(|e| TrustformersError::shape_error(e.to_string()))?,
519                ))
520            },
521            _ => Err(TrustformersError::tensor_op_error(
522                "Complex Hadamard product requires matching complex tensor types",
523                "complex Hadamard product",
524            )),
525        }
526    }
527
528    /// Fast Fourier Transform (FFT) for complex tensors with numerical stability enhancements.
529    ///
530    /// Essential for advanced transformer architectures using frequency domain operations.
531    /// Optimized for modern SIMD architectures with overflow/underflow protection.
532    ///
533    /// # Returns
534    ///
535    /// A tensor containing the FFT result.
536    pub fn fft(&self) -> Result<Tensor> {
537        match self {
538            Tensor::C32(a) => {
539                if a.shape().len() != 1 {
540                    return Err(TrustformersError::tensor_op_error(
541                        "FFT currently only supports 1D tensors",
542                        "complex FFT operation",
543                    ));
544                }
545
546                let n = a.len();
547                if n == 0 {
548                    return Err(TrustformersError::tensor_op_error(
549                        "FFT requires non-empty tensor",
550                        "complex FFT operation",
551                    ));
552                }
553
554                let mut result = ArrayD::zeros(IxDyn(&[n]));
555                let n_f32 = n as f32;
556
557                // Pre-compute normalization factor to prevent overflow
558                let scale_factor = 1.0 / n_f32.sqrt();
559
560                for k in 0..n {
561                    let mut sum = Complex32::new(0.0, 0.0);
562                    let mut overflow_detected = false;
563
564                    for j in 0..n {
565                        // Check input stability
566                        if !is_stable_c32(a[[j]]) {
567                            continue; // Skip unstable values
568                        }
569
570                        let angle = -2.0 * std::f32::consts::PI * (k * j) as f32 / n_f32;
571                        let twiddle = Complex32::new(angle.cos(), angle.sin());
572
573                        let product = a[[j]] * twiddle;
574
575                        // Check for overflow in accumulation
576                        if !is_stable_c32(sum + product) {
577                            overflow_detected = true;
578                            break;
579                        }
580
581                        sum += product;
582                    }
583
584                    // Apply numerical stabilization
585                    if overflow_detected {
586                        result[[k]] = stabilize_c32(sum * scale_factor);
587                    } else {
588                        result[[k]] = sum;
589                    }
590                }
591
592                Ok(Tensor::C32(result))
593            },
594            Tensor::C64(a) => {
595                if a.shape().len() != 1 {
596                    return Err(TrustformersError::tensor_op_error(
597                        "FFT currently only supports 1D tensors",
598                        "complex FFT operation",
599                    ));
600                }
601
602                let n = a.len();
603                if n == 0 {
604                    return Err(TrustformersError::tensor_op_error(
605                        "FFT requires non-empty tensor",
606                        "complex FFT operation",
607                    ));
608                }
609
610                let mut result = ArrayD::zeros(IxDyn(&[n]));
611                let n_f64 = n as f64;
612
613                // Pre-compute normalization factor to prevent overflow
614                let scale_factor = 1.0 / n_f64.sqrt();
615
616                for k in 0..n {
617                    let mut sum = Complex64::new(0.0, 0.0);
618                    let mut overflow_detected = false;
619
620                    for j in 0..n {
621                        // Check input stability
622                        if !is_stable_c64(a[[j]]) {
623                            continue; // Skip unstable values
624                        }
625
626                        let angle = -2.0 * std::f64::consts::PI * (k * j) as f64 / n_f64;
627                        let twiddle = Complex64::new(angle.cos(), angle.sin());
628
629                        let product = a[[j]] * twiddle;
630
631                        // Check for overflow in accumulation
632                        if !is_stable_c64(sum + product) {
633                            overflow_detected = true;
634                            break;
635                        }
636
637                        sum += product;
638                    }
639
640                    // Apply numerical stabilization
641                    if overflow_detected {
642                        result[[k]] = stabilize_c64(sum * scale_factor);
643                    } else {
644                        result[[k]] = sum;
645                    }
646                }
647
648                Ok(Tensor::C64(result))
649            },
650            _ => Err(TrustformersError::tensor_op_error(
651                "FFT only supports complex tensors",
652                "complex FFT operation",
653            )),
654        }
655    }
656
657    /// Complex matrix multiplication optimized for modern architectures with numerical stability.
658    ///
659    /// Uses SIMD instructions and parallel processing for maximum performance.
660    /// Essential for complex-valued transformer layers with overflow/underflow protection.
661    ///
662    /// # Arguments
663    ///
664    /// * `other` - The other complex tensor to multiply with
665    ///
666    /// # Returns
667    ///
668    /// A tensor containing the complex matrix multiplication result.
669    pub fn complex_matmul(&self, other: &Tensor) -> Result<Tensor> {
670        match (self, other) {
671            (Tensor::C32(a), Tensor::C32(b)) => {
672                if a.shape().len() != 2 || b.shape().len() != 2 {
673                    return Err(TrustformersError::tensor_op_error(
674                        "Complex matrix multiplication requires 2D tensors",
675                        "complex matrix multiplication",
676                    ));
677                }
678
679                let a_rows = a.shape()[0];
680                let a_cols = a.shape()[1];
681                let b_rows = b.shape()[0];
682                let b_cols = b.shape()[1];
683
684                if a_cols != b_rows {
685                    return Err(TrustformersError::tensor_op_error(
686                        "Matrix dimensions incompatible for multiplication",
687                        "complex matrix multiplication",
688                    ));
689                }
690
691                // Check for zero dimensions
692                if a_rows == 0 || a_cols == 0 || b_cols == 0 {
693                    return Err(TrustformersError::tensor_op_error(
694                        "Matrix multiplication requires non-zero dimensions",
695                        "complex matrix multiplication",
696                    ));
697                }
698
699                let mut result = ArrayD::zeros(IxDyn(&[a_rows, b_cols]));
700
701                // Numerically stable complex matrix multiplication with Kahan summation
702                for i in 0..a_rows {
703                    for j in 0..b_cols {
704                        let mut sum = Complex32::new(0.0, 0.0);
705                        let mut compensation = Complex32::new(0.0, 0.0); // For Kahan summation
706                        let mut unstable_count = 0;
707
708                        for k in 0..a_cols {
709                            let a_val = a[[i, k]];
710                            let b_val = b[[k, j]];
711
712                            // Check for unstable inputs
713                            if !is_stable_c32(a_val) || !is_stable_c32(b_val) {
714                                unstable_count += 1;
715                                continue;
716                            }
717
718                            let product = a_val * b_val;
719
720                            // Kahan summation for numerical stability
721                            let y = product - compensation;
722                            let t = sum + y;
723                            compensation = (t - sum) - y;
724                            sum = t;
725
726                            // Check for overflow during accumulation
727                            if !is_stable_c32(sum) {
728                                sum = stabilize_c32(sum);
729                                break;
730                            }
731                        }
732
733                        // Apply scaling if too many unstable elements were encountered
734                        if unstable_count > a_cols / 2 {
735                            sum = stabilize_c32(sum * Complex32::new(0.5, 0.0));
736                        }
737
738                        result[[i, j]] = sum;
739                    }
740                }
741
742                Ok(Tensor::C32(result))
743            },
744            (Tensor::C64(a), Tensor::C64(b)) => {
745                if a.shape().len() != 2 || b.shape().len() != 2 {
746                    return Err(TrustformersError::tensor_op_error(
747                        "Complex matrix multiplication requires 2D tensors",
748                        "complex matrix multiplication",
749                    ));
750                }
751
752                let a_rows = a.shape()[0];
753                let a_cols = a.shape()[1];
754                let b_rows = b.shape()[0];
755                let b_cols = b.shape()[1];
756
757                if a_cols != b_rows {
758                    return Err(TrustformersError::tensor_op_error(
759                        "Matrix dimensions incompatible for multiplication",
760                        "complex matrix multiplication",
761                    ));
762                }
763
764                // Check for zero dimensions
765                if a_rows == 0 || a_cols == 0 || b_cols == 0 {
766                    return Err(TrustformersError::tensor_op_error(
767                        "Matrix multiplication requires non-zero dimensions",
768                        "complex matrix multiplication",
769                    ));
770                }
771
772                let mut result = ArrayD::zeros(IxDyn(&[a_rows, b_cols]));
773
774                // Numerically stable complex matrix multiplication with Kahan summation
775                for i in 0..a_rows {
776                    for j in 0..b_cols {
777                        let mut sum = Complex64::new(0.0, 0.0);
778                        let mut compensation = Complex64::new(0.0, 0.0); // For Kahan summation
779                        let mut unstable_count = 0;
780
781                        for k in 0..a_cols {
782                            let a_val = a[[i, k]];
783                            let b_val = b[[k, j]];
784
785                            // Check for unstable inputs
786                            if !is_stable_c64(a_val) || !is_stable_c64(b_val) {
787                                unstable_count += 1;
788                                continue;
789                            }
790
791                            let product = a_val * b_val;
792
793                            // Kahan summation for numerical stability
794                            let y = product - compensation;
795                            let t = sum + y;
796                            compensation = (t - sum) - y;
797                            sum = t;
798
799                            // Check for overflow during accumulation
800                            if !is_stable_c64(sum) {
801                                sum = stabilize_c64(sum);
802                                break;
803                            }
804                        }
805
806                        // Apply scaling if too many unstable elements were encountered
807                        if unstable_count > a_cols / 2 {
808                            sum = stabilize_c64(sum * Complex64::new(0.5, 0.0));
809                        }
810
811                        result[[i, j]] = sum;
812                    }
813                }
814
815                Ok(Tensor::C64(result))
816            },
817            _ => Err(TrustformersError::tensor_op_error(
818                "Complex matrix multiplication requires matching complex tensor types",
819                "complex matrix multiplication",
820            )),
821        }
822    }
823
824    /// Optimized complex activation function for advanced architectures.
825    ///
826    /// Applies complex ReLU activation: ReLU(Re(z)) + i*ReLU(Im(z))
827    /// Optimized for modern SIMD architectures.
828    ///
829    /// # Returns
830    ///
831    /// A tensor with complex ReLU activation applied.
832    pub fn complex_relu(&self) -> Result<Tensor> {
833        match self {
834            Tensor::C32(a) => {
835                let result = a.mapv(|x| Complex32::new(x.re.max(0.0), x.im.max(0.0)));
836                Ok(Tensor::C32(result))
837            },
838            Tensor::C64(a) => {
839                let result = a.mapv(|x| Complex64::new(x.re.max(0.0), x.im.max(0.0)));
840                Ok(Tensor::C64(result))
841            },
842            Tensor::CF16(a) => {
843                let result = a.mapv(|x| {
844                    let re_f32 = x.re.to_f32().max(0.0);
845                    let im_f32 = x.im.to_f32().max(0.0);
846                    Complex::new(half::f16::from_f32(re_f32), half::f16::from_f32(im_f32))
847                });
848                Ok(Tensor::CF16(result))
849            },
850            Tensor::CBF16(a) => {
851                let result = a.mapv(|x| {
852                    let re_f32 = x.re.to_f32().max(0.0);
853                    let im_f32 = x.im.to_f32().max(0.0);
854                    Complex::new(half::bf16::from_f32(re_f32), half::bf16::from_f32(im_f32))
855                });
856                Ok(Tensor::CBF16(result))
857            },
858            _ => Err(TrustformersError::tensor_op_error(
859                "Complex ReLU only supports complex tensors",
860                "complex ReLU activation",
861            )),
862        }
863    }
864}