Skip to main content

u_numflow/
matrix.rs

1//! Dense matrix operations.
2//!
3//! A minimal, row-major dense matrix type with fundamental linear algebra
4//! operations required for statistical analysis: multiplication, transpose,
5//! determinant, inverse, Cholesky decomposition, and triangular solves.
6//!
7//! # Design
8//!
9//! - **Row-major storage**: `data[i * cols + j] = A[i, j]`
10//! - **No external dependencies**: Pure Rust, no nalgebra/LAPACK
11//! - **Partial pivoting**: Used in LU and Gauss-Jordan for numerical stability
12//! - **Explicit error handling**: Returns `Result<T, MatrixError>` with descriptive variants
13//!
14//! # Examples
15//!
16//! ```
17//! use u_numflow::matrix::Matrix;
18//!
19//! let a = Matrix::from_rows(&[
20//!     &[1.0, 2.0],
21//!     &[3.0, 4.0],
22//! ]);
23//! let b = a.transpose();
24//! let c = a.mul_mat(&b).unwrap();
25//! assert_eq!(c.rows(), 2);
26//! assert_eq!(c.cols(), 2);
27//! ```
28
29/// Error type for matrix operations.
30#[derive(Debug, Clone, PartialEq)]
31pub enum MatrixError {
32    /// Dimensions do not match for the operation.
33    DimensionMismatch {
34        expected: (usize, usize),
35        got: (usize, usize),
36    },
37    /// Matrix must be square for this operation.
38    NotSquare { rows: usize, cols: usize },
39    /// Matrix is singular (zero or near-zero pivot encountered).
40    Singular,
41    /// Matrix is not symmetric (required for Cholesky).
42    NotSymmetric,
43    /// Matrix is not positive-definite (required for Cholesky).
44    NotPositiveDefinite,
45    /// Data length does not match dimensions.
46    InvalidData { expected: usize, got: usize },
47}
48
49impl std::fmt::Display for MatrixError {
50    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51        match self {
52            MatrixError::DimensionMismatch { expected, got } => {
53                write!(
54                    f,
55                    "dimension mismatch: expected {}×{}, got {}×{}",
56                    expected.0, expected.1, got.0, got.1
57                )
58            }
59            MatrixError::NotSquare { rows, cols } => {
60                write!(f, "matrix must be square, got {rows}×{cols}")
61            }
62            MatrixError::Singular => write!(f, "matrix is singular"),
63            MatrixError::NotSymmetric => write!(f, "matrix is not symmetric"),
64            MatrixError::NotPositiveDefinite => write!(f, "matrix is not positive-definite"),
65            MatrixError::InvalidData { expected, got } => {
66                write!(f, "data length mismatch: expected {expected}, got {got}")
67            }
68        }
69    }
70}
71
72impl std::error::Error for MatrixError {}
73
74/// A dense matrix stored in row-major order.
75///
76/// # Storage
77///
78/// Elements are stored contiguously: `data[i * cols + j]` holds `A[i, j]`.
79#[derive(Debug, Clone, PartialEq)]
80pub struct Matrix {
81    data: Vec<f64>,
82    rows: usize,
83    cols: usize,
84}
85
86impl Matrix {
87    /// Creates a matrix from raw data in row-major order.
88    ///
89    /// # Errors
90    /// Returns `Err` if `data.len() != rows * cols`.
91    ///
92    /// # Examples
93    /// ```
94    /// use u_numflow::matrix::Matrix;
95    /// let m = Matrix::new(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
96    /// assert_eq!(m.get(0, 2), 3.0);
97    /// assert_eq!(m.get(1, 0), 4.0);
98    /// ```
99    pub fn new(rows: usize, cols: usize, data: Vec<f64>) -> Result<Self, MatrixError> {
100        if data.len() != rows * cols {
101            return Err(MatrixError::InvalidData {
102                expected: rows * cols,
103                got: data.len(),
104            });
105        }
106        Ok(Self { data, rows, cols })
107    }
108
109    /// Creates a matrix from row slices.
110    ///
111    /// # Panics
112    /// Panics if rows have inconsistent lengths or `rows` is empty.
113    ///
114    /// # Examples
115    /// ```
116    /// use u_numflow::matrix::Matrix;
117    /// let m = Matrix::from_rows(&[&[1.0, 2.0], &[3.0, 4.0]]);
118    /// assert_eq!(m.get(1, 1), 4.0);
119    /// ```
120    pub fn from_rows(rows: &[&[f64]]) -> Self {
121        assert!(!rows.is_empty(), "must have at least one row");
122        let ncols = rows[0].len();
123        assert!(ncols > 0, "must have at least one column");
124        let nrows = rows.len();
125        let mut data = Vec::with_capacity(nrows * ncols);
126        for (i, row) in rows.iter().enumerate() {
127            assert_eq!(
128                row.len(),
129                ncols,
130                "row {i} has {} columns, expected {ncols}",
131                row.len()
132            );
133            data.extend_from_slice(row);
134        }
135        Self {
136            data,
137            rows: nrows,
138            cols: ncols,
139        }
140    }
141
142    /// Creates a zero matrix.
143    pub fn zeros(rows: usize, cols: usize) -> Self {
144        Self {
145            data: vec![0.0; rows * cols],
146            rows,
147            cols,
148        }
149    }
150
151    /// Creates an identity matrix.
152    ///
153    /// # Examples
154    /// ```
155    /// use u_numflow::matrix::Matrix;
156    /// let eye = Matrix::identity(3);
157    /// assert_eq!(eye.get(0, 0), 1.0);
158    /// assert_eq!(eye.get(0, 1), 0.0);
159    /// assert_eq!(eye.get(2, 2), 1.0);
160    /// ```
161    pub fn identity(n: usize) -> Self {
162        let mut m = Self::zeros(n, n);
163        for i in 0..n {
164            m.data[i * n + i] = 1.0;
165        }
166        m
167    }
168
169    /// Creates a column vector (n×1 matrix) from a slice.
170    pub fn from_col(data: &[f64]) -> Self {
171        Self {
172            data: data.to_vec(),
173            rows: data.len(),
174            cols: 1,
175        }
176    }
177
178    /// Number of rows.
179    #[inline]
180    pub fn rows(&self) -> usize {
181        self.rows
182    }
183
184    /// Number of columns.
185    #[inline]
186    pub fn cols(&self) -> usize {
187        self.cols
188    }
189
190    /// Returns the element at (row, col).
191    ///
192    /// # Panics
193    /// Panics if indices are out of bounds.
194    #[inline]
195    pub fn get(&self, row: usize, col: usize) -> f64 {
196        self.data[row * self.cols + col]
197    }
198
199    /// Sets the element at (row, col).
200    ///
201    /// # Panics
202    /// Panics if indices are out of bounds.
203    #[inline]
204    pub fn set(&mut self, row: usize, col: usize, value: f64) {
205        self.data[row * self.cols + col] = value;
206    }
207
208    /// Returns the raw data as a slice.
209    pub fn data(&self) -> &[f64] {
210        &self.data
211    }
212
213    /// Returns a row as a slice.
214    #[inline]
215    pub fn row(&self, row: usize) -> &[f64] {
216        let start = row * self.cols;
217        &self.data[start..start + self.cols]
218    }
219
220    /// Returns the diagonal elements.
221    pub fn diag(&self) -> Vec<f64> {
222        let n = self.rows.min(self.cols);
223        (0..n).map(|i| self.get(i, i)).collect()
224    }
225
226    /// Returns true if the matrix is square.
227    pub fn is_square(&self) -> bool {
228        self.rows == self.cols
229    }
230
231    // ========================================================================
232    // Basic Operations
233    // ========================================================================
234
235    /// Transpose: returns Aᵀ.
236    ///
237    /// # Examples
238    /// ```
239    /// use u_numflow::matrix::Matrix;
240    /// let m = Matrix::from_rows(&[&[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0]]);
241    /// let t = m.transpose();
242    /// assert_eq!(t.rows(), 3);
243    /// assert_eq!(t.cols(), 2);
244    /// assert_eq!(t.get(0, 1), 4.0);
245    /// ```
246    pub fn transpose(&self) -> Self {
247        let mut result = Self::zeros(self.cols, self.rows);
248        for i in 0..self.rows {
249            for j in 0..self.cols {
250                result.data[j * self.rows + i] = self.data[i * self.cols + j];
251            }
252        }
253        result
254    }
255
256    /// Matrix addition: A + B.
257    ///
258    /// # Errors
259    /// Returns `Err` if dimensions do not match.
260    pub fn add(&self, other: &Self) -> Result<Self, MatrixError> {
261        if self.rows != other.rows || self.cols != other.cols {
262            return Err(MatrixError::DimensionMismatch {
263                expected: (self.rows, self.cols),
264                got: (other.rows, other.cols),
265            });
266        }
267        let data: Vec<f64> = self
268            .data
269            .iter()
270            .zip(&other.data)
271            .map(|(a, b)| a + b)
272            .collect();
273        Ok(Self {
274            data,
275            rows: self.rows,
276            cols: self.cols,
277        })
278    }
279
280    /// Matrix subtraction: A - B.
281    ///
282    /// # Errors
283    /// Returns `Err` if dimensions do not match.
284    pub fn sub(&self, other: &Self) -> Result<Self, MatrixError> {
285        if self.rows != other.rows || self.cols != other.cols {
286            return Err(MatrixError::DimensionMismatch {
287                expected: (self.rows, self.cols),
288                got: (other.rows, other.cols),
289            });
290        }
291        let data: Vec<f64> = self
292            .data
293            .iter()
294            .zip(&other.data)
295            .map(|(a, b)| a - b)
296            .collect();
297        Ok(Self {
298            data,
299            rows: self.rows,
300            cols: self.cols,
301        })
302    }
303
304    /// Scalar multiplication: c · A.
305    pub fn scale(&self, c: f64) -> Self {
306        let data: Vec<f64> = self.data.iter().map(|x| c * x).collect();
307        Self {
308            data,
309            rows: self.rows,
310            cols: self.cols,
311        }
312    }
313
314    /// Matrix multiplication: A · B.
315    ///
316    /// Uses i-k-j loop order for better cache locality on row-major storage.
317    ///
318    /// # Errors
319    /// Returns `Err` if `self.cols != other.rows`.
320    ///
321    /// # Complexity
322    /// O(n·m·p) where self is n×m and other is m×p.
323    ///
324    /// # Examples
325    /// ```
326    /// use u_numflow::matrix::Matrix;
327    /// let a = Matrix::identity(3);
328    /// let b = Matrix::from_rows(&[&[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0], &[7.0, 8.0, 9.0]]);
329    /// let c = a.mul_mat(&b).unwrap();
330    /// assert_eq!(c.get(2, 2), 9.0);
331    /// ```
332    pub fn mul_mat(&self, other: &Self) -> Result<Self, MatrixError> {
333        if self.cols != other.rows {
334            return Err(MatrixError::DimensionMismatch {
335                expected: (self.rows, self.cols),
336                got: (other.rows, other.cols),
337            });
338        }
339        let mut result = Self::zeros(self.rows, other.cols);
340        // i-k-j loop order for row-major cache friendliness
341        for i in 0..self.rows {
342            for k in 0..self.cols {
343                let a_ik = self.data[i * self.cols + k];
344                let row_start = i * other.cols;
345                let other_row_start = k * other.cols;
346                for j in 0..other.cols {
347                    result.data[row_start + j] += a_ik * other.data[other_row_start + j];
348                }
349            }
350        }
351        Ok(result)
352    }
353
354    /// Matrix-vector multiplication: A · v.
355    ///
356    /// # Errors
357    /// Returns `Err` if `self.cols != v.len()`.
358    pub fn mul_vec(&self, v: &[f64]) -> Result<Vec<f64>, MatrixError> {
359        if self.cols != v.len() {
360            return Err(MatrixError::DimensionMismatch {
361                expected: (self.rows, self.cols),
362                got: (v.len(), 1),
363            });
364        }
365        let mut result = vec![0.0; self.rows];
366        for (i, res) in result.iter_mut().enumerate() {
367            let row_start = i * self.cols;
368            *res = self.data[row_start..row_start + self.cols]
369                .iter()
370                .zip(v.iter())
371                .map(|(&a, &b)| a * b)
372                .sum();
373        }
374        Ok(result)
375    }
376
377    /// Frobenius norm: ‖A‖_F = √(Σᵢⱼ aᵢⱼ²).
378    pub fn frobenius_norm(&self) -> f64 {
379        self.data.iter().map(|x| x * x).sum::<f64>().sqrt()
380    }
381
382    /// Checks whether the matrix is symmetric within tolerance.
383    pub fn is_symmetric(&self, tol: f64) -> bool {
384        if self.rows != self.cols {
385            return false;
386        }
387        for i in 0..self.rows {
388            for j in (i + 1)..self.cols {
389                if (self.get(i, j) - self.get(j, i)).abs() > tol {
390                    return false;
391                }
392            }
393        }
394        true
395    }
396
397    fn swap_rows(&mut self, a: usize, b: usize) {
398        if a == b {
399            return;
400        }
401        let cols = self.cols;
402        for j in 0..cols {
403            self.data.swap(a * cols + j, b * cols + j);
404        }
405    }
406
407    // ========================================================================
408    // Decompositions
409    // ========================================================================
410
411    /// Determinant via LU decomposition with partial pivoting.
412    ///
413    /// # Errors
414    /// Returns `Err(NotSquare)` if the matrix is not square.
415    /// Returns `Err(Singular)` if a zero pivot is encountered.
416    ///
417    /// # Complexity
418    /// O(n³/3).
419    ///
420    /// # Examples
421    /// ```
422    /// use u_numflow::matrix::Matrix;
423    /// let m = Matrix::from_rows(&[&[2.0, 3.0], &[1.0, 4.0]]);
424    /// assert!((m.determinant().unwrap() - 5.0).abs() < 1e-10);
425    /// ```
426    pub fn determinant(&self) -> Result<f64, MatrixError> {
427        if !self.is_square() {
428            return Err(MatrixError::NotSquare {
429                rows: self.rows,
430                cols: self.cols,
431            });
432        }
433        let n = self.rows;
434        if n == 0 {
435            return Ok(1.0);
436        }
437        if n == 1 {
438            return Ok(self.data[0]);
439        }
440
441        let mut work = self.clone();
442        let mut sign = 1.0_f64;
443        let pivot_tol = 1e-15 * self.frobenius_norm().max(1e-300);
444
445        for k in 0..n {
446            // Partial pivoting
447            let mut max_val = work.get(k, k).abs();
448            let mut max_row = k;
449            for i in (k + 1)..n {
450                let v = work.get(i, k).abs();
451                if v > max_val {
452                    max_val = v;
453                    max_row = i;
454                }
455            }
456            if max_val <= pivot_tol {
457                return Ok(0.0); // Singular → det = 0
458            }
459            if max_row != k {
460                work.swap_rows(k, max_row);
461                sign = -sign;
462            }
463
464            let pivot = work.get(k, k);
465            for i in (k + 1)..n {
466                let factor = work.get(i, k) / pivot;
467                for j in (k + 1)..n {
468                    let val = work.get(i, j) - factor * work.get(k, j);
469                    work.set(i, j, val);
470                }
471            }
472        }
473
474        let mut det = sign;
475        for i in 0..n {
476            det *= work.get(i, i);
477        }
478        Ok(det)
479    }
480
481    /// Matrix inverse via Gauss-Jordan elimination with partial pivoting.
482    ///
483    /// # Algorithm
484    /// Augments [A | I], reduces to [I | A⁻¹] using row operations.
485    ///
486    /// Reference: Golub & Van Loan (1996), *Matrix Computations*, §1.2.
487    ///
488    /// # Errors
489    /// Returns `Err(Singular)` if the matrix is singular.
490    ///
491    /// # Examples
492    /// ```
493    /// use u_numflow::matrix::Matrix;
494    /// let a = Matrix::from_rows(&[&[4.0, 7.0], &[2.0, 6.0]]);
495    /// let inv = a.inverse().unwrap();
496    /// let eye = a.mul_mat(&inv).unwrap();
497    /// assert!((eye.get(0, 0) - 1.0).abs() < 1e-10);
498    /// assert!(eye.get(0, 1).abs() < 1e-10);
499    /// ```
500    pub fn inverse(&self) -> Result<Self, MatrixError> {
501        if !self.is_square() {
502            return Err(MatrixError::NotSquare {
503                rows: self.rows,
504                cols: self.cols,
505            });
506        }
507        let n = self.rows;
508        if n == 0 {
509            return Ok(Self::zeros(0, 0));
510        }
511
512        // Augmented matrix [A | I]
513        let n2 = 2 * n;
514        let mut aug = Self::zeros(n, n2);
515        for i in 0..n {
516            for j in 0..n {
517                aug.set(i, j, self.get(i, j));
518            }
519            aug.set(i, n + i, 1.0);
520        }
521
522        let pivot_tol = 1e-14 * self.frobenius_norm().max(1e-300);
523
524        for k in 0..n {
525            // Partial pivoting
526            let mut max_val = aug.get(k, k).abs();
527            let mut max_row = k;
528            for i in (k + 1)..n {
529                let v = aug.get(i, k).abs();
530                if v > max_val {
531                    max_val = v;
532                    max_row = i;
533                }
534            }
535            if max_val <= pivot_tol {
536                return Err(MatrixError::Singular);
537            }
538            if max_row != k {
539                aug.swap_rows(k, max_row);
540            }
541
542            // Scale pivot row
543            let pivot = aug.get(k, k);
544            for j in 0..n2 {
545                aug.set(k, j, aug.get(k, j) / pivot);
546            }
547
548            // Eliminate column k in all other rows
549            for i in 0..n {
550                if i != k {
551                    let factor = aug.get(i, k);
552                    for j in 0..n2 {
553                        let val = aug.get(i, j) - factor * aug.get(k, j);
554                        aug.set(i, j, val);
555                    }
556                }
557            }
558        }
559
560        // Extract right half
561        let mut inv = Self::zeros(n, n);
562        for i in 0..n {
563            for j in 0..n {
564                inv.set(i, j, aug.get(i, n + j));
565            }
566        }
567        Ok(inv)
568    }
569
570    /// Cholesky decomposition: returns lower-triangular L such that A = L·Lᵀ.
571    ///
572    /// # Algorithm
573    /// Column-by-column Cholesky-Banachiewicz factorization.
574    ///
575    /// Reference: Golub & Van Loan (1996), *Matrix Computations*, Algorithm 4.2.1.
576    ///
577    /// # Requirements
578    /// Matrix must be symmetric and positive-definite.
579    ///
580    /// # Complexity
581    /// O(n³/3).
582    ///
583    /// # Examples
584    /// ```
585    /// use u_numflow::matrix::Matrix;
586    /// let a = Matrix::from_rows(&[
587    ///     &[4.0, 2.0],
588    ///     &[2.0, 3.0],
589    /// ]);
590    /// let l = a.cholesky().unwrap();
591    /// let llt = l.mul_mat(&l.transpose()).unwrap();
592    /// assert!((llt.get(0, 0) - 4.0).abs() < 1e-10);
593    /// assert!((llt.get(0, 1) - 2.0).abs() < 1e-10);
594    /// ```
595    pub fn cholesky(&self) -> Result<Self, MatrixError> {
596        if !self.is_square() {
597            return Err(MatrixError::NotSquare {
598                rows: self.rows,
599                cols: self.cols,
600            });
601        }
602        let n = self.rows;
603        let sym_tol = 1e-10 * self.frobenius_norm().max(1e-300);
604        if !self.is_symmetric(sym_tol) {
605            return Err(MatrixError::NotSymmetric);
606        }
607
608        let mut l = Self::zeros(n, n);
609
610        for j in 0..n {
611            // Diagonal entry
612            let mut sum = 0.0;
613            for k in 0..j {
614                let ljk = l.get(j, k);
615                sum += ljk * ljk;
616            }
617            let diag = self.get(j, j) - sum;
618            if diag <= 0.0 {
619                return Err(MatrixError::NotPositiveDefinite);
620            }
621            l.set(j, j, diag.sqrt());
622
623            // Below-diagonal entries
624            let ljj = l.get(j, j);
625            for i in (j + 1)..n {
626                let mut sum = 0.0;
627                for k in 0..j {
628                    sum += l.get(i, k) * l.get(j, k);
629                }
630                l.set(i, j, (self.get(i, j) - sum) / ljj);
631            }
632        }
633
634        Ok(l)
635    }
636
637    /// Solves the linear system A·x = b using Cholesky decomposition.
638    ///
639    /// Equivalent to computing x = A⁻¹·b but more efficient and stable.
640    ///
641    /// # Algorithm
642    /// 1. Decompose A = L·Lᵀ via Cholesky
643    /// 2. Solve L·y = b (forward substitution)
644    /// 3. Solve Lᵀ·x = y (backward substitution)
645    ///
646    /// # Requirements
647    /// Matrix must be symmetric positive-definite. `b.len()` must equal `self.rows()`.
648    pub fn cholesky_solve(&self, b: &[f64]) -> Result<Vec<f64>, MatrixError> {
649        if b.len() != self.rows {
650            return Err(MatrixError::DimensionMismatch {
651                expected: (self.rows, 1),
652                got: (b.len(), 1),
653            });
654        }
655        let l = self.cholesky()?;
656        let y = solve_lower_triangular(&l, b)?;
657        let lt = l.transpose();
658        solve_upper_triangular(&lt, &y)
659    }
660
661    /// Eigenvalue decomposition of a real symmetric matrix using the
662    /// classical Jacobi rotation algorithm.
663    ///
664    /// Returns `(eigenvalues, eigenvectors)` where eigenvalues are sorted
665    /// in **descending** order and eigenvectors are the corresponding columns
666    /// of the returned matrix (column `i` is the eigenvector for eigenvalue `i`).
667    ///
668    /// # Algorithm
669    ///
670    /// Cyclic Jacobi rotations zero off-diagonal elements iteratively.
671    /// Converges quadratically for symmetric matrices.
672    ///
673    /// Reference: Golub & Van Loan (1996), "Matrix Computations", §8.4
674    ///
675    /// # Complexity
676    ///
677    /// O(n³) per sweep, typically 5–10 sweeps. Best for n < 200.
678    ///
679    /// # Errors
680    ///
681    /// Returns `NotSquare` if the matrix is not square, or `NotSymmetric`
682    /// if the matrix is not symmetric within tolerance.
683    ///
684    /// # Examples
685    ///
686    /// ```
687    /// use u_numflow::matrix::Matrix;
688    ///
689    /// let a = Matrix::from_rows(&[
690    ///     &[4.0, 1.0],
691    ///     &[1.0, 3.0],
692    /// ]);
693    /// let (eigenvalues, eigenvectors) = a.eigen_symmetric().unwrap();
694    ///
695    /// // Eigenvalues of [[4,1],[1,3]] are (7+√5)/2 ≈ 4.618 and (7-√5)/2 ≈ 2.382
696    /// assert!((eigenvalues[0] - 4.618).abs() < 0.01);
697    /// assert!((eigenvalues[1] - 2.382).abs() < 0.01);
698    ///
699    /// // Eigenvectors are orthonormal
700    /// let dot: f64 = (0..2).map(|i| eigenvectors.get(i, 0) * eigenvectors.get(i, 1)).sum();
701    /// assert!(dot.abs() < 1e-10);
702    /// ```
703    pub fn eigen_symmetric(&self) -> Result<(Vec<f64>, Matrix), MatrixError> {
704        let n = self.rows;
705        if !self.is_square() {
706            return Err(MatrixError::NotSquare {
707                rows: self.rows,
708                cols: self.cols,
709            });
710        }
711        // Symmetry tolerance: relative to matrix scale
712        let sym_tol = 1e-10 * self.frobenius_norm();
713        if !self.is_symmetric(sym_tol) {
714            return Err(MatrixError::NotSymmetric);
715        }
716
717        // Work on a mutable copy of the matrix
718        let mut a = self.data.clone();
719        // Eigenvector accumulator — starts as identity
720        let mut v = vec![0.0; n * n];
721        for i in 0..n {
722            v[i * n + i] = 1.0;
723        }
724
725        let max_sweeps = 100;
726        let tol = 1e-15;
727
728        for _ in 0..max_sweeps {
729            // Compute off-diagonal Frobenius norm
730            let mut off_norm = 0.0;
731            for i in 0..n {
732                for j in (i + 1)..n {
733                    off_norm += 2.0 * a[i * n + j] * a[i * n + j];
734                }
735            }
736            off_norm = off_norm.sqrt();
737
738            if off_norm < tol {
739                break;
740            }
741
742            // One full sweep: rotate each (p, q) pair
743            for p in 0..n {
744                for q in (p + 1)..n {
745                    let apq = a[p * n + q];
746                    if apq.abs() < tol * 0.01 {
747                        continue;
748                    }
749
750                    let app = a[p * n + p];
751                    let aqq = a[q * n + q];
752                    let diff = aqq - app;
753
754                    // Compute rotation angle
755                    let (cos, sin) = if diff.abs() < 1e-300 {
756                        // Special case: diagonal elements equal
757                        let s = std::f64::consts::FRAC_1_SQRT_2;
758                        (s, if apq > 0.0 { s } else { -s })
759                    } else {
760                        let tau = diff / (2.0 * apq);
761                        // t = sign(tau) / (|tau| + sqrt(1 + tau²))
762                        let t = if tau >= 0.0 {
763                            1.0 / (tau + (1.0 + tau * tau).sqrt())
764                        } else {
765                            -1.0 / (-tau + (1.0 + tau * tau).sqrt())
766                        };
767                        let c = 1.0 / (1.0 + t * t).sqrt();
768                        let s = t * c;
769                        (c, s)
770                    };
771
772                    // Apply rotation to matrix A (symmetric, only update needed parts)
773                    a[p * n + p] -=
774                        2.0 * sin * cos * apq + sin * sin * (a[q * n + q] - a[p * n + p]);
775                    a[q * n + q] += 2.0 * sin * cos * apq + sin * sin * (aqq - app); // use original aqq, app
776                    a[p * n + q] = 0.0;
777                    a[q * n + p] = 0.0;
778
779                    // Actually, let's use the standard Jacobi rotation formula properly.
780                    // Reset and recompute.
781                    // Undo the above:
782                    a[p * n + p] = app;
783                    a[q * n + q] = aqq;
784                    a[p * n + q] = apq;
785                    a[q * n + p] = apq;
786
787                    // Standard update: for all rows/cols
788                    // First update rows p and q for all columns
789                    for r in 0..n {
790                        if r == p || r == q {
791                            continue;
792                        }
793                        let arp = a[r * n + p];
794                        let arq = a[r * n + q];
795                        a[r * n + p] = cos * arp - sin * arq;
796                        a[r * n + q] = sin * arp + cos * arq;
797                        a[p * n + r] = a[r * n + p]; // symmetric
798                        a[q * n + r] = a[r * n + q]; // symmetric
799                    }
800
801                    // Update diagonal and off-diagonal (p,q)
802                    let new_pp = cos * cos * app - 2.0 * sin * cos * apq + sin * sin * aqq;
803                    let new_qq = sin * sin * app + 2.0 * sin * cos * apq + cos * cos * aqq;
804                    a[p * n + p] = new_pp;
805                    a[q * n + q] = new_qq;
806                    a[p * n + q] = 0.0;
807                    a[q * n + p] = 0.0;
808
809                    // Accumulate eigenvectors: V = V * J
810                    for r in 0..n {
811                        let vp = v[r * n + p];
812                        let vq = v[r * n + q];
813                        v[r * n + p] = cos * vp - sin * vq;
814                        v[r * n + q] = sin * vp + cos * vq;
815                    }
816                }
817            }
818        }
819
820        // Extract eigenvalues from diagonal
821        let mut eigen_pairs: Vec<(f64, Vec<f64>)> = (0..n)
822            .map(|i| {
823                let eigenvalue = a[i * n + i];
824                let eigenvector: Vec<f64> = (0..n).map(|r| v[r * n + i]).collect();
825                (eigenvalue, eigenvector)
826            })
827            .collect();
828
829        // Sort by eigenvalue descending
830        eigen_pairs.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
831
832        let eigenvalues: Vec<f64> = eigen_pairs.iter().map(|(val, _)| *val).collect();
833        let mut eigvec_data = vec![0.0; n * n];
834        for (col, (_, vec)) in eigen_pairs.iter().enumerate() {
835            for (row, &val) in vec.iter().enumerate() {
836                eigvec_data[row * n + col] = val;
837            }
838        }
839        let eigenvectors = Matrix {
840            data: eigvec_data,
841            rows: n,
842            cols: n,
843        };
844
845        Ok((eigenvalues, eigenvectors))
846    }
847}
848
849/// Solves L·x = b where L is lower-triangular (forward substitution).
850fn solve_lower_triangular(l: &Matrix, b: &[f64]) -> Result<Vec<f64>, MatrixError> {
851    let n = l.rows();
852    let mut x = vec![0.0; n];
853    for i in 0..n {
854        let mut sum = 0.0;
855        for (j, &xj) in x[..i].iter().enumerate() {
856            sum += l.get(i, j) * xj;
857        }
858        let diag = l.get(i, i);
859        if diag.abs() < 1e-300 {
860            return Err(MatrixError::Singular);
861        }
862        x[i] = (b[i] - sum) / diag;
863    }
864    Ok(x)
865}
866
867/// Solves U·x = b where U is upper-triangular (backward substitution).
868fn solve_upper_triangular(u: &Matrix, b: &[f64]) -> Result<Vec<f64>, MatrixError> {
869    let n = u.rows();
870    let mut x = vec![0.0; n];
871    for i in (0..n).rev() {
872        let mut sum = 0.0;
873        for (off, &xj) in x[i + 1..].iter().enumerate() {
874            sum += u.get(i, i + 1 + off) * xj;
875        }
876        let diag = u.get(i, i);
877        if diag.abs() < 1e-300 {
878            return Err(MatrixError::Singular);
879        }
880        x[i] = (b[i] - sum) / diag;
881    }
882    Ok(x)
883}
884
885// ============================================================================
886// Display
887// ============================================================================
888
889impl std::fmt::Display for Matrix {
890    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
891        for i in 0..self.rows {
892            write!(f, "[")?;
893            for j in 0..self.cols {
894                if j > 0 {
895                    write!(f, ", ")?;
896                }
897                write!(f, "{:>10.4}", self.get(i, j))?;
898            }
899            writeln!(f, "]")?;
900        }
901        Ok(())
902    }
903}
904
905// ============================================================================
906// Tests
907// ============================================================================
908
909#[cfg(test)]
910mod tests {
911    use super::*;
912
913    // --- Construction ---
914
915    #[test]
916    fn test_new_valid() {
917        let m = Matrix::new(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
918        assert_eq!(m.rows(), 2);
919        assert_eq!(m.cols(), 3);
920        assert_eq!(m.get(0, 0), 1.0);
921        assert_eq!(m.get(1, 2), 6.0);
922    }
923
924    #[test]
925    fn test_new_invalid_length() {
926        assert!(Matrix::new(2, 3, vec![1.0, 2.0]).is_err());
927    }
928
929    #[test]
930    fn test_from_rows() {
931        let m = Matrix::from_rows(&[&[1.0, 2.0], &[3.0, 4.0]]);
932        assert_eq!(m.get(0, 0), 1.0);
933        assert_eq!(m.get(1, 1), 4.0);
934    }
935
936    #[test]
937    fn test_zeros() {
938        let m = Matrix::zeros(3, 4);
939        assert_eq!(m.rows(), 3);
940        assert_eq!(m.cols(), 4);
941        assert_eq!(m.get(2, 3), 0.0);
942    }
943
944    #[test]
945    fn test_identity() {
946        let eye = Matrix::identity(3);
947        assert_eq!(eye.get(0, 0), 1.0);
948        assert_eq!(eye.get(1, 1), 1.0);
949        assert_eq!(eye.get(2, 2), 1.0);
950        assert_eq!(eye.get(0, 1), 0.0);
951        assert_eq!(eye.get(1, 2), 0.0);
952    }
953
954    #[test]
955    fn test_diag() {
956        let m = Matrix::from_rows(&[&[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0], &[7.0, 8.0, 9.0]]);
957        assert_eq!(m.diag(), vec![1.0, 5.0, 9.0]);
958    }
959
960    // --- Basic operations ---
961
962    #[test]
963    fn test_transpose() {
964        let m = Matrix::from_rows(&[&[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0]]);
965        let t = m.transpose();
966        assert_eq!(t.rows(), 3);
967        assert_eq!(t.cols(), 2);
968        assert_eq!(t.get(0, 0), 1.0);
969        assert_eq!(t.get(2, 1), 6.0);
970    }
971
972    #[test]
973    fn test_transpose_twice() {
974        let m = Matrix::from_rows(&[&[1.0, 2.0], &[3.0, 4.0], &[5.0, 6.0]]);
975        let tt = m.transpose().transpose();
976        assert_eq!(m, tt);
977    }
978
979    #[test]
980    fn test_add() {
981        let a = Matrix::from_rows(&[&[1.0, 2.0], &[3.0, 4.0]]);
982        let b = Matrix::from_rows(&[&[5.0, 6.0], &[7.0, 8.0]]);
983        let c = a.add(&b).unwrap();
984        assert_eq!(c.get(0, 0), 6.0);
985        assert_eq!(c.get(1, 1), 12.0);
986    }
987
988    #[test]
989    fn test_add_dimension_mismatch() {
990        let a = Matrix::zeros(2, 3);
991        let b = Matrix::zeros(3, 2);
992        assert!(a.add(&b).is_err());
993    }
994
995    #[test]
996    fn test_sub() {
997        let a = Matrix::from_rows(&[&[5.0, 6.0], &[7.0, 8.0]]);
998        let b = Matrix::from_rows(&[&[1.0, 2.0], &[3.0, 4.0]]);
999        let c = a.sub(&b).unwrap();
1000        assert_eq!(c.get(0, 0), 4.0);
1001        assert_eq!(c.get(1, 1), 4.0);
1002    }
1003
1004    #[test]
1005    fn test_scale() {
1006        let m = Matrix::from_rows(&[&[1.0, 2.0], &[3.0, 4.0]]);
1007        let s = m.scale(2.0);
1008        assert_eq!(s.get(0, 0), 2.0);
1009        assert_eq!(s.get(1, 1), 8.0);
1010    }
1011
1012    // --- Multiplication ---
1013
1014    #[test]
1015    fn test_mul_identity() {
1016        let a = Matrix::from_rows(&[&[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0], &[7.0, 8.0, 9.0]]);
1017        let eye = Matrix::identity(3);
1018        let result = a.mul_mat(&eye).unwrap();
1019        assert_eq!(a, result);
1020    }
1021
1022    #[test]
1023    fn test_mul_2x2() {
1024        let a = Matrix::from_rows(&[&[1.0, 2.0], &[3.0, 4.0]]);
1025        let b = Matrix::from_rows(&[&[5.0, 6.0], &[7.0, 8.0]]);
1026        let c = a.mul_mat(&b).unwrap();
1027        // [1*5+2*7, 1*6+2*8] = [19, 22]
1028        // [3*5+4*7, 3*6+4*8] = [43, 50]
1029        assert_eq!(c.get(0, 0), 19.0);
1030        assert_eq!(c.get(0, 1), 22.0);
1031        assert_eq!(c.get(1, 0), 43.0);
1032        assert_eq!(c.get(1, 1), 50.0);
1033    }
1034
1035    #[test]
1036    fn test_mul_nonsquare() {
1037        let a = Matrix::from_rows(&[&[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0]]);
1038        let b = Matrix::from_rows(&[&[7.0, 8.0], &[9.0, 10.0], &[11.0, 12.0]]);
1039        let c = a.mul_mat(&b).unwrap();
1040        assert_eq!(c.rows(), 2);
1041        assert_eq!(c.cols(), 2);
1042        // [1*7+2*9+3*11, 1*8+2*10+3*12] = [58, 64]
1043        assert_eq!(c.get(0, 0), 58.0);
1044        assert_eq!(c.get(0, 1), 64.0);
1045    }
1046
1047    #[test]
1048    fn test_mul_vec() {
1049        let a = Matrix::from_rows(&[&[1.0, 2.0], &[3.0, 4.0]]);
1050        let v = vec![5.0, 6.0];
1051        let result = a.mul_vec(&v).unwrap();
1052        assert_eq!(result, vec![17.0, 39.0]);
1053    }
1054
1055    #[test]
1056    fn test_mul_dimension_mismatch() {
1057        let a = Matrix::zeros(2, 3);
1058        let b = Matrix::zeros(2, 3);
1059        assert!(a.mul_mat(&b).is_err());
1060    }
1061
1062    // --- Determinant ---
1063
1064    #[test]
1065    fn test_det_2x2() {
1066        let m = Matrix::from_rows(&[&[2.0, 3.0], &[1.0, 4.0]]);
1067        assert!((m.determinant().unwrap() - 5.0).abs() < 1e-10);
1068    }
1069
1070    #[test]
1071    fn test_det_3x3() {
1072        let m = Matrix::from_rows(&[&[6.0, 1.0, 1.0], &[4.0, -2.0, 5.0], &[2.0, 8.0, 7.0]]);
1073        // det = 6*(-14-40) - 1*(28-10) + 1*(32+4) = -306
1074        // Actually: 6(-2*7 - 5*8) - 1(4*7 - 5*2) + 1(4*8 - (-2)*2)
1075        //         = 6(-14 - 40) - 1(28 - 10) + 1(32 + 4)
1076        //         = 6*(-54) - 18 + 36 = -324 - 18 + 36 = -306
1077        assert!((m.determinant().unwrap() - (-306.0)).abs() < 1e-8);
1078    }
1079
1080    #[test]
1081    fn test_det_identity() {
1082        let eye = Matrix::identity(4);
1083        assert!((eye.determinant().unwrap() - 1.0).abs() < 1e-10);
1084    }
1085
1086    #[test]
1087    fn test_det_singular() {
1088        let m = Matrix::from_rows(&[&[1.0, 2.0], &[2.0, 4.0]]);
1089        assert!(m.determinant().unwrap().abs() < 1e-10);
1090    }
1091
1092    #[test]
1093    fn test_det_not_square() {
1094        let m = Matrix::zeros(2, 3);
1095        assert!(m.determinant().is_err());
1096    }
1097
1098    // --- Inverse ---
1099
1100    #[test]
1101    fn test_inverse_2x2() {
1102        let a = Matrix::from_rows(&[&[4.0, 7.0], &[2.0, 6.0]]);
1103        let inv = a.inverse().unwrap();
1104        let eye = a.mul_mat(&inv).unwrap();
1105        for i in 0..2 {
1106            for j in 0..2 {
1107                let expected = if i == j { 1.0 } else { 0.0 };
1108                assert!(
1109                    (eye.get(i, j) - expected).abs() < 1e-10,
1110                    "A·A⁻¹[{i},{j}] = {}, expected {expected}",
1111                    eye.get(i, j)
1112                );
1113            }
1114        }
1115    }
1116
1117    #[test]
1118    fn test_inverse_3x3() {
1119        let a = Matrix::from_rows(&[&[1.0, 2.0, 3.0], &[0.0, 1.0, 4.0], &[5.0, 6.0, 0.0]]);
1120        let inv = a.inverse().unwrap();
1121        let eye = a.mul_mat(&inv).unwrap();
1122        for i in 0..3 {
1123            for j in 0..3 {
1124                let expected = if i == j { 1.0 } else { 0.0 };
1125                assert!(
1126                    (eye.get(i, j) - expected).abs() < 1e-10,
1127                    "A·A⁻¹[{i},{j}] = {}",
1128                    eye.get(i, j)
1129                );
1130            }
1131        }
1132    }
1133
1134    #[test]
1135    fn test_inverse_identity() {
1136        let eye = Matrix::identity(4);
1137        let inv = eye.inverse().unwrap();
1138        assert_eq!(eye, inv);
1139    }
1140
1141    #[test]
1142    fn test_inverse_singular() {
1143        let m = Matrix::from_rows(&[&[1.0, 2.0], &[2.0, 4.0]]);
1144        assert!(m.inverse().is_err());
1145    }
1146
1147    // --- Cholesky ---
1148
1149    #[test]
1150    fn test_cholesky_2x2() {
1151        let a = Matrix::from_rows(&[&[4.0, 2.0], &[2.0, 3.0]]);
1152        let l = a.cholesky().unwrap();
1153        // L should be lower triangular
1154        assert!(l.get(0, 1).abs() < 1e-15);
1155        // Verify A = L·Lᵀ
1156        let llt = l.mul_mat(&l.transpose()).unwrap();
1157        for i in 0..2 {
1158            for j in 0..2 {
1159                assert!(
1160                    (llt.get(i, j) - a.get(i, j)).abs() < 1e-10,
1161                    "LLᵀ[{i},{j}] = {}, expected {}",
1162                    llt.get(i, j),
1163                    a.get(i, j)
1164                );
1165            }
1166        }
1167    }
1168
1169    #[test]
1170    fn test_cholesky_3x3() {
1171        let a = Matrix::from_rows(&[&[25.0, 15.0, -5.0], &[15.0, 18.0, 0.0], &[-5.0, 0.0, 11.0]]);
1172        let l = a.cholesky().unwrap();
1173        let llt = l.mul_mat(&l.transpose()).unwrap();
1174        for i in 0..3 {
1175            for j in 0..3 {
1176                assert!(
1177                    (llt.get(i, j) - a.get(i, j)).abs() < 1e-10,
1178                    "LLᵀ[{i},{j}] = {}, A[{i},{j}] = {}",
1179                    llt.get(i, j),
1180                    a.get(i, j)
1181                );
1182            }
1183        }
1184    }
1185
1186    #[test]
1187    fn test_cholesky_identity() {
1188        let eye = Matrix::identity(3);
1189        let l = eye.cholesky().unwrap();
1190        assert_eq!(l, eye);
1191    }
1192
1193    #[test]
1194    fn test_cholesky_not_positive_definite() {
1195        let a = Matrix::from_rows(&[&[1.0, 2.0], &[2.0, 1.0]]);
1196        assert!(matches!(
1197            a.cholesky(),
1198            Err(MatrixError::NotPositiveDefinite)
1199        ));
1200    }
1201
1202    #[test]
1203    fn test_cholesky_not_symmetric() {
1204        let a = Matrix::from_rows(&[&[1.0, 2.0], &[3.0, 4.0]]);
1205        assert!(matches!(a.cholesky(), Err(MatrixError::NotSymmetric)));
1206    }
1207
1208    // --- Cholesky solve ---
1209
1210    #[test]
1211    fn test_cholesky_solve() {
1212        // A = [[4, 2], [2, 3]], b = [1, 2]
1213        // Solution: x = [-0.125, 0.75]
1214        let a = Matrix::from_rows(&[&[4.0, 2.0], &[2.0, 3.0]]);
1215        let b = vec![1.0, 2.0];
1216        let x = a.cholesky_solve(&b).unwrap();
1217        // Verify A·x = b
1218        let ax = a.mul_vec(&x).unwrap();
1219        for i in 0..2 {
1220            assert!(
1221                (ax[i] - b[i]).abs() < 1e-10,
1222                "Ax[{i}] = {}, b[{i}] = {}",
1223                ax[i],
1224                b[i]
1225            );
1226        }
1227    }
1228
1229    #[test]
1230    fn test_cholesky_solve_3x3() {
1231        let a = Matrix::from_rows(&[&[25.0, 15.0, -5.0], &[15.0, 18.0, 0.0], &[-5.0, 0.0, 11.0]]);
1232        let b = vec![35.0, 33.0, 6.0];
1233        let x = a.cholesky_solve(&b).unwrap();
1234        let ax = a.mul_vec(&x).unwrap();
1235        for i in 0..3 {
1236            assert!((ax[i] - b[i]).abs() < 1e-10);
1237        }
1238    }
1239
1240    // --- Frobenius norm ---
1241
1242    #[test]
1243    fn test_frobenius_norm() {
1244        let m = Matrix::from_rows(&[&[1.0, 2.0], &[3.0, 4.0]]);
1245        // sqrt(1 + 4 + 9 + 16) = sqrt(30) ≈ 5.477
1246        assert!((m.frobenius_norm() - 30.0_f64.sqrt()).abs() < 1e-10);
1247    }
1248
1249    // --- is_symmetric ---
1250
1251    #[test]
1252    fn test_is_symmetric() {
1253        let sym = Matrix::from_rows(&[&[1.0, 2.0, 3.0], &[2.0, 5.0, 6.0], &[3.0, 6.0, 9.0]]);
1254        assert!(sym.is_symmetric(1e-10));
1255
1256        let asym = Matrix::from_rows(&[&[1.0, 2.0], &[3.0, 4.0]]);
1257        assert!(!asym.is_symmetric(1e-10));
1258    }
1259}
1260
1261#[cfg(test)]
1262mod proptests {
1263    use super::*;
1264    use proptest::prelude::*;
1265
1266    fn square_matrix(n: usize) -> impl Strategy<Value = Matrix> {
1267        proptest::collection::vec(-10.0_f64..10.0, n * n)
1268            .prop_map(move |data| Matrix::new(n, n, data).expect("valid dimensions"))
1269    }
1270
1271    fn spd_matrix(n: usize) -> impl Strategy<Value = Matrix> {
1272        // Generate random matrix A, then form A'A + nI (guaranteed SPD)
1273        proptest::collection::vec(-5.0_f64..5.0, n * n).prop_map(move |data| {
1274            let a = Matrix::new(n, n, data).expect("valid dimensions");
1275            let ata = a.transpose().mul_mat(&a).expect("compatible");
1276            let eye_scaled = Matrix::identity(n).scale(n as f64);
1277            ata.add(&eye_scaled).expect("compatible")
1278        })
1279    }
1280
1281    proptest! {
1282        #![proptest_config(ProptestConfig::with_cases(200))]
1283
1284        #[test]
1285        fn transpose_involution(m in square_matrix(3)) {
1286            let m_tt = m.transpose().transpose();
1287            for i in 0..3 {
1288                for j in 0..3 {
1289                    prop_assert!((m.get(i, j) - m_tt.get(i, j)).abs() < 1e-14);
1290                }
1291            }
1292        }
1293
1294        #[test]
1295        fn mul_identity_is_identity(m in square_matrix(3)) {
1296            let eye = Matrix::identity(3);
1297            let me = m.mul_mat(&eye).unwrap();
1298            let em = eye.mul_mat(&m).unwrap();
1299            for i in 0..3 {
1300                for j in 0..3 {
1301                    prop_assert!((me.get(i, j) - m.get(i, j)).abs() < 1e-10);
1302                    prop_assert!((em.get(i, j) - m.get(i, j)).abs() < 1e-10);
1303                }
1304            }
1305        }
1306
1307        #[test]
1308        fn det_of_product(a in square_matrix(3), b in square_matrix(3)) {
1309            // det(A·B) = det(A)·det(B)
1310            let det_a = a.determinant().unwrap();
1311            let det_b = b.determinant().unwrap();
1312            let ab = a.mul_mat(&b).unwrap();
1313            let det_ab = ab.determinant().unwrap();
1314            let expected = det_a * det_b;
1315            // Use relative tolerance for large determinants
1316            let tol = 1e-6 * expected.abs().max(det_ab.abs()).max(1.0);
1317            prop_assert!(
1318                (det_ab - expected).abs() < tol,
1319                "det(AB)={det_ab}, det(A)*det(B)={expected}"
1320            );
1321        }
1322
1323        #[test]
1324        fn cholesky_roundtrip(a in spd_matrix(3)) {
1325            let l = a.cholesky().expect("SPD should decompose");
1326            let llt = l.mul_mat(&l.transpose()).expect("compatible");
1327            for i in 0..3 {
1328                for j in 0..3 {
1329                    let diff = (llt.get(i, j) - a.get(i, j)).abs();
1330                    let tol = 1e-8 * a.get(i, j).abs().max(1.0);
1331                    prop_assert!(
1332                        diff < tol,
1333                        "LLᵀ[{i},{j}]={}, A[{i},{j}]={}",
1334                        llt.get(i, j), a.get(i, j)
1335                    );
1336                }
1337            }
1338        }
1339
1340        #[test]
1341        fn cholesky_solve_roundtrip(a in spd_matrix(3), b in proptest::collection::vec(-10.0_f64..10.0, 3)) {
1342            let x = a.cholesky_solve(&b).expect("SPD solve should work");
1343            let ax = a.mul_vec(&x).expect("compatible");
1344            for i in 0..3 {
1345                let tol = 1e-8 * b[i].abs().max(1.0);
1346                prop_assert!(
1347                    (ax[i] - b[i]).abs() < tol,
1348                    "Ax[{i}]={}, b[{i}]={}",
1349                    ax[i], b[i]
1350                );
1351            }
1352        }
1353
1354        #[test]
1355        fn inverse_roundtrip(a in spd_matrix(3)) {
1356            // SPD matrices are always invertible
1357            let inv = a.inverse().expect("SPD invertible");
1358            let eye = a.mul_mat(&inv).expect("compatible");
1359            for i in 0..3 {
1360                for j in 0..3 {
1361                    let expected = if i == j { 1.0 } else { 0.0 };
1362                    let diff = (eye.get(i, j) - expected).abs();
1363                    prop_assert!(
1364                        diff < 1e-6,
1365                        "A·A⁻¹[{i},{j}]={}, expected {expected}",
1366                        eye.get(i, j)
1367                    );
1368                }
1369            }
1370        }
1371    }
1372}