1#[derive(Debug, Clone, PartialEq)]
31pub enum MatrixError {
32 DimensionMismatch {
34 expected: (usize, usize),
35 got: (usize, usize),
36 },
37 NotSquare { rows: usize, cols: usize },
39 Singular,
41 NotSymmetric,
43 NotPositiveDefinite,
45 InvalidData { expected: usize, got: usize },
47}
48
49impl std::fmt::Display for MatrixError {
50 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51 match self {
52 MatrixError::DimensionMismatch { expected, got } => {
53 write!(
54 f,
55 "dimension mismatch: expected {}×{}, got {}×{}",
56 expected.0, expected.1, got.0, got.1
57 )
58 }
59 MatrixError::NotSquare { rows, cols } => {
60 write!(f, "matrix must be square, got {rows}×{cols}")
61 }
62 MatrixError::Singular => write!(f, "matrix is singular"),
63 MatrixError::NotSymmetric => write!(f, "matrix is not symmetric"),
64 MatrixError::NotPositiveDefinite => write!(f, "matrix is not positive-definite"),
65 MatrixError::InvalidData { expected, got } => {
66 write!(f, "data length mismatch: expected {expected}, got {got}")
67 }
68 }
69 }
70}
71
72impl std::error::Error for MatrixError {}
73
74#[derive(Debug, Clone, PartialEq)]
80pub struct Matrix {
81 data: Vec<f64>,
82 rows: usize,
83 cols: usize,
84}
85
86impl Matrix {
87 pub fn new(rows: usize, cols: usize, data: Vec<f64>) -> Result<Self, MatrixError> {
100 if data.len() != rows * cols {
101 return Err(MatrixError::InvalidData {
102 expected: rows * cols,
103 got: data.len(),
104 });
105 }
106 Ok(Self { data, rows, cols })
107 }
108
109 pub fn from_rows(rows: &[&[f64]]) -> Self {
121 assert!(!rows.is_empty(), "must have at least one row");
122 let ncols = rows[0].len();
123 assert!(ncols > 0, "must have at least one column");
124 let nrows = rows.len();
125 let mut data = Vec::with_capacity(nrows * ncols);
126 for (i, row) in rows.iter().enumerate() {
127 assert_eq!(
128 row.len(),
129 ncols,
130 "row {i} has {} columns, expected {ncols}",
131 row.len()
132 );
133 data.extend_from_slice(row);
134 }
135 Self {
136 data,
137 rows: nrows,
138 cols: ncols,
139 }
140 }
141
142 pub fn zeros(rows: usize, cols: usize) -> Self {
144 Self {
145 data: vec![0.0; rows * cols],
146 rows,
147 cols,
148 }
149 }
150
151 pub fn identity(n: usize) -> Self {
162 let mut m = Self::zeros(n, n);
163 for i in 0..n {
164 m.data[i * n + i] = 1.0;
165 }
166 m
167 }
168
169 pub fn from_col(data: &[f64]) -> Self {
171 Self {
172 data: data.to_vec(),
173 rows: data.len(),
174 cols: 1,
175 }
176 }
177
178 #[inline]
180 pub fn rows(&self) -> usize {
181 self.rows
182 }
183
184 #[inline]
186 pub fn cols(&self) -> usize {
187 self.cols
188 }
189
190 #[inline]
195 pub fn get(&self, row: usize, col: usize) -> f64 {
196 self.data[row * self.cols + col]
197 }
198
199 #[inline]
204 pub fn set(&mut self, row: usize, col: usize, value: f64) {
205 self.data[row * self.cols + col] = value;
206 }
207
208 pub fn data(&self) -> &[f64] {
210 &self.data
211 }
212
213 #[inline]
215 pub fn row(&self, row: usize) -> &[f64] {
216 let start = row * self.cols;
217 &self.data[start..start + self.cols]
218 }
219
220 pub fn diag(&self) -> Vec<f64> {
222 let n = self.rows.min(self.cols);
223 (0..n).map(|i| self.get(i, i)).collect()
224 }
225
226 pub fn is_square(&self) -> bool {
228 self.rows == self.cols
229 }
230
231 pub fn transpose(&self) -> Self {
247 let mut result = Self::zeros(self.cols, self.rows);
248 for i in 0..self.rows {
249 for j in 0..self.cols {
250 result.data[j * self.rows + i] = self.data[i * self.cols + j];
251 }
252 }
253 result
254 }
255
256 pub fn add(&self, other: &Self) -> Result<Self, MatrixError> {
261 if self.rows != other.rows || self.cols != other.cols {
262 return Err(MatrixError::DimensionMismatch {
263 expected: (self.rows, self.cols),
264 got: (other.rows, other.cols),
265 });
266 }
267 let data: Vec<f64> = self
268 .data
269 .iter()
270 .zip(&other.data)
271 .map(|(a, b)| a + b)
272 .collect();
273 Ok(Self {
274 data,
275 rows: self.rows,
276 cols: self.cols,
277 })
278 }
279
280 pub fn sub(&self, other: &Self) -> Result<Self, MatrixError> {
285 if self.rows != other.rows || self.cols != other.cols {
286 return Err(MatrixError::DimensionMismatch {
287 expected: (self.rows, self.cols),
288 got: (other.rows, other.cols),
289 });
290 }
291 let data: Vec<f64> = self
292 .data
293 .iter()
294 .zip(&other.data)
295 .map(|(a, b)| a - b)
296 .collect();
297 Ok(Self {
298 data,
299 rows: self.rows,
300 cols: self.cols,
301 })
302 }
303
304 pub fn scale(&self, c: f64) -> Self {
306 let data: Vec<f64> = self.data.iter().map(|x| c * x).collect();
307 Self {
308 data,
309 rows: self.rows,
310 cols: self.cols,
311 }
312 }
313
314 pub fn mul_mat(&self, other: &Self) -> Result<Self, MatrixError> {
333 if self.cols != other.rows {
334 return Err(MatrixError::DimensionMismatch {
335 expected: (self.rows, self.cols),
336 got: (other.rows, other.cols),
337 });
338 }
339 let mut result = Self::zeros(self.rows, other.cols);
340 for i in 0..self.rows {
342 for k in 0..self.cols {
343 let a_ik = self.data[i * self.cols + k];
344 let row_start = i * other.cols;
345 let other_row_start = k * other.cols;
346 for j in 0..other.cols {
347 result.data[row_start + j] += a_ik * other.data[other_row_start + j];
348 }
349 }
350 }
351 Ok(result)
352 }
353
354 pub fn mul_vec(&self, v: &[f64]) -> Result<Vec<f64>, MatrixError> {
359 if self.cols != v.len() {
360 return Err(MatrixError::DimensionMismatch {
361 expected: (self.rows, self.cols),
362 got: (v.len(), 1),
363 });
364 }
365 let mut result = vec![0.0; self.rows];
366 for (i, res) in result.iter_mut().enumerate() {
367 let row_start = i * self.cols;
368 *res = self.data[row_start..row_start + self.cols]
369 .iter()
370 .zip(v.iter())
371 .map(|(&a, &b)| a * b)
372 .sum();
373 }
374 Ok(result)
375 }
376
377 pub fn frobenius_norm(&self) -> f64 {
379 self.data.iter().map(|x| x * x).sum::<f64>().sqrt()
380 }
381
382 pub fn is_symmetric(&self, tol: f64) -> bool {
384 if self.rows != self.cols {
385 return false;
386 }
387 for i in 0..self.rows {
388 for j in (i + 1)..self.cols {
389 if (self.get(i, j) - self.get(j, i)).abs() > tol {
390 return false;
391 }
392 }
393 }
394 true
395 }
396
397 fn swap_rows(&mut self, a: usize, b: usize) {
398 if a == b {
399 return;
400 }
401 let cols = self.cols;
402 for j in 0..cols {
403 self.data.swap(a * cols + j, b * cols + j);
404 }
405 }
406
407 pub fn determinant(&self) -> Result<f64, MatrixError> {
427 if !self.is_square() {
428 return Err(MatrixError::NotSquare {
429 rows: self.rows,
430 cols: self.cols,
431 });
432 }
433 let n = self.rows;
434 if n == 0 {
435 return Ok(1.0);
436 }
437 if n == 1 {
438 return Ok(self.data[0]);
439 }
440
441 let mut work = self.clone();
442 let mut sign = 1.0_f64;
443 let pivot_tol = 1e-15 * self.frobenius_norm().max(1e-300);
444
445 for k in 0..n {
446 let mut max_val = work.get(k, k).abs();
448 let mut max_row = k;
449 for i in (k + 1)..n {
450 let v = work.get(i, k).abs();
451 if v > max_val {
452 max_val = v;
453 max_row = i;
454 }
455 }
456 if max_val <= pivot_tol {
457 return Ok(0.0); }
459 if max_row != k {
460 work.swap_rows(k, max_row);
461 sign = -sign;
462 }
463
464 let pivot = work.get(k, k);
465 for i in (k + 1)..n {
466 let factor = work.get(i, k) / pivot;
467 for j in (k + 1)..n {
468 let val = work.get(i, j) - factor * work.get(k, j);
469 work.set(i, j, val);
470 }
471 }
472 }
473
474 let mut det = sign;
475 for i in 0..n {
476 det *= work.get(i, i);
477 }
478 Ok(det)
479 }
480
481 pub fn inverse(&self) -> Result<Self, MatrixError> {
501 if !self.is_square() {
502 return Err(MatrixError::NotSquare {
503 rows: self.rows,
504 cols: self.cols,
505 });
506 }
507 let n = self.rows;
508 if n == 0 {
509 return Ok(Self::zeros(0, 0));
510 }
511
512 let n2 = 2 * n;
514 let mut aug = Self::zeros(n, n2);
515 for i in 0..n {
516 for j in 0..n {
517 aug.set(i, j, self.get(i, j));
518 }
519 aug.set(i, n + i, 1.0);
520 }
521
522 let pivot_tol = 1e-14 * self.frobenius_norm().max(1e-300);
523
524 for k in 0..n {
525 let mut max_val = aug.get(k, k).abs();
527 let mut max_row = k;
528 for i in (k + 1)..n {
529 let v = aug.get(i, k).abs();
530 if v > max_val {
531 max_val = v;
532 max_row = i;
533 }
534 }
535 if max_val <= pivot_tol {
536 return Err(MatrixError::Singular);
537 }
538 if max_row != k {
539 aug.swap_rows(k, max_row);
540 }
541
542 let pivot = aug.get(k, k);
544 for j in 0..n2 {
545 aug.set(k, j, aug.get(k, j) / pivot);
546 }
547
548 for i in 0..n {
550 if i != k {
551 let factor = aug.get(i, k);
552 for j in 0..n2 {
553 let val = aug.get(i, j) - factor * aug.get(k, j);
554 aug.set(i, j, val);
555 }
556 }
557 }
558 }
559
560 let mut inv = Self::zeros(n, n);
562 for i in 0..n {
563 for j in 0..n {
564 inv.set(i, j, aug.get(i, n + j));
565 }
566 }
567 Ok(inv)
568 }
569
570 pub fn cholesky(&self) -> Result<Self, MatrixError> {
596 if !self.is_square() {
597 return Err(MatrixError::NotSquare {
598 rows: self.rows,
599 cols: self.cols,
600 });
601 }
602 let n = self.rows;
603 let sym_tol = 1e-10 * self.frobenius_norm().max(1e-300);
604 if !self.is_symmetric(sym_tol) {
605 return Err(MatrixError::NotSymmetric);
606 }
607
608 let mut l = Self::zeros(n, n);
609
610 for j in 0..n {
611 let mut sum = 0.0;
613 for k in 0..j {
614 let ljk = l.get(j, k);
615 sum += ljk * ljk;
616 }
617 let diag = self.get(j, j) - sum;
618 if diag <= 0.0 {
619 return Err(MatrixError::NotPositiveDefinite);
620 }
621 l.set(j, j, diag.sqrt());
622
623 let ljj = l.get(j, j);
625 for i in (j + 1)..n {
626 let mut sum = 0.0;
627 for k in 0..j {
628 sum += l.get(i, k) * l.get(j, k);
629 }
630 l.set(i, j, (self.get(i, j) - sum) / ljj);
631 }
632 }
633
634 Ok(l)
635 }
636
637 pub fn cholesky_solve(&self, b: &[f64]) -> Result<Vec<f64>, MatrixError> {
649 if b.len() != self.rows {
650 return Err(MatrixError::DimensionMismatch {
651 expected: (self.rows, 1),
652 got: (b.len(), 1),
653 });
654 }
655 let l = self.cholesky()?;
656 let y = solve_lower_triangular(&l, b)?;
657 let lt = l.transpose();
658 solve_upper_triangular(<, &y)
659 }
660
661 pub fn eigen_symmetric(&self) -> Result<(Vec<f64>, Matrix), MatrixError> {
704 let n = self.rows;
705 if !self.is_square() {
706 return Err(MatrixError::NotSquare {
707 rows: self.rows,
708 cols: self.cols,
709 });
710 }
711 let sym_tol = 1e-10 * self.frobenius_norm();
713 if !self.is_symmetric(sym_tol) {
714 return Err(MatrixError::NotSymmetric);
715 }
716
717 let mut a = self.data.clone();
719 let mut v = vec![0.0; n * n];
721 for i in 0..n {
722 v[i * n + i] = 1.0;
723 }
724
725 let max_sweeps = 100;
726 let tol = 1e-15;
727
728 for _ in 0..max_sweeps {
729 let mut off_norm = 0.0;
731 for i in 0..n {
732 for j in (i + 1)..n {
733 off_norm += 2.0 * a[i * n + j] * a[i * n + j];
734 }
735 }
736 off_norm = off_norm.sqrt();
737
738 if off_norm < tol {
739 break;
740 }
741
742 for p in 0..n {
744 for q in (p + 1)..n {
745 let apq = a[p * n + q];
746 if apq.abs() < tol * 0.01 {
747 continue;
748 }
749
750 let app = a[p * n + p];
751 let aqq = a[q * n + q];
752 let diff = aqq - app;
753
754 let (cos, sin) = if diff.abs() < 1e-300 {
756 let s = std::f64::consts::FRAC_1_SQRT_2;
758 (s, if apq > 0.0 { s } else { -s })
759 } else {
760 let tau = diff / (2.0 * apq);
761 let t = if tau >= 0.0 {
763 1.0 / (tau + (1.0 + tau * tau).sqrt())
764 } else {
765 -1.0 / (-tau + (1.0 + tau * tau).sqrt())
766 };
767 let c = 1.0 / (1.0 + t * t).sqrt();
768 let s = t * c;
769 (c, s)
770 };
771
772 a[p * n + p] -=
774 2.0 * sin * cos * apq + sin * sin * (a[q * n + q] - a[p * n + p]);
775 a[q * n + q] += 2.0 * sin * cos * apq + sin * sin * (aqq - app); a[p * n + q] = 0.0;
777 a[q * n + p] = 0.0;
778
779 a[p * n + p] = app;
783 a[q * n + q] = aqq;
784 a[p * n + q] = apq;
785 a[q * n + p] = apq;
786
787 for r in 0..n {
790 if r == p || r == q {
791 continue;
792 }
793 let arp = a[r * n + p];
794 let arq = a[r * n + q];
795 a[r * n + p] = cos * arp - sin * arq;
796 a[r * n + q] = sin * arp + cos * arq;
797 a[p * n + r] = a[r * n + p]; a[q * n + r] = a[r * n + q]; }
800
801 let new_pp = cos * cos * app - 2.0 * sin * cos * apq + sin * sin * aqq;
803 let new_qq = sin * sin * app + 2.0 * sin * cos * apq + cos * cos * aqq;
804 a[p * n + p] = new_pp;
805 a[q * n + q] = new_qq;
806 a[p * n + q] = 0.0;
807 a[q * n + p] = 0.0;
808
809 for r in 0..n {
811 let vp = v[r * n + p];
812 let vq = v[r * n + q];
813 v[r * n + p] = cos * vp - sin * vq;
814 v[r * n + q] = sin * vp + cos * vq;
815 }
816 }
817 }
818 }
819
820 let mut eigen_pairs: Vec<(f64, Vec<f64>)> = (0..n)
822 .map(|i| {
823 let eigenvalue = a[i * n + i];
824 let eigenvector: Vec<f64> = (0..n).map(|r| v[r * n + i]).collect();
825 (eigenvalue, eigenvector)
826 })
827 .collect();
828
829 eigen_pairs.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
831
832 let eigenvalues: Vec<f64> = eigen_pairs.iter().map(|(val, _)| *val).collect();
833 let mut eigvec_data = vec![0.0; n * n];
834 for (col, (_, vec)) in eigen_pairs.iter().enumerate() {
835 for (row, &val) in vec.iter().enumerate() {
836 eigvec_data[row * n + col] = val;
837 }
838 }
839 let eigenvectors = Matrix {
840 data: eigvec_data,
841 rows: n,
842 cols: n,
843 };
844
845 Ok((eigenvalues, eigenvectors))
846 }
847}
848
849fn solve_lower_triangular(l: &Matrix, b: &[f64]) -> Result<Vec<f64>, MatrixError> {
851 let n = l.rows();
852 let mut x = vec![0.0; n];
853 for i in 0..n {
854 let mut sum = 0.0;
855 for (j, &xj) in x[..i].iter().enumerate() {
856 sum += l.get(i, j) * xj;
857 }
858 let diag = l.get(i, i);
859 if diag.abs() < 1e-300 {
860 return Err(MatrixError::Singular);
861 }
862 x[i] = (b[i] - sum) / diag;
863 }
864 Ok(x)
865}
866
867fn solve_upper_triangular(u: &Matrix, b: &[f64]) -> Result<Vec<f64>, MatrixError> {
869 let n = u.rows();
870 let mut x = vec![0.0; n];
871 for i in (0..n).rev() {
872 let mut sum = 0.0;
873 for (off, &xj) in x[i + 1..].iter().enumerate() {
874 sum += u.get(i, i + 1 + off) * xj;
875 }
876 let diag = u.get(i, i);
877 if diag.abs() < 1e-300 {
878 return Err(MatrixError::Singular);
879 }
880 x[i] = (b[i] - sum) / diag;
881 }
882 Ok(x)
883}
884
885impl std::fmt::Display for Matrix {
890 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
891 for i in 0..self.rows {
892 write!(f, "[")?;
893 for j in 0..self.cols {
894 if j > 0 {
895 write!(f, ", ")?;
896 }
897 write!(f, "{:>10.4}", self.get(i, j))?;
898 }
899 writeln!(f, "]")?;
900 }
901 Ok(())
902 }
903}
904
905#[cfg(test)]
910mod tests {
911 use super::*;
912
913 #[test]
916 fn test_new_valid() {
917 let m = Matrix::new(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
918 assert_eq!(m.rows(), 2);
919 assert_eq!(m.cols(), 3);
920 assert_eq!(m.get(0, 0), 1.0);
921 assert_eq!(m.get(1, 2), 6.0);
922 }
923
924 #[test]
925 fn test_new_invalid_length() {
926 assert!(Matrix::new(2, 3, vec![1.0, 2.0]).is_err());
927 }
928
929 #[test]
930 fn test_from_rows() {
931 let m = Matrix::from_rows(&[&[1.0, 2.0], &[3.0, 4.0]]);
932 assert_eq!(m.get(0, 0), 1.0);
933 assert_eq!(m.get(1, 1), 4.0);
934 }
935
936 #[test]
937 fn test_zeros() {
938 let m = Matrix::zeros(3, 4);
939 assert_eq!(m.rows(), 3);
940 assert_eq!(m.cols(), 4);
941 assert_eq!(m.get(2, 3), 0.0);
942 }
943
944 #[test]
945 fn test_identity() {
946 let eye = Matrix::identity(3);
947 assert_eq!(eye.get(0, 0), 1.0);
948 assert_eq!(eye.get(1, 1), 1.0);
949 assert_eq!(eye.get(2, 2), 1.0);
950 assert_eq!(eye.get(0, 1), 0.0);
951 assert_eq!(eye.get(1, 2), 0.0);
952 }
953
954 #[test]
955 fn test_diag() {
956 let m = Matrix::from_rows(&[&[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0], &[7.0, 8.0, 9.0]]);
957 assert_eq!(m.diag(), vec![1.0, 5.0, 9.0]);
958 }
959
960 #[test]
963 fn test_transpose() {
964 let m = Matrix::from_rows(&[&[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0]]);
965 let t = m.transpose();
966 assert_eq!(t.rows(), 3);
967 assert_eq!(t.cols(), 2);
968 assert_eq!(t.get(0, 0), 1.0);
969 assert_eq!(t.get(2, 1), 6.0);
970 }
971
972 #[test]
973 fn test_transpose_twice() {
974 let m = Matrix::from_rows(&[&[1.0, 2.0], &[3.0, 4.0], &[5.0, 6.0]]);
975 let tt = m.transpose().transpose();
976 assert_eq!(m, tt);
977 }
978
979 #[test]
980 fn test_add() {
981 let a = Matrix::from_rows(&[&[1.0, 2.0], &[3.0, 4.0]]);
982 let b = Matrix::from_rows(&[&[5.0, 6.0], &[7.0, 8.0]]);
983 let c = a.add(&b).unwrap();
984 assert_eq!(c.get(0, 0), 6.0);
985 assert_eq!(c.get(1, 1), 12.0);
986 }
987
988 #[test]
989 fn test_add_dimension_mismatch() {
990 let a = Matrix::zeros(2, 3);
991 let b = Matrix::zeros(3, 2);
992 assert!(a.add(&b).is_err());
993 }
994
995 #[test]
996 fn test_sub() {
997 let a = Matrix::from_rows(&[&[5.0, 6.0], &[7.0, 8.0]]);
998 let b = Matrix::from_rows(&[&[1.0, 2.0], &[3.0, 4.0]]);
999 let c = a.sub(&b).unwrap();
1000 assert_eq!(c.get(0, 0), 4.0);
1001 assert_eq!(c.get(1, 1), 4.0);
1002 }
1003
1004 #[test]
1005 fn test_scale() {
1006 let m = Matrix::from_rows(&[&[1.0, 2.0], &[3.0, 4.0]]);
1007 let s = m.scale(2.0);
1008 assert_eq!(s.get(0, 0), 2.0);
1009 assert_eq!(s.get(1, 1), 8.0);
1010 }
1011
1012 #[test]
1015 fn test_mul_identity() {
1016 let a = Matrix::from_rows(&[&[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0], &[7.0, 8.0, 9.0]]);
1017 let eye = Matrix::identity(3);
1018 let result = a.mul_mat(&eye).unwrap();
1019 assert_eq!(a, result);
1020 }
1021
1022 #[test]
1023 fn test_mul_2x2() {
1024 let a = Matrix::from_rows(&[&[1.0, 2.0], &[3.0, 4.0]]);
1025 let b = Matrix::from_rows(&[&[5.0, 6.0], &[7.0, 8.0]]);
1026 let c = a.mul_mat(&b).unwrap();
1027 assert_eq!(c.get(0, 0), 19.0);
1030 assert_eq!(c.get(0, 1), 22.0);
1031 assert_eq!(c.get(1, 0), 43.0);
1032 assert_eq!(c.get(1, 1), 50.0);
1033 }
1034
1035 #[test]
1036 fn test_mul_nonsquare() {
1037 let a = Matrix::from_rows(&[&[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0]]);
1038 let b = Matrix::from_rows(&[&[7.0, 8.0], &[9.0, 10.0], &[11.0, 12.0]]);
1039 let c = a.mul_mat(&b).unwrap();
1040 assert_eq!(c.rows(), 2);
1041 assert_eq!(c.cols(), 2);
1042 assert_eq!(c.get(0, 0), 58.0);
1044 assert_eq!(c.get(0, 1), 64.0);
1045 }
1046
1047 #[test]
1048 fn test_mul_vec() {
1049 let a = Matrix::from_rows(&[&[1.0, 2.0], &[3.0, 4.0]]);
1050 let v = vec![5.0, 6.0];
1051 let result = a.mul_vec(&v).unwrap();
1052 assert_eq!(result, vec![17.0, 39.0]);
1053 }
1054
1055 #[test]
1056 fn test_mul_dimension_mismatch() {
1057 let a = Matrix::zeros(2, 3);
1058 let b = Matrix::zeros(2, 3);
1059 assert!(a.mul_mat(&b).is_err());
1060 }
1061
1062 #[test]
1065 fn test_det_2x2() {
1066 let m = Matrix::from_rows(&[&[2.0, 3.0], &[1.0, 4.0]]);
1067 assert!((m.determinant().unwrap() - 5.0).abs() < 1e-10);
1068 }
1069
1070 #[test]
1071 fn test_det_3x3() {
1072 let m = Matrix::from_rows(&[&[6.0, 1.0, 1.0], &[4.0, -2.0, 5.0], &[2.0, 8.0, 7.0]]);
1073 assert!((m.determinant().unwrap() - (-306.0)).abs() < 1e-8);
1078 }
1079
1080 #[test]
1081 fn test_det_identity() {
1082 let eye = Matrix::identity(4);
1083 assert!((eye.determinant().unwrap() - 1.0).abs() < 1e-10);
1084 }
1085
1086 #[test]
1087 fn test_det_singular() {
1088 let m = Matrix::from_rows(&[&[1.0, 2.0], &[2.0, 4.0]]);
1089 assert!(m.determinant().unwrap().abs() < 1e-10);
1090 }
1091
1092 #[test]
1093 fn test_det_not_square() {
1094 let m = Matrix::zeros(2, 3);
1095 assert!(m.determinant().is_err());
1096 }
1097
1098 #[test]
1101 fn test_inverse_2x2() {
1102 let a = Matrix::from_rows(&[&[4.0, 7.0], &[2.0, 6.0]]);
1103 let inv = a.inverse().unwrap();
1104 let eye = a.mul_mat(&inv).unwrap();
1105 for i in 0..2 {
1106 for j in 0..2 {
1107 let expected = if i == j { 1.0 } else { 0.0 };
1108 assert!(
1109 (eye.get(i, j) - expected).abs() < 1e-10,
1110 "A·A⁻¹[{i},{j}] = {}, expected {expected}",
1111 eye.get(i, j)
1112 );
1113 }
1114 }
1115 }
1116
1117 #[test]
1118 fn test_inverse_3x3() {
1119 let a = Matrix::from_rows(&[&[1.0, 2.0, 3.0], &[0.0, 1.0, 4.0], &[5.0, 6.0, 0.0]]);
1120 let inv = a.inverse().unwrap();
1121 let eye = a.mul_mat(&inv).unwrap();
1122 for i in 0..3 {
1123 for j in 0..3 {
1124 let expected = if i == j { 1.0 } else { 0.0 };
1125 assert!(
1126 (eye.get(i, j) - expected).abs() < 1e-10,
1127 "A·A⁻¹[{i},{j}] = {}",
1128 eye.get(i, j)
1129 );
1130 }
1131 }
1132 }
1133
1134 #[test]
1135 fn test_inverse_identity() {
1136 let eye = Matrix::identity(4);
1137 let inv = eye.inverse().unwrap();
1138 assert_eq!(eye, inv);
1139 }
1140
1141 #[test]
1142 fn test_inverse_singular() {
1143 let m = Matrix::from_rows(&[&[1.0, 2.0], &[2.0, 4.0]]);
1144 assert!(m.inverse().is_err());
1145 }
1146
1147 #[test]
1150 fn test_cholesky_2x2() {
1151 let a = Matrix::from_rows(&[&[4.0, 2.0], &[2.0, 3.0]]);
1152 let l = a.cholesky().unwrap();
1153 assert!(l.get(0, 1).abs() < 1e-15);
1155 let llt = l.mul_mat(&l.transpose()).unwrap();
1157 for i in 0..2 {
1158 for j in 0..2 {
1159 assert!(
1160 (llt.get(i, j) - a.get(i, j)).abs() < 1e-10,
1161 "LLᵀ[{i},{j}] = {}, expected {}",
1162 llt.get(i, j),
1163 a.get(i, j)
1164 );
1165 }
1166 }
1167 }
1168
1169 #[test]
1170 fn test_cholesky_3x3() {
1171 let a = Matrix::from_rows(&[&[25.0, 15.0, -5.0], &[15.0, 18.0, 0.0], &[-5.0, 0.0, 11.0]]);
1172 let l = a.cholesky().unwrap();
1173 let llt = l.mul_mat(&l.transpose()).unwrap();
1174 for i in 0..3 {
1175 for j in 0..3 {
1176 assert!(
1177 (llt.get(i, j) - a.get(i, j)).abs() < 1e-10,
1178 "LLᵀ[{i},{j}] = {}, A[{i},{j}] = {}",
1179 llt.get(i, j),
1180 a.get(i, j)
1181 );
1182 }
1183 }
1184 }
1185
1186 #[test]
1187 fn test_cholesky_identity() {
1188 let eye = Matrix::identity(3);
1189 let l = eye.cholesky().unwrap();
1190 assert_eq!(l, eye);
1191 }
1192
1193 #[test]
1194 fn test_cholesky_not_positive_definite() {
1195 let a = Matrix::from_rows(&[&[1.0, 2.0], &[2.0, 1.0]]);
1196 assert!(matches!(
1197 a.cholesky(),
1198 Err(MatrixError::NotPositiveDefinite)
1199 ));
1200 }
1201
1202 #[test]
1203 fn test_cholesky_not_symmetric() {
1204 let a = Matrix::from_rows(&[&[1.0, 2.0], &[3.0, 4.0]]);
1205 assert!(matches!(a.cholesky(), Err(MatrixError::NotSymmetric)));
1206 }
1207
1208 #[test]
1211 fn test_cholesky_solve() {
1212 let a = Matrix::from_rows(&[&[4.0, 2.0], &[2.0, 3.0]]);
1215 let b = vec![1.0, 2.0];
1216 let x = a.cholesky_solve(&b).unwrap();
1217 let ax = a.mul_vec(&x).unwrap();
1219 for i in 0..2 {
1220 assert!(
1221 (ax[i] - b[i]).abs() < 1e-10,
1222 "Ax[{i}] = {}, b[{i}] = {}",
1223 ax[i],
1224 b[i]
1225 );
1226 }
1227 }
1228
1229 #[test]
1230 fn test_cholesky_solve_3x3() {
1231 let a = Matrix::from_rows(&[&[25.0, 15.0, -5.0], &[15.0, 18.0, 0.0], &[-5.0, 0.0, 11.0]]);
1232 let b = vec![35.0, 33.0, 6.0];
1233 let x = a.cholesky_solve(&b).unwrap();
1234 let ax = a.mul_vec(&x).unwrap();
1235 for i in 0..3 {
1236 assert!((ax[i] - b[i]).abs() < 1e-10);
1237 }
1238 }
1239
1240 #[test]
1243 fn test_frobenius_norm() {
1244 let m = Matrix::from_rows(&[&[1.0, 2.0], &[3.0, 4.0]]);
1245 assert!((m.frobenius_norm() - 30.0_f64.sqrt()).abs() < 1e-10);
1247 }
1248
1249 #[test]
1252 fn test_is_symmetric() {
1253 let sym = Matrix::from_rows(&[&[1.0, 2.0, 3.0], &[2.0, 5.0, 6.0], &[3.0, 6.0, 9.0]]);
1254 assert!(sym.is_symmetric(1e-10));
1255
1256 let asym = Matrix::from_rows(&[&[1.0, 2.0], &[3.0, 4.0]]);
1257 assert!(!asym.is_symmetric(1e-10));
1258 }
1259}
1260
1261#[cfg(test)]
1262mod proptests {
1263 use super::*;
1264 use proptest::prelude::*;
1265
1266 fn square_matrix(n: usize) -> impl Strategy<Value = Matrix> {
1267 proptest::collection::vec(-10.0_f64..10.0, n * n)
1268 .prop_map(move |data| Matrix::new(n, n, data).expect("valid dimensions"))
1269 }
1270
1271 fn spd_matrix(n: usize) -> impl Strategy<Value = Matrix> {
1272 proptest::collection::vec(-5.0_f64..5.0, n * n).prop_map(move |data| {
1274 let a = Matrix::new(n, n, data).expect("valid dimensions");
1275 let ata = a.transpose().mul_mat(&a).expect("compatible");
1276 let eye_scaled = Matrix::identity(n).scale(n as f64);
1277 ata.add(&eye_scaled).expect("compatible")
1278 })
1279 }
1280
1281 proptest! {
1282 #![proptest_config(ProptestConfig::with_cases(200))]
1283
1284 #[test]
1285 fn transpose_involution(m in square_matrix(3)) {
1286 let m_tt = m.transpose().transpose();
1287 for i in 0..3 {
1288 for j in 0..3 {
1289 prop_assert!((m.get(i, j) - m_tt.get(i, j)).abs() < 1e-14);
1290 }
1291 }
1292 }
1293
1294 #[test]
1295 fn mul_identity_is_identity(m in square_matrix(3)) {
1296 let eye = Matrix::identity(3);
1297 let me = m.mul_mat(&eye).unwrap();
1298 let em = eye.mul_mat(&m).unwrap();
1299 for i in 0..3 {
1300 for j in 0..3 {
1301 prop_assert!((me.get(i, j) - m.get(i, j)).abs() < 1e-10);
1302 prop_assert!((em.get(i, j) - m.get(i, j)).abs() < 1e-10);
1303 }
1304 }
1305 }
1306
1307 #[test]
1308 fn det_of_product(a in square_matrix(3), b in square_matrix(3)) {
1309 let det_a = a.determinant().unwrap();
1311 let det_b = b.determinant().unwrap();
1312 let ab = a.mul_mat(&b).unwrap();
1313 let det_ab = ab.determinant().unwrap();
1314 let expected = det_a * det_b;
1315 let tol = 1e-6 * expected.abs().max(det_ab.abs()).max(1.0);
1317 prop_assert!(
1318 (det_ab - expected).abs() < tol,
1319 "det(AB)={det_ab}, det(A)*det(B)={expected}"
1320 );
1321 }
1322
1323 #[test]
1324 fn cholesky_roundtrip(a in spd_matrix(3)) {
1325 let l = a.cholesky().expect("SPD should decompose");
1326 let llt = l.mul_mat(&l.transpose()).expect("compatible");
1327 for i in 0..3 {
1328 for j in 0..3 {
1329 let diff = (llt.get(i, j) - a.get(i, j)).abs();
1330 let tol = 1e-8 * a.get(i, j).abs().max(1.0);
1331 prop_assert!(
1332 diff < tol,
1333 "LLᵀ[{i},{j}]={}, A[{i},{j}]={}",
1334 llt.get(i, j), a.get(i, j)
1335 );
1336 }
1337 }
1338 }
1339
1340 #[test]
1341 fn cholesky_solve_roundtrip(a in spd_matrix(3), b in proptest::collection::vec(-10.0_f64..10.0, 3)) {
1342 let x = a.cholesky_solve(&b).expect("SPD solve should work");
1343 let ax = a.mul_vec(&x).expect("compatible");
1344 for i in 0..3 {
1345 let tol = 1e-8 * b[i].abs().max(1.0);
1346 prop_assert!(
1347 (ax[i] - b[i]).abs() < tol,
1348 "Ax[{i}]={}, b[{i}]={}",
1349 ax[i], b[i]
1350 );
1351 }
1352 }
1353
1354 #[test]
1355 fn inverse_roundtrip(a in spd_matrix(3)) {
1356 let inv = a.inverse().expect("SPD invertible");
1358 let eye = a.mul_mat(&inv).expect("compatible");
1359 for i in 0..3 {
1360 for j in 0..3 {
1361 let expected = if i == j { 1.0 } else { 0.0 };
1362 let diff = (eye.get(i, j) - expected).abs();
1363 prop_assert!(
1364 diff < 1e-6,
1365 "A·A⁻¹[{i},{j}]={}, expected {expected}",
1366 eye.get(i, j)
1367 );
1368 }
1369 }
1370 }
1371 }
1372}