thales/
matrix.rs

1//! Matrix expression type with basic linear algebra operations.
2//!
3//! This module provides a symbolic matrix type where elements are mathematical expressions,
4//! supporting operations like addition, multiplication, transpose, and trace with symbolic
5//! manipulation capabilities.
6//!
7//! # Examples
8//!
9//! ```
10//! use thales::matrix::MatrixExpr;
11//! use thales::ast::Expression;
12//!
13//! // Create a 2x2 identity matrix
14//! let identity = MatrixExpr::identity(2);
15//!
16//! // Create a matrix from expressions
17//! let a = Expression::Integer(1);
18//! let b = Expression::Integer(2);
19//! let c = Expression::Integer(3);
20//! let d = Expression::Integer(4);
21//! let m = MatrixExpr::from_elements(vec![
22//!     vec![a, b],
23//!     vec![c, d],
24//! ]).unwrap();
25//!
26//! // Transpose
27//! let mt = m.transpose();
28//! ```
29
30use crate::ast::{Expression, Variable};
31use std::fmt;
32
33/// Error type for matrix operations.
34#[derive(Debug, Clone, PartialEq)]
35#[non_exhaustive]
36pub enum MatrixError {
37    /// Dimension mismatch for operation.
38    DimensionMismatch {
39        operation: String,
40        expected: (usize, usize),
41        got: (usize, usize),
42    },
43    /// Empty matrix or row not allowed.
44    EmptyMatrix,
45    /// Non-rectangular matrix (rows have different lengths).
46    NonRectangular,
47    /// Index out of bounds.
48    IndexOutOfBounds {
49        row: usize,
50        col: usize,
51        rows: usize,
52        cols: usize,
53    },
54    /// Cannot compute operation (e.g., determinant of non-square matrix).
55    InvalidOperation(String),
56}
57
58impl fmt::Display for MatrixError {
59    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60        match self {
61            MatrixError::DimensionMismatch {
62                operation,
63                expected,
64                got,
65            } => {
66                write!(
67                    f,
68                    "{}: expected {}x{}, got {}x{}",
69                    operation, expected.0, expected.1, got.0, got.1
70                )
71            }
72            MatrixError::EmptyMatrix => write!(f, "Empty matrix not allowed"),
73            MatrixError::NonRectangular => {
74                write!(f, "Matrix must be rectangular (all rows same length)")
75            }
76            MatrixError::IndexOutOfBounds {
77                row,
78                col,
79                rows,
80                cols,
81            } => {
82                write!(
83                    f,
84                    "Index ({}, {}) out of bounds for {}x{} matrix",
85                    row, col, rows, cols
86                )
87            }
88            MatrixError::InvalidOperation(msg) => write!(f, "Invalid operation: {}", msg),
89        }
90    }
91}
92
93impl std::error::Error for MatrixError {}
94
95/// Result type for matrix operations.
96pub type MatrixResult<T> = Result<T, MatrixError>;
97
98/// Bracket style for LaTeX output.
99#[derive(Debug, Clone, Copy, PartialEq)]
100pub enum BracketStyle {
101    /// Parentheses: `\begin{pmatrix}`
102    Parentheses,
103    /// Square brackets: `\begin{bmatrix}`
104    Square,
105    /// Curly braces: `\begin{Bmatrix}`
106    Curly,
107    /// Vertical bars (determinant): `\begin{vmatrix}`
108    Determinant,
109    /// Double vertical bars (norm): `\begin{Vmatrix}`
110    Norm,
111    /// No brackets
112    None,
113}
114
115impl Default for BracketStyle {
116    fn default() -> Self {
117        BracketStyle::Parentheses
118    }
119}
120
121/// A matrix of symbolic expressions.
122///
123/// Each element is an [`Expression`] allowing symbolic computation on matrices.
124/// Supports standard matrix operations including addition, multiplication,
125/// transpose, and trace.
126///
127/// # Examples
128///
129/// ```
130/// use thales::matrix::MatrixExpr;
131/// use thales::ast::{Expression, Variable};
132///
133/// // Create a 2x2 matrix with symbolic entries
134/// let x = Expression::Variable(Variable::new("x"));
135/// let one = Expression::Integer(1);
136/// let two = Expression::Integer(2);
137/// let three = Expression::Integer(3);
138///
139/// let m = MatrixExpr::from_elements(vec![
140///     vec![x, one],
141///     vec![two, three],
142/// ]).unwrap();
143///
144/// assert_eq!(m.rows(), 2);
145/// assert_eq!(m.cols(), 2);
146/// ```
147#[derive(Debug, Clone, PartialEq)]
148pub struct MatrixExpr {
149    rows: usize,
150    cols: usize,
151    elements: Vec<Vec<Expression>>,
152}
153
154impl MatrixExpr {
155    /// Create a matrix from a 2D vector of expressions.
156    ///
157    /// # Errors
158    ///
159    /// Returns an error if:
160    /// - The input is empty
161    /// - Any row is empty
162    /// - Rows have different lengths (non-rectangular)
163    ///
164    /// # Examples
165    ///
166    /// ```
167    /// use thales::matrix::MatrixExpr;
168    /// use thales::ast::Expression;
169    ///
170    /// let m = MatrixExpr::from_elements(vec![
171    ///     vec![Expression::Integer(1), Expression::Integer(2)],
172    ///     vec![Expression::Integer(3), Expression::Integer(4)],
173    /// ]).unwrap();
174    /// ```
175    pub fn from_elements(elements: Vec<Vec<Expression>>) -> MatrixResult<Self> {
176        if elements.is_empty() || elements[0].is_empty() {
177            return Err(MatrixError::EmptyMatrix);
178        }
179
180        let cols = elements[0].len();
181        for row in &elements {
182            if row.len() != cols {
183                return Err(MatrixError::NonRectangular);
184            }
185        }
186
187        let rows = elements.len();
188        Ok(Self {
189            rows,
190            cols,
191            elements,
192        })
193    }
194
195    /// Create an identity matrix of size n x n.
196    ///
197    /// # Examples
198    ///
199    /// ```
200    /// use thales::matrix::MatrixExpr;
201    ///
202    /// let i3 = MatrixExpr::identity(3);
203    /// assert_eq!(i3.rows(), 3);
204    /// assert_eq!(i3.cols(), 3);
205    /// ```
206    pub fn identity(n: usize) -> Self {
207        let elements: Vec<Vec<Expression>> = (0..n)
208            .map(|i| {
209                (0..n)
210                    .map(|j| {
211                        if i == j {
212                            Expression::Integer(1)
213                        } else {
214                            Expression::Integer(0)
215                        }
216                    })
217                    .collect()
218            })
219            .collect();
220        Self {
221            rows: n,
222            cols: n,
223            elements,
224        }
225    }
226
227    /// Create a zero matrix of size rows x cols.
228    ///
229    /// # Examples
230    ///
231    /// ```
232    /// use thales::matrix::MatrixExpr;
233    ///
234    /// let z = MatrixExpr::zero(2, 3);
235    /// assert_eq!(z.rows(), 2);
236    /// assert_eq!(z.cols(), 3);
237    /// ```
238    pub fn zero(rows: usize, cols: usize) -> Self {
239        let elements: Vec<Vec<Expression>> = (0..rows)
240            .map(|_| (0..cols).map(|_| Expression::Integer(0)).collect())
241            .collect();
242        Self {
243            rows,
244            cols,
245            elements,
246        }
247    }
248
249    /// Create a diagonal matrix from a vector of expressions.
250    ///
251    /// # Examples
252    ///
253    /// ```
254    /// use thales::matrix::MatrixExpr;
255    /// use thales::ast::Expression;
256    ///
257    /// let diag = MatrixExpr::diagonal(vec![
258    ///     Expression::Integer(1),
259    ///     Expression::Integer(2),
260    ///     Expression::Integer(3),
261    /// ]);
262    /// assert_eq!(diag.rows(), 3);
263    /// assert_eq!(diag.cols(), 3);
264    /// ```
265    pub fn diagonal(diag: Vec<Expression>) -> Self {
266        let n = diag.len();
267        let elements: Vec<Vec<Expression>> = (0..n)
268            .map(|i| {
269                (0..n)
270                    .map(|j| {
271                        if i == j {
272                            diag[i].clone()
273                        } else {
274                            Expression::Integer(0)
275                        }
276                    })
277                    .collect()
278            })
279            .collect();
280        Self {
281            rows: n,
282            cols: n,
283            elements,
284        }
285    }
286
287    /// Get the number of rows.
288    pub fn rows(&self) -> usize {
289        self.rows
290    }
291
292    /// Get the number of columns.
293    pub fn cols(&self) -> usize {
294        self.cols
295    }
296
297    /// Get the dimensions as (rows, cols).
298    pub fn dimensions(&self) -> (usize, usize) {
299        (self.rows, self.cols)
300    }
301
302    /// Check if the matrix is square.
303    pub fn is_square(&self) -> bool {
304        self.rows == self.cols
305    }
306
307    /// Get a reference to an element at (row, col).
308    ///
309    /// # Errors
310    ///
311    /// Returns an error if indices are out of bounds.
312    pub fn get(&self, row: usize, col: usize) -> MatrixResult<&Expression> {
313        if row >= self.rows || col >= self.cols {
314            return Err(MatrixError::IndexOutOfBounds {
315                row,
316                col,
317                rows: self.rows,
318                cols: self.cols,
319            });
320        }
321        Ok(&self.elements[row][col])
322    }
323
324    /// Set an element at (row, col).
325    ///
326    /// # Errors
327    ///
328    /// Returns an error if indices are out of bounds.
329    pub fn set(&mut self, row: usize, col: usize, value: Expression) -> MatrixResult<()> {
330        if row >= self.rows || col >= self.cols {
331            return Err(MatrixError::IndexOutOfBounds {
332                row,
333                col,
334                rows: self.rows,
335                cols: self.cols,
336            });
337        }
338        self.elements[row][col] = value;
339        Ok(())
340    }
341
342    /// Get a row as a vector of expressions.
343    pub fn row(&self, index: usize) -> MatrixResult<&Vec<Expression>> {
344        if index >= self.rows {
345            return Err(MatrixError::IndexOutOfBounds {
346                row: index,
347                col: 0,
348                rows: self.rows,
349                cols: self.cols,
350            });
351        }
352        Ok(&self.elements[index])
353    }
354
355    /// Get a column as a vector of expressions.
356    pub fn col(&self, index: usize) -> MatrixResult<Vec<&Expression>> {
357        if index >= self.cols {
358            return Err(MatrixError::IndexOutOfBounds {
359                row: 0,
360                col: index,
361                rows: self.rows,
362                cols: self.cols,
363            });
364        }
365        Ok(self.elements.iter().map(|row| &row[index]).collect())
366    }
367
368    /// Compute the transpose of this matrix.
369    ///
370    /// # Examples
371    ///
372    /// ```
373    /// use thales::matrix::MatrixExpr;
374    /// use thales::ast::Expression;
375    ///
376    /// let m = MatrixExpr::from_elements(vec![
377    ///     vec![Expression::Integer(1), Expression::Integer(2), Expression::Integer(3)],
378    ///     vec![Expression::Integer(4), Expression::Integer(5), Expression::Integer(6)],
379    /// ]).unwrap();
380    ///
381    /// let mt = m.transpose();
382    /// assert_eq!(mt.rows(), 3);
383    /// assert_eq!(mt.cols(), 2);
384    /// ```
385    pub fn transpose(&self) -> Self {
386        let elements: Vec<Vec<Expression>> = (0..self.cols)
387            .map(|j| {
388                (0..self.rows)
389                    .map(|i| self.elements[i][j].clone())
390                    .collect()
391            })
392            .collect();
393        Self {
394            rows: self.cols,
395            cols: self.rows,
396            elements,
397        }
398    }
399
400    /// Compute the trace (sum of diagonal elements).
401    ///
402    /// # Errors
403    ///
404    /// Returns an error if the matrix is not square.
405    ///
406    /// # Examples
407    ///
408    /// ```
409    /// use thales::matrix::MatrixExpr;
410    /// use thales::ast::Expression;
411    /// use std::collections::HashMap;
412    ///
413    /// let m = MatrixExpr::from_elements(vec![
414    ///     vec![Expression::Integer(1), Expression::Integer(2)],
415    ///     vec![Expression::Integer(3), Expression::Integer(4)],
416    /// ]).unwrap();
417    ///
418    /// let trace = m.trace().unwrap();
419    /// // trace = 1 + 4 = 5
420    /// assert_eq!(trace.evaluate(&HashMap::new()), Some(5.0));
421    /// ```
422    pub fn trace(&self) -> MatrixResult<Expression> {
423        if !self.is_square() {
424            return Err(MatrixError::InvalidOperation(
425                "Trace requires a square matrix".to_string(),
426            ));
427        }
428
429        let mut trace = self.elements[0][0].clone();
430        for i in 1..self.rows {
431            trace = Expression::Binary(
432                crate::ast::BinaryOp::Add,
433                Box::new(trace),
434                Box::new(self.elements[i][i].clone()),
435            );
436        }
437        Ok(trace.simplify())
438    }
439
440    /// Add two matrices element-wise.
441    ///
442    /// # Errors
443    ///
444    /// Returns an error if dimensions don't match.
445    ///
446    /// # Examples
447    ///
448    /// ```
449    /// use thales::matrix::MatrixExpr;
450    /// use thales::ast::Expression;
451    ///
452    /// let a = MatrixExpr::from_elements(vec![
453    ///     vec![Expression::Integer(1), Expression::Integer(2)],
454    ///     vec![Expression::Integer(3), Expression::Integer(4)],
455    /// ]).unwrap();
456    ///
457    /// let b = MatrixExpr::from_elements(vec![
458    ///     vec![Expression::Integer(5), Expression::Integer(6)],
459    ///     vec![Expression::Integer(7), Expression::Integer(8)],
460    /// ]).unwrap();
461    ///
462    /// let sum = a.add(&b).unwrap();
463    /// ```
464    pub fn add(&self, other: &MatrixExpr) -> MatrixResult<MatrixExpr> {
465        if self.rows != other.rows || self.cols != other.cols {
466            return Err(MatrixError::DimensionMismatch {
467                operation: "Matrix addition".to_string(),
468                expected: (self.rows, self.cols),
469                got: (other.rows, other.cols),
470            });
471        }
472
473        let elements: Vec<Vec<Expression>> = (0..self.rows)
474            .map(|i| {
475                (0..self.cols)
476                    .map(|j| {
477                        Expression::Binary(
478                            crate::ast::BinaryOp::Add,
479                            Box::new(self.elements[i][j].clone()),
480                            Box::new(other.elements[i][j].clone()),
481                        )
482                        .simplify()
483                    })
484                    .collect()
485            })
486            .collect();
487
488        Ok(MatrixExpr {
489            rows: self.rows,
490            cols: self.cols,
491            elements,
492        })
493    }
494
495    /// Subtract another matrix element-wise.
496    ///
497    /// # Errors
498    ///
499    /// Returns an error if dimensions don't match.
500    pub fn sub(&self, other: &MatrixExpr) -> MatrixResult<MatrixExpr> {
501        if self.rows != other.rows || self.cols != other.cols {
502            return Err(MatrixError::DimensionMismatch {
503                operation: "Matrix subtraction".to_string(),
504                expected: (self.rows, self.cols),
505                got: (other.rows, other.cols),
506            });
507        }
508
509        let elements: Vec<Vec<Expression>> = (0..self.rows)
510            .map(|i| {
511                (0..self.cols)
512                    .map(|j| {
513                        Expression::Binary(
514                            crate::ast::BinaryOp::Sub,
515                            Box::new(self.elements[i][j].clone()),
516                            Box::new(other.elements[i][j].clone()),
517                        )
518                        .simplify()
519                    })
520                    .collect()
521            })
522            .collect();
523
524        Ok(MatrixExpr {
525            rows: self.rows,
526            cols: self.cols,
527            elements,
528        })
529    }
530
531    /// Multiply by a scalar expression.
532    ///
533    /// # Examples
534    ///
535    /// ```
536    /// use thales::matrix::MatrixExpr;
537    /// use thales::ast::Expression;
538    ///
539    /// let m = MatrixExpr::identity(2);
540    /// let scaled = m.scalar_mul(&Expression::Integer(3));
541    /// ```
542    pub fn scalar_mul(&self, scalar: &Expression) -> MatrixExpr {
543        let elements: Vec<Vec<Expression>> = self
544            .elements
545            .iter()
546            .map(|row| {
547                row.iter()
548                    .map(|elem| {
549                        Expression::Binary(
550                            crate::ast::BinaryOp::Mul,
551                            Box::new(scalar.clone()),
552                            Box::new(elem.clone()),
553                        )
554                        .simplify()
555                    })
556                    .collect()
557            })
558            .collect();
559
560        MatrixExpr {
561            rows: self.rows,
562            cols: self.cols,
563            elements,
564        }
565    }
566
567    /// Multiply two matrices.
568    ///
569    /// Computes self * other where self is m×n and other is n×p, resulting in m×p.
570    ///
571    /// # Errors
572    ///
573    /// Returns an error if the inner dimensions don't match (self.cols != other.rows).
574    ///
575    /// # Examples
576    ///
577    /// ```
578    /// use thales::matrix::MatrixExpr;
579    /// use thales::ast::Expression;
580    ///
581    /// // 2x3 matrix
582    /// let a = MatrixExpr::from_elements(vec![
583    ///     vec![Expression::Integer(1), Expression::Integer(2), Expression::Integer(3)],
584    ///     vec![Expression::Integer(4), Expression::Integer(5), Expression::Integer(6)],
585    /// ]).unwrap();
586    ///
587    /// // 3x2 matrix
588    /// let b = MatrixExpr::from_elements(vec![
589    ///     vec![Expression::Integer(7), Expression::Integer(8)],
590    ///     vec![Expression::Integer(9), Expression::Integer(10)],
591    ///     vec![Expression::Integer(11), Expression::Integer(12)],
592    /// ]).unwrap();
593    ///
594    /// // Result is 2x2
595    /// let c = a.mul(&b).unwrap();
596    /// assert_eq!(c.rows(), 2);
597    /// assert_eq!(c.cols(), 2);
598    /// ```
599    pub fn mul(&self, other: &MatrixExpr) -> MatrixResult<MatrixExpr> {
600        if self.cols != other.rows {
601            return Err(MatrixError::DimensionMismatch {
602                operation: format!(
603                    "Matrix multiplication ({}x{} * {}x{})",
604                    self.rows, self.cols, other.rows, other.cols
605                ),
606                expected: (self.cols, other.rows),
607                got: (self.cols, other.rows),
608            });
609        }
610
611        let elements: Vec<Vec<Expression>> = (0..self.rows)
612            .map(|i| {
613                (0..other.cols)
614                    .map(|j| {
615                        // C[i][j] = sum(A[i][k] * B[k][j] for k in 0..n)
616                        let mut sum = Expression::Binary(
617                            crate::ast::BinaryOp::Mul,
618                            Box::new(self.elements[i][0].clone()),
619                            Box::new(other.elements[0][j].clone()),
620                        );
621                        for k in 1..self.cols {
622                            let product = Expression::Binary(
623                                crate::ast::BinaryOp::Mul,
624                                Box::new(self.elements[i][k].clone()),
625                                Box::new(other.elements[k][j].clone()),
626                            );
627                            sum = Expression::Binary(
628                                crate::ast::BinaryOp::Add,
629                                Box::new(sum),
630                                Box::new(product),
631                            );
632                        }
633                        sum.simplify()
634                    })
635                    .collect()
636            })
637            .collect();
638
639        Ok(MatrixExpr {
640            rows: self.rows,
641            cols: other.cols,
642            elements,
643        })
644    }
645
646    /// Simplify all elements in the matrix.
647    pub fn simplify(&self) -> MatrixExpr {
648        let elements: Vec<Vec<Expression>> = self
649            .elements
650            .iter()
651            .map(|row| row.iter().map(|elem| elem.simplify()).collect())
652            .collect();
653
654        MatrixExpr {
655            rows: self.rows,
656            cols: self.cols,
657            elements,
658        }
659    }
660
661    /// Get the submatrix by removing row `row_idx` and column `col_idx`.
662    ///
663    /// This is used for computing minors and cofactors.
664    ///
665    /// # Errors
666    ///
667    /// Returns an error if the matrix is 1x1 or smaller.
668    pub fn submatrix(&self, row_idx: usize, col_idx: usize) -> MatrixResult<MatrixExpr> {
669        if self.rows <= 1 || self.cols <= 1 {
670            return Err(MatrixError::InvalidOperation(
671                "Cannot compute submatrix of 1x1 or smaller matrix".to_string(),
672            ));
673        }
674
675        let elements: Vec<Vec<Expression>> = self
676            .elements
677            .iter()
678            .enumerate()
679            .filter(|(i, _)| *i != row_idx)
680            .map(|(_, row)| {
681                row.iter()
682                    .enumerate()
683                    .filter(|(j, _)| *j != col_idx)
684                    .map(|(_, elem)| elem.clone())
685                    .collect()
686            })
687            .collect();
688
689        MatrixExpr::from_elements(elements)
690    }
691
692    /// Compute the minor M(i, j) - the determinant of the submatrix excluding row i and column j.
693    ///
694    /// # Errors
695    ///
696    /// Returns an error if the matrix is not square or is 1x1.
697    pub fn minor(&self, row: usize, col: usize) -> MatrixResult<Expression> {
698        if !self.is_square() {
699            return Err(MatrixError::InvalidOperation(
700                "Minor requires a square matrix".to_string(),
701            ));
702        }
703        let sub = self.submatrix(row, col)?;
704        sub.determinant()
705    }
706
707    /// Compute the cofactor C(i, j) = (-1)^(i+j) * M(i, j).
708    ///
709    /// # Errors
710    ///
711    /// Returns an error if the matrix is not square or is 1x1.
712    pub fn cofactor(&self, row: usize, col: usize) -> MatrixResult<Expression> {
713        let minor = self.minor(row, col)?;
714        if (row + col) % 2 == 0 {
715            Ok(minor)
716        } else {
717            Ok(Expression::Unary(crate::ast::UnaryOp::Neg, Box::new(minor)).simplify())
718        }
719    }
720
721    /// Compute the determinant of the matrix.
722    ///
723    /// Uses the following algorithms:
724    /// - 1x1: Returns the single element
725    /// - 2x2: Uses ad - bc formula
726    /// - NxN: Uses cofactor expansion along the first row
727    ///
728    /// # Errors
729    ///
730    /// Returns an error if the matrix is not square.
731    ///
732    /// # Examples
733    ///
734    /// ```
735    /// use thales::matrix::MatrixExpr;
736    /// use thales::ast::Expression;
737    /// use std::collections::HashMap;
738    ///
739    /// // 2x2 matrix: [[1, 2], [3, 4]]
740    /// let m = MatrixExpr::from_elements(vec![
741    ///     vec![Expression::Integer(1), Expression::Integer(2)],
742    ///     vec![Expression::Integer(3), Expression::Integer(4)],
743    /// ]).unwrap();
744    ///
745    /// let det = m.determinant().unwrap();
746    /// // det = 1*4 - 2*3 = -2
747    /// assert_eq!(det.evaluate(&HashMap::new()), Some(-2.0));
748    /// ```
749    pub fn determinant(&self) -> MatrixResult<Expression> {
750        if !self.is_square() {
751            return Err(MatrixError::InvalidOperation(
752                "Determinant requires a square matrix".to_string(),
753            ));
754        }
755
756        match self.rows {
757            1 => Ok(self.elements[0][0].clone()),
758            2 => {
759                // det = a*d - b*c for [[a, b], [c, d]]
760                let a = &self.elements[0][0];
761                let b = &self.elements[0][1];
762                let c = &self.elements[1][0];
763                let d = &self.elements[1][1];
764
765                let ad = Expression::Binary(
766                    crate::ast::BinaryOp::Mul,
767                    Box::new(a.clone()),
768                    Box::new(d.clone()),
769                );
770                let bc = Expression::Binary(
771                    crate::ast::BinaryOp::Mul,
772                    Box::new(b.clone()),
773                    Box::new(c.clone()),
774                );
775                Ok(
776                    Expression::Binary(crate::ast::BinaryOp::Sub, Box::new(ad), Box::new(bc))
777                        .simplify(),
778                )
779            }
780            _ => {
781                // Cofactor expansion along first row
782                let mut det = Expression::Integer(0);
783                for j in 0..self.cols {
784                    let cofactor = self.cofactor(0, j)?;
785                    let term = Expression::Binary(
786                        crate::ast::BinaryOp::Mul,
787                        Box::new(self.elements[0][j].clone()),
788                        Box::new(cofactor),
789                    );
790                    det = Expression::Binary(
791                        crate::ast::BinaryOp::Add,
792                        Box::new(det),
793                        Box::new(term),
794                    );
795                }
796                Ok(det.simplify())
797            }
798        }
799    }
800
801    /// Compute the cofactor matrix (matrix of all cofactors).
802    ///
803    /// # Errors
804    ///
805    /// Returns an error if the matrix is not square or is 1x1.
806    pub fn cofactor_matrix(&self) -> MatrixResult<MatrixExpr> {
807        if !self.is_square() {
808            return Err(MatrixError::InvalidOperation(
809                "Cofactor matrix requires a square matrix".to_string(),
810            ));
811        }
812        if self.rows == 1 {
813            return Err(MatrixError::InvalidOperation(
814                "Cofactor matrix not defined for 1x1 matrix".to_string(),
815            ));
816        }
817
818        let mut elements = Vec::with_capacity(self.rows);
819        for i in 0..self.rows {
820            let mut row = Vec::with_capacity(self.cols);
821            for j in 0..self.cols {
822                row.push(self.cofactor(i, j)?);
823            }
824            elements.push(row);
825        }
826
827        MatrixExpr::from_elements(elements)
828    }
829
830    /// Compute the adjugate (classical adjoint) matrix.
831    ///
832    /// The adjugate is the transpose of the cofactor matrix.
833    ///
834    /// # Errors
835    ///
836    /// Returns an error if the matrix is not square.
837    ///
838    /// # Examples
839    ///
840    /// ```
841    /// use thales::matrix::MatrixExpr;
842    /// use thales::ast::Expression;
843    ///
844    /// let m = MatrixExpr::from_elements(vec![
845    ///     vec![Expression::Integer(1), Expression::Integer(2)],
846    ///     vec![Expression::Integer(3), Expression::Integer(4)],
847    /// ]).unwrap();
848    ///
849    /// let adj = m.adjugate().unwrap();
850    /// // adj = [[4, -2], [-3, 1]]
851    /// ```
852    pub fn adjugate(&self) -> MatrixResult<MatrixExpr> {
853        if !self.is_square() {
854            return Err(MatrixError::InvalidOperation(
855                "Adjugate requires a square matrix".to_string(),
856            ));
857        }
858
859        // Special case for 1x1 matrix
860        if self.rows == 1 {
861            return Ok(MatrixExpr::from_elements(vec![vec![Expression::Integer(1)]]).unwrap());
862        }
863
864        let cofactor_mat = self.cofactor_matrix()?;
865        Ok(cofactor_mat.transpose())
866    }
867
868    /// Compute the inverse of the matrix.
869    ///
870    /// Uses the formula: A^(-1) = adj(A) / det(A)
871    ///
872    /// # Errors
873    ///
874    /// Returns an error if:
875    /// - The matrix is not square
876    /// - The matrix is singular (determinant is zero)
877    ///
878    /// # Examples
879    ///
880    /// ```
881    /// use thales::matrix::MatrixExpr;
882    /// use thales::ast::Expression;
883    /// use std::collections::HashMap;
884    ///
885    /// let m = MatrixExpr::from_elements(vec![
886    ///     vec![Expression::Integer(4), Expression::Integer(7)],
887    ///     vec![Expression::Integer(2), Expression::Integer(6)],
888    /// ]).unwrap();
889    ///
890    /// let inv = m.inverse().unwrap();
891    /// // Verify A * A^(-1) = I
892    /// let product = m.mul(&inv).unwrap();
893    /// let vars = HashMap::new();
894    /// let result = product.evaluate(&vars).unwrap();
895    /// assert!((result[0][0] - 1.0).abs() < 1e-10);
896    /// assert!((result[1][1] - 1.0).abs() < 1e-10);
897    /// ```
898    pub fn inverse(&self) -> MatrixResult<MatrixExpr> {
899        if !self.is_square() {
900            return Err(MatrixError::InvalidOperation(
901                "Inverse requires a square matrix".to_string(),
902            ));
903        }
904
905        let det = self.determinant()?;
906
907        // Check if determinant is zero (symbolically or numerically)
908        let is_zero = match &det {
909            Expression::Integer(0) => true,
910            Expression::Float(f) if f.abs() < 1e-10 => true,
911            _ => {
912                // Try numerical evaluation for expressions that simplify to zero
913                let empty = std::collections::HashMap::new();
914                det.evaluate(&empty).map_or(false, |v| v.abs() < 1e-10)
915            }
916        };
917
918        if is_zero {
919            return Err(MatrixError::InvalidOperation(
920                "Matrix is singular (determinant is zero)".to_string(),
921            ));
922        }
923
924        // For 1x1 matrix
925        if self.rows == 1 {
926            let inv_element = Expression::Binary(
927                crate::ast::BinaryOp::Div,
928                Box::new(Expression::Integer(1)),
929                Box::new(self.elements[0][0].clone()),
930            )
931            .simplify();
932            return MatrixExpr::from_elements(vec![vec![inv_element]]);
933        }
934
935        let adj = self.adjugate()?;
936
937        // Multiply adjugate by 1/det
938        let inv_det = Expression::Binary(
939            crate::ast::BinaryOp::Div,
940            Box::new(Expression::Integer(1)),
941            Box::new(det),
942        );
943
944        Ok(adj.scalar_mul(&inv_det).simplify())
945    }
946
947    /// Check if the matrix is singular (determinant is zero when evaluated numerically).
948    ///
949    /// Returns `None` if the determinant cannot be evaluated numerically.
950    pub fn is_singular(&self, vars: &std::collections::HashMap<String, f64>) -> Option<bool> {
951        let det = self.determinant().ok()?;
952        let det_value = det.evaluate(vars)?;
953        Some(det_value.abs() < 1e-10)
954    }
955
956    /// Compute the characteristic polynomial det(A - λI).
957    ///
958    /// Returns a polynomial expression in the given variable (typically "lambda").
959    ///
960    /// # Errors
961    ///
962    /// Returns an error if the matrix is not square.
963    ///
964    /// # Examples
965    ///
966    /// ```
967    /// use thales::matrix::MatrixExpr;
968    /// use thales::ast::Expression;
969    /// use std::collections::HashMap;
970    ///
971    /// let m = MatrixExpr::from_elements(vec![
972    ///     vec![Expression::Integer(2), Expression::Integer(1)],
973    ///     vec![Expression::Integer(1), Expression::Integer(2)],
974    /// ]).unwrap();
975    ///
976    /// let char_poly = m.characteristic_polynomial("lambda").unwrap();
977    /// // For this matrix, eigenvalues are 1 and 3
978    /// // So char poly = (λ - 1)(λ - 3) = λ² - 4λ + 3
979    /// ```
980    pub fn characteristic_polynomial(&self, lambda_var: &str) -> MatrixResult<Expression> {
981        if !self.is_square() {
982            return Err(MatrixError::InvalidOperation(
983                "Characteristic polynomial requires a square matrix".to_string(),
984            ));
985        }
986
987        // Compute A - λI
988        let lambda = Expression::Variable(Variable::new(lambda_var));
989        let lambda_i = MatrixExpr::identity(self.rows).scalar_mul(&lambda);
990        let a_minus_lambda_i = self.sub(&lambda_i)?;
991
992        // Compute det(A - λI)
993        a_minus_lambda_i.determinant()
994    }
995
996    /// Compute eigenvalues of the matrix numerically.
997    ///
998    /// For 2x2 matrices, uses the quadratic formula.
999    /// For larger matrices, uses numerical methods (power iteration or similar).
1000    ///
1001    /// # Errors
1002    ///
1003    /// Returns an error if the matrix is not square.
1004    ///
1005    /// # Examples
1006    ///
1007    /// ```
1008    /// use thales::matrix::MatrixExpr;
1009    /// use thales::ast::Expression;
1010    ///
1011    /// let m = MatrixExpr::from_elements(vec![
1012    ///     vec![Expression::Integer(2), Expression::Integer(1)],
1013    ///     vec![Expression::Integer(1), Expression::Integer(2)],
1014    /// ]).unwrap();
1015    ///
1016    /// let eigenvalues = m.eigenvalues_numeric().unwrap();
1017    /// // Eigenvalues should be 1 and 3
1018    /// ```
1019    pub fn eigenvalues_numeric(&self) -> MatrixResult<Vec<f64>> {
1020        if !self.is_square() {
1021            return Err(MatrixError::InvalidOperation(
1022                "Eigenvalues require a square matrix".to_string(),
1023            ));
1024        }
1025
1026        let empty = std::collections::HashMap::new();
1027        let elements = self.evaluate(&empty).ok_or_else(|| {
1028            MatrixError::InvalidOperation("Cannot evaluate matrix numerically".to_string())
1029        })?;
1030
1031        match self.rows {
1032            1 => Ok(vec![elements[0][0]]),
1033            2 => self.eigenvalues_2x2(&elements),
1034            3 => self.eigenvalues_3x3(&elements),
1035            _ => self.eigenvalues_qr(&elements),
1036        }
1037    }
1038
1039    /// Compute eigenvalues for a 2x2 matrix using the quadratic formula.
1040    fn eigenvalues_2x2(&self, elements: &[Vec<f64>]) -> MatrixResult<Vec<f64>> {
1041        let a = elements[0][0];
1042        let b = elements[0][1];
1043        let c = elements[1][0];
1044        let d = elements[1][1];
1045
1046        // Characteristic equation: λ² - (a+d)λ + (ad - bc) = 0
1047        // Using quadratic formula: λ = ((a+d) ± sqrt((a+d)² - 4(ad-bc))) / 2
1048        let trace = a + d;
1049        let det = a * d - b * c;
1050        let discriminant = trace * trace - 4.0 * det;
1051
1052        if discriminant < 0.0 {
1053            // Complex eigenvalues - return just the real parts for now
1054            // A full implementation would return Complex numbers
1055            let real_part = trace / 2.0;
1056            Ok(vec![real_part, real_part])
1057        } else {
1058            let sqrt_disc = discriminant.sqrt();
1059            let lambda1 = (trace + sqrt_disc) / 2.0;
1060            let lambda2 = (trace - sqrt_disc) / 2.0;
1061            Ok(vec![lambda1, lambda2])
1062        }
1063    }
1064
1065    /// Compute eigenvalues for a 3x3 matrix using Cardano's formula.
1066    fn eigenvalues_3x3(&self, elements: &[Vec<f64>]) -> MatrixResult<Vec<f64>> {
1067        // For 3x3, we solve the cubic characteristic equation
1068        // det(A - λI) = -λ³ + tr(A)λ² - (sum of 2x2 principal minors)λ + det(A)
1069        let a11 = elements[0][0];
1070        let a12 = elements[0][1];
1071        let a13 = elements[0][2];
1072        let a21 = elements[1][0];
1073        let a22 = elements[1][1];
1074        let a23 = elements[1][2];
1075        let a31 = elements[2][0];
1076        let a32 = elements[2][1];
1077        let a33 = elements[2][2];
1078
1079        // Coefficients of λ³ + p*λ² + q*λ + r = 0
1080        let trace = a11 + a22 + a33;
1081        let p = -trace;
1082
1083        // Sum of 2x2 principal minors
1084        let minor12 = a11 * a22 - a12 * a21;
1085        let minor13 = a11 * a33 - a13 * a31;
1086        let minor23 = a22 * a33 - a23 * a32;
1087        let q = minor12 + minor13 + minor23;
1088
1089        // Determinant
1090        let det = a11 * (a22 * a33 - a23 * a32) - a12 * (a21 * a33 - a23 * a31)
1091            + a13 * (a21 * a32 - a22 * a31);
1092        let r = -det;
1093
1094        // Solve cubic using Cardano's formula or numerical method
1095        solve_cubic(p, q, r)
1096    }
1097
1098    /// Compute eigenvalues using QR algorithm for larger matrices.
1099    fn eigenvalues_qr(&self, elements: &[Vec<f64>]) -> MatrixResult<Vec<f64>> {
1100        // Simple QR iteration
1101        let n = elements.len();
1102        let mut a = elements.to_vec();
1103
1104        // Maximum iterations
1105        const MAX_ITER: usize = 100;
1106        const TOL: f64 = 1e-10;
1107
1108        for _ in 0..MAX_ITER {
1109            // QR decomposition
1110            let (q, r) = qr_decomposition(&a);
1111
1112            // A = R * Q
1113            a = matrix_multiply(&r, &q);
1114
1115            // Check for convergence (off-diagonal elements small)
1116            let mut converged = true;
1117            for i in 0..n {
1118                for j in 0..i {
1119                    if a[i][j].abs() > TOL {
1120                        converged = false;
1121                        break;
1122                    }
1123                }
1124                if !converged {
1125                    break;
1126                }
1127            }
1128
1129            if converged {
1130                break;
1131            }
1132        }
1133
1134        // Extract eigenvalues from diagonal
1135        Ok((0..n).map(|i| a[i][i]).collect())
1136    }
1137
1138    /// Compute eigenvector for a given eigenvalue numerically.
1139    ///
1140    /// Returns the eigenvector as a column matrix.
1141    ///
1142    /// # Errors
1143    ///
1144    /// Returns an error if the matrix is not square.
1145    pub fn eigenvector_numeric(&self, eigenvalue: f64) -> MatrixResult<Vec<f64>> {
1146        if !self.is_square() {
1147            return Err(MatrixError::InvalidOperation(
1148                "Eigenvector requires a square matrix".to_string(),
1149            ));
1150        }
1151
1152        let empty = std::collections::HashMap::new();
1153        let elements = self.evaluate(&empty).ok_or_else(|| {
1154            MatrixError::InvalidOperation("Cannot evaluate matrix numerically".to_string())
1155        })?;
1156
1157        let n = self.rows;
1158
1159        // Compute A - λI
1160        let mut a_minus_lambda: Vec<Vec<f64>> = elements.clone();
1161        for i in 0..n {
1162            a_minus_lambda[i][i] -= eigenvalue;
1163        }
1164
1165        // Use inverse iteration to find eigenvector
1166        // Start with a random vector
1167        let mut v: Vec<f64> = (0..n).map(|i| (i + 1) as f64).collect();
1168
1169        // Normalize
1170        let norm: f64 = v.iter().map(|x| x * x).sum::<f64>().sqrt();
1171        for x in &mut v {
1172            *x /= norm;
1173        }
1174
1175        // Inverse iteration: solve (A - λI)w = v, then v = w/||w||
1176        // Since A - λI is singular (or near-singular), we perturb slightly
1177        const MAX_ITER: usize = 50;
1178        const TOL: f64 = 1e-8;
1179
1180        for _ in 0..MAX_ITER {
1181            // Solve (A - λI + εI)w = v using Gaussian elimination
1182            let mut augmented = a_minus_lambda.clone();
1183            for i in 0..n {
1184                augmented[i][i] += 1e-10; // Small perturbation
1185            }
1186
1187            // Solve using Gaussian elimination
1188            let w = solve_linear_system(&augmented, &v);
1189
1190            // Normalize
1191            let norm: f64 = w.iter().map(|x| x * x).sum::<f64>().sqrt();
1192            if norm < 1e-14 {
1193                break;
1194            }
1195
1196            let w_normalized: Vec<f64> = w.iter().map(|x| x / norm).collect();
1197
1198            // Check convergence
1199            let diff: f64 = v
1200                .iter()
1201                .zip(w_normalized.iter())
1202                .map(|(a, b)| (a - b).abs())
1203                .sum();
1204
1205            v = w_normalized;
1206
1207            if diff < TOL {
1208                break;
1209            }
1210        }
1211
1212        Ok(v)
1213    }
1214
1215    /// Compute all eigenpairs (eigenvalue, eigenvector) numerically.
1216    ///
1217    /// # Errors
1218    ///
1219    /// Returns an error if the matrix is not square.
1220    pub fn eigenpairs_numeric(&self) -> MatrixResult<Vec<(f64, Vec<f64>)>> {
1221        let eigenvalues = self.eigenvalues_numeric()?;
1222        let mut pairs = Vec::with_capacity(eigenvalues.len());
1223
1224        for eigenvalue in eigenvalues {
1225            let eigenvector = self.eigenvector_numeric(eigenvalue)?;
1226            pairs.push((eigenvalue, eigenvector));
1227        }
1228
1229        Ok(pairs)
1230    }
1231
1232    /// Check if the matrix is diagonalizable.
1233    ///
1234    /// A matrix is diagonalizable if it has n linearly independent eigenvectors.
1235    pub fn is_diagonalizable(&self) -> MatrixResult<bool> {
1236        if !self.is_square() {
1237            return Err(MatrixError::InvalidOperation(
1238                "Diagonalizability check requires a square matrix".to_string(),
1239            ));
1240        }
1241
1242        // A simple check: symmetric matrices are always diagonalizable
1243        let transpose = self.transpose();
1244        let empty = std::collections::HashMap::new();
1245
1246        if let (Some(a), Some(at)) = (self.evaluate(&empty), transpose.evaluate(&empty)) {
1247            let is_symmetric = a.iter().zip(at.iter()).all(|(row_a, row_at)| {
1248                row_a
1249                    .iter()
1250                    .zip(row_at.iter())
1251                    .all(|(x, y)| (x - y).abs() < 1e-10)
1252            });
1253
1254            if is_symmetric {
1255                return Ok(true);
1256            }
1257        }
1258
1259        // For non-symmetric matrices, we would need to check algebraic vs geometric multiplicity
1260        // This is a simplified check - return true if we can compute distinct eigenvalues
1261        let eigenvalues = self.eigenvalues_numeric()?;
1262
1263        // Check if all eigenvalues are distinct (sufficient condition)
1264        for (i, &ev1) in eigenvalues.iter().enumerate() {
1265            for (j, &ev2) in eigenvalues.iter().enumerate() {
1266                if i != j && (ev1 - ev2).abs() < 1e-10 {
1267                    // Repeated eigenvalue - would need to check geometric multiplicity
1268                    // For simplicity, assume diagonalizable
1269                    return Ok(true);
1270                }
1271            }
1272        }
1273
1274        Ok(true)
1275    }
1276
1277    /// Render the matrix as LaTeX.
1278    ///
1279    /// # Examples
1280    ///
1281    /// ```
1282    /// use thales::matrix::{MatrixExpr, BracketStyle};
1283    /// use thales::ast::Expression;
1284    ///
1285    /// let m = MatrixExpr::from_elements(vec![
1286    ///     vec![Expression::Integer(1), Expression::Integer(2)],
1287    ///     vec![Expression::Integer(3), Expression::Integer(4)],
1288    /// ]).unwrap();
1289    ///
1290    /// let latex = m.to_latex(BracketStyle::Parentheses);
1291    /// assert!(latex.contains("pmatrix"));
1292    /// ```
1293    pub fn to_latex(&self, style: BracketStyle) -> String {
1294        let env = match style {
1295            BracketStyle::Parentheses => "pmatrix",
1296            BracketStyle::Square => "bmatrix",
1297            BracketStyle::Curly => "Bmatrix",
1298            BracketStyle::Determinant => "vmatrix",
1299            BracketStyle::Norm => "Vmatrix",
1300            BracketStyle::None => "matrix",
1301        };
1302
1303        let mut result = format!("\\begin{{{}}}\n", env);
1304        for (i, row) in self.elements.iter().enumerate() {
1305            let row_str: Vec<String> = row.iter().map(|e| e.to_latex()).collect();
1306            result.push_str(&row_str.join(" & "));
1307            if i < self.rows - 1 {
1308                result.push_str(" \\\\\n");
1309            } else {
1310                result.push('\n');
1311            }
1312        }
1313        result.push_str(&format!("\\end{{{}}}", env));
1314        result
1315    }
1316
1317    /// Render the matrix as LaTeX with default parentheses style.
1318    pub fn to_latex_default(&self) -> String {
1319        self.to_latex(BracketStyle::default())
1320    }
1321
1322    /// Evaluate all elements numerically.
1323    ///
1324    /// Returns None if any element cannot be evaluated.
1325    pub fn evaluate(&self, vars: &std::collections::HashMap<String, f64>) -> Option<Vec<Vec<f64>>> {
1326        self.elements
1327            .iter()
1328            .map(|row| {
1329                row.iter()
1330                    .map(|elem| elem.evaluate(vars))
1331                    .collect::<Option<Vec<f64>>>()
1332            })
1333            .collect()
1334    }
1335}
1336
1337impl fmt::Display for MatrixExpr {
1338    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1339        write!(f, "[")?;
1340        for (i, row) in self.elements.iter().enumerate() {
1341            if i > 0 {
1342                write!(f, "; ")?;
1343            }
1344            write!(f, "[")?;
1345            for (j, elem) in row.iter().enumerate() {
1346                if j > 0 {
1347                    write!(f, ", ")?;
1348                }
1349                write!(f, "{}", elem)?;
1350            }
1351            write!(f, "]")?;
1352        }
1353        write!(f, "]")
1354    }
1355}
1356
1357// =============================================================================
1358// Helper functions for eigenvalue computation
1359// =============================================================================
1360
1361/// Solve cubic equation x³ + p*x² + q*x + r = 0 using Cardano's formula.
1362fn solve_cubic(p: f64, q: f64, r: f64) -> MatrixResult<Vec<f64>> {
1363    // Depress the cubic: substitute x = t - p/3
1364    // t³ + at + b = 0 where:
1365    // a = q - p²/3
1366    // b = r - pq/3 + 2p³/27
1367    let a = q - p * p / 3.0;
1368    let b = r - p * q / 3.0 + 2.0 * p * p * p / 27.0;
1369
1370    // Discriminant
1371    let discriminant = -4.0 * a * a * a - 27.0 * b * b;
1372
1373    let offset = -p / 3.0;
1374
1375    if discriminant > 0.0 {
1376        // Three distinct real roots
1377        let theta = (-b / 2.0 / ((-a / 3.0).powi(3).sqrt())).acos();
1378        let r_cubed = (-a / 3.0).sqrt();
1379
1380        let t1 = 2.0 * r_cubed * (theta / 3.0).cos();
1381        let t2 = 2.0 * r_cubed * ((theta + 2.0 * std::f64::consts::PI) / 3.0).cos();
1382        let t3 = 2.0 * r_cubed * ((theta + 4.0 * std::f64::consts::PI) / 3.0).cos();
1383
1384        Ok(vec![t1 + offset, t2 + offset, t3 + offset])
1385    } else if discriminant.abs() < 1e-10 {
1386        // Multiple roots
1387        if b.abs() < 1e-10 {
1388            // Triple root
1389            Ok(vec![offset, offset, offset])
1390        } else {
1391            // Double root
1392            let double_root = 3.0 * b / a;
1393            let simple_root = -3.0 * b / (2.0 * a);
1394            Ok(vec![
1395                double_root + offset,
1396                simple_root + offset,
1397                simple_root + offset,
1398            ])
1399        }
1400    } else {
1401        // One real root, two complex (return real root 3 times for now)
1402        let sqrt_disc = (b * b / 4.0 + a * a * a / 27.0).sqrt();
1403        let u = (-b / 2.0 + sqrt_disc).cbrt();
1404        let v = (-b / 2.0 - sqrt_disc).cbrt();
1405        let real_root = u + v + offset;
1406
1407        // Return real root; complex roots have same real part
1408        let complex_real = -(u + v) / 2.0 + offset;
1409        Ok(vec![real_root, complex_real, complex_real])
1410    }
1411}
1412
1413/// QR decomposition using Gram-Schmidt process.
1414fn qr_decomposition(a: &[Vec<f64>]) -> (Vec<Vec<f64>>, Vec<Vec<f64>>) {
1415    let n = a.len();
1416    let mut q = vec![vec![0.0; n]; n];
1417    let mut r = vec![vec![0.0; n]; n];
1418
1419    for j in 0..n {
1420        // Start with column j of A
1421        let mut v: Vec<f64> = (0..n).map(|i| a[i][j]).collect();
1422
1423        // Subtract projections onto previous q vectors
1424        for i in 0..j {
1425            let q_i: Vec<f64> = (0..n).map(|k| q[k][i]).collect();
1426            r[i][j] = dot_product(&q_i, &v);
1427            for k in 0..n {
1428                v[k] -= r[i][j] * q_i[k];
1429            }
1430        }
1431
1432        // Compute norm and normalize
1433        r[j][j] = v.iter().map(|x| x * x).sum::<f64>().sqrt();
1434        if r[j][j] > 1e-14 {
1435            for k in 0..n {
1436                q[k][j] = v[k] / r[j][j];
1437            }
1438        }
1439    }
1440
1441    (q, r)
1442}
1443
1444/// Dot product of two vectors.
1445fn dot_product(a: &[f64], b: &[f64]) -> f64 {
1446    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
1447}
1448
1449/// Matrix multiplication for f64 matrices.
1450fn matrix_multiply(a: &[Vec<f64>], b: &[Vec<f64>]) -> Vec<Vec<f64>> {
1451    let n = a.len();
1452    let mut result = vec![vec![0.0; n]; n];
1453
1454    for i in 0..n {
1455        for j in 0..n {
1456            for k in 0..n {
1457                result[i][j] += a[i][k] * b[k][j];
1458            }
1459        }
1460    }
1461
1462    result
1463}
1464
1465/// Solve linear system Ax = b using Gaussian elimination with partial pivoting.
1466fn solve_linear_system(a: &[Vec<f64>], b: &[f64]) -> Vec<f64> {
1467    let n = a.len();
1468
1469    // Create augmented matrix
1470    let mut aug: Vec<Vec<f64>> = a.iter().cloned().collect();
1471    let mut rhs = b.to_vec();
1472
1473    // Forward elimination with partial pivoting
1474    for k in 0..n {
1475        // Find pivot
1476        let mut max_row = k;
1477        let mut max_val = aug[k][k].abs();
1478        for i in (k + 1)..n {
1479            if aug[i][k].abs() > max_val {
1480                max_val = aug[i][k].abs();
1481                max_row = i;
1482            }
1483        }
1484
1485        // Swap rows
1486        if max_row != k {
1487            aug.swap(k, max_row);
1488            rhs.swap(k, max_row);
1489        }
1490
1491        // Eliminate
1492        if aug[k][k].abs() > 1e-14 {
1493            for i in (k + 1)..n {
1494                let factor = aug[i][k] / aug[k][k];
1495                for j in k..n {
1496                    aug[i][j] -= factor * aug[k][j];
1497                }
1498                rhs[i] -= factor * rhs[k];
1499            }
1500        }
1501    }
1502
1503    // Back substitution
1504    let mut x = vec![0.0; n];
1505    for i in (0..n).rev() {
1506        if aug[i][i].abs() > 1e-14 {
1507            x[i] = rhs[i];
1508            for j in (i + 1)..n {
1509                x[i] -= aug[i][j] * x[j];
1510            }
1511            x[i] /= aug[i][i];
1512        }
1513    }
1514
1515    x
1516}
1517
1518#[cfg(test)]
1519mod tests {
1520    use super::*;
1521    use crate::ast::{Expression, Variable};
1522    use std::collections::HashMap;
1523
1524    fn int(n: i64) -> Expression {
1525        Expression::Integer(n)
1526    }
1527
1528    fn var(name: &str) -> Expression {
1529        Expression::Variable(Variable::new(name))
1530    }
1531
1532    #[test]
1533    fn test_matrix_creation() {
1534        let m =
1535            MatrixExpr::from_elements(vec![vec![int(1), int(2)], vec![int(3), int(4)]]).unwrap();
1536
1537        assert_eq!(m.rows(), 2);
1538        assert_eq!(m.cols(), 2);
1539        assert!(m.is_square());
1540    }
1541
1542    #[test]
1543    fn test_identity_matrix() {
1544        let i3 = MatrixExpr::identity(3);
1545        assert_eq!(i3.rows(), 3);
1546        assert_eq!(i3.cols(), 3);
1547
1548        // Check diagonal elements are 1
1549        assert_eq!(i3.get(0, 0).unwrap(), &int(1));
1550        assert_eq!(i3.get(1, 1).unwrap(), &int(1));
1551        assert_eq!(i3.get(2, 2).unwrap(), &int(1));
1552
1553        // Check off-diagonal elements are 0
1554        assert_eq!(i3.get(0, 1).unwrap(), &int(0));
1555        assert_eq!(i3.get(1, 2).unwrap(), &int(0));
1556    }
1557
1558    #[test]
1559    fn test_zero_matrix() {
1560        let z = MatrixExpr::zero(2, 3);
1561        assert_eq!(z.rows(), 2);
1562        assert_eq!(z.cols(), 3);
1563
1564        for i in 0..2 {
1565            for j in 0..3 {
1566                assert_eq!(z.get(i, j).unwrap(), &int(0));
1567            }
1568        }
1569    }
1570
1571    #[test]
1572    fn test_diagonal_matrix() {
1573        let d = MatrixExpr::diagonal(vec![int(1), int(2), int(3)]);
1574        assert_eq!(d.rows(), 3);
1575        assert_eq!(d.cols(), 3);
1576
1577        assert_eq!(d.get(0, 0).unwrap(), &int(1));
1578        assert_eq!(d.get(1, 1).unwrap(), &int(2));
1579        assert_eq!(d.get(2, 2).unwrap(), &int(3));
1580        assert_eq!(d.get(0, 1).unwrap(), &int(0));
1581    }
1582
1583    #[test]
1584    fn test_transpose() {
1585        let m = MatrixExpr::from_elements(vec![
1586            vec![int(1), int(2), int(3)],
1587            vec![int(4), int(5), int(6)],
1588        ])
1589        .unwrap();
1590
1591        let mt = m.transpose();
1592        assert_eq!(mt.rows(), 3);
1593        assert_eq!(mt.cols(), 2);
1594
1595        assert_eq!(mt.get(0, 0).unwrap(), &int(1));
1596        assert_eq!(mt.get(0, 1).unwrap(), &int(4));
1597        assert_eq!(mt.get(1, 0).unwrap(), &int(2));
1598        assert_eq!(mt.get(2, 1).unwrap(), &int(6));
1599    }
1600
1601    #[test]
1602    fn test_double_transpose() {
1603        let m =
1604            MatrixExpr::from_elements(vec![vec![int(1), int(2)], vec![int(3), int(4)]]).unwrap();
1605
1606        let mtt = m.transpose().transpose();
1607        assert_eq!(mtt.elements, m.elements);
1608    }
1609
1610    #[test]
1611    fn test_trace() {
1612        let m =
1613            MatrixExpr::from_elements(vec![vec![int(1), int(2)], vec![int(3), int(4)]]).unwrap();
1614
1615        let trace = m.trace().unwrap();
1616        let vars = HashMap::new();
1617        assert_eq!(trace.evaluate(&vars), Some(5.0));
1618    }
1619
1620    #[test]
1621    fn test_addition() {
1622        let a =
1623            MatrixExpr::from_elements(vec![vec![int(1), int(2)], vec![int(3), int(4)]]).unwrap();
1624
1625        let b =
1626            MatrixExpr::from_elements(vec![vec![int(5), int(6)], vec![int(7), int(8)]]).unwrap();
1627
1628        let sum = a.add(&b).unwrap();
1629        let vars = HashMap::new();
1630
1631        assert_eq!(sum.get(0, 0).unwrap().evaluate(&vars), Some(6.0));
1632        assert_eq!(sum.get(0, 1).unwrap().evaluate(&vars), Some(8.0));
1633        assert_eq!(sum.get(1, 0).unwrap().evaluate(&vars), Some(10.0));
1634        assert_eq!(sum.get(1, 1).unwrap().evaluate(&vars), Some(12.0));
1635    }
1636
1637    #[test]
1638    fn test_addition_dimension_check() {
1639        let a = MatrixExpr::from_elements(vec![vec![int(1), int(2)]]).unwrap();
1640
1641        let b = MatrixExpr::from_elements(vec![vec![int(1)], vec![int(2)]]).unwrap();
1642
1643        let result = a.add(&b);
1644        assert!(result.is_err());
1645    }
1646
1647    #[test]
1648    fn test_matrix_multiplication() {
1649        // 2x3 * 3x2 = 2x2
1650        let a = MatrixExpr::from_elements(vec![
1651            vec![int(1), int(2), int(3)],
1652            vec![int(4), int(5), int(6)],
1653        ])
1654        .unwrap();
1655
1656        let b = MatrixExpr::from_elements(vec![
1657            vec![int(7), int(8)],
1658            vec![int(9), int(10)],
1659            vec![int(11), int(12)],
1660        ])
1661        .unwrap();
1662
1663        let c = a.mul(&b).unwrap();
1664        assert_eq!(c.rows(), 2);
1665        assert_eq!(c.cols(), 2);
1666
1667        let vars = HashMap::new();
1668        // C[0][0] = 1*7 + 2*9 + 3*11 = 7 + 18 + 33 = 58
1669        assert_eq!(c.get(0, 0).unwrap().evaluate(&vars), Some(58.0));
1670        // C[0][1] = 1*8 + 2*10 + 3*12 = 8 + 20 + 36 = 64
1671        assert_eq!(c.get(0, 1).unwrap().evaluate(&vars), Some(64.0));
1672        // C[1][0] = 4*7 + 5*9 + 6*11 = 28 + 45 + 66 = 139
1673        assert_eq!(c.get(1, 0).unwrap().evaluate(&vars), Some(139.0));
1674        // C[1][1] = 4*8 + 5*10 + 6*12 = 32 + 50 + 72 = 154
1675        assert_eq!(c.get(1, 1).unwrap().evaluate(&vars), Some(154.0));
1676    }
1677
1678    #[test]
1679    fn test_scalar_multiplication() {
1680        let m = MatrixExpr::identity(2);
1681        let scaled = m.scalar_mul(&int(3));
1682
1683        let vars = HashMap::new();
1684        assert_eq!(scaled.get(0, 0).unwrap().evaluate(&vars), Some(3.0));
1685        assert_eq!(scaled.get(1, 1).unwrap().evaluate(&vars), Some(3.0));
1686        assert_eq!(scaled.get(0, 1).unwrap().evaluate(&vars), Some(0.0));
1687    }
1688
1689    #[test]
1690    fn test_symbolic_matrix() {
1691        let m = MatrixExpr::from_elements(vec![vec![var("a"), var("b")], vec![var("c"), var("d")]])
1692            .unwrap();
1693
1694        let mut vars = HashMap::new();
1695        vars.insert("a".to_string(), 1.0);
1696        vars.insert("b".to_string(), 2.0);
1697        vars.insert("c".to_string(), 3.0);
1698        vars.insert("d".to_string(), 4.0);
1699
1700        let result = m.evaluate(&vars).unwrap();
1701        assert_eq!(result[0][0], 1.0);
1702        assert_eq!(result[0][1], 2.0);
1703        assert_eq!(result[1][0], 3.0);
1704        assert_eq!(result[1][1], 4.0);
1705    }
1706
1707    #[test]
1708    fn test_latex_output() {
1709        let m =
1710            MatrixExpr::from_elements(vec![vec![int(1), int(2)], vec![int(3), int(4)]]).unwrap();
1711
1712        let latex = m.to_latex(BracketStyle::Parentheses);
1713        assert!(latex.contains("\\begin{pmatrix}"));
1714        assert!(latex.contains("\\end{pmatrix}"));
1715        assert!(latex.contains("1 & 2"));
1716        assert!(latex.contains("3 & 4"));
1717    }
1718
1719    #[test]
1720    fn test_transpose_multiplication_property() {
1721        // (AB)^T = B^T A^T
1722        let a =
1723            MatrixExpr::from_elements(vec![vec![int(1), int(2)], vec![int(3), int(4)]]).unwrap();
1724
1725        let b =
1726            MatrixExpr::from_elements(vec![vec![int(5), int(6)], vec![int(7), int(8)]]).unwrap();
1727
1728        let ab = a.mul(&b).unwrap();
1729        let ab_t = ab.transpose();
1730
1731        let bt_at = b.transpose().mul(&a.transpose()).unwrap();
1732
1733        let vars = HashMap::new();
1734        for i in 0..2 {
1735            for j in 0..2 {
1736                assert_eq!(
1737                    ab_t.get(i, j).unwrap().evaluate(&vars),
1738                    bt_at.get(i, j).unwrap().evaluate(&vars)
1739                );
1740            }
1741        }
1742    }
1743
1744    #[test]
1745    fn test_determinant_2x2() {
1746        // det([[1, 2], [3, 4]]) = 1*4 - 2*3 = -2
1747        let m =
1748            MatrixExpr::from_elements(vec![vec![int(1), int(2)], vec![int(3), int(4)]]).unwrap();
1749
1750        let det = m.determinant().unwrap();
1751        let vars = HashMap::new();
1752        assert_eq!(det.evaluate(&vars), Some(-2.0));
1753    }
1754
1755    #[test]
1756    fn test_determinant_3x3() {
1757        // det([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) = 0 (rows are linearly dependent)
1758        let m = MatrixExpr::from_elements(vec![
1759            vec![int(1), int(2), int(3)],
1760            vec![int(4), int(5), int(6)],
1761            vec![int(7), int(8), int(9)],
1762        ])
1763        .unwrap();
1764
1765        let det = m.determinant().unwrap();
1766        let vars = HashMap::new();
1767        assert_eq!(det.evaluate(&vars), Some(0.0));
1768    }
1769
1770    #[test]
1771    fn test_determinant_3x3_nonzero() {
1772        // det([[1, 2, 3], [0, 1, 4], [5, 6, 0]]) = 1
1773        let m = MatrixExpr::from_elements(vec![
1774            vec![int(1), int(2), int(3)],
1775            vec![int(0), int(1), int(4)],
1776            vec![int(5), int(6), int(0)],
1777        ])
1778        .unwrap();
1779
1780        let det = m.determinant().unwrap();
1781        let vars = HashMap::new();
1782        assert_eq!(det.evaluate(&vars), Some(1.0));
1783    }
1784
1785    #[test]
1786    fn test_determinant_identity() {
1787        // det(I) = 1
1788        let i3 = MatrixExpr::identity(3);
1789        let det = i3.determinant().unwrap();
1790        let vars = HashMap::new();
1791        assert_eq!(det.evaluate(&vars), Some(1.0));
1792    }
1793
1794    #[test]
1795    fn test_determinant_non_square() {
1796        let m = MatrixExpr::from_elements(vec![
1797            vec![int(1), int(2), int(3)],
1798            vec![int(4), int(5), int(6)],
1799        ])
1800        .unwrap();
1801
1802        let result = m.determinant();
1803        assert!(result.is_err());
1804    }
1805
1806    #[test]
1807    fn test_inverse_2x2() {
1808        // A = [[4, 7], [2, 6]], det(A) = 24 - 14 = 10
1809        // A^(-1) = (1/10) * [[6, -7], [-2, 4]] = [[0.6, -0.7], [-0.2, 0.4]]
1810        let m =
1811            MatrixExpr::from_elements(vec![vec![int(4), int(7)], vec![int(2), int(6)]]).unwrap();
1812
1813        let inv = m.inverse().unwrap();
1814        let vars = HashMap::new();
1815
1816        // Verify A * A^(-1) = I
1817        let product = m.mul(&inv).unwrap();
1818        let result = product.evaluate(&vars).unwrap();
1819
1820        assert!((result[0][0] - 1.0).abs() < 1e-10);
1821        assert!((result[0][1] - 0.0).abs() < 1e-10);
1822        assert!((result[1][0] - 0.0).abs() < 1e-10);
1823        assert!((result[1][1] - 1.0).abs() < 1e-10);
1824    }
1825
1826    #[test]
1827    fn test_inverse_3x3() {
1828        // A = [[1, 2, 3], [0, 1, 4], [5, 6, 0]]
1829        let m = MatrixExpr::from_elements(vec![
1830            vec![int(1), int(2), int(3)],
1831            vec![int(0), int(1), int(4)],
1832            vec![int(5), int(6), int(0)],
1833        ])
1834        .unwrap();
1835
1836        let inv = m.inverse().unwrap();
1837        let vars = HashMap::new();
1838
1839        // Verify A * A^(-1) = I
1840        let product = m.mul(&inv).unwrap();
1841        let result = product.evaluate(&vars).unwrap();
1842
1843        for i in 0..3 {
1844            for j in 0..3 {
1845                let expected = if i == j { 1.0 } else { 0.0 };
1846                assert!(
1847                    (result[i][j] - expected).abs() < 1e-10,
1848                    "Expected {} at ({}, {}), got {}",
1849                    expected,
1850                    i,
1851                    j,
1852                    result[i][j]
1853                );
1854            }
1855        }
1856    }
1857
1858    #[test]
1859    fn test_inverse_singular_matrix() {
1860        // Singular matrix (det = 0)
1861        let m =
1862            MatrixExpr::from_elements(vec![vec![int(1), int(2)], vec![int(2), int(4)]]).unwrap();
1863
1864        let result = m.inverse();
1865        assert!(result.is_err());
1866    }
1867
1868    #[test]
1869    fn test_determinant_symbolic() {
1870        // det([[a, b], [c, d]]) = ad - bc
1871        let m = MatrixExpr::from_elements(vec![vec![var("a"), var("b")], vec![var("c"), var("d")]])
1872            .unwrap();
1873
1874        let det = m.determinant().unwrap();
1875
1876        let mut vars = HashMap::new();
1877        vars.insert("a".to_string(), 2.0);
1878        vars.insert("b".to_string(), 3.0);
1879        vars.insert("c".to_string(), 4.0);
1880        vars.insert("d".to_string(), 5.0);
1881
1882        // det = 2*5 - 3*4 = 10 - 12 = -2
1883        assert_eq!(det.evaluate(&vars), Some(-2.0));
1884    }
1885
1886    #[test]
1887    fn test_submatrix() {
1888        let m = MatrixExpr::from_elements(vec![
1889            vec![int(1), int(2), int(3)],
1890            vec![int(4), int(5), int(6)],
1891            vec![int(7), int(8), int(9)],
1892        ])
1893        .unwrap();
1894
1895        // Remove row 1, col 1 -> [[1, 3], [7, 9]]
1896        let sub = m.submatrix(1, 1).unwrap();
1897        let vars = HashMap::new();
1898
1899        assert_eq!(sub.rows(), 2);
1900        assert_eq!(sub.cols(), 2);
1901        assert_eq!(sub.get(0, 0).unwrap().evaluate(&vars), Some(1.0));
1902        assert_eq!(sub.get(0, 1).unwrap().evaluate(&vars), Some(3.0));
1903        assert_eq!(sub.get(1, 0).unwrap().evaluate(&vars), Some(7.0));
1904        assert_eq!(sub.get(1, 1).unwrap().evaluate(&vars), Some(9.0));
1905    }
1906
1907    #[test]
1908    fn test_adjugate_2x2() {
1909        // adj([[a, b], [c, d]]) = [[d, -b], [-c, a]]
1910        let m =
1911            MatrixExpr::from_elements(vec![vec![int(1), int(2)], vec![int(3), int(4)]]).unwrap();
1912
1913        let adj = m.adjugate().unwrap();
1914        let vars = HashMap::new();
1915
1916        assert_eq!(adj.get(0, 0).unwrap().evaluate(&vars), Some(4.0));
1917        assert_eq!(adj.get(0, 1).unwrap().evaluate(&vars), Some(-2.0));
1918        assert_eq!(adj.get(1, 0).unwrap().evaluate(&vars), Some(-3.0));
1919        assert_eq!(adj.get(1, 1).unwrap().evaluate(&vars), Some(1.0));
1920    }
1921
1922    #[test]
1923    fn test_is_singular() {
1924        let singular =
1925            MatrixExpr::from_elements(vec![vec![int(1), int(2)], vec![int(2), int(4)]]).unwrap();
1926
1927        let non_singular =
1928            MatrixExpr::from_elements(vec![vec![int(1), int(2)], vec![int(3), int(4)]]).unwrap();
1929
1930        let vars = HashMap::new();
1931        assert_eq!(singular.is_singular(&vars), Some(true));
1932        assert_eq!(non_singular.is_singular(&vars), Some(false));
1933    }
1934
1935    #[test]
1936    fn test_inverse_identity() {
1937        // I^(-1) = I
1938        let i3 = MatrixExpr::identity(3);
1939        let inv = i3.inverse().unwrap();
1940        let vars = HashMap::new();
1941
1942        for i in 0..3 {
1943            for j in 0..3 {
1944                let expected = if i == j { 1.0 } else { 0.0 };
1945                assert_eq!(inv.get(i, j).unwrap().evaluate(&vars), Some(expected));
1946            }
1947        }
1948    }
1949
1950    // =========================================================================
1951    // Eigenvalue and Eigenvector Tests
1952    // =========================================================================
1953
1954    #[test]
1955    fn test_characteristic_polynomial_2x2() {
1956        // A = [[2, 1], [1, 2]], eigenvalues are 1 and 3
1957        // char poly = (λ - 1)(λ - 3) = λ² - 4λ + 3
1958        let m =
1959            MatrixExpr::from_elements(vec![vec![int(2), int(1)], vec![int(1), int(2)]]).unwrap();
1960
1961        let char_poly = m.characteristic_polynomial("lambda").unwrap();
1962
1963        // Evaluate at λ = 1 (should be 0)
1964        let mut vars = HashMap::new();
1965        vars.insert("lambda".to_string(), 1.0);
1966        let at_1 = char_poly.evaluate(&vars).unwrap();
1967        assert!(
1968            at_1.abs() < 1e-10,
1969            "char poly at λ=1 should be 0, got {}",
1970            at_1
1971        );
1972
1973        // Evaluate at λ = 3 (should be 0)
1974        vars.insert("lambda".to_string(), 3.0);
1975        let at_3 = char_poly.evaluate(&vars).unwrap();
1976        assert!(
1977            at_3.abs() < 1e-10,
1978            "char poly at λ=3 should be 0, got {}",
1979            at_3
1980        );
1981    }
1982
1983    #[test]
1984    fn test_eigenvalues_2x2_symmetric() {
1985        // A = [[2, 1], [1, 2]], eigenvalues are 1 and 3
1986        let m =
1987            MatrixExpr::from_elements(vec![vec![int(2), int(1)], vec![int(1), int(2)]]).unwrap();
1988
1989        let eigenvalues = m.eigenvalues_numeric().unwrap();
1990        assert_eq!(eigenvalues.len(), 2);
1991
1992        // Sort eigenvalues for consistent comparison
1993        let mut sorted = eigenvalues.clone();
1994        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
1995
1996        assert!(
1997            (sorted[0] - 1.0).abs() < 1e-10,
1998            "Expected 1, got {}",
1999            sorted[0]
2000        );
2001        assert!(
2002            (sorted[1] - 3.0).abs() < 1e-10,
2003            "Expected 3, got {}",
2004            sorted[1]
2005        );
2006    }
2007
2008    #[test]
2009    fn test_eigenvalues_diagonal() {
2010        // Diagonal matrix: eigenvalues are the diagonal elements
2011        let m =
2012            MatrixExpr::from_elements(vec![vec![int(5), int(0)], vec![int(0), int(3)]]).unwrap();
2013
2014        let eigenvalues = m.eigenvalues_numeric().unwrap();
2015        let mut sorted = eigenvalues.clone();
2016        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
2017
2018        assert!((sorted[0] - 3.0).abs() < 1e-10);
2019        assert!((sorted[1] - 5.0).abs() < 1e-10);
2020    }
2021
2022    #[test]
2023    fn test_eigenvalues_identity() {
2024        // Identity matrix: all eigenvalues are 1
2025        let m = MatrixExpr::identity(3);
2026
2027        let eigenvalues = m.eigenvalues_numeric().unwrap();
2028        assert_eq!(eigenvalues.len(), 3);
2029
2030        for ev in eigenvalues {
2031            assert!((ev - 1.0).abs() < 1e-10);
2032        }
2033    }
2034
2035    #[test]
2036    fn test_eigenvector_2x2() {
2037        // A = [[2, 1], [1, 2]], eigenvalue 3 has eigenvector [1, 1]
2038        let m =
2039            MatrixExpr::from_elements(vec![vec![int(2), int(1)], vec![int(1), int(2)]]).unwrap();
2040
2041        let eigenvector = m.eigenvector_numeric(3.0).unwrap();
2042        assert_eq!(eigenvector.len(), 2);
2043
2044        // Check Av = λv (up to normalization)
2045        // v should be proportional to [1, 1]
2046        let ratio = eigenvector[0] / eigenvector[1];
2047        assert!(
2048            (ratio - 1.0).abs() < 1e-5,
2049            "Expected ratio 1, got {}",
2050            ratio
2051        );
2052    }
2053
2054    #[test]
2055    fn test_eigenpairs() {
2056        let m =
2057            MatrixExpr::from_elements(vec![vec![int(2), int(1)], vec![int(1), int(2)]]).unwrap();
2058
2059        let pairs = m.eigenpairs_numeric().unwrap();
2060        assert_eq!(pairs.len(), 2);
2061
2062        for (eigenvalue, eigenvector) in pairs {
2063            // Verify Av = λv
2064            let empty = HashMap::new();
2065            let a = m.evaluate(&empty).unwrap();
2066
2067            // Compute Av
2068            let av: Vec<f64> = (0..2)
2069                .map(|i| {
2070                    a[i].iter()
2071                        .zip(eigenvector.iter())
2072                        .map(|(a, v)| a * v)
2073                        .sum()
2074                })
2075                .collect();
2076
2077            // Compute λv
2078            let lambda_v: Vec<f64> = eigenvector.iter().map(|v| eigenvalue * v).collect();
2079
2080            // Check Av ≈ λv
2081            for i in 0..2 {
2082                assert!(
2083                    (av[i] - lambda_v[i]).abs() < 1e-5,
2084                    "Av[{}] = {}, λv[{}] = {}, eigenvalue = {}",
2085                    i,
2086                    av[i],
2087                    i,
2088                    lambda_v[i],
2089                    eigenvalue
2090                );
2091            }
2092        }
2093    }
2094
2095    #[test]
2096    fn test_eigenvalues_3x3() {
2097        // A simple 3x3 matrix with known eigenvalues
2098        // A = [[1, 0, 0], [0, 2, 0], [0, 0, 3]] has eigenvalues 1, 2, 3
2099        let m = MatrixExpr::from_elements(vec![
2100            vec![int(1), int(0), int(0)],
2101            vec![int(0), int(2), int(0)],
2102            vec![int(0), int(0), int(3)],
2103        ])
2104        .unwrap();
2105
2106        let eigenvalues = m.eigenvalues_numeric().unwrap();
2107        let mut sorted = eigenvalues.clone();
2108        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
2109
2110        assert!((sorted[0] - 1.0).abs() < 1e-10);
2111        assert!((sorted[1] - 2.0).abs() < 1e-10);
2112        assert!((sorted[2] - 3.0).abs() < 1e-10);
2113    }
2114
2115    #[test]
2116    fn test_is_diagonalizable_symmetric() {
2117        // Symmetric matrices are always diagonalizable
2118        let m =
2119            MatrixExpr::from_elements(vec![vec![int(2), int(1)], vec![int(1), int(2)]]).unwrap();
2120
2121        assert!(m.is_diagonalizable().unwrap());
2122    }
2123
2124    #[test]
2125    fn test_is_diagonalizable_identity() {
2126        let m = MatrixExpr::identity(3);
2127        assert!(m.is_diagonalizable().unwrap());
2128    }
2129
2130    #[test]
2131    fn test_eigenvalues_non_square() {
2132        let m = MatrixExpr::from_elements(vec![
2133            vec![int(1), int(2), int(3)],
2134            vec![int(4), int(5), int(6)],
2135        ])
2136        .unwrap();
2137
2138        let result = m.eigenvalues_numeric();
2139        assert!(result.is_err());
2140    }
2141
2142    #[test]
2143    fn test_characteristic_polynomial_non_square() {
2144        let m = MatrixExpr::from_elements(vec![
2145            vec![int(1), int(2), int(3)],
2146            vec![int(4), int(5), int(6)],
2147        ])
2148        .unwrap();
2149
2150        let result = m.characteristic_polynomial("lambda");
2151        assert!(result.is_err());
2152    }
2153}