Skip to main content

scirs2_core/ndarray/
mod.rs

1//! Complete ndarray re-export for SciRS2 ecosystem
2//!
3//! This module provides a single, unified access point for ALL ndarray functionality,
4//! ensuring SciRS2 POLICY compliance across the entire ecosystem.
5//!
6//! ## Design Philosophy
7//!
8//! 1. **Complete Feature Parity**: All ndarray functionality available through scirs2-core
9//! 2. **Zero Breaking Changes**: Existing ndarray_ext continues to work
10//! 3. **Policy Compliance**: No need for direct ndarray imports anywhere
11//! 4. **Single Source of Truth**: One place for all array operations
12//!
13//! ## Usage
14//!
15//! ```rust
16//! // Instead of:
17//! use ndarray::{Array, array, s, Axis};  // ❌ POLICY violation
18//!
19//! // Use:
20//! use scirs2_core::ndarray::*;  // ✅ POLICY compliant
21//!
22//! let arr = array![[1, 2], [3, 4]];
23//! let slice = arr.slice(s![.., 0]);
24//! ```
25
26// ========================================
27// COMPLETE NDARRAY RE-EXPORT
28// ========================================
29
30// Complete ndarray 0.17 re-export (no version switching needed anymore)
31// Use ::ndarray to refer to the external crate (not this module)
32pub use ::ndarray::*;
33
34// Note: All macros (array!, s!, azip!, etc.) are already included via `pub use ::ndarray::*;`
35
36// ========================================
37// NDARRAY-RELATED CRATE RE-EXPORTS
38// ========================================
39
40#[cfg(feature = "random")]
41pub use ndarray_rand::{rand_distr as distributions, RandomExt, SamplingStrategy};
42
43// Note: ndarray_rand is compatible with both ndarray 0.16 and 0.17
44
45// NOTE: ndarray_linalg removed - using OxiBLAS via scirs2_core::linalg module
46
47#[cfg(feature = "array_stats")]
48pub use ndarray_stats::{
49    errors as stats_errors, interpolate, CorrelationExt, DeviationExt, MaybeNan, QuantileExt,
50    Sort1dExt, SummaryStatisticsExt,
51};
52
53// NOTE: ndarray_npy removed to eliminate `zip` crate from dependency tree (COOLJAPAN Pure Rust Policy)
54
55// ========================================
56// ENHANCED FUNCTIONALITY
57// ========================================
58
59/// Additional utilities for SciRS2 ecosystem
60pub mod utils {
61    use super::*;
62
63    /// Create an identity matrix
64    pub fn eye<A>(n: usize) -> Array2<A>
65    where
66        A: Clone + num_traits::Zero + num_traits::One,
67    {
68        let mut arr = Array2::zeros((n, n));
69        for i in 0..n {
70            arr[[i, i]] = A::one();
71        }
72        arr
73    }
74
75    /// Create a diagonal matrix from a vector
76    pub fn diag<A>(v: &Array1<A>) -> Array2<A>
77    where
78        A: Clone + num_traits::Zero,
79    {
80        let n = v.len();
81        let mut arr = Array2::zeros((n, n));
82        for i in 0..n {
83            arr[[i, i]] = v[i].clone();
84        }
85        arr
86    }
87
88    /// Check if arrays are approximately equal
89    pub fn allclose<A, D>(
90        a: &ArrayBase<impl Data<Elem = A>, D>,
91        b: &ArrayBase<impl Data<Elem = A>, D>,
92        rtol: A,
93        atol: A,
94    ) -> bool
95    where
96        A: PartialOrd
97            + std::ops::Sub<Output = A>
98            + std::ops::Mul<Output = A>
99            + std::ops::Add<Output = A>
100            + Clone,
101        D: Dimension,
102    {
103        if a.shape() != b.shape() {
104            return false;
105        }
106
107        a.iter().zip(b.iter()).all(|(a_val, b_val)| {
108            let diff = if a_val > b_val {
109                a_val.clone() - b_val.clone()
110            } else {
111                b_val.clone() - a_val.clone()
112            };
113
114            let threshold = atol.clone()
115                + rtol.clone()
116                    * (if a_val > b_val {
117                        a_val.clone()
118                    } else {
119                        b_val.clone()
120                    });
121
122            diff <= threshold
123        })
124    }
125
126    /// Concatenate arrays along an axis
127    pub fn concatenate<A, D>(
128        axis: Axis,
129        arrays: &[ArrayView<A, D>],
130    ) -> Result<Array<A, D>, ShapeError>
131    where
132        A: Clone,
133        D: Dimension + RemoveAxis,
134    {
135        ndarray::concatenate(axis, arrays)
136    }
137
138    /// Stack arrays along a new axis
139    pub fn stack<A, D>(
140        axis: Axis,
141        arrays: &[ArrayView<A, D>],
142    ) -> Result<Array<A, D::Larger>, ShapeError>
143    where
144        A: Clone,
145        D: Dimension,
146        D::Larger: RemoveAxis,
147    {
148        ndarray::stack(axis, arrays)
149    }
150}
151
152// ========================================
153// COMPATIBILITY LAYER
154// ========================================
155
156/// Compatibility module for smooth migration from fragmented imports
157/// and ndarray version changes (SciRS2 POLICY compliance)
158pub mod compat {
159    pub use super::*;
160    use crate::numeric::{Float, FromPrimitive};
161
162    /// Alias for commonly used types to match existing usage patterns
163    pub type DynArray<T> = ArrayD<T>;
164    pub type Matrix<T> = Array2<T>;
165    pub type Vector<T> = Array1<T>;
166    pub type Tensor3<T> = Array3<T>;
167    pub type Tensor4<T> = Array4<T>;
168
169    /// Compatibility extensions for ndarray statistical operations
170    ///
171    /// This trait provides stable statistical operation APIs that remain consistent
172    /// across ndarray version updates, implementing the SciRS2 POLICY principle
173    /// of isolating external dependency changes to scirs2-core only.
174    ///
175    /// ## Rationale
176    ///
177    /// ndarray's statistical methods have changed across versions:
178    /// - v0.16: `.mean()` returns `Option<T>`
179    /// - v0.17: `.mean()` returns `T` directly (may be NaN for invalid operations)
180    ///
181    /// This trait provides a consistent API regardless of the underlying ndarray version.
182    ///
183    /// ## Example
184    ///
185    /// ```rust
186    /// use scirs2_core::ndarray::{Array1, compat::ArrayStatCompat};
187    ///
188    /// let data = Array1::from(vec![1.0, 2.0, 3.0]);
189    /// let mean = data.mean_or(0.0);  // Stable API across ndarray versions
190    /// ```
191    pub trait ArrayStatCompat<T> {
192        /// Compute the mean of the array, returning a default value if computation fails
193        ///
194        /// This method abstracts over ndarray version differences:
195        /// - For ndarray 0.16: Unwraps the Option, using default if None
196        /// - For ndarray 0.17+: Returns the value, using default if NaN
197        fn mean_or(&self, default: T) -> T;
198
199        /// Compute the variance with optional default
200        fn var_or(&self, ddof: T, default: T) -> T;
201
202        /// Compute the standard deviation with optional default
203        fn std_or(&self, ddof: T, default: T) -> T;
204    }
205
206    impl<T, S, D> ArrayStatCompat<T> for ArrayBase<S, D>
207    where
208        T: Float + FromPrimitive,
209        S: Data<Elem = T>,
210        D: Dimension,
211    {
212        fn mean_or(&self, default: T) -> T {
213            // ndarray returns Option<T> in both 0.16 and 0.17
214            self.mean().unwrap_or(default)
215        }
216
217        fn var_or(&self, ddof: T, default: T) -> T {
218            // ndarray returns T directly (may be NaN for invalid inputs)
219            let v = self.var(ddof);
220            if v.is_nan() {
221                default
222            } else {
223                v
224            }
225        }
226
227        fn std_or(&self, ddof: T, default: T) -> T {
228            // ndarray returns T directly (may be NaN for invalid inputs)
229            let s = self.std(ddof);
230            if s.is_nan() {
231                default
232            } else {
233                s
234            }
235        }
236    }
237
238    /// Re-export from ndarray_ext for backward compatibility
239    pub use crate::ndarray_ext::{
240        broadcast_1d_to_2d,
241        broadcast_apply,
242        fancy_index_2d,
243        // Keep existing extended functionality
244        indexing,
245        is_broadcast_compatible,
246        manipulation,
247        mask_select,
248        matrix,
249        reshape_2d,
250        split_2d,
251        stack_2d,
252        stats,
253        take_2d,
254        transpose_2d,
255        where_condition,
256    };
257}
258
259// ========================================
260// PRELUDE MODULE
261// ========================================
262
263/// Prelude module with most commonly used items
264pub mod prelude {
265    pub use super::{
266        arr1,
267        arr2,
268        // Essential macros
269        array,
270        azip,
271        // Utilities
272        concatenate,
273        s,
274        stack,
275
276        stack as stack_fn,
277        // Essential types
278        Array,
279        Array0,
280        Array1,
281        Array2,
282        Array3,
283        ArrayD,
284        ArrayView,
285        ArrayView1,
286        ArrayView2,
287        ArrayViewMut,
288
289        // Common operations
290        Axis,
291        // Essential traits
292        Dimension,
293        Ix1,
294        Ix2,
295        Ix3,
296        IxDyn,
297        ScalarOperand,
298        ShapeBuilder,
299
300        Zip,
301    };
302
303    #[cfg(feature = "random")]
304    pub use super::RandomExt;
305
306    // Useful type aliases
307    pub type Matrix<T> = super::Array2<T>;
308    pub type Vector<T> = super::Array1<T>;
309}
310
311// ========================================
312// EXAMPLES MODULE
313// ========================================
314
315#[cfg(test)]
316pub mod examples {
317    //! Examples demonstrating unified ndarray access through scirs2-core
318
319    use super::*;
320
321    /// Example: Using all essential ndarray features through scirs2-core
322    ///
323    /// ```
324    /// use scirs2_core::ndarray::*;
325    ///
326    /// // Create arrays using the array! macro
327    /// let a = array![[1, 2, 3], [4, 5, 6]];
328    ///
329    /// // Use the s! macro for slicing
330    /// let row = a.slice(s![0, ..]);
331    /// let col = a.slice(s![.., 1]);
332    ///
333    /// // Use Axis for operations
334    /// let sum_axis0 = a.sum_axis(Axis(0));
335    /// let mean_axis1 = a.mean_axis(Axis(1));
336    ///
337    /// // Stack and concatenate
338    /// let b = array![[7, 8, 9], [10, 11, 12]];
339    /// let stacked = stack![Axis(0), a, b];
340    ///
341    /// // Views and iteration
342    /// for row in a.axis_iter(Axis(0)) {
343    ///     println!("Row: {:?}", row);
344    /// }
345    /// ```
346    #[test]
347    fn test_complete_functionality() {
348        // Array creation
349        let a = array![[1., 2.], [3., 4.]];
350        assert_eq!(a.shape(), &[2, 2]);
351
352        // Slicing with s! macro
353        let slice = a.slice(s![.., 0]);
354        assert_eq!(slice.len(), 2);
355
356        // Mathematical operations
357        let b = &a + &a;
358        assert_eq!(b[[0, 0]], 2.);
359
360        // Axis operations
361        let sum = a.sum_axis(Axis(0));
362        assert_eq!(sum.len(), 2);
363
364        // Broadcasting
365        let c = array![1., 2.];
366        let d = &a + &c;
367        assert_eq!(d[[0, 0]], 2.);
368    }
369}
370
371// ========================================
372// MIGRATION GUIDE
373// ========================================
374
375pub mod migration_guide {
376    //! # Migration Guide: From Fragmented to Unified ndarray Access
377    //!
378    //! ## Before (Fragmented, Policy-Violating)
379    //!
380    //! ```rust,ignore
381    //! // Different files used different imports
382    //! use scirs2_autograd::ndarray::{Array1, array};
383    //! use scirs2_core::ndarray_ext::{ArrayView};
384    //! use ndarray::{s!, Axis};  // POLICY VIOLATION!
385    //! ```
386    //!
387    //! ## After (Unified, Policy-Compliant)
388    //!
389    //! ```rust,ignore
390    //! // Single, consistent import
391    //! use scirs2_core::ndarray::*;
392    //!
393    //! // Everything works:
394    //! let arr = array![[1, 2], [3, 4]];
395    //! let slice = arr.slice(s![.., 0]);
396    //! let view: ArrayView<_, _> = arr.view();
397    //! let sum = arr.sum_axis(Axis(0));
398    //! ```
399    //!
400    //! ## Benefits
401    //!
402    //! 1. **Single Import Path**: No more confusion about where to import from
403    //! 2. **Complete Functionality**: All ndarray features available
404    //! 3. **Policy Compliance**: No direct ndarray imports needed
405    //! 4. **Future-Proof**: Centralized control over array functionality
406    //!
407    //! ## Quick Reference
408    //!
409    //! | Old Import | New Import |
410    //! |------------|------------|
411    //! | `use ndarray::{Array, array}` | `use scirs2_core::ndarray::{Array, array}` |
412    //! | `use scirs2_autograd::ndarray::*` | `use scirs2_core::ndarray::*` |
413    //! | `use scirs2_core::ndarray_ext::*` | `use scirs2_core::ndarray::*` |
414    //! | `use ndarray::{s!, Axis}` | `use scirs2_core::ndarray::{s, Axis}` |
415}
416
417// Re-export compatibility traits for easy access
418pub use compat::ArrayStatCompat;
419
420#[cfg(test)]
421mod tests {
422    use super::*;
423
424    #[test]
425    fn test_array_macro_available() {
426        let arr = array![[1, 2], [3, 4]];
427        assert_eq!(arr.shape(), &[2, 2]);
428        assert_eq!(arr[[0, 0]], 1);
429    }
430
431    #[test]
432    fn test_s_macro_available() {
433        let arr = array![[1, 2, 3], [4, 5, 6]];
434        let slice = arr.slice(s![.., 1..]);
435        assert_eq!(slice.shape(), &[2, 2]);
436    }
437
438    #[test]
439    fn test_axis_operations() {
440        let arr = array![[1., 2.], [3., 4.]];
441        let sum = arr.sum_axis(Axis(0));
442        assert_eq!(sum, array![4., 6.]);
443    }
444
445    #[test]
446    fn test_views_and_iteration() {
447        let mut arr = array![[1, 2], [3, 4]];
448
449        // Test immutable view first
450        {
451            let view: ArrayView<_, _> = arr.view();
452            for val in view.iter() {
453                assert!(*val > 0);
454            }
455        }
456
457        // Test mutable view after immutable view is dropped
458        {
459            let mut view_mut: ArrayViewMut<_, _> = arr.view_mut();
460            for val in view_mut.iter_mut() {
461                *val *= 2;
462            }
463        }
464
465        assert_eq!(arr[[0, 0]], 2);
466    }
467
468    #[test]
469    fn test_concatenate_and_stack() {
470        let a = array![[1, 2], [3, 4]];
471        let b = array![[5, 6], [7, 8]];
472
473        // Concatenate along axis 0
474        let concat = concatenate(Axis(0), &[a.view(), b.view()]).expect("Operation failed");
475        assert_eq!(concat.shape(), &[4, 2]);
476
477        // Stack along new axis
478        let stacked =
479            crate::ndarray::stack(Axis(0), &[a.view(), b.view()]).expect("Operation failed");
480        assert_eq!(stacked.shape(), &[2, 2, 2]);
481    }
482
483    #[test]
484    fn test_zip_operations() {
485        let a = array![1, 2, 3];
486        let b = array![4, 5, 6];
487        let mut c = array![0, 0, 0];
488
489        azip!((a in &a, b in &b, c in &mut c) {
490            *c = a + b;
491        });
492
493        assert_eq!(c, array![5, 7, 9]);
494    }
495
496    #[test]
497    fn test_array_stat_compat() {
498        use compat::ArrayStatCompat;
499
500        // Test mean_or
501        let data = array![1.0, 2.0, 3.0, 4.0, 5.0];
502        assert_eq!(data.mean_or(0.0), 3.0);
503
504        let empty = Array1::<f64>::from(vec![]);
505        assert_eq!(empty.mean_or(0.0), 0.0);
506
507        // Test var_or
508        let data = array![1.0, 2.0, 3.0, 4.0, 5.0];
509        let var = data.var_or(1.0, 0.0);
510        assert!(var > 0.0);
511
512        // Test std_or
513        let std = data.std_or(1.0, 0.0);
514        assert!(std > 0.0);
515    }
516}
517
518// ========================================
519// NDARRAY-LINALG COMPATIBILITY LAYER
520// ========================================
521
522/// ndarray-linalg compatibility layer for backward compatibility
523///
524/// Provides traits matching ndarray-linalg API using OxiBLAS v0.1.2+ backend
525#[cfg(feature = "linalg")]
526pub mod ndarray_linalg {
527    use crate::linalg::prelude::*;
528    use crate::ndarray::*;
529    use num_complex::Complex;
530
531    // Import OxiBLAS v0.1.2+ Complex functions
532    use oxiblas_ndarray::lapack::{
533        cholesky_hermitian_ndarray, eig_hermitian_ndarray, qr_complex_ndarray, svd_complex_ndarray,
534    };
535
536    // Re-export error types
537    pub use crate::linalg::{LapackError, LapackResult};
538
539    /// UPLO enum for triangular matrix specification
540    #[derive(Debug, Clone, Copy, PartialEq, Eq)]
541    pub enum UPLO {
542        Upper,
543        Lower,
544    }
545
546    /// Linear system solver trait
547    pub trait Solve<A> {
548        fn solve_into(&self, b: &Array1<A>) -> Result<Array1<A>, LapackError>;
549    }
550
551    impl Solve<f64> for Array2<f64> {
552        #[inline]
553        fn solve_into(&self, b: &Array1<f64>) -> Result<Array1<f64>, LapackError> {
554            solve_ndarray(self, b)
555        }
556    }
557
558    impl Solve<Complex<f64>> for Array2<Complex<f64>> {
559        #[inline]
560        fn solve_into(
561            &self,
562            b: &Array1<Complex<f64>>,
563        ) -> Result<Array1<Complex<f64>>, LapackError> {
564            solve_ndarray(self, b)
565        }
566    }
567
568    /// SVD trait
569    pub trait SVD {
570        type Elem;
571        type Real;
572
573        fn svd(
574            &self,
575            compute_u: bool,
576            compute_vt: bool,
577        ) -> Result<(Array2<Self::Elem>, Array1<Self::Real>, Array2<Self::Elem>), LapackError>;
578    }
579
580    impl SVD for Array2<f64> {
581        type Elem = f64;
582        type Real = f64;
583
584        #[inline]
585        fn svd(
586            &self,
587            _compute_u: bool,
588            _compute_vt: bool,
589        ) -> Result<(Array2<f64>, Array1<f64>, Array2<f64>), LapackError> {
590            let result = svd_ndarray(self)?;
591            Ok((result.u, result.s, result.vt))
592        }
593    }
594
595    impl SVD for Array2<Complex<f64>> {
596        type Elem = Complex<f64>;
597        type Real = f64;
598
599        #[inline]
600        fn svd(
601            &self,
602            _compute_u: bool,
603            _compute_vt: bool,
604        ) -> Result<(Array2<Complex<f64>>, Array1<f64>, Array2<Complex<f64>>), LapackError>
605        {
606            let result = svd_complex_ndarray(self)?;
607            Ok((result.u, result.s, result.vt))
608        }
609    }
610
611    /// Hermitian/symmetric eigenvalue decomposition trait
612    pub trait Eigh {
613        type Elem;
614        type Real;
615
616        fn eigh(&self, uplo: UPLO)
617            -> Result<(Array1<Self::Real>, Array2<Self::Elem>), LapackError>;
618    }
619
620    impl Eigh for Array2<f64> {
621        type Elem = f64;
622        type Real = f64;
623
624        #[inline]
625        fn eigh(&self, _uplo: UPLO) -> Result<(Array1<f64>, Array2<f64>), LapackError> {
626            let result = eig_symmetric(self)?;
627            Ok((result.eigenvalues, result.eigenvectors))
628        }
629    }
630
631    impl Eigh for Array2<Complex<f64>> {
632        type Elem = Complex<f64>;
633        type Real = f64;
634
635        #[inline]
636        fn eigh(&self, _uplo: UPLO) -> Result<(Array1<f64>, Array2<Complex<f64>>), LapackError> {
637            eig_hermitian_ndarray(self)
638        }
639    }
640
641    /// Matrix norm trait
642    pub trait Norm {
643        type Real;
644
645        fn norm_l2(&self) -> Result<Self::Real, LapackError>;
646    }
647
648    impl Norm for Array2<f64> {
649        type Real = f64;
650
651        #[inline]
652        fn norm_l2(&self) -> Result<f64, LapackError> {
653            let sum_sq: f64 = self.iter().map(|x| x * x).sum();
654            Ok(sum_sq.sqrt())
655        }
656    }
657
658    impl Norm for Array2<Complex<f64>> {
659        type Real = f64;
660
661        #[inline]
662        fn norm_l2(&self) -> Result<f64, LapackError> {
663            let sum_sq: f64 = self.iter().map(|x| x.norm_sqr()).sum();
664            Ok(sum_sq.sqrt())
665        }
666    }
667
668    // Norm trait for Array1 (vectors)
669    impl Norm for Array1<f64> {
670        type Real = f64;
671
672        #[inline]
673        fn norm_l2(&self) -> Result<f64, LapackError> {
674            let sum_sq: f64 = self.iter().map(|x| x * x).sum();
675            Ok(sum_sq.sqrt())
676        }
677    }
678
679    impl Norm for Array1<Complex<f64>> {
680        type Real = f64;
681
682        #[inline]
683        fn norm_l2(&self) -> Result<f64, LapackError> {
684            let sum_sq: f64 = self.iter().map(|x| x.norm_sqr()).sum();
685            Ok(sum_sq.sqrt())
686        }
687    }
688
689    /// QR decomposition trait
690    pub trait QR {
691        type Elem;
692
693        fn qr(&self) -> Result<(Array2<Self::Elem>, Array2<Self::Elem>), LapackError>;
694    }
695
696    impl QR for Array2<f64> {
697        type Elem = f64;
698
699        #[inline]
700        fn qr(&self) -> Result<(Array2<f64>, Array2<f64>), LapackError> {
701            let result = qr_ndarray(self)?;
702            Ok((result.q, result.r))
703        }
704    }
705
706    impl QR for Array2<Complex<f64>> {
707        type Elem = Complex<f64>;
708
709        #[inline]
710        fn qr(&self) -> Result<(Array2<Complex<f64>>, Array2<Complex<f64>>), LapackError> {
711            let result = qr_complex_ndarray(self)?;
712            Ok((result.q, result.r))
713        }
714    }
715
716    /// Eigenvalue decomposition trait (general matrices)
717    pub trait Eig {
718        type Elem;
719
720        fn eig(&self) -> Result<(Array1<Self::Elem>, Array2<Self::Elem>), LapackError>;
721    }
722
723    // For general complex matrices use the complex QR algorithm (shifted QR iteration).
724    // The algorithm:
725    //  1. Reduce A to upper Hessenberg form H = Q^H A Q via Householder reflections.
726    //  2. Apply complex QR iteration with Wilkinson shifts on H until convergence.
727    //  3. Extract eigenvalues from the diagonal of the converged quasi-triangular form.
728    //  4. Compute right eigenvectors by back-substitution on the upper-triangular Schur form.
729    //  5. Transform eigenvectors back: X = Q * V.
730    impl Eig for Array2<Complex<f64>> {
731        type Elem = Complex<f64>;
732
733        fn eig(&self) -> Result<(Array1<Complex<f64>>, Array2<Complex<f64>>), LapackError> {
734            let (m, n) = self.dim();
735            if m != n {
736                return Err(LapackError::DimensionMismatch(
737                    "Matrix must be square for eigendecomposition".to_string(),
738                ));
739            }
740            if n == 0 {
741                return Ok((
742                    Array1::<Complex<f64>>::zeros(0),
743                    Array2::<Complex<f64>>::zeros((0, 0)),
744                ));
745            }
746            if n == 1 {
747                let eigenvalue = self[[0, 0]];
748                let eigenvector = Array2::from_elem((1, 1), Complex::new(1.0, 0.0));
749                return Ok((Array1::from_vec(vec![eigenvalue]), eigenvector));
750            }
751
752            // Step 1: Reduce to upper Hessenberg form via Householder reflections.
753            let mut h = self.clone();
754            let mut q = Array2::<Complex<f64>>::eye(n);
755
756            for col in 0..n.saturating_sub(2) {
757                let xlen = n - col - 1;
758                if xlen == 0 {
759                    continue;
760                }
761
762                let mut x: Vec<Complex<f64>> = (col + 1..n).map(|r| h[[r, col]]).collect();
763
764                let norm_x = x.iter().map(|v| v.norm_sqr()).sum::<f64>().sqrt();
765                if norm_x < 1e-300 {
766                    continue;
767                }
768
769                let phase = if x[0].norm() > 1e-300 {
770                    x[0] / x[0].norm()
771                } else {
772                    Complex::new(1.0, 0.0)
773                };
774                x[0] += phase * norm_x;
775
776                let norm_v = x.iter().map(|v| v.norm_sqr()).sum::<f64>().sqrt();
777                if norm_v < 1e-300 {
778                    continue;
779                }
780                let v: Vec<Complex<f64>> = x.iter().map(|vi| *vi / norm_v).collect();
781
782                // Apply (I - 2vv^H) from the left to H
783                for c in col..n {
784                    let dot: Complex<f64> = v
785                        .iter()
786                        .enumerate()
787                        .map(|(i, &vi)| vi.conj() * h[[col + 1 + i, c]])
788                        .sum();
789                    for (i, &vi) in v.iter().enumerate() {
790                        h[[col + 1 + i, c]] -= Complex::new(2.0, 0.0) * vi * dot;
791                    }
792                }
793
794                // Apply (I - 2vv^H) from the right to H
795                for r in 0..n {
796                    let dot: Complex<f64> = v
797                        .iter()
798                        .enumerate()
799                        .map(|(i, &vi)| h[[r, col + 1 + i]] * vi)
800                        .sum();
801                    for (i, &vi) in v.iter().enumerate() {
802                        h[[r, col + 1 + i]] -= Complex::new(2.0, 0.0) * dot * vi.conj();
803                    }
804                }
805
806                // Accumulate Q
807                for r in 0..n {
808                    let dot: Complex<f64> = v
809                        .iter()
810                        .enumerate()
811                        .map(|(i, &vi)| q[[r, col + 1 + i]] * vi)
812                        .sum();
813                    for (i, &vi) in v.iter().enumerate() {
814                        q[[r, col + 1 + i]] -= Complex::new(2.0, 0.0) * dot * vi.conj();
815                    }
816                }
817
818                // Zero subdiagonal entries below the first subdiagonal in column col
819                for r in col + 2..n {
820                    h[[r, col]] = Complex::new(0.0, 0.0);
821                }
822            }
823
824            // Step 2: Complex QR algorithm with Wilkinson shifts.
825            const MAX_ITER: usize = 30;
826            let mut p = n;
827
828            'outer: while p > 1 {
829                // Deflation check
830                let mut deflated = false;
831                for l in (1..p).rev() {
832                    let sub = h[[l, l - 1]].norm();
833                    let diag = h[[l - 1, l - 1]].norm() + h[[l, l]].norm();
834                    if sub <= 1e-14 * diag || sub <= f64::MIN_POSITIVE.sqrt() {
835                        h[[l, l - 1]] = Complex::new(0.0, 0.0);
836                        if l == p - 1 {
837                            p -= 1;
838                            deflated = true;
839                            break;
840                        }
841                    }
842                }
843                if deflated {
844                    continue 'outer;
845                }
846
847                let mut converged_inner = false;
848                for _iter in 0..MAX_ITER {
849                    // Wilkinson shift from bottom 2x2 block
850                    let a_sub = h[[p - 2, p - 2]];
851                    let b_sub = h[[p - 2, p - 1]];
852                    let c_sub = h[[p - 1, p - 2]];
853                    let d_sub = h[[p - 1, p - 1]];
854                    let tr = a_sub + d_sub;
855                    let det = a_sub * d_sub - b_sub * c_sub;
856                    let disc = (tr * tr - Complex::new(4.0, 0.0) * det).sqrt();
857                    let mu1 = (tr + disc) * Complex::new(0.5, 0.0);
858                    let mu2 = (tr - disc) * Complex::new(0.5, 0.0);
859                    let shift = if (mu1 - d_sub).norm() < (mu2 - d_sub).norm() {
860                        mu1
861                    } else {
862                        mu2
863                    };
864
865                    // One QR step with Givens rotations (preserves Hessenberg structure)
866                    for k in 0..p.saturating_sub(1) {
867                        let a_g = if k == 0 {
868                            h[[0, 0]] - shift
869                        } else {
870                            h[[k, k - 1]]
871                        };
872                        let b_g = h[[k + 1, k]];
873                        let r = (a_g.norm_sqr() + b_g.norm_sqr()).sqrt();
874                        if r < 1e-300 {
875                            continue;
876                        }
877                        let c = a_g / r;
878                        let s = b_g / r;
879
880                        // Apply from the left (rows k and k+1)
881                        let col_start = if k == 0 { 0 } else { k - 1 };
882                        for j in col_start..n {
883                            let t1 = c.conj() * h[[k, j]] + s.conj() * h[[k + 1, j]];
884                            let t2 = -s * h[[k, j]] + c * h[[k + 1, j]];
885                            h[[k, j]] = t1;
886                            h[[k + 1, j]] = t2;
887                        }
888
889                        // Apply from the right (cols k and k+1)
890                        let row_max = (k + 2).min(p);
891                        for i in 0..row_max {
892                            let t1 = h[[i, k]] * c + h[[i, k + 1]] * s;
893                            let t2 = h[[i, k]] * (-s.conj()) + h[[i, k + 1]] * c.conj();
894                            h[[i, k]] = t1;
895                            h[[i, k + 1]] = t2;
896                        }
897
898                        // Accumulate in Q
899                        for i in 0..n {
900                            let t1 = q[[i, k]] * c + q[[i, k + 1]] * s;
901                            let t2 = q[[i, k]] * (-s.conj()) + q[[i, k + 1]] * c.conj();
902                            q[[i, k]] = t1;
903                            q[[i, k + 1]] = t2;
904                        }
905                    }
906
907                    let sub_norm = h[[p - 1, p - 2]].norm();
908                    let diag_norm = h[[p - 2, p - 2]].norm() + h[[p - 1, p - 1]].norm();
909                    if sub_norm <= 1e-14 * diag_norm || sub_norm <= f64::MIN_POSITIVE.sqrt() {
910                        h[[p - 1, p - 2]] = Complex::new(0.0, 0.0);
911                        p -= 1;
912                        converged_inner = true;
913                        break;
914                    }
915                }
916
917                if !converged_inner {
918                    p -= 1; // Force deflation to avoid infinite loop
919                }
920            }
921
922            // Step 3: Extract eigenvalues from diagonal of the Schur form
923            let eigenvalues: Array1<Complex<f64>> = Array1::from_iter((0..n).map(|i| h[[i, i]]));
924
925            // Step 4: Compute right eigenvectors by back-substitution from the
926            // upper-triangular Schur form.
927            let mut vecs = Array2::<Complex<f64>>::zeros((n, n));
928            for ei in 0..n {
929                let lambda = eigenvalues[ei];
930                let mut v = vec![Complex::new(0.0, 0.0); n];
931                v[ei] = Complex::new(1.0, 0.0);
932
933                for row in (0..ei).rev() {
934                    let mut sum = Complex::new(0.0, 0.0);
935                    for col in row + 1..=ei {
936                        sum += h[[row, col]] * v[col];
937                    }
938                    let diag = h[[row, row]] - lambda;
939                    v[row] = if diag.norm() > 1e-14 {
940                        -sum / diag
941                    } else {
942                        Complex::new(0.0, 0.0)
943                    };
944                }
945
946                let norm = v.iter().map(|vi| vi.norm_sqr()).sum::<f64>().sqrt();
947                if norm > 1e-300 {
948                    for vi in &mut v {
949                        *vi /= norm;
950                    }
951                } else {
952                    v[ei] = Complex::new(1.0, 0.0);
953                }
954
955                for row in 0..n {
956                    vecs[[row, ei]] = v[row];
957                }
958            }
959
960            // Step 5: Transform eigenvectors back to original basis: X = Q * V
961            let eigenvectors = q.dot(&vecs);
962            Ok((eigenvalues, eigenvectors))
963        }
964    }
965
966    /// Cholesky decomposition trait
967    pub trait Cholesky {
968        type Elem;
969
970        fn cholesky(&self, uplo: UPLO) -> Result<Array2<Self::Elem>, LapackError>;
971    }
972
973    impl Cholesky for Array2<f64> {
974        type Elem = f64;
975
976        #[inline]
977        fn cholesky(&self, _uplo: UPLO) -> Result<Array2<f64>, LapackError> {
978            let result = cholesky_ndarray(self)?;
979            Ok(result.l)
980        }
981    }
982
983    impl Cholesky for Array2<Complex<f64>> {
984        type Elem = Complex<f64>;
985
986        #[inline]
987        fn cholesky(&self, _uplo: UPLO) -> Result<Array2<Complex<f64>>, LapackError> {
988            let result = cholesky_hermitian_ndarray(self)?;
989            Ok(result.l)
990        }
991    }
992}