1use crate::ast::{Expression, Variable};
31use std::fmt;
32
33#[derive(Debug, Clone, PartialEq)]
35#[non_exhaustive]
36pub enum MatrixError {
37 DimensionMismatch {
39 operation: String,
40 expected: (usize, usize),
41 got: (usize, usize),
42 },
43 EmptyMatrix,
45 NonRectangular,
47 IndexOutOfBounds {
49 row: usize,
50 col: usize,
51 rows: usize,
52 cols: usize,
53 },
54 InvalidOperation(String),
56}
57
58impl fmt::Display for MatrixError {
59 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60 match self {
61 MatrixError::DimensionMismatch {
62 operation,
63 expected,
64 got,
65 } => {
66 write!(
67 f,
68 "{}: expected {}x{}, got {}x{}",
69 operation, expected.0, expected.1, got.0, got.1
70 )
71 }
72 MatrixError::EmptyMatrix => write!(f, "Empty matrix not allowed"),
73 MatrixError::NonRectangular => {
74 write!(f, "Matrix must be rectangular (all rows same length)")
75 }
76 MatrixError::IndexOutOfBounds {
77 row,
78 col,
79 rows,
80 cols,
81 } => {
82 write!(
83 f,
84 "Index ({}, {}) out of bounds for {}x{} matrix",
85 row, col, rows, cols
86 )
87 }
88 MatrixError::InvalidOperation(msg) => write!(f, "Invalid operation: {}", msg),
89 }
90 }
91}
92
93impl std::error::Error for MatrixError {}
94
95pub type MatrixResult<T> = Result<T, MatrixError>;
97
98#[derive(Debug, Clone, Copy, PartialEq)]
100pub enum BracketStyle {
101 Parentheses,
103 Square,
105 Curly,
107 Determinant,
109 Norm,
111 None,
113}
114
115impl Default for BracketStyle {
116 fn default() -> Self {
117 BracketStyle::Parentheses
118 }
119}
120
121#[derive(Debug, Clone, PartialEq)]
148pub struct MatrixExpr {
149 rows: usize,
150 cols: usize,
151 elements: Vec<Vec<Expression>>,
152}
153
154impl MatrixExpr {
155 pub fn from_elements(elements: Vec<Vec<Expression>>) -> MatrixResult<Self> {
176 if elements.is_empty() || elements[0].is_empty() {
177 return Err(MatrixError::EmptyMatrix);
178 }
179
180 let cols = elements[0].len();
181 for row in &elements {
182 if row.len() != cols {
183 return Err(MatrixError::NonRectangular);
184 }
185 }
186
187 let rows = elements.len();
188 Ok(Self {
189 rows,
190 cols,
191 elements,
192 })
193 }
194
195 pub fn identity(n: usize) -> Self {
207 let elements: Vec<Vec<Expression>> = (0..n)
208 .map(|i| {
209 (0..n)
210 .map(|j| {
211 if i == j {
212 Expression::Integer(1)
213 } else {
214 Expression::Integer(0)
215 }
216 })
217 .collect()
218 })
219 .collect();
220 Self {
221 rows: n,
222 cols: n,
223 elements,
224 }
225 }
226
227 pub fn zero(rows: usize, cols: usize) -> Self {
239 let elements: Vec<Vec<Expression>> = (0..rows)
240 .map(|_| (0..cols).map(|_| Expression::Integer(0)).collect())
241 .collect();
242 Self {
243 rows,
244 cols,
245 elements,
246 }
247 }
248
249 pub fn diagonal(diag: Vec<Expression>) -> Self {
266 let n = diag.len();
267 let elements: Vec<Vec<Expression>> = (0..n)
268 .map(|i| {
269 (0..n)
270 .map(|j| {
271 if i == j {
272 diag[i].clone()
273 } else {
274 Expression::Integer(0)
275 }
276 })
277 .collect()
278 })
279 .collect();
280 Self {
281 rows: n,
282 cols: n,
283 elements,
284 }
285 }
286
287 pub fn rows(&self) -> usize {
289 self.rows
290 }
291
292 pub fn cols(&self) -> usize {
294 self.cols
295 }
296
297 pub fn dimensions(&self) -> (usize, usize) {
299 (self.rows, self.cols)
300 }
301
302 pub fn is_square(&self) -> bool {
304 self.rows == self.cols
305 }
306
307 pub fn get(&self, row: usize, col: usize) -> MatrixResult<&Expression> {
313 if row >= self.rows || col >= self.cols {
314 return Err(MatrixError::IndexOutOfBounds {
315 row,
316 col,
317 rows: self.rows,
318 cols: self.cols,
319 });
320 }
321 Ok(&self.elements[row][col])
322 }
323
324 pub fn set(&mut self, row: usize, col: usize, value: Expression) -> MatrixResult<()> {
330 if row >= self.rows || col >= self.cols {
331 return Err(MatrixError::IndexOutOfBounds {
332 row,
333 col,
334 rows: self.rows,
335 cols: self.cols,
336 });
337 }
338 self.elements[row][col] = value;
339 Ok(())
340 }
341
342 pub fn row(&self, index: usize) -> MatrixResult<&Vec<Expression>> {
344 if index >= self.rows {
345 return Err(MatrixError::IndexOutOfBounds {
346 row: index,
347 col: 0,
348 rows: self.rows,
349 cols: self.cols,
350 });
351 }
352 Ok(&self.elements[index])
353 }
354
355 pub fn col(&self, index: usize) -> MatrixResult<Vec<&Expression>> {
357 if index >= self.cols {
358 return Err(MatrixError::IndexOutOfBounds {
359 row: 0,
360 col: index,
361 rows: self.rows,
362 cols: self.cols,
363 });
364 }
365 Ok(self.elements.iter().map(|row| &row[index]).collect())
366 }
367
368 pub fn transpose(&self) -> Self {
386 let elements: Vec<Vec<Expression>> = (0..self.cols)
387 .map(|j| {
388 (0..self.rows)
389 .map(|i| self.elements[i][j].clone())
390 .collect()
391 })
392 .collect();
393 Self {
394 rows: self.cols,
395 cols: self.rows,
396 elements,
397 }
398 }
399
400 pub fn trace(&self) -> MatrixResult<Expression> {
423 if !self.is_square() {
424 return Err(MatrixError::InvalidOperation(
425 "Trace requires a square matrix".to_string(),
426 ));
427 }
428
429 let mut trace = self.elements[0][0].clone();
430 for i in 1..self.rows {
431 trace = Expression::Binary(
432 crate::ast::BinaryOp::Add,
433 Box::new(trace),
434 Box::new(self.elements[i][i].clone()),
435 );
436 }
437 Ok(trace.simplify())
438 }
439
440 pub fn add(&self, other: &MatrixExpr) -> MatrixResult<MatrixExpr> {
465 if self.rows != other.rows || self.cols != other.cols {
466 return Err(MatrixError::DimensionMismatch {
467 operation: "Matrix addition".to_string(),
468 expected: (self.rows, self.cols),
469 got: (other.rows, other.cols),
470 });
471 }
472
473 let elements: Vec<Vec<Expression>> = (0..self.rows)
474 .map(|i| {
475 (0..self.cols)
476 .map(|j| {
477 Expression::Binary(
478 crate::ast::BinaryOp::Add,
479 Box::new(self.elements[i][j].clone()),
480 Box::new(other.elements[i][j].clone()),
481 )
482 .simplify()
483 })
484 .collect()
485 })
486 .collect();
487
488 Ok(MatrixExpr {
489 rows: self.rows,
490 cols: self.cols,
491 elements,
492 })
493 }
494
495 pub fn sub(&self, other: &MatrixExpr) -> MatrixResult<MatrixExpr> {
501 if self.rows != other.rows || self.cols != other.cols {
502 return Err(MatrixError::DimensionMismatch {
503 operation: "Matrix subtraction".to_string(),
504 expected: (self.rows, self.cols),
505 got: (other.rows, other.cols),
506 });
507 }
508
509 let elements: Vec<Vec<Expression>> = (0..self.rows)
510 .map(|i| {
511 (0..self.cols)
512 .map(|j| {
513 Expression::Binary(
514 crate::ast::BinaryOp::Sub,
515 Box::new(self.elements[i][j].clone()),
516 Box::new(other.elements[i][j].clone()),
517 )
518 .simplify()
519 })
520 .collect()
521 })
522 .collect();
523
524 Ok(MatrixExpr {
525 rows: self.rows,
526 cols: self.cols,
527 elements,
528 })
529 }
530
531 pub fn scalar_mul(&self, scalar: &Expression) -> MatrixExpr {
543 let elements: Vec<Vec<Expression>> = self
544 .elements
545 .iter()
546 .map(|row| {
547 row.iter()
548 .map(|elem| {
549 Expression::Binary(
550 crate::ast::BinaryOp::Mul,
551 Box::new(scalar.clone()),
552 Box::new(elem.clone()),
553 )
554 .simplify()
555 })
556 .collect()
557 })
558 .collect();
559
560 MatrixExpr {
561 rows: self.rows,
562 cols: self.cols,
563 elements,
564 }
565 }
566
567 pub fn mul(&self, other: &MatrixExpr) -> MatrixResult<MatrixExpr> {
600 if self.cols != other.rows {
601 return Err(MatrixError::DimensionMismatch {
602 operation: format!(
603 "Matrix multiplication ({}x{} * {}x{})",
604 self.rows, self.cols, other.rows, other.cols
605 ),
606 expected: (self.cols, other.rows),
607 got: (self.cols, other.rows),
608 });
609 }
610
611 let elements: Vec<Vec<Expression>> = (0..self.rows)
612 .map(|i| {
613 (0..other.cols)
614 .map(|j| {
615 let mut sum = Expression::Binary(
617 crate::ast::BinaryOp::Mul,
618 Box::new(self.elements[i][0].clone()),
619 Box::new(other.elements[0][j].clone()),
620 );
621 for k in 1..self.cols {
622 let product = Expression::Binary(
623 crate::ast::BinaryOp::Mul,
624 Box::new(self.elements[i][k].clone()),
625 Box::new(other.elements[k][j].clone()),
626 );
627 sum = Expression::Binary(
628 crate::ast::BinaryOp::Add,
629 Box::new(sum),
630 Box::new(product),
631 );
632 }
633 sum.simplify()
634 })
635 .collect()
636 })
637 .collect();
638
639 Ok(MatrixExpr {
640 rows: self.rows,
641 cols: other.cols,
642 elements,
643 })
644 }
645
646 pub fn simplify(&self) -> MatrixExpr {
648 let elements: Vec<Vec<Expression>> = self
649 .elements
650 .iter()
651 .map(|row| row.iter().map(|elem| elem.simplify()).collect())
652 .collect();
653
654 MatrixExpr {
655 rows: self.rows,
656 cols: self.cols,
657 elements,
658 }
659 }
660
661 pub fn submatrix(&self, row_idx: usize, col_idx: usize) -> MatrixResult<MatrixExpr> {
669 if self.rows <= 1 || self.cols <= 1 {
670 return Err(MatrixError::InvalidOperation(
671 "Cannot compute submatrix of 1x1 or smaller matrix".to_string(),
672 ));
673 }
674
675 let elements: Vec<Vec<Expression>> = self
676 .elements
677 .iter()
678 .enumerate()
679 .filter(|(i, _)| *i != row_idx)
680 .map(|(_, row)| {
681 row.iter()
682 .enumerate()
683 .filter(|(j, _)| *j != col_idx)
684 .map(|(_, elem)| elem.clone())
685 .collect()
686 })
687 .collect();
688
689 MatrixExpr::from_elements(elements)
690 }
691
692 pub fn minor(&self, row: usize, col: usize) -> MatrixResult<Expression> {
698 if !self.is_square() {
699 return Err(MatrixError::InvalidOperation(
700 "Minor requires a square matrix".to_string(),
701 ));
702 }
703 let sub = self.submatrix(row, col)?;
704 sub.determinant()
705 }
706
707 pub fn cofactor(&self, row: usize, col: usize) -> MatrixResult<Expression> {
713 let minor = self.minor(row, col)?;
714 if (row + col) % 2 == 0 {
715 Ok(minor)
716 } else {
717 Ok(Expression::Unary(crate::ast::UnaryOp::Neg, Box::new(minor)).simplify())
718 }
719 }
720
721 pub fn determinant(&self) -> MatrixResult<Expression> {
750 if !self.is_square() {
751 return Err(MatrixError::InvalidOperation(
752 "Determinant requires a square matrix".to_string(),
753 ));
754 }
755
756 match self.rows {
757 1 => Ok(self.elements[0][0].clone()),
758 2 => {
759 let a = &self.elements[0][0];
761 let b = &self.elements[0][1];
762 let c = &self.elements[1][0];
763 let d = &self.elements[1][1];
764
765 let ad = Expression::Binary(
766 crate::ast::BinaryOp::Mul,
767 Box::new(a.clone()),
768 Box::new(d.clone()),
769 );
770 let bc = Expression::Binary(
771 crate::ast::BinaryOp::Mul,
772 Box::new(b.clone()),
773 Box::new(c.clone()),
774 );
775 Ok(
776 Expression::Binary(crate::ast::BinaryOp::Sub, Box::new(ad), Box::new(bc))
777 .simplify(),
778 )
779 }
780 _ => {
781 let mut det = Expression::Integer(0);
783 for j in 0..self.cols {
784 let cofactor = self.cofactor(0, j)?;
785 let term = Expression::Binary(
786 crate::ast::BinaryOp::Mul,
787 Box::new(self.elements[0][j].clone()),
788 Box::new(cofactor),
789 );
790 det = Expression::Binary(
791 crate::ast::BinaryOp::Add,
792 Box::new(det),
793 Box::new(term),
794 );
795 }
796 Ok(det.simplify())
797 }
798 }
799 }
800
801 pub fn cofactor_matrix(&self) -> MatrixResult<MatrixExpr> {
807 if !self.is_square() {
808 return Err(MatrixError::InvalidOperation(
809 "Cofactor matrix requires a square matrix".to_string(),
810 ));
811 }
812 if self.rows == 1 {
813 return Err(MatrixError::InvalidOperation(
814 "Cofactor matrix not defined for 1x1 matrix".to_string(),
815 ));
816 }
817
818 let mut elements = Vec::with_capacity(self.rows);
819 for i in 0..self.rows {
820 let mut row = Vec::with_capacity(self.cols);
821 for j in 0..self.cols {
822 row.push(self.cofactor(i, j)?);
823 }
824 elements.push(row);
825 }
826
827 MatrixExpr::from_elements(elements)
828 }
829
830 pub fn adjugate(&self) -> MatrixResult<MatrixExpr> {
853 if !self.is_square() {
854 return Err(MatrixError::InvalidOperation(
855 "Adjugate requires a square matrix".to_string(),
856 ));
857 }
858
859 if self.rows == 1 {
861 return Ok(MatrixExpr::from_elements(vec![vec![Expression::Integer(1)]]).unwrap());
862 }
863
864 let cofactor_mat = self.cofactor_matrix()?;
865 Ok(cofactor_mat.transpose())
866 }
867
868 pub fn inverse(&self) -> MatrixResult<MatrixExpr> {
899 if !self.is_square() {
900 return Err(MatrixError::InvalidOperation(
901 "Inverse requires a square matrix".to_string(),
902 ));
903 }
904
905 let det = self.determinant()?;
906
907 let is_zero = match &det {
909 Expression::Integer(0) => true,
910 Expression::Float(f) if f.abs() < 1e-10 => true,
911 _ => {
912 let empty = std::collections::HashMap::new();
914 det.evaluate(&empty).map_or(false, |v| v.abs() < 1e-10)
915 }
916 };
917
918 if is_zero {
919 return Err(MatrixError::InvalidOperation(
920 "Matrix is singular (determinant is zero)".to_string(),
921 ));
922 }
923
924 if self.rows == 1 {
926 let inv_element = Expression::Binary(
927 crate::ast::BinaryOp::Div,
928 Box::new(Expression::Integer(1)),
929 Box::new(self.elements[0][0].clone()),
930 )
931 .simplify();
932 return MatrixExpr::from_elements(vec![vec![inv_element]]);
933 }
934
935 let adj = self.adjugate()?;
936
937 let inv_det = Expression::Binary(
939 crate::ast::BinaryOp::Div,
940 Box::new(Expression::Integer(1)),
941 Box::new(det),
942 );
943
944 Ok(adj.scalar_mul(&inv_det).simplify())
945 }
946
947 pub fn is_singular(&self, vars: &std::collections::HashMap<String, f64>) -> Option<bool> {
951 let det = self.determinant().ok()?;
952 let det_value = det.evaluate(vars)?;
953 Some(det_value.abs() < 1e-10)
954 }
955
956 pub fn characteristic_polynomial(&self, lambda_var: &str) -> MatrixResult<Expression> {
981 if !self.is_square() {
982 return Err(MatrixError::InvalidOperation(
983 "Characteristic polynomial requires a square matrix".to_string(),
984 ));
985 }
986
987 let lambda = Expression::Variable(Variable::new(lambda_var));
989 let lambda_i = MatrixExpr::identity(self.rows).scalar_mul(&lambda);
990 let a_minus_lambda_i = self.sub(&lambda_i)?;
991
992 a_minus_lambda_i.determinant()
994 }
995
996 pub fn eigenvalues_numeric(&self) -> MatrixResult<Vec<f64>> {
1020 if !self.is_square() {
1021 return Err(MatrixError::InvalidOperation(
1022 "Eigenvalues require a square matrix".to_string(),
1023 ));
1024 }
1025
1026 let empty = std::collections::HashMap::new();
1027 let elements = self.evaluate(&empty).ok_or_else(|| {
1028 MatrixError::InvalidOperation("Cannot evaluate matrix numerically".to_string())
1029 })?;
1030
1031 match self.rows {
1032 1 => Ok(vec![elements[0][0]]),
1033 2 => self.eigenvalues_2x2(&elements),
1034 3 => self.eigenvalues_3x3(&elements),
1035 _ => self.eigenvalues_qr(&elements),
1036 }
1037 }
1038
1039 fn eigenvalues_2x2(&self, elements: &[Vec<f64>]) -> MatrixResult<Vec<f64>> {
1041 let a = elements[0][0];
1042 let b = elements[0][1];
1043 let c = elements[1][0];
1044 let d = elements[1][1];
1045
1046 let trace = a + d;
1049 let det = a * d - b * c;
1050 let discriminant = trace * trace - 4.0 * det;
1051
1052 if discriminant < 0.0 {
1053 let real_part = trace / 2.0;
1056 Ok(vec![real_part, real_part])
1057 } else {
1058 let sqrt_disc = discriminant.sqrt();
1059 let lambda1 = (trace + sqrt_disc) / 2.0;
1060 let lambda2 = (trace - sqrt_disc) / 2.0;
1061 Ok(vec![lambda1, lambda2])
1062 }
1063 }
1064
1065 fn eigenvalues_3x3(&self, elements: &[Vec<f64>]) -> MatrixResult<Vec<f64>> {
1067 let a11 = elements[0][0];
1070 let a12 = elements[0][1];
1071 let a13 = elements[0][2];
1072 let a21 = elements[1][0];
1073 let a22 = elements[1][1];
1074 let a23 = elements[1][2];
1075 let a31 = elements[2][0];
1076 let a32 = elements[2][1];
1077 let a33 = elements[2][2];
1078
1079 let trace = a11 + a22 + a33;
1081 let p = -trace;
1082
1083 let minor12 = a11 * a22 - a12 * a21;
1085 let minor13 = a11 * a33 - a13 * a31;
1086 let minor23 = a22 * a33 - a23 * a32;
1087 let q = minor12 + minor13 + minor23;
1088
1089 let det = a11 * (a22 * a33 - a23 * a32) - a12 * (a21 * a33 - a23 * a31)
1091 + a13 * (a21 * a32 - a22 * a31);
1092 let r = -det;
1093
1094 solve_cubic(p, q, r)
1096 }
1097
1098 fn eigenvalues_qr(&self, elements: &[Vec<f64>]) -> MatrixResult<Vec<f64>> {
1100 let n = elements.len();
1102 let mut a = elements.to_vec();
1103
1104 const MAX_ITER: usize = 100;
1106 const TOL: f64 = 1e-10;
1107
1108 for _ in 0..MAX_ITER {
1109 let (q, r) = qr_decomposition(&a);
1111
1112 a = matrix_multiply(&r, &q);
1114
1115 let mut converged = true;
1117 for i in 0..n {
1118 for j in 0..i {
1119 if a[i][j].abs() > TOL {
1120 converged = false;
1121 break;
1122 }
1123 }
1124 if !converged {
1125 break;
1126 }
1127 }
1128
1129 if converged {
1130 break;
1131 }
1132 }
1133
1134 Ok((0..n).map(|i| a[i][i]).collect())
1136 }
1137
1138 pub fn eigenvector_numeric(&self, eigenvalue: f64) -> MatrixResult<Vec<f64>> {
1146 if !self.is_square() {
1147 return Err(MatrixError::InvalidOperation(
1148 "Eigenvector requires a square matrix".to_string(),
1149 ));
1150 }
1151
1152 let empty = std::collections::HashMap::new();
1153 let elements = self.evaluate(&empty).ok_or_else(|| {
1154 MatrixError::InvalidOperation("Cannot evaluate matrix numerically".to_string())
1155 })?;
1156
1157 let n = self.rows;
1158
1159 let mut a_minus_lambda: Vec<Vec<f64>> = elements.clone();
1161 for i in 0..n {
1162 a_minus_lambda[i][i] -= eigenvalue;
1163 }
1164
1165 let mut v: Vec<f64> = (0..n).map(|i| (i + 1) as f64).collect();
1168
1169 let norm: f64 = v.iter().map(|x| x * x).sum::<f64>().sqrt();
1171 for x in &mut v {
1172 *x /= norm;
1173 }
1174
1175 const MAX_ITER: usize = 50;
1178 const TOL: f64 = 1e-8;
1179
1180 for _ in 0..MAX_ITER {
1181 let mut augmented = a_minus_lambda.clone();
1183 for i in 0..n {
1184 augmented[i][i] += 1e-10; }
1186
1187 let w = solve_linear_system(&augmented, &v);
1189
1190 let norm: f64 = w.iter().map(|x| x * x).sum::<f64>().sqrt();
1192 if norm < 1e-14 {
1193 break;
1194 }
1195
1196 let w_normalized: Vec<f64> = w.iter().map(|x| x / norm).collect();
1197
1198 let diff: f64 = v
1200 .iter()
1201 .zip(w_normalized.iter())
1202 .map(|(a, b)| (a - b).abs())
1203 .sum();
1204
1205 v = w_normalized;
1206
1207 if diff < TOL {
1208 break;
1209 }
1210 }
1211
1212 Ok(v)
1213 }
1214
1215 pub fn eigenpairs_numeric(&self) -> MatrixResult<Vec<(f64, Vec<f64>)>> {
1221 let eigenvalues = self.eigenvalues_numeric()?;
1222 let mut pairs = Vec::with_capacity(eigenvalues.len());
1223
1224 for eigenvalue in eigenvalues {
1225 let eigenvector = self.eigenvector_numeric(eigenvalue)?;
1226 pairs.push((eigenvalue, eigenvector));
1227 }
1228
1229 Ok(pairs)
1230 }
1231
1232 pub fn is_diagonalizable(&self) -> MatrixResult<bool> {
1236 if !self.is_square() {
1237 return Err(MatrixError::InvalidOperation(
1238 "Diagonalizability check requires a square matrix".to_string(),
1239 ));
1240 }
1241
1242 let transpose = self.transpose();
1244 let empty = std::collections::HashMap::new();
1245
1246 if let (Some(a), Some(at)) = (self.evaluate(&empty), transpose.evaluate(&empty)) {
1247 let is_symmetric = a.iter().zip(at.iter()).all(|(row_a, row_at)| {
1248 row_a
1249 .iter()
1250 .zip(row_at.iter())
1251 .all(|(x, y)| (x - y).abs() < 1e-10)
1252 });
1253
1254 if is_symmetric {
1255 return Ok(true);
1256 }
1257 }
1258
1259 let eigenvalues = self.eigenvalues_numeric()?;
1262
1263 for (i, &ev1) in eigenvalues.iter().enumerate() {
1265 for (j, &ev2) in eigenvalues.iter().enumerate() {
1266 if i != j && (ev1 - ev2).abs() < 1e-10 {
1267 return Ok(true);
1270 }
1271 }
1272 }
1273
1274 Ok(true)
1275 }
1276
1277 pub fn to_latex(&self, style: BracketStyle) -> String {
1294 let env = match style {
1295 BracketStyle::Parentheses => "pmatrix",
1296 BracketStyle::Square => "bmatrix",
1297 BracketStyle::Curly => "Bmatrix",
1298 BracketStyle::Determinant => "vmatrix",
1299 BracketStyle::Norm => "Vmatrix",
1300 BracketStyle::None => "matrix",
1301 };
1302
1303 let mut result = format!("\\begin{{{}}}\n", env);
1304 for (i, row) in self.elements.iter().enumerate() {
1305 let row_str: Vec<String> = row.iter().map(|e| e.to_latex()).collect();
1306 result.push_str(&row_str.join(" & "));
1307 if i < self.rows - 1 {
1308 result.push_str(" \\\\\n");
1309 } else {
1310 result.push('\n');
1311 }
1312 }
1313 result.push_str(&format!("\\end{{{}}}", env));
1314 result
1315 }
1316
1317 pub fn to_latex_default(&self) -> String {
1319 self.to_latex(BracketStyle::default())
1320 }
1321
1322 pub fn evaluate(&self, vars: &std::collections::HashMap<String, f64>) -> Option<Vec<Vec<f64>>> {
1326 self.elements
1327 .iter()
1328 .map(|row| {
1329 row.iter()
1330 .map(|elem| elem.evaluate(vars))
1331 .collect::<Option<Vec<f64>>>()
1332 })
1333 .collect()
1334 }
1335}
1336
1337impl fmt::Display for MatrixExpr {
1338 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1339 write!(f, "[")?;
1340 for (i, row) in self.elements.iter().enumerate() {
1341 if i > 0 {
1342 write!(f, "; ")?;
1343 }
1344 write!(f, "[")?;
1345 for (j, elem) in row.iter().enumerate() {
1346 if j > 0 {
1347 write!(f, ", ")?;
1348 }
1349 write!(f, "{}", elem)?;
1350 }
1351 write!(f, "]")?;
1352 }
1353 write!(f, "]")
1354 }
1355}
1356
1357fn solve_cubic(p: f64, q: f64, r: f64) -> MatrixResult<Vec<f64>> {
1363 let a = q - p * p / 3.0;
1368 let b = r - p * q / 3.0 + 2.0 * p * p * p / 27.0;
1369
1370 let discriminant = -4.0 * a * a * a - 27.0 * b * b;
1372
1373 let offset = -p / 3.0;
1374
1375 if discriminant > 0.0 {
1376 let theta = (-b / 2.0 / ((-a / 3.0).powi(3).sqrt())).acos();
1378 let r_cubed = (-a / 3.0).sqrt();
1379
1380 let t1 = 2.0 * r_cubed * (theta / 3.0).cos();
1381 let t2 = 2.0 * r_cubed * ((theta + 2.0 * std::f64::consts::PI) / 3.0).cos();
1382 let t3 = 2.0 * r_cubed * ((theta + 4.0 * std::f64::consts::PI) / 3.0).cos();
1383
1384 Ok(vec![t1 + offset, t2 + offset, t3 + offset])
1385 } else if discriminant.abs() < 1e-10 {
1386 if b.abs() < 1e-10 {
1388 Ok(vec![offset, offset, offset])
1390 } else {
1391 let double_root = 3.0 * b / a;
1393 let simple_root = -3.0 * b / (2.0 * a);
1394 Ok(vec![
1395 double_root + offset,
1396 simple_root + offset,
1397 simple_root + offset,
1398 ])
1399 }
1400 } else {
1401 let sqrt_disc = (b * b / 4.0 + a * a * a / 27.0).sqrt();
1403 let u = (-b / 2.0 + sqrt_disc).cbrt();
1404 let v = (-b / 2.0 - sqrt_disc).cbrt();
1405 let real_root = u + v + offset;
1406
1407 let complex_real = -(u + v) / 2.0 + offset;
1409 Ok(vec![real_root, complex_real, complex_real])
1410 }
1411}
1412
1413fn qr_decomposition(a: &[Vec<f64>]) -> (Vec<Vec<f64>>, Vec<Vec<f64>>) {
1415 let n = a.len();
1416 let mut q = vec![vec![0.0; n]; n];
1417 let mut r = vec![vec![0.0; n]; n];
1418
1419 for j in 0..n {
1420 let mut v: Vec<f64> = (0..n).map(|i| a[i][j]).collect();
1422
1423 for i in 0..j {
1425 let q_i: Vec<f64> = (0..n).map(|k| q[k][i]).collect();
1426 r[i][j] = dot_product(&q_i, &v);
1427 for k in 0..n {
1428 v[k] -= r[i][j] * q_i[k];
1429 }
1430 }
1431
1432 r[j][j] = v.iter().map(|x| x * x).sum::<f64>().sqrt();
1434 if r[j][j] > 1e-14 {
1435 for k in 0..n {
1436 q[k][j] = v[k] / r[j][j];
1437 }
1438 }
1439 }
1440
1441 (q, r)
1442}
1443
1444fn dot_product(a: &[f64], b: &[f64]) -> f64 {
1446 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
1447}
1448
1449fn matrix_multiply(a: &[Vec<f64>], b: &[Vec<f64>]) -> Vec<Vec<f64>> {
1451 let n = a.len();
1452 let mut result = vec![vec![0.0; n]; n];
1453
1454 for i in 0..n {
1455 for j in 0..n {
1456 for k in 0..n {
1457 result[i][j] += a[i][k] * b[k][j];
1458 }
1459 }
1460 }
1461
1462 result
1463}
1464
1465fn solve_linear_system(a: &[Vec<f64>], b: &[f64]) -> Vec<f64> {
1467 let n = a.len();
1468
1469 let mut aug: Vec<Vec<f64>> = a.iter().cloned().collect();
1471 let mut rhs = b.to_vec();
1472
1473 for k in 0..n {
1475 let mut max_row = k;
1477 let mut max_val = aug[k][k].abs();
1478 for i in (k + 1)..n {
1479 if aug[i][k].abs() > max_val {
1480 max_val = aug[i][k].abs();
1481 max_row = i;
1482 }
1483 }
1484
1485 if max_row != k {
1487 aug.swap(k, max_row);
1488 rhs.swap(k, max_row);
1489 }
1490
1491 if aug[k][k].abs() > 1e-14 {
1493 for i in (k + 1)..n {
1494 let factor = aug[i][k] / aug[k][k];
1495 for j in k..n {
1496 aug[i][j] -= factor * aug[k][j];
1497 }
1498 rhs[i] -= factor * rhs[k];
1499 }
1500 }
1501 }
1502
1503 let mut x = vec![0.0; n];
1505 for i in (0..n).rev() {
1506 if aug[i][i].abs() > 1e-14 {
1507 x[i] = rhs[i];
1508 for j in (i + 1)..n {
1509 x[i] -= aug[i][j] * x[j];
1510 }
1511 x[i] /= aug[i][i];
1512 }
1513 }
1514
1515 x
1516}
1517
1518#[cfg(test)]
1519mod tests {
1520 use super::*;
1521 use crate::ast::{Expression, Variable};
1522 use std::collections::HashMap;
1523
1524 fn int(n: i64) -> Expression {
1525 Expression::Integer(n)
1526 }
1527
1528 fn var(name: &str) -> Expression {
1529 Expression::Variable(Variable::new(name))
1530 }
1531
1532 #[test]
1533 fn test_matrix_creation() {
1534 let m =
1535 MatrixExpr::from_elements(vec![vec![int(1), int(2)], vec![int(3), int(4)]]).unwrap();
1536
1537 assert_eq!(m.rows(), 2);
1538 assert_eq!(m.cols(), 2);
1539 assert!(m.is_square());
1540 }
1541
1542 #[test]
1543 fn test_identity_matrix() {
1544 let i3 = MatrixExpr::identity(3);
1545 assert_eq!(i3.rows(), 3);
1546 assert_eq!(i3.cols(), 3);
1547
1548 assert_eq!(i3.get(0, 0).unwrap(), &int(1));
1550 assert_eq!(i3.get(1, 1).unwrap(), &int(1));
1551 assert_eq!(i3.get(2, 2).unwrap(), &int(1));
1552
1553 assert_eq!(i3.get(0, 1).unwrap(), &int(0));
1555 assert_eq!(i3.get(1, 2).unwrap(), &int(0));
1556 }
1557
1558 #[test]
1559 fn test_zero_matrix() {
1560 let z = MatrixExpr::zero(2, 3);
1561 assert_eq!(z.rows(), 2);
1562 assert_eq!(z.cols(), 3);
1563
1564 for i in 0..2 {
1565 for j in 0..3 {
1566 assert_eq!(z.get(i, j).unwrap(), &int(0));
1567 }
1568 }
1569 }
1570
1571 #[test]
1572 fn test_diagonal_matrix() {
1573 let d = MatrixExpr::diagonal(vec![int(1), int(2), int(3)]);
1574 assert_eq!(d.rows(), 3);
1575 assert_eq!(d.cols(), 3);
1576
1577 assert_eq!(d.get(0, 0).unwrap(), &int(1));
1578 assert_eq!(d.get(1, 1).unwrap(), &int(2));
1579 assert_eq!(d.get(2, 2).unwrap(), &int(3));
1580 assert_eq!(d.get(0, 1).unwrap(), &int(0));
1581 }
1582
1583 #[test]
1584 fn test_transpose() {
1585 let m = MatrixExpr::from_elements(vec![
1586 vec![int(1), int(2), int(3)],
1587 vec![int(4), int(5), int(6)],
1588 ])
1589 .unwrap();
1590
1591 let mt = m.transpose();
1592 assert_eq!(mt.rows(), 3);
1593 assert_eq!(mt.cols(), 2);
1594
1595 assert_eq!(mt.get(0, 0).unwrap(), &int(1));
1596 assert_eq!(mt.get(0, 1).unwrap(), &int(4));
1597 assert_eq!(mt.get(1, 0).unwrap(), &int(2));
1598 assert_eq!(mt.get(2, 1).unwrap(), &int(6));
1599 }
1600
1601 #[test]
1602 fn test_double_transpose() {
1603 let m =
1604 MatrixExpr::from_elements(vec![vec![int(1), int(2)], vec![int(3), int(4)]]).unwrap();
1605
1606 let mtt = m.transpose().transpose();
1607 assert_eq!(mtt.elements, m.elements);
1608 }
1609
1610 #[test]
1611 fn test_trace() {
1612 let m =
1613 MatrixExpr::from_elements(vec![vec![int(1), int(2)], vec![int(3), int(4)]]).unwrap();
1614
1615 let trace = m.trace().unwrap();
1616 let vars = HashMap::new();
1617 assert_eq!(trace.evaluate(&vars), Some(5.0));
1618 }
1619
1620 #[test]
1621 fn test_addition() {
1622 let a =
1623 MatrixExpr::from_elements(vec![vec![int(1), int(2)], vec![int(3), int(4)]]).unwrap();
1624
1625 let b =
1626 MatrixExpr::from_elements(vec![vec![int(5), int(6)], vec![int(7), int(8)]]).unwrap();
1627
1628 let sum = a.add(&b).unwrap();
1629 let vars = HashMap::new();
1630
1631 assert_eq!(sum.get(0, 0).unwrap().evaluate(&vars), Some(6.0));
1632 assert_eq!(sum.get(0, 1).unwrap().evaluate(&vars), Some(8.0));
1633 assert_eq!(sum.get(1, 0).unwrap().evaluate(&vars), Some(10.0));
1634 assert_eq!(sum.get(1, 1).unwrap().evaluate(&vars), Some(12.0));
1635 }
1636
1637 #[test]
1638 fn test_addition_dimension_check() {
1639 let a = MatrixExpr::from_elements(vec![vec![int(1), int(2)]]).unwrap();
1640
1641 let b = MatrixExpr::from_elements(vec![vec![int(1)], vec![int(2)]]).unwrap();
1642
1643 let result = a.add(&b);
1644 assert!(result.is_err());
1645 }
1646
1647 #[test]
1648 fn test_matrix_multiplication() {
1649 let a = MatrixExpr::from_elements(vec![
1651 vec![int(1), int(2), int(3)],
1652 vec![int(4), int(5), int(6)],
1653 ])
1654 .unwrap();
1655
1656 let b = MatrixExpr::from_elements(vec![
1657 vec![int(7), int(8)],
1658 vec![int(9), int(10)],
1659 vec![int(11), int(12)],
1660 ])
1661 .unwrap();
1662
1663 let c = a.mul(&b).unwrap();
1664 assert_eq!(c.rows(), 2);
1665 assert_eq!(c.cols(), 2);
1666
1667 let vars = HashMap::new();
1668 assert_eq!(c.get(0, 0).unwrap().evaluate(&vars), Some(58.0));
1670 assert_eq!(c.get(0, 1).unwrap().evaluate(&vars), Some(64.0));
1672 assert_eq!(c.get(1, 0).unwrap().evaluate(&vars), Some(139.0));
1674 assert_eq!(c.get(1, 1).unwrap().evaluate(&vars), Some(154.0));
1676 }
1677
1678 #[test]
1679 fn test_scalar_multiplication() {
1680 let m = MatrixExpr::identity(2);
1681 let scaled = m.scalar_mul(&int(3));
1682
1683 let vars = HashMap::new();
1684 assert_eq!(scaled.get(0, 0).unwrap().evaluate(&vars), Some(3.0));
1685 assert_eq!(scaled.get(1, 1).unwrap().evaluate(&vars), Some(3.0));
1686 assert_eq!(scaled.get(0, 1).unwrap().evaluate(&vars), Some(0.0));
1687 }
1688
1689 #[test]
1690 fn test_symbolic_matrix() {
1691 let m = MatrixExpr::from_elements(vec![vec![var("a"), var("b")], vec![var("c"), var("d")]])
1692 .unwrap();
1693
1694 let mut vars = HashMap::new();
1695 vars.insert("a".to_string(), 1.0);
1696 vars.insert("b".to_string(), 2.0);
1697 vars.insert("c".to_string(), 3.0);
1698 vars.insert("d".to_string(), 4.0);
1699
1700 let result = m.evaluate(&vars).unwrap();
1701 assert_eq!(result[0][0], 1.0);
1702 assert_eq!(result[0][1], 2.0);
1703 assert_eq!(result[1][0], 3.0);
1704 assert_eq!(result[1][1], 4.0);
1705 }
1706
1707 #[test]
1708 fn test_latex_output() {
1709 let m =
1710 MatrixExpr::from_elements(vec![vec![int(1), int(2)], vec![int(3), int(4)]]).unwrap();
1711
1712 let latex = m.to_latex(BracketStyle::Parentheses);
1713 assert!(latex.contains("\\begin{pmatrix}"));
1714 assert!(latex.contains("\\end{pmatrix}"));
1715 assert!(latex.contains("1 & 2"));
1716 assert!(latex.contains("3 & 4"));
1717 }
1718
1719 #[test]
1720 fn test_transpose_multiplication_property() {
1721 let a =
1723 MatrixExpr::from_elements(vec![vec![int(1), int(2)], vec![int(3), int(4)]]).unwrap();
1724
1725 let b =
1726 MatrixExpr::from_elements(vec![vec![int(5), int(6)], vec![int(7), int(8)]]).unwrap();
1727
1728 let ab = a.mul(&b).unwrap();
1729 let ab_t = ab.transpose();
1730
1731 let bt_at = b.transpose().mul(&a.transpose()).unwrap();
1732
1733 let vars = HashMap::new();
1734 for i in 0..2 {
1735 for j in 0..2 {
1736 assert_eq!(
1737 ab_t.get(i, j).unwrap().evaluate(&vars),
1738 bt_at.get(i, j).unwrap().evaluate(&vars)
1739 );
1740 }
1741 }
1742 }
1743
1744 #[test]
1745 fn test_determinant_2x2() {
1746 let m =
1748 MatrixExpr::from_elements(vec![vec![int(1), int(2)], vec![int(3), int(4)]]).unwrap();
1749
1750 let det = m.determinant().unwrap();
1751 let vars = HashMap::new();
1752 assert_eq!(det.evaluate(&vars), Some(-2.0));
1753 }
1754
1755 #[test]
1756 fn test_determinant_3x3() {
1757 let m = MatrixExpr::from_elements(vec![
1759 vec![int(1), int(2), int(3)],
1760 vec![int(4), int(5), int(6)],
1761 vec![int(7), int(8), int(9)],
1762 ])
1763 .unwrap();
1764
1765 let det = m.determinant().unwrap();
1766 let vars = HashMap::new();
1767 assert_eq!(det.evaluate(&vars), Some(0.0));
1768 }
1769
1770 #[test]
1771 fn test_determinant_3x3_nonzero() {
1772 let m = MatrixExpr::from_elements(vec![
1774 vec![int(1), int(2), int(3)],
1775 vec![int(0), int(1), int(4)],
1776 vec![int(5), int(6), int(0)],
1777 ])
1778 .unwrap();
1779
1780 let det = m.determinant().unwrap();
1781 let vars = HashMap::new();
1782 assert_eq!(det.evaluate(&vars), Some(1.0));
1783 }
1784
1785 #[test]
1786 fn test_determinant_identity() {
1787 let i3 = MatrixExpr::identity(3);
1789 let det = i3.determinant().unwrap();
1790 let vars = HashMap::new();
1791 assert_eq!(det.evaluate(&vars), Some(1.0));
1792 }
1793
1794 #[test]
1795 fn test_determinant_non_square() {
1796 let m = MatrixExpr::from_elements(vec![
1797 vec![int(1), int(2), int(3)],
1798 vec![int(4), int(5), int(6)],
1799 ])
1800 .unwrap();
1801
1802 let result = m.determinant();
1803 assert!(result.is_err());
1804 }
1805
1806 #[test]
1807 fn test_inverse_2x2() {
1808 let m =
1811 MatrixExpr::from_elements(vec![vec![int(4), int(7)], vec![int(2), int(6)]]).unwrap();
1812
1813 let inv = m.inverse().unwrap();
1814 let vars = HashMap::new();
1815
1816 let product = m.mul(&inv).unwrap();
1818 let result = product.evaluate(&vars).unwrap();
1819
1820 assert!((result[0][0] - 1.0).abs() < 1e-10);
1821 assert!((result[0][1] - 0.0).abs() < 1e-10);
1822 assert!((result[1][0] - 0.0).abs() < 1e-10);
1823 assert!((result[1][1] - 1.0).abs() < 1e-10);
1824 }
1825
1826 #[test]
1827 fn test_inverse_3x3() {
1828 let m = MatrixExpr::from_elements(vec![
1830 vec![int(1), int(2), int(3)],
1831 vec![int(0), int(1), int(4)],
1832 vec![int(5), int(6), int(0)],
1833 ])
1834 .unwrap();
1835
1836 let inv = m.inverse().unwrap();
1837 let vars = HashMap::new();
1838
1839 let product = m.mul(&inv).unwrap();
1841 let result = product.evaluate(&vars).unwrap();
1842
1843 for i in 0..3 {
1844 for j in 0..3 {
1845 let expected = if i == j { 1.0 } else { 0.0 };
1846 assert!(
1847 (result[i][j] - expected).abs() < 1e-10,
1848 "Expected {} at ({}, {}), got {}",
1849 expected,
1850 i,
1851 j,
1852 result[i][j]
1853 );
1854 }
1855 }
1856 }
1857
1858 #[test]
1859 fn test_inverse_singular_matrix() {
1860 let m =
1862 MatrixExpr::from_elements(vec![vec![int(1), int(2)], vec![int(2), int(4)]]).unwrap();
1863
1864 let result = m.inverse();
1865 assert!(result.is_err());
1866 }
1867
1868 #[test]
1869 fn test_determinant_symbolic() {
1870 let m = MatrixExpr::from_elements(vec![vec![var("a"), var("b")], vec![var("c"), var("d")]])
1872 .unwrap();
1873
1874 let det = m.determinant().unwrap();
1875
1876 let mut vars = HashMap::new();
1877 vars.insert("a".to_string(), 2.0);
1878 vars.insert("b".to_string(), 3.0);
1879 vars.insert("c".to_string(), 4.0);
1880 vars.insert("d".to_string(), 5.0);
1881
1882 assert_eq!(det.evaluate(&vars), Some(-2.0));
1884 }
1885
1886 #[test]
1887 fn test_submatrix() {
1888 let m = MatrixExpr::from_elements(vec![
1889 vec![int(1), int(2), int(3)],
1890 vec![int(4), int(5), int(6)],
1891 vec![int(7), int(8), int(9)],
1892 ])
1893 .unwrap();
1894
1895 let sub = m.submatrix(1, 1).unwrap();
1897 let vars = HashMap::new();
1898
1899 assert_eq!(sub.rows(), 2);
1900 assert_eq!(sub.cols(), 2);
1901 assert_eq!(sub.get(0, 0).unwrap().evaluate(&vars), Some(1.0));
1902 assert_eq!(sub.get(0, 1).unwrap().evaluate(&vars), Some(3.0));
1903 assert_eq!(sub.get(1, 0).unwrap().evaluate(&vars), Some(7.0));
1904 assert_eq!(sub.get(1, 1).unwrap().evaluate(&vars), Some(9.0));
1905 }
1906
1907 #[test]
1908 fn test_adjugate_2x2() {
1909 let m =
1911 MatrixExpr::from_elements(vec![vec![int(1), int(2)], vec![int(3), int(4)]]).unwrap();
1912
1913 let adj = m.adjugate().unwrap();
1914 let vars = HashMap::new();
1915
1916 assert_eq!(adj.get(0, 0).unwrap().evaluate(&vars), Some(4.0));
1917 assert_eq!(adj.get(0, 1).unwrap().evaluate(&vars), Some(-2.0));
1918 assert_eq!(adj.get(1, 0).unwrap().evaluate(&vars), Some(-3.0));
1919 assert_eq!(adj.get(1, 1).unwrap().evaluate(&vars), Some(1.0));
1920 }
1921
1922 #[test]
1923 fn test_is_singular() {
1924 let singular =
1925 MatrixExpr::from_elements(vec![vec![int(1), int(2)], vec![int(2), int(4)]]).unwrap();
1926
1927 let non_singular =
1928 MatrixExpr::from_elements(vec![vec![int(1), int(2)], vec![int(3), int(4)]]).unwrap();
1929
1930 let vars = HashMap::new();
1931 assert_eq!(singular.is_singular(&vars), Some(true));
1932 assert_eq!(non_singular.is_singular(&vars), Some(false));
1933 }
1934
1935 #[test]
1936 fn test_inverse_identity() {
1937 let i3 = MatrixExpr::identity(3);
1939 let inv = i3.inverse().unwrap();
1940 let vars = HashMap::new();
1941
1942 for i in 0..3 {
1943 for j in 0..3 {
1944 let expected = if i == j { 1.0 } else { 0.0 };
1945 assert_eq!(inv.get(i, j).unwrap().evaluate(&vars), Some(expected));
1946 }
1947 }
1948 }
1949
1950 #[test]
1955 fn test_characteristic_polynomial_2x2() {
1956 let m =
1959 MatrixExpr::from_elements(vec![vec![int(2), int(1)], vec![int(1), int(2)]]).unwrap();
1960
1961 let char_poly = m.characteristic_polynomial("lambda").unwrap();
1962
1963 let mut vars = HashMap::new();
1965 vars.insert("lambda".to_string(), 1.0);
1966 let at_1 = char_poly.evaluate(&vars).unwrap();
1967 assert!(
1968 at_1.abs() < 1e-10,
1969 "char poly at λ=1 should be 0, got {}",
1970 at_1
1971 );
1972
1973 vars.insert("lambda".to_string(), 3.0);
1975 let at_3 = char_poly.evaluate(&vars).unwrap();
1976 assert!(
1977 at_3.abs() < 1e-10,
1978 "char poly at λ=3 should be 0, got {}",
1979 at_3
1980 );
1981 }
1982
1983 #[test]
1984 fn test_eigenvalues_2x2_symmetric() {
1985 let m =
1987 MatrixExpr::from_elements(vec![vec![int(2), int(1)], vec![int(1), int(2)]]).unwrap();
1988
1989 let eigenvalues = m.eigenvalues_numeric().unwrap();
1990 assert_eq!(eigenvalues.len(), 2);
1991
1992 let mut sorted = eigenvalues.clone();
1994 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
1995
1996 assert!(
1997 (sorted[0] - 1.0).abs() < 1e-10,
1998 "Expected 1, got {}",
1999 sorted[0]
2000 );
2001 assert!(
2002 (sorted[1] - 3.0).abs() < 1e-10,
2003 "Expected 3, got {}",
2004 sorted[1]
2005 );
2006 }
2007
2008 #[test]
2009 fn test_eigenvalues_diagonal() {
2010 let m =
2012 MatrixExpr::from_elements(vec![vec![int(5), int(0)], vec![int(0), int(3)]]).unwrap();
2013
2014 let eigenvalues = m.eigenvalues_numeric().unwrap();
2015 let mut sorted = eigenvalues.clone();
2016 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
2017
2018 assert!((sorted[0] - 3.0).abs() < 1e-10);
2019 assert!((sorted[1] - 5.0).abs() < 1e-10);
2020 }
2021
2022 #[test]
2023 fn test_eigenvalues_identity() {
2024 let m = MatrixExpr::identity(3);
2026
2027 let eigenvalues = m.eigenvalues_numeric().unwrap();
2028 assert_eq!(eigenvalues.len(), 3);
2029
2030 for ev in eigenvalues {
2031 assert!((ev - 1.0).abs() < 1e-10);
2032 }
2033 }
2034
2035 #[test]
2036 fn test_eigenvector_2x2() {
2037 let m =
2039 MatrixExpr::from_elements(vec![vec![int(2), int(1)], vec![int(1), int(2)]]).unwrap();
2040
2041 let eigenvector = m.eigenvector_numeric(3.0).unwrap();
2042 assert_eq!(eigenvector.len(), 2);
2043
2044 let ratio = eigenvector[0] / eigenvector[1];
2047 assert!(
2048 (ratio - 1.0).abs() < 1e-5,
2049 "Expected ratio 1, got {}",
2050 ratio
2051 );
2052 }
2053
2054 #[test]
2055 fn test_eigenpairs() {
2056 let m =
2057 MatrixExpr::from_elements(vec![vec![int(2), int(1)], vec![int(1), int(2)]]).unwrap();
2058
2059 let pairs = m.eigenpairs_numeric().unwrap();
2060 assert_eq!(pairs.len(), 2);
2061
2062 for (eigenvalue, eigenvector) in pairs {
2063 let empty = HashMap::new();
2065 let a = m.evaluate(&empty).unwrap();
2066
2067 let av: Vec<f64> = (0..2)
2069 .map(|i| {
2070 a[i].iter()
2071 .zip(eigenvector.iter())
2072 .map(|(a, v)| a * v)
2073 .sum()
2074 })
2075 .collect();
2076
2077 let lambda_v: Vec<f64> = eigenvector.iter().map(|v| eigenvalue * v).collect();
2079
2080 for i in 0..2 {
2082 assert!(
2083 (av[i] - lambda_v[i]).abs() < 1e-5,
2084 "Av[{}] = {}, λv[{}] = {}, eigenvalue = {}",
2085 i,
2086 av[i],
2087 i,
2088 lambda_v[i],
2089 eigenvalue
2090 );
2091 }
2092 }
2093 }
2094
2095 #[test]
2096 fn test_eigenvalues_3x3() {
2097 let m = MatrixExpr::from_elements(vec![
2100 vec![int(1), int(0), int(0)],
2101 vec![int(0), int(2), int(0)],
2102 vec![int(0), int(0), int(3)],
2103 ])
2104 .unwrap();
2105
2106 let eigenvalues = m.eigenvalues_numeric().unwrap();
2107 let mut sorted = eigenvalues.clone();
2108 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
2109
2110 assert!((sorted[0] - 1.0).abs() < 1e-10);
2111 assert!((sorted[1] - 2.0).abs() < 1e-10);
2112 assert!((sorted[2] - 3.0).abs() < 1e-10);
2113 }
2114
2115 #[test]
2116 fn test_is_diagonalizable_symmetric() {
2117 let m =
2119 MatrixExpr::from_elements(vec![vec![int(2), int(1)], vec![int(1), int(2)]]).unwrap();
2120
2121 assert!(m.is_diagonalizable().unwrap());
2122 }
2123
2124 #[test]
2125 fn test_is_diagonalizable_identity() {
2126 let m = MatrixExpr::identity(3);
2127 assert!(m.is_diagonalizable().unwrap());
2128 }
2129
2130 #[test]
2131 fn test_eigenvalues_non_square() {
2132 let m = MatrixExpr::from_elements(vec![
2133 vec![int(1), int(2), int(3)],
2134 vec![int(4), int(5), int(6)],
2135 ])
2136 .unwrap();
2137
2138 let result = m.eigenvalues_numeric();
2139 assert!(result.is_err());
2140 }
2141
2142 #[test]
2143 fn test_characteristic_polynomial_non_square() {
2144 let m = MatrixExpr::from_elements(vec![
2145 vec![int(1), int(2), int(3)],
2146 vec![int(4), int(5), int(6)],
2147 ])
2148 .unwrap();
2149
2150 let result = m.characteristic_polynomial("lambda");
2151 assert!(result.is_err());
2152 }
2153}