1use crate::types::ARdouble;
41use std::ops::Mul;
42
43#[derive(Debug, Clone, PartialEq)]
45#[repr(C)]
46pub struct ARMat {
47 pub m: Vec<ARdouble>,
48 pub row: i32,
49 pub clm: i32,
50}
51
52impl Default for ARMat {
53 fn default() -> Self {
54 Self {
55 m: Vec::new(),
56 row: 0,
57 clm: 0,
58 }
59 }
60}
61
62impl ARMat {
63 pub fn new(row: i32, clm: i32) -> Self {
65 let size = (row * clm) as usize;
66 Self {
67 m: vec![0.0; size],
68 row,
69 clm,
70 }
71 }
72
73 pub fn transpose(&self) -> ARMat {
75 let mut dest = ARMat::new(self.clm, self.row);
76 let dest_clm = dest.clm as usize;
77 let src_clm = self.clm as usize;
78
79 for r in 0..(dest.row as usize) {
80 for c in 0..(dest.clm as usize) {
81 dest.m[r * dest_clm + c] = self.m[c * src_clm + r];
83 }
84 }
85 dest
86 }
87
88 pub fn det(&self) -> f64 {
90 if self.row != self.clm {
91 return 0.0; }
93
94 let dimen = self.row as usize;
95 let mut ap = self.m.clone(); let mut det = 1.0;
97 let mut is = 0;
98
99 for k in 0..(dimen - 1) {
100 let mut mmax = k;
101 for i in (k + 1)..dimen {
102 if ap[i * dimen + k].abs() > ap[mmax * dimen + k].abs() {
103 mmax = i;
104 }
105 }
106 if mmax != k {
107 for j in k..dimen {
108 let work = ap[k * dimen + j];
109 ap[k * dimen + j] = ap[mmax * dimen + j];
110 ap[mmax * dimen + j] = work;
111 }
112 is += 1;
113 }
114 for i in (k + 1)..dimen {
115 let work = ap[i * dimen + k] / ap[k * dimen + k];
116 for j in (k + 1)..dimen {
117 ap[i * dimen + j] -= work * ap[k * dimen + j];
118 }
119 }
120 }
121
122 for i in 0..dimen {
123 det *= ap[i * dimen + i];
124 }
125 for _ in 0..is {
126 det *= -1.0;
127 }
128
129 det
130 }
131
132 pub fn self_inv(&mut self) -> Result<(), &'static str> {
134 let dimen = self.row as usize;
135 if dimen != self.clm as usize {
136 return Err("Matrix must be square");
137 }
138 if dimen > 500 {
139 return Err("Matrix too large");
140 }
141 if dimen == 0 {
142 return Err("Matrix is empty");
143 }
144 if dimen == 1 {
145 self.m[0] = 1.0 / self.m[0];
146 return Ok(());
147 }
148
149 let mut nos = vec![0; dimen];
150 for n in 0..dimen {
151 nos[n] = n;
152 }
153
154 for n in 0..dimen {
155 let mut p = 0.0;
156 let mut ip = -1isize;
157
158 for i in n..dimen {
159 let pbuf = self.m[i * dimen + 0].abs();
160 if p < pbuf {
161 p = pbuf;
162 ip = i as isize;
163 }
164 }
165
166 if p <= 1.0e-10 || ip == -1 {
167 return Err("Matrix is singular");
168 }
169
170 let ip = ip as usize;
171
172 let nwork = nos[ip];
173 nos[ip] = nos[n];
174 nos[n] = nwork;
175
176 for j in 0..dimen {
177 let work = self.m[ip * dimen + j];
178 self.m[ip * dimen + j] = self.m[n * dimen + j];
179 self.m[n * dimen + j] = work;
180 }
181
182 let work = self.m[n * dimen + 0];
183 for j in 1..dimen {
184 self.m[n * dimen + j - 1] = self.m[n * dimen + j] / work;
185 }
186 self.m[n * dimen + dimen - 1] = 1.0 / work;
187
188 for i in 0..dimen {
189 if i != n {
190 let work = self.m[i * dimen + 0];
191 for j in 1..dimen {
192 self.m[i * dimen + j - 1] = self.m[i * dimen + j] - work * self.m[n * dimen + j - 1];
193 }
194 self.m[i * dimen + dimen - 1] = -work * self.m[n * dimen + dimen - 1];
195 }
196 }
197 }
198
199 for n in 0..dimen {
200 let mut j = n;
201 while j < dimen {
202 if nos[j] == n { break; }
203 j += 1;
204 }
205 nos[j] = nos[n];
206 for i in 0..dimen {
207 let work = self.m[i * dimen + j];
208 self.m[i * dimen + j] = self.m[i * dimen + n];
209 self.m[i * dimen + n] = work;
210 }
211 }
212
213 Ok(())
214 }
215
216 pub fn inv(&self) -> Result<ARMat, &'static str> {
218 let mut dest = self.clone();
219 dest.self_inv()?;
220 Ok(dest)
221 }
222
223 pub fn tridiagonalize(&mut self, d: &mut [ARdouble], e: &mut [ARdouble]) -> Result<(), &'static str> {
225 let dim = self.row as usize;
226 if dim != self.clm as usize || dim != d.len() || dim != e.len() + 1 {
227 return Err("Mismatched dimensions for tridiagonalize");
228 }
229
230 for k in 0..dim.saturating_sub(2) {
231 d[k] = self.m[k * dim + k];
232
233 e[k] = household(&mut self.m[k * dim + k + 1 .. k * dim + dim]);
234 if e[k] == 0.0 { continue; }
235
236 for i in (k + 1)..dim {
237 let mut s = 0.0;
238 for j in (k + 1)..i {
239 s += self.m[j * dim + i] * self.m[k * dim + j];
240 }
241 for j in i..dim {
242 s += self.m[i * dim + j] * self.m[k * dim + j];
243 }
244 d[i] = s;
245 }
246
247 let t = inner_product(&self.m[k * dim + k + 1 .. k * dim + dim], &d[k + 1 .. dim]) / 2.0;
248
249 for i in (k + 1..dim).rev() {
250 let p = self.m[k * dim + i];
251 d[i] -= t * p;
252 let q = d[i];
253 for j in i..dim {
254 self.m[i * dim + j] -= p * d[j] + q * self.m[k * dim + j];
255 }
256 }
257 }
258
259 if dim >= 2 {
260 d[dim - 2] = self.m[(dim - 2) * dim + (dim - 2)];
261 e[dim - 2] = self.m[(dim - 2) * dim + (dim - 1)];
262 }
263
264 if dim >= 1 {
265 d[dim - 1] = self.m[(dim - 1) * dim + (dim - 1)];
266 }
267
268 for k in (0..dim).rev() {
269 if k < dim.saturating_sub(2) {
270 for i in (k + 1)..dim {
271 let mut t = 0.0;
272 for j in (k + 1)..dim {
273 t += self.m[k * dim + j] * self.m[i * dim + j];
274 }
275 for j in (k + 1)..dim {
276 self.m[i * dim + j] -= t * self.m[k * dim + j];
277 }
278 }
279 }
280 for i in 0..dim {
281 self.m[k * dim + i] = 0.0;
282 }
283 self.m[k * dim + k] = 1.0;
284 }
285
286 Ok(())
287 }
288}
289
290pub fn inner_product(x: &[ARdouble], y: &[ARdouble]) -> ARdouble {
292 x.iter().zip(y.iter()).map(|(a, b)| a * b).sum()
293}
294
295pub fn household(x: &mut [ARdouble]) -> ARdouble {
297 let mut s = inner_product(x, x).sqrt();
298 if s != 0.0 {
299 if x[0] < 0.0 { s = -s; }
300 x[0] += s;
301 let t = 1.0 / (x[0] * s).sqrt();
302 for val in x.iter_mut() {
303 *val *= t;
304 }
305 }
306 -s
307}
308
309pub fn qrm(a: &mut ARMat, dv: &mut [ARdouble]) -> Result<(), &'static str> {
310 let dim = a.row as usize;
311 if dim != a.clm as usize || dim < 2 || dv.len() != dim {
312 return Err("Invalid dimensions for QRM");
313 }
314
315 let mut ev = vec![0.0; dim];
316 a.tridiagonalize(dv, &mut ev[1..])?;
317 ev[0] = 0.0;
318
319 let eps = 1e-6;
320 let vzero = 1e-16;
321 let max_iter = 100;
322
323 for h in (1..dim).rev() {
324 let mut j = h;
325 while j > 0 && ev[j].abs() > eps * (dv[j - 1].abs() + dv[j].abs()) { j -= 1; }
326 if j == h { continue; }
327
328 let mut iter = 0;
329 while ev[h].abs() > eps * (dv[h - 1].abs() + dv[h].abs()) {
330 iter += 1;
331 if iter > max_iter { break; }
332
333 let mut w = (dv[h - 1] - dv[h]) / 2.0;
334 let mut t = ev[h] * ev[h];
335 let mut s = (w * w + t).sqrt();
336 if w < 0.0 { s = -s; }
337 let mut x = dv[j] - dv[h] + t / (w + s);
338 let mut y = ev[j + 1];
339
340 for k in j..h {
341 let c: ARdouble;
342 if x.abs() >= y.abs() {
343 if x.abs() > vzero {
344 t = -y / x;
345 c = 1.0 / (t * t + 1.0).sqrt();
346 s = t * c;
347 } else {
348 c = 1.0;
349 s = 0.0;
350 }
351 } else {
352 t = -x / y;
353 s = 1.0 / (t * t + 1.0).sqrt();
354 c = t * s;
355 }
356 w = dv[k] - dv[k + 1];
357 t = (w * s + 2.0 * c * ev[k + 1]) * s;
358 dv[k] -= t;
359 dv[k + 1] += t;
360 if k > j { ev[k] = c * ev[k] - s * y; }
361 ev[k + 1] += s * (c * w - 2.0 * s * ev[k + 1]);
362
363 for i in 0..dim {
364 let rx = a.m[k * dim + i];
365 let ry = a.m[(k + 1) * dim + i];
366 a.m[k * dim + i] = c * rx - s * ry;
367 a.m[(k + 1) * dim + i] = s * rx + c * ry;
368 }
369 if k < h - 1 {
370 x = ev[k + 1];
371 y = -s * ev[k + 2];
372 ev[k + 2] *= c;
373 }
374 }
375 }
376 }
377
378 for k in 0..dim - 1 {
379 let mut h = k;
380 let mut t = dv[h];
381 for i in k + 1..dim {
382 if dv[i] > t {
383 h = i;
384 t = dv[h];
385 }
386 }
387 dv[h] = dv[k];
388 dv[k] = t;
389 for i in 0..dim {
390 let w = a.m[h * dim + i];
391 a.m[h * dim + i] = a.m[k * dim + i];
392 a.m[k * dim + i] = w;
393 }
394 }
395 Ok(())
396}
397
398impl ARMat {
399 pub fn ex(&self, mean: &mut [ARdouble]) -> Result<(), &'static str> {
400 let row = self.row as usize;
401 let clm = self.clm as usize;
402 if row == 0 || clm == 0 || mean.len() != clm {
403 return Err("Invalid dimensions for EX");
404 }
405
406 for i in 0..clm { mean[i] = 0.0; }
407
408 for r in 0..row {
409 for c in 0..clm {
410 mean[c] += self.m[r * clm + c];
411 }
412 }
413
414 for i in 0..clm {
415 mean[i] /= row as ARdouble;
416 }
417 Ok(())
418 }
419
420 pub fn center(&mut self, mean: &[ARdouble]) -> Result<(), &'static str> {
421 let row = self.row as usize;
422 let clm = self.clm as usize;
423 if mean.len() != clm { return Err("Invalid dimensions for CENTER"); }
424
425 for r in 0..row {
426 for c in 0..clm {
427 self.m[r * clm + c] -= mean[c];
428 }
429 }
430 Ok(())
431 }
432
433 pub fn x_by_xt(&self, output: &mut ARMat) -> Result<(), &'static str> {
434 let row = self.row as usize;
435 let clm = self.clm as usize;
436 if output.row as usize != row || output.clm as usize != row {
437 return Err("Invalid dimensions for x_by_xt");
438 }
439
440 for i in 0..row {
441 for j in 0..row {
442 if j < i {
443 output.m[i * row + j] = output.m[j * row + i];
444 } else {
445 let mut out = 0.0;
446 for k in 0..clm {
447 out += self.m[i * clm + k] * self.m[j * clm + k];
448 }
449 output.m[i * row + j] = out;
450 }
451 }
452 }
453 Ok(())
454 }
455
456 pub fn xt_by_x(&self, output: &mut ARMat) -> Result<(), &'static str> {
457 let row = self.row as usize;
458 let clm = self.clm as usize;
459 if output.row as usize != clm || output.clm as usize != clm {
460 return Err("Invalid dimensions for xt_by_x");
461 }
462
463 for i in 0..clm {
464 for j in 0..clm {
465 if j < i {
466 output.m[i * clm + j] = output.m[j * clm + i];
467 } else {
468 let mut out = 0.0;
469 for k in 0..row {
470 out += self.m[k * clm + i] * self.m[k * clm + j];
471 }
472 output.m[i * clm + j] = out;
473 }
474 }
475 }
476 Ok(())
477 }
478
479 pub fn ev_create(&self, u: &ARMat, output: &mut ARMat, ev: &mut [ARdouble]) -> Result<(), &'static str> {
480 let row = self.row as usize;
481 let clm = self.clm as usize;
482 if row == 0 || clm == 0 || u.row as usize != row || u.clm as usize != row
483 || output.row as usize != row || output.clm as usize != clm || ev.len() != row {
484 return Err("Invalid dimensions for EV_create");
485 }
486
487 let mut i = 0;
488 while i < row {
489 if ev[i] < 1e-16 { break; }
490 let work = 1.0 / ev[i].abs().sqrt();
491 for j in 0..clm {
492 let mut sum = 0.0;
493 for k in 0..row {
494 sum += u.m[i * row + k] * self.m[k * clm + j];
495 }
496 output.m[i * clm + j] = sum * work;
497 }
498 i += 1;
499 }
500
501 while i < row {
502 ev[i] = 0.0;
503 for j in 0..clm {
504 output.m[i * clm + j] = 0.0;
505 }
506 i += 1;
507 }
508 Ok(())
509 }
510
511 pub fn pca_internal(&self, output: &mut ARMat, ev: &mut [ARdouble]) -> Result<(), &'static str> {
512 let row = self.row as usize;
513 let clm = self.clm as usize;
514 let min = row.min(clm);
515 if row < 2 || clm < 2 || output.clm as usize != clm || output.row as usize != min || ev.len() != min {
516 return Err("Invalid dimensions for PCA internal");
517 }
518
519 let mut u = ARMat::new(min as i32, min as i32);
520 if row < clm {
521 self.x_by_xt(&mut u)?;
522 } else {
523 self.xt_by_x(&mut u)?;
524 }
525
526 qrm(&mut u, ev)?;
527
528 if row < clm {
529 self.ev_create(&u, output, ev)?;
530 } else {
531 let mut i = 0;
532 while i < min {
533 if ev[i] < 1e-16 { break; }
534 for j in 0..min {
535 output.m[i * clm + j] = u.m[i * min + j];
536 }
537 i += 1;
538 }
539 while i < min {
540 ev[i] = 0.0;
541 for j in 0..min {
542 output.m[i * clm + j] = 0.0;
543 }
544 i += 1;
545 }
546 }
547 Ok(())
548 }
549
550 pub fn pca(&self, evec: &mut ARMat, ev: &mut ARVec, mean: &mut ARVec) -> Result<(), &'static str> {
551 let row = self.row as usize;
552 let clm = self.clm as usize;
553 let check = row.min(clm);
554 if row < 2 || clm < 2 || evec.clm as usize != clm || evec.row as usize != check
555 || ev.clm as usize != check || mean.clm as usize != clm {
556 return Err("Invalid dimensions for PCA");
557 }
558
559 let mut work = self.clone();
560 work.ex(&mut mean.v)?;
561 work.center(&mean.v)?;
562
563 let srow = (row as f64).sqrt();
564 for val in work.m.iter_mut() {
565 *val /= srow;
566 }
567
568 work.pca_internal(evec, &mut ev.v)?;
569
570 let sum: f64 = ev.v.iter().sum();
571 if sum != 0.0 {
572 for val in ev.v.iter_mut() {
573 *val /= sum;
574 }
575 }
576 Ok(())
577 }
578}
579
580impl<'a, 'b> Mul<&'b ARMat> for &'a ARMat {
581 type Output = Result<ARMat, &'static str>;
582
583 fn mul(self, rhs: &'b ARMat) -> Self::Output {
586 if self.clm != rhs.row {
587 return Err("Matrix dimensions do not match for multiplication");
588 }
589
590 let mut dest = ARMat::new(self.row, rhs.clm);
591 let dest_clm = dest.clm as usize;
592 let a_clm = self.clm as usize;
593 let b_clm = rhs.clm as usize;
594
595 for r in 0..(dest.row as usize) {
596 for c in 0..(dest.clm as usize) {
597 let mut sum = 0.0;
598 let mut p1_idx = r * a_clm;
599 let mut p2_idx = c;
600 for _ in 0..a_clm {
601 sum += self.m[p1_idx] * rhs.m[p2_idx];
602 p1_idx += 1;
603 p2_idx += b_clm;
604 }
605 dest.m[r * dest_clm + c] = sum;
606 }
607 }
608
609 Ok(dest)
610 }
611}
612
613#[derive(Debug, Clone, PartialEq)]
615#[repr(C)]
616pub struct ARMatf {
617 pub m: Vec<f32>,
618 pub row: i32,
619 pub clm: i32,
620}
621
622impl Default for ARMatf {
623 fn default() -> Self {
624 Self {
625 m: Vec::new(),
626 row: 0,
627 clm: 0,
628 }
629 }
630}
631
632impl ARMatf {
633 pub fn new(row: i32, clm: i32) -> Self {
635 let size = (row * clm) as usize;
636 Self {
637 m: vec![0.0; size],
638 row,
639 clm,
640 }
641 }
642}
643
644impl<'a, 'b> Mul<&'b ARMatf> for &'a ARMatf {
645 type Output = Result<ARMatf, &'static str>;
646
647 fn mul(self, rhs: &'b ARMatf) -> Self::Output {
650 if self.clm != rhs.row {
651 return Err("Matrix dimensions do not match for multiplication");
652 }
653
654 let mut dest = ARMatf::new(self.row, rhs.clm);
655 let dest_clm = dest.clm as usize;
656 let a_clm = self.clm as usize;
657 let b_clm = rhs.clm as usize;
658
659 for r in 0..(dest.row as usize) {
660 for c in 0..(dest.clm as usize) {
661 let mut sum = 0.0;
662 let mut p1_idx = r * a_clm;
663 let mut p2_idx = c;
664 for _ in 0..a_clm {
665 sum += self.m[p1_idx] * rhs.m[p2_idx];
666 p1_idx += 1;
667 p2_idx += b_clm;
668 }
669 dest.m[r * dest_clm + c] = sum;
670 }
671 }
672
673 Ok(dest)
674 }
675}
676
677#[derive(Debug, Clone, PartialEq)]
679#[repr(C)]
680pub struct ARVec {
681 pub v: Vec<ARdouble>,
682 pub clm: i32,
683}
684
685impl Default for ARVec {
686 fn default() -> Self {
687 Self {
688 v: Vec::new(),
689 clm: 0,
690 }
691 }
692}
693
694impl ARVec {
695 pub fn new(clm: i32) -> Self {
697 Self {
698 v: vec![0.0; clm as usize],
699 clm,
700 }
701 }
702}
703
704#[cfg(test)]
705mod tests {
706 use super::*;
707
708 #[test]
709 fn test_armat_default_initialization() {
710 let mat = ARMat::default();
711 assert_eq!(mat.m.len(), 0);
712 assert_eq!(mat.row, 0);
713 assert_eq!(mat.clm, 0);
714 }
715
716 #[test]
717 fn test_armat_multiplication() {
718 let mut a = ARMat::new(2, 3);
719 a.m = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
720 let mut b = ARMat::new(3, 2);
724 b.m = vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0];
725 let result = (&a * &b).unwrap();
730
731 assert_eq!(result.row, 2);
732 assert_eq!(result.clm, 2);
733 assert_eq!(result.m, vec![58.0, 64.0, 139.0, 154.0]);
739 }
740
741 #[test]
742 fn test_armat_transpose() {
743 let mut a = ARMat::new(2, 3);
744 a.m = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
745 let result = a.transpose();
749 assert_eq!(result.row, 3);
750 assert_eq!(result.clm, 2);
751 assert_eq!(result.m, vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
755 }
756
757 #[test]
758 fn test_armat_det() {
759 let mut a = ARMat::new(3, 3);
760 a.m = vec![
761 6.0, 1.0, 1.0,
762 4.0, -2.0, 5.0,
763 2.0, 8.0, 7.0
764 ];
765
766 let det = a.det();
767 assert_eq!(det.round(), -306.0);
772 }
773
774 #[test]
775 fn test_armat_inv() {
776 let mut a = ARMat::new(2, 2);
777 a.m = vec![
778 4.0, 7.0,
779 2.0, 6.0
780 ];
781
782 let inv_a = a.inv().expect("Failed to invert matrix");
786
787 assert_eq!(inv_a.row, 2);
788 assert_eq!(inv_a.clm, 2);
789
790 let epsilon = 1e-6;
791 assert!((inv_a.m[0] - 0.6).abs() < epsilon);
792 assert!((inv_a.m[1] - (-0.7)).abs() < epsilon);
793 assert!((inv_a.m[2] - (-0.2)).abs() < epsilon);
794 assert!((inv_a.m[3] - 0.4).abs() < epsilon);
795 }
796
797 #[test]
798 fn test_armatf_default_initialization() {
799 let matf = ARMatf::default();
800 assert_eq!(matf.m.len(), 0);
801 assert_eq!(matf.row, 0);
802 assert_eq!(matf.clm, 0);
803 }
804
805 #[test]
806 fn test_arvec_default_initialization() {
807 let vec = ARVec::default();
808 assert_eq!(vec.v.len(), 0);
809 assert_eq!(vec.clm, 0);
810 }
811}
812
813pub fn mat_mul_dff(a: &[[f64; 4]; 3], b: &[[f32; 4]; 3], dest: &mut [[f32; 4]; 3]) {
816 for i in 0..3 {
817 for j in 0..3 {
818 dest[i][j] = (a[i][0] * b[0][j] as f64 + a[i][1] * b[1][j] as f64 + a[i][2] * b[2][j] as f64) as f32;
819 }
820 dest[i][3] = (a[i][0] * b[0][3] as f64 + a[i][1] * b[1][3] as f64 + a[i][2] * b[2][3] as f64 + a[i][3]) as f32;
821 }
822}
823
824pub fn mat_mul_fff(a: &[[f32; 4]; 3], b: &[[f32; 4]; 3], dest: &mut [[f32; 4]; 3]) {
826 for i in 0..3 {
827 for j in 0..3 {
828 dest[i][j] = a[i][0] * b[0][j] + a[i][1] * b[1][j] + a[i][2] * b[2][j];
829 }
830 dest[i][3] = a[i][0] * b[0][3] + a[i][1] * b[1][3] + a[i][2] * b[2][3] + a[i][3];
831 }
832}